Skip to content

Make the rendering of Comfy's implementation identical to Chroma's workflow. #7965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,12 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):


class CLIP:
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}):
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}, clip_type_enum=None): # MODIFIED: Added clip_type_enum
if no_init:
return

self.clip_type_enum = clip_type_enum

params = target.params.copy()
clip = target.clip
tokenizer = target.tokenizer
Expand Down Expand Up @@ -131,6 +134,7 @@ def clone(self):
n.tokenizer_options = self.tokenizer_options.copy()
n.use_clip_schedule = self.use_clip_schedule
n.apply_hooks_to_conds = self.apply_hooks_to_conds
n.clip_type_enum = self.clip_type_enum
return n

def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
Expand Down Expand Up @@ -159,12 +163,13 @@ def encode_from_tokens_scheduled(self, tokens, unprojected=False, add_dict: dict
all_cond_pooled: list[tuple[torch.Tensor, dict[str]]] = []
all_hooks = self.patcher.forced_hooks
if all_hooks is None or not self.use_clip_schedule:
# if no hooks or shouldn't use clip schedule, do unscheduled encode_from_tokens and perform add_dict
# if no hooks or shouldn't use clip schedule, do unscheduled encode_from_tokens and perform add_dict
return_pooled = "unprojected" if unprojected else True
pooled_dict = self.encode_from_tokens(tokens, return_pooled=return_pooled, return_dict=True)
cond = pooled_dict.pop("cond")
# add/update any keys with the provided add_dict
pooled_dict.update(add_dict)
# add hooks stored on clip
all_cond_pooled.append([cond, pooled_dict])
else:
scheduled_keyframes = all_hooks.get_hooks_for_clip_schedule()
Expand Down Expand Up @@ -198,8 +203,17 @@ def encode_from_tokens_scheduled(self, tokens, unprojected=False, add_dict: dict
# perform encoding as normal
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]

pooled_dict = {"pooled_output": pooled}
# add clip_start_percent and clip_end_percent in pooled
if len(o) > 2 and isinstance(o[2], dict):
pooled_dict.update(o[2])

if hasattr(self, 'clip_type_enum') and self.clip_type_enum == CLIPType.CHROMA:
if 'attention_mask' in pooled_dict:
logging.debug(f"CLIP type {self.clip_type_enum.name} (scheduled path): Removing 'attention_mask' from conditioning output.")
pooled_dict.pop('attention_mask', None)

pooled_dict["clip_start_percent"] = t_range[0]
pooled_dict["clip_end_percent"] = t_range[1]
# add/update any keys with the provided add_dict
Expand Down Expand Up @@ -227,10 +241,15 @@ def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False):
cond, pooled = o[:2]
if return_dict:
out = {"cond": cond, "pooled_output": pooled}
if len(o) > 2:
if len(o) > 2 and isinstance(o[2], dict):
for k in o[2]:
out[k] = o[2][k]
self.add_hooks_to_dict(out)

if hasattr(self, 'clip_type_enum') and self.clip_type_enum == CLIPType.CHROMA:
if 'attention_mask' in out:
logging.debug(f"CLIP type {self.clip_type_enum.name} (non-scheduled path): Removing 'attention_mask' from conditioning output.")
out.pop('attention_mask', None)
return out

if return_pooled:
Expand Down Expand Up @@ -261,6 +280,7 @@ def load_model(self):
def get_key_patches(self):
return self.patcher.get_key_patches()


class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
Expand Down Expand Up @@ -788,8 +808,9 @@ class EmptyClass:
if "transformer.resblocks.0.ln_1.weight" in clip_data[i]:
clip_data[i] = comfy.utils.clip_text_transformers_convert(clip_data[i], "", "")
else:
if "text_projection" in clip_data[i]:
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node
# Ensure "text_projection" exists and is a tensor before trying to transpose
if "text_projection" in clip_data[i] and isinstance(clip_data[i]["text_projection"], torch.Tensor):
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1)

tokenizer_data = {}
clip_target = EmptyClass()
Expand Down Expand Up @@ -819,7 +840,7 @@ class EmptyClass:
elif clip_type == CLIPType.LTXV:
clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer
elif clip_type == CLIPType.PIXART or clip_type == CLIPType.CHROMA:
elif clip_type == CLIPType.PIXART:
clip_target.clip = comfy.text_encoders.pixart_t5.pixart_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.pixart_t5.PixArtTokenizer
elif clip_type == CLIPType.WAN:
Expand All @@ -830,7 +851,7 @@ class EmptyClass:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else: #CLIPType.MOCHI
else: #CLIPType.MOCHI or CLIPType.CHROMA
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
elif te_model == TEModel.T5_XXL_OLD:
Expand All @@ -851,14 +872,14 @@ class EmptyClass:
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else:
# clip_l
if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.HIDREAM:
# Detect
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else:
else:
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
elif len(clip_data) == 2:
Expand All @@ -876,7 +897,6 @@ class EmptyClass:
clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer
elif clip_type == CLIPType.HIDREAM:
# Detect
hidream_dualclip_classes = []
for hidream_te in clip_data:
te_model = detect_te_model(hidream_te)
Expand All @@ -886,8 +906,8 @@ class EmptyClass:
clip_g = TEModel.CLIP_G in hidream_dualclip_classes
t5 = TEModel.T5_XXL in hidream_dualclip_classes
llama = TEModel.LLAMA3_8 in hidream_dualclip_classes

# Initialize t5xxl_detect and llama_detect kwargs if needed

t5_kwargs = t5xxl_detect(clip_data) if t5 else {}
llama_kwargs = llama_detect(clip_data) if llama else {}

Expand All @@ -908,7 +928,8 @@ class EmptyClass:
parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)

clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options, clip_type_enum=clip_type)

for c in clip_data:
m, u = clip.load_sd(c)
if len(m) > 0:
Expand Down