Skip to content

Commit 48a6d29

Browse files
asomozayiyixuxu
andauthored
[SD3] CFG Cutoff fix and official callback (#11890)
fix and official callback Co-authored-by: YiYi Xu <yixu310@gmail.com>
1 parent 2d3d376 commit 48a6d29

File tree

2 files changed

+41
-5
lines changed

2 files changed

+41
-5
lines changed

src/diffusers/callbacks.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,3 +207,38 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s
207207
if step_index == cutoff_step:
208208
pipeline.set_ip_adapter_scale(0.0)
209209
return callback_kwargs
210+
211+
212+
class SD3CFGCutoffCallback(PipelineCallback):
213+
"""
214+
Callback function for Stable Diffusion 3 Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
215+
`cutoff_step_index`), this callback will disable the CFG.
216+
217+
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
218+
"""
219+
220+
tensor_inputs = ["prompt_embeds", "pooled_prompt_embeds"]
221+
222+
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
223+
cutoff_step_ratio = self.config.cutoff_step_ratio
224+
cutoff_step_index = self.config.cutoff_step_index
225+
226+
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
227+
cutoff_step = (
228+
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
229+
)
230+
231+
if step_index == cutoff_step:
232+
prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
233+
prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
234+
235+
pooled_prompt_embeds = callback_kwargs[self.tensor_inputs[1]]
236+
pooled_prompt_embeds = pooled_prompt_embeds[
237+
-1:
238+
] # "-1" denotes the embeddings for conditional pooled text tokens.
239+
240+
pipeline._guidance_scale = 0.0
241+
242+
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
243+
callback_kwargs[self.tensor_inputs[1]] = pooled_prompt_embeds
244+
return callback_kwargs

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
T5TokenizerFast,
2626
)
2727

28+
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2829
from ...image_processor import PipelineImageInput, VaeImageProcessor
2930
from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
3031
from ...models.autoencoders import AutoencoderKL
@@ -184,7 +185,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
184185

185186
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
186187
_optional_components = ["image_encoder", "feature_extractor"]
187-
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
188+
_callback_tensor_inputs = ["latents", "prompt_embeds", "pooled_prompt_embeds"]
188189

189190
def __init__(
190191
self,
@@ -923,6 +924,9 @@ def __call__(
923924
height = height or self.default_sample_size * self.vae_scale_factor
924925
width = width or self.default_sample_size * self.vae_scale_factor
925926

927+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
928+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
929+
926930
# 1. Check inputs. Raise error if not correct
927931
self.check_inputs(
928932
prompt,
@@ -1109,10 +1113,7 @@ def __call__(
11091113

11101114
latents = callback_outputs.pop("latents", latents)
11111115
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1112-
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1113-
negative_pooled_prompt_embeds = callback_outputs.pop(
1114-
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1115-
)
1116+
pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", pooled_prompt_embeds)
11161117

11171118
# call the callback, if provided
11181119
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

0 commit comments

Comments
 (0)