|
25 | 25 | T5TokenizerFast,
|
26 | 26 | )
|
27 | 27 |
|
| 28 | +from ...callbacks import MultiPipelineCallbacks, PipelineCallback |
28 | 29 | from ...image_processor import PipelineImageInput, VaeImageProcessor
|
29 | 30 | from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
|
30 | 31 | from ...models.autoencoders import AutoencoderKL
|
@@ -184,7 +185,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
184 | 185 |
|
185 | 186 | model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
|
186 | 187 | _optional_components = ["image_encoder", "feature_extractor"]
|
187 |
| - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] |
| 188 | + _callback_tensor_inputs = ["latents", "prompt_embeds", "pooled_prompt_embeds"] |
188 | 189 |
|
189 | 190 | def __init__(
|
190 | 191 | self,
|
@@ -923,6 +924,9 @@ def __call__(
|
923 | 924 | height = height or self.default_sample_size * self.vae_scale_factor
|
924 | 925 | width = width or self.default_sample_size * self.vae_scale_factor
|
925 | 926 |
|
| 927 | + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): |
| 928 | + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs |
| 929 | + |
926 | 930 | # 1. Check inputs. Raise error if not correct
|
927 | 931 | self.check_inputs(
|
928 | 932 | prompt,
|
@@ -1109,10 +1113,7 @@ def __call__(
|
1109 | 1113 |
|
1110 | 1114 | latents = callback_outputs.pop("latents", latents)
|
1111 | 1115 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1112 |
| - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) |
1113 |
| - negative_pooled_prompt_embeds = callback_outputs.pop( |
1114 |
| - "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds |
1115 |
| - ) |
| 1116 | + pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", pooled_prompt_embeds) |
1116 | 1117 |
|
1117 | 1118 | # call the callback, if provided
|
1118 | 1119 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
|
0 commit comments