From 48d8256d31f0659575b2ce14bc5aeeb98ad75d09 Mon Sep 17 00:00:00 2001 From: PDillis Date: Sat, 3 Aug 2024 01:02:37 +0200 Subject: [PATCH 1/3] Remove unnecessary logging --- src/diffusers/models/transformers/transformer_flux.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 364867275dc2..391ca1418d34 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -373,7 +373,6 @@ def forward( ) encoder_hidden_states = self.context_embedder(encoder_hidden_states) - print(f"{txt_ids.shape=}, {img_ids.shape=}") ids = torch.cat((txt_ids, img_ids), dim=1) image_rotary_emb = self.pos_embed(ids) From 6b6c480e8d75bb01e0f1d1cbbab2bad195d48abe Mon Sep 17 00:00:00 2001 From: PDillis Date: Sat, 3 Aug 2024 02:29:45 +0200 Subject: [PATCH 2/3] Improve console log if user provides latents for `FluxPipeline` (added in `self.check_inputs` and `self.prepare_latents`) --- src/diffusers/pipelines/flux/pipeline_flux.py | 50 +++++++++++++++++-- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 857c213e5c16..f2500f5d04a5 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -146,6 +146,11 @@ class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin): r""" The Flux pipeline for text-to-image generation. + Note: + This pipeline expects `latents` to be in a packed format. If you're providing + custom latents, make sure to use the `_pack_latents` method to prepare them. + Packed latents should be a 3D tensor of shape (batch_size, num_patches, channels). + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ Args: @@ -391,6 +396,7 @@ def check_inputs( pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, + latents=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -429,6 +435,26 @@ def check_inputs( if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + if latents is not None: + if not isinstance(latents, torch.Tensor): + raise ValueError(f"`latents` has to be of type `torch.Tensor` but is {type(latents)}") + + if not _are_latents_packed(latents): + raise ValueError(f"`latents` should be a 3-dimensional tensor but has {latents.ndim=} dimensions") + + batch_size, num_patches, channels = latents.shape + if channels != self.transformer.config.in_channels: + raise ValueError( + f"Number of channels in `latents` ({channels}) does not match the number of channels expected by" + f" the transformer ({self.transformer.config.in_channels=})." + ) + + if num_patches != (height // self.vae_scale_factor) * (width // self.vae_scale_factor): + raise ValueError( + f"Number of patches in `latents` ({num_patches}) does not match the number of patches expected by" + f" the transformer ({(height // self.vae_scale_factor) * (width // self.vae_scale_factor)=})." + ) + @staticmethod def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_ids = torch.zeros(height // 2, width // 2, 3) @@ -466,6 +492,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor): return latents + @staticmethod + def _are_latents_packed(latents): + return latents.ndim == 3 + def prepare_latents( self, batch_size, @@ -477,15 +507,21 @@ def prepare_latents( generator, latents=None, ): + if latents is not None: + if latents.ndim == 4: + logger.warning( + "Unpacked latents detected. These will be automatically packed. " + "In the future, please provide packed latents to improve performance." + ) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + height = 2 * (int(height) // self.vae_scale_factor) width = 2 * (int(width) // self.vae_scale_factor) shape = (batch_size, num_channels_latents, height, width) - if latents is not None: - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) - return latents.to(device=device, dtype=dtype), latent_image_ids - if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -621,6 +657,7 @@ def __call__( pooled_prompt_embeds=pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, + latents=latents, ) self._guidance_scale = guidance_scale @@ -668,6 +705,11 @@ def __call__( latents, ) + if not self._are_latents_packed(latents): + raise ValueError( + "Latents are not in the correct packed format. Please use `_pack_latents` to prepare them." + ) + # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) image_seq_len = latents.shape[1] From d307205e72ac388c61451855b77d844fbdfe6da8 Mon Sep 17 00:00:00 2001 From: Diego Porres Date: Sun, 4 Aug 2024 23:58:57 +0000 Subject: [PATCH 3/3] Removed unnecessary checks/logs. Left logging as comment and added docstring for shape of latents. --- src/diffusers/pipelines/flux/pipeline_flux.py | 28 ++++--------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index f2500f5d04a5..96af98552a58 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -146,11 +146,6 @@ class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin): r""" The Flux pipeline for text-to-image generation. - Note: - This pipeline expects `latents` to be in a packed format. If you're providing - custom latents, make sure to use the `_pack_latents` method to prepare them. - Packed latents should be a 3D tensor of shape (batch_size, num_patches, channels). - Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ Args: @@ -438,10 +433,7 @@ def check_inputs( if latents is not None: if not isinstance(latents, torch.Tensor): raise ValueError(f"`latents` has to be of type `torch.Tensor` but is {type(latents)}") - - if not _are_latents_packed(latents): - raise ValueError(f"`latents` should be a 3-dimensional tensor but has {latents.ndim=} dimensions") - + batch_size, num_patches, channels = latents.shape if channels != self.transformer.config.in_channels: raise ValueError( @@ -492,10 +484,6 @@ def _unpack_latents(latents, height, width, vae_scale_factor): return latents - @staticmethod - def _are_latents_packed(latents): - return latents.ndim == 3 - def prepare_latents( self, batch_size, @@ -509,10 +497,7 @@ def prepare_latents( ): if latents is not None: if latents.ndim == 4: - logger.warning( - "Unpacked latents detected. These will be automatically packed. " - "In the future, please provide packed latents to improve performance." - ) + # Packing the latents to be of shape (batch_size, num_patches, channels) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) return latents.to(device=device, dtype=dtype), latent_image_ids @@ -609,7 +594,10 @@ def __call__( latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. + tensor will ge generated by sampling using the supplied random `generator`. Note: This pipeline expects + `latents` to be in a packed format. If you're providing custom latents, make sure to use the + `_pack_latents` method to prepare them. Packed latents should be a 3D tensor of shape + (batch_size, num_patches, channels). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -705,10 +693,6 @@ def __call__( latents, ) - if not self._are_latents_packed(latents): - raise ValueError( - "Latents are not in the correct packed format. Please use `_pack_latents` to prepare them." - ) # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)