diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index c1a7010d919a..7210451c9ffc 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -379,6 +379,7 @@ def check_inputs( pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, + latents=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -417,6 +418,23 @@ def check_inputs( if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + if latents is not None: + if not isinstance(latents, torch.Tensor): + raise ValueError(f"`latents` has to be of type `torch.Tensor` but is {type(latents)}") + + batch_size, num_patches, channels = latents.shape + if channels != self.transformer.config.in_channels: + raise ValueError( + f"Number of channels in `latents` ({channels}) does not match the number of channels expected by" + f" the transformer ({self.transformer.config.in_channels=})." + ) + + if num_patches != (height // self.vae_scale_factor) * (width // self.vae_scale_factor): + raise ValueError( + f"Number of patches in `latents` ({num_patches}) does not match the number of patches expected by" + f" the transformer ({(height // self.vae_scale_factor) * (width // self.vae_scale_factor)=})." + ) + @staticmethod def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_ids = torch.zeros(height // 2, width // 2, 3) @@ -465,15 +483,18 @@ def prepare_latents( generator, latents=None, ): + if latents is not None: + if latents.ndim == 4: + # Packing the latents to be of shape (batch_size, num_patches, channels) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + height = 2 * (int(height) // self.vae_scale_factor) width = 2 * (int(width) // self.vae_scale_factor) shape = (batch_size, num_channels_latents, height, width) - if latents is not None: - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) - return latents.to(device=device, dtype=dtype), latent_image_ids - if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -561,7 +582,10 @@ def __call__( latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. + tensor will ge generated by sampling using the supplied random `generator`. Note: This pipeline expects + `latents` to be in a packed format. If you're providing custom latents, make sure to use the + `_pack_latents` method to prepare them. Packed latents should be a 3D tensor of shape + (batch_size, num_patches, channels). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -609,6 +633,7 @@ def __call__( pooled_prompt_embeds=pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, + latents=latents, ) self._guidance_scale = guidance_scale @@ -656,6 +681,7 @@ def __call__( latents, ) + # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) image_seq_len = latents.shape[1]