diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py index 7249e033186f..e82fdd7ccbfb 100644 --- a/examples/community/lpw_stable_diffusion.py +++ b/examples/community/lpw_stable_diffusion.py @@ -200,6 +200,7 @@ def get_unweighted_text_embeddings( text_input: torch.Tensor, chunk_length: int, no_boseos_middle: Optional[bool] = True, + clip_skip: Optional[int] = None, ): """ When the length of tokens is a multiple of the capacity of the text encoder, @@ -215,7 +216,12 @@ def get_unweighted_text_embeddings( # cover the head and the tail by the starting and the ending tokens text_input_chunk[:, 0] = text_input[0, 0] text_input_chunk[:, -1] = text_input[0, -1] - text_embedding = pipe.text_encoder(text_input_chunk)[0] + if clip_skip is None: + text_embedding = pipe.text_encoder(text_input_chunk)[0] + else: + text_embedding = pipe.text_encoder(text_input_chunk, output_hidden_states=True) + text_embedding = text_embedding[-1][-(clip_skip + 1)] + text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) if no_boseos_middle: if i == 0: @@ -229,9 +235,14 @@ def get_unweighted_text_embeddings( text_embedding = text_embedding[:, 1:-1] text_embeddings.append(text_embedding) - text_embeddings = torch.concat(text_embeddings, axis=1) + text_embeddings = torch.concat(text_embeddings, dim=1) else: - text_embeddings = pipe.text_encoder(text_input)[0] + if clip_skip is None: + text_embeddings = pipe.text_encoder(text_input)[0] + else: + text_embeddings = pipe.text_encoder(text_input, output_hidden_states=True) + text_embeddings = text_embeddings[-1][-(clip_skip + 1)] + text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings) return text_embeddings @@ -243,6 +254,7 @@ def get_weighted_text_embeddings( no_boseos_middle: Optional[bool] = False, skip_parsing: Optional[bool] = False, skip_weighting: Optional[bool] = False, + clip_skip: Optional[int] = None, ): r""" Prompts can be assigned with local weights using brackets. For example, @@ -268,6 +280,7 @@ def get_weighted_text_embeddings( Skip the parsing of brackets. skip_weighting (`bool`, *optional*, defaults to `False`): Skip the weighting. When the parsing is skipped, it is forced True. + clip_skip (`int`, *optional*, defaults to `None`) """ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 if isinstance(prompt, str): @@ -339,6 +352,7 @@ def get_weighted_text_embeddings( prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle, + clip_skip=clip_skip, ) prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device) if uncond_prompt is not None: @@ -347,6 +361,7 @@ def get_weighted_text_embeddings( uncond_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle, + clip_skip=clip_skip, ) uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device) @@ -650,6 +665,7 @@ def _encode_prompt( max_embeddings_multiples=3, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -698,6 +714,7 @@ def _encode_prompt( prompt=prompt, uncond_prompt=negative_prompt if do_classifier_free_guidance else None, max_embeddings_multiples=max_embeddings_multiples, + clip_skip=clip_skip ) if prompt_embeds is None: prompt_embeds = prompt_embeds1 @@ -888,6 +905,7 @@ def __call__( is_cancelled_callback: Optional[Callable[[], bool]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -968,6 +986,9 @@ def __call__( A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. Returns: `None` if cancelled by `is_cancelled_callback`, @@ -1010,6 +1031,7 @@ def __call__( max_embeddings_multiples, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + clip_skip=clip_skip ) dtype = prompt_embeds.dtype @@ -1142,6 +1164,7 @@ def text2img( is_cancelled_callback: Optional[Callable[[], bool]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, ): r""" Function for text-to-image generation. @@ -1204,6 +1227,9 @@ def text2img( A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. Returns: `None` if cancelled by `is_cancelled_callback`, @@ -1233,6 +1259,7 @@ def text2img( is_cancelled_callback=is_cancelled_callback, callback_steps=callback_steps, cross_attention_kwargs=cross_attention_kwargs, + clip_skip=clip_skip, ) def img2img( @@ -1255,6 +1282,7 @@ def img2img( is_cancelled_callback: Optional[Callable[[], bool]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, ): r""" Function for image-to-image generation. @@ -1345,6 +1373,7 @@ def img2img( is_cancelled_callback=is_cancelled_callback, callback_steps=callback_steps, cross_attention_kwargs=cross_attention_kwargs, + clip_skip=clip_skip, ) def inpaint( @@ -1369,6 +1398,7 @@ def inpaint( is_cancelled_callback: Optional[Callable[[], bool]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, ): r""" Function for inpaint. @@ -1439,6 +1469,9 @@ def inpaint( A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. Returns: `None` if cancelled by `is_cancelled_callback`, @@ -1468,4 +1501,5 @@ def inpaint( is_cancelled_callback=is_cancelled_callback, callback_steps=callback_steps, cross_attention_kwargs=cross_attention_kwargs, + clip_skip=clip_skip, )