Skip to content

HunyuanVideo pipe.transformer.compile(): torch._dynamo hit config.recompile_limit (8) #10932

@eppaneamd

Description

@eppaneamd

Describe the bug

HunyuanVideo transformer compilation is not working as expected and results also in corrupted output video.

See here related discussion as well and a functioning example for Flux-1.dev.

Reproduction

import torch
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video

MODEL_ID = "tencent/HunyuanVideo"
PROMPT = "A cat walks on the grass, realistic"

transformer = HunyuanVideoTransformer3DModel.from_pretrained(
    pretrained_model_name_or_path=MODEL_ID,
    subfolder="transformer",
    torch_dtype=torch.bfloat16,
    revision="refs/pr/18",
)
pipe = HunyuanVideoPipeline.from_pretrained(
    pretrained_model_name_or_path=MODEL_ID,
    transformer=transformer,
    torch_dtype=torch.float16,
    revision="refs/pr/18",
)

pipe.vae.enable_tiling()
pipe.transformer.compile()

for _ in range(2):
    output = pipe(
        prompt="A cat walks on the grass, realistic",
        height=320,
        width=512,
        num_frames=61,
        num_inference_steps=30,
    ).frames[0]

export_to_video(output, "output.mp4", fps=15)

Logs

(.repro) root@1bc6f7fdd66e:/workspace/scripts# TORCH_LOGS="+recompiles" python huvideo_compile_repro.py |& tee recompiles.log
Fetching 6 files: 100%|██████████| 6/6 [00:00<00:00, 61380.06it/s]
Loading checkpoint shards: 100%|██████████| 6/6 [00:00<00:00, 53.05it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00,  7.20it/s]it/s]
Loading pipeline components...: 100%|██████████| 7/7 [00:01<00:00,  4.09it/s]
  0%|          | 0/30 [00:00<?, ?it/s]V0228 19:39:31.096000 1516 torch/_dynamo/guards.py:2791] [1/1] [__recompiles] Recompiling function forward in /workspace/repos/diffusers/src/diffusers/models/transformers/transformer_hunyuan_video.py:464
V0228 19:39:31.096000 1516 torch/_dynamo/guards.py:2791] [1/1] [__recompiles]     triggered by the following guard failure(s):
V0228 19:39:31.096000 1516 torch/_dynamo/guards.py:2791] [1/1] [__recompiles]     - 1/0: ___check_obj_id(L['self']._modules['attn'].processor, 131625168112624)
V0228 19:39:32.391000 1516 torch/_dynamo/guards.py:2791] [1/2] [__recompiles] Recompiling function forward in /workspace/repos/diffusers/src/diffusers/models/transformers/transformer_hunyuan_video.py:464
V0228 19:39:32.391000 1516 torch/_dynamo/guards.py:2791] [1/2] [__recompiles]     triggered by the following guard failure(s):
V0228 19:39:32.391000 1516 torch/_dynamo/guards.py:2791] [1/2] [__recompiles]     - 1/1: ___check_obj_id(L['self']._modules['attn'].processor, 131625168123856)
V0228 19:39:32.391000 1516 torch/_dynamo/guards.py:2791] [1/2] [__recompiles]     - 1/0: ___check_obj_id(L['self']._modules['attn'].processor, 131625168112624)
V0228 19:39:33.985000 1516 torch/_dynamo/guards.py:2791] [1/3] [__recompiles] Recompiling function forward in /workspace/repos/diffusers/src/diffusers/models/transformers/transformer_hunyuan_video.py:464
V0228 19:39:33.985000 1516 torch/_dynamo/guards.py:2791] [1/3] [__recompiles]     triggered by the following guard failure(s):
V0228 19:39:33.985000 1516 torch/_dynamo/guards.py:2791] [1/3] [__recompiles]     - 1/2: ___check_obj_id(L['self']._modules['attn'].processor, 131625168523232)
V0228 19:39:33.985000 1516 torch/_dynamo/guards.py:2791] [1/3] [__recompiles]     - 1/1: ___check_obj_id(L['self']._modules['attn'].processor, 131625168123856)
V0228 19:39:33.985000 1516 torch/_dynamo/guards.py:2791] [1/3] [__recompiles]     - 1/0: ___check_obj_id(L['self']._modules['attn'].processor, 131625168112624)
V0228 19:39:35.268000 1516 torch/_dynamo/guards.py:2791] [1/4] [__recompiles] Recompiling function forward in /workspace/repos/diffusers/src/diffusers/models/transformers/transformer_hunyuan_video.py:464
V0228 19:39:35.268000 1516 torch/_dynamo/guards.py:2791] [1/4] [__recompiles]     triggered by the following guard failure(s):
V0228 19:39:35.268000 1516 torch/_dynamo/guards.py:2791] [1/4] [__recompiles]     - 1/3: ___check_obj_id(L['self']._modules['attn'].processor, 131625168524720)
V0228 19:39:35.268000 1516 torch/_dynamo/guards.py:2791] [1/4] [__recompiles]     - 1/2: ___check_obj_id(L['self']._modules['attn'].processor, 131625168523232)
V0228 19:39:35.268000 1516 torch/_dynamo/guards.py:2791] [1/4] [__recompiles]     - 1/1: ___check_obj_id(L['self']._modules['attn'].processor, 131625168123856)
V0228 19:39:35.268000 1516 torch/_dynamo/guards.py:2791] [1/4] [__recompiles]     - 1/0: ___check_obj_id(L['self']._modules['attn'].processor, 131625168112624)
V0228 19:39:36.559000 1516 torch/_dynamo/guards.py:2791] [1/5] [__recompiles] Recompiling function forward in /workspace/repos/diffusers/src/diffusers/models/transformers/transformer_hunyuan_video.py:464
V0228 19:39:36.559000 1516 torch/_dynamo/guards.py:2791] [1/5] [__recompiles]     triggered by the following guard failure(s):
V0228 19:39:36.559000 1516 torch/_dynamo/guards.py:2791] [1/5] [__recompiles]     - 1/4: ___check_obj_id(L['self']._modules['attn'].processor, 131625168527168)
V0228 19:39:36.559000 1516 torch/_dynamo/guards.py:2791] [1/5] [__recompiles]     - 1/3: ___check_obj_id(L['self']._modules['attn'].processor, 131625168524720)
V0228 19:39:36.559000 1516 torch/_dynamo/guards.py:2791] [1/5] [__recompiles]     - 1/2: ___check_obj_id(L['self']._modules['attn'].processor, 131625168523232)
V0228 19:39:36.559000 1516 torch/_dynamo/guards.py:2791] [1/5] [__recompiles]     - 1/1: ___check_obj_id(L['self']._modules['attn'].processor, 131625168123856)
V0228 19:39:36.559000 1516 torch/_dynamo/guards.py:2791] [1/5] [__recompiles]     - 1/0: ___check_obj_id(L['self']._modules['attn'].processor, 131625168112624)
V0228 19:39:37.864000 1516 torch/_dynamo/guards.py:2791] [1/6] [__recompiles] Recompiling function forward in /workspace/repos/diffusers/src/diffusers/models/transformers/transformer_hunyuan_video.py:464
V0228 19:39:37.864000 1516 torch/_dynamo/guards.py:2791] [1/6] [__recompiles]     triggered by the following guard failure(s):
V0228 19:39:37.864000 1516 torch/_dynamo/guards.py:2791] [1/6] [__recompiles]     - 1/5: ___check_obj_id(L['self']._modules['attn'].processor, 131625168529616)
V0228 19:39:37.864000 1516 torch/_dynamo/guards.py:2791] [1/6] [__recompiles]     - 1/4: ___check_obj_id(L['self']._modules['attn'].processor, 131625168527168)
V0228 19:39:37.864000 1516 torch/_dynamo/guards.py:2791] [1/6] [__recompiles]     - 1/3: ___check_obj_id(L['self']._modules['attn'].processor, 131625168524720)
V0228 19:39:37.864000 1516 torch/_dynamo/guards.py:2791] [1/6] [__recompiles]     - 1/2: ___check_obj_id(L['self']._modules['attn'].processor, 131625168523232)
V0228 19:39:37.864000 1516 torch/_dynamo/guards.py:2791] [1/6] [__recompiles]     - 1/1: ___check_obj_id(L['self']._modules['attn'].processor, 131625168123856)
V0228 19:39:37.864000 1516 torch/_dynamo/guards.py:2791] [1/6] [__recompiles]     - 1/0: ___check_obj_id(L['self']._modules['attn'].processor, 131625168112624)
V0228 19:39:39.192000 1516 torch/_dynamo/guards.py:2791] [1/7] [__recompiles] Recompiling function forward in /workspace/repos/diffusers/src/diffusers/models/transformers/transformer_hunyuan_video.py:464
V0228 19:39:39.192000 1516 torch/_dynamo/guards.py:2791] [1/7] [__recompiles]     triggered by the following guard failure(s):
V0228 19:39:39.192000 1516 torch/_dynamo/guards.py:2791] [1/7] [__recompiles]     - 1/6: ___check_obj_id(L['self']._modules['attn'].processor, 131625168532064)
V0228 19:39:39.192000 1516 torch/_dynamo/guards.py:2791] [1/7] [__recompiles]     - 1/5: ___check_obj_id(L['self']._modules['attn'].processor, 131625168529616)
V0228 19:39:39.192000 1516 torch/_dynamo/guards.py:2791] [1/7] [__recompiles]     - 1/4: ___check_obj_id(L['self']._modules['attn'].processor, 131625168527168)
V0228 19:39:39.192000 1516 torch/_dynamo/guards.py:2791] [1/7] [__recompiles]     - 1/3: ___check_obj_id(L['self']._modules['attn'].processor, 131625168524720)
V0228 19:39:39.192000 1516 torch/_dynamo/guards.py:2791] [1/7] [__recompiles]     - 1/2: ___check_obj_id(L['self']._modules['attn'].processor, 131625168523232)
V0228 19:39:39.192000 1516 torch/_dynamo/guards.py:2791] [1/7] [__recompiles]     - 1/1: ___check_obj_id(L['self']._modules['attn'].processor, 131625168123856)
V0228 19:39:39.192000 1516 torch/_dynamo/guards.py:2791] [1/7] [__recompiles]     - 1/0: ___check_obj_id(L['self']._modules['attn'].processor, 131625168112624)
V0228 19:39:40.506000 1516 torch/_dynamo/guards.py:2791] [1/8] [__recompiles] Recompiling function forward in /workspace/repos/diffusers/src/diffusers/models/transformers/transformer_hunyuan_video.py:464
V0228 19:39:40.506000 1516 torch/_dynamo/guards.py:2791] [1/8] [__recompiles]     triggered by the following guard failure(s):
V0228 19:39:40.506000 1516 torch/_dynamo/guards.py:2791] [1/8] [__recompiles]     - 1/7: ___check_obj_id(L['self']._modules['attn'].processor, 131625168534512)
V0228 19:39:40.506000 1516 torch/_dynamo/guards.py:2791] [1/8] [__recompiles]     - 1/6: ___check_obj_id(L['self']._modules['attn'].processor, 131625168532064)
V0228 19:39:40.506000 1516 torch/_dynamo/guards.py:2791] [1/8] [__recompiles]     - 1/5: ___check_obj_id(L['self']._modules['attn'].processor, 131625168529616)
V0228 19:39:40.506000 1516 torch/_dynamo/guards.py:2791] [1/8] [__recompiles]     - 1/4: ___check_obj_id(L['self']._modules['attn'].processor, 131625168527168)
V0228 19:39:40.506000 1516 torch/_dynamo/guards.py:2791] [1/8] [__recompiles]     - 1/3: ___check_obj_id(L['self']._modules['attn'].processor, 131625168524720)
V0228 19:39:40.506000 1516 torch/_dynamo/guards.py:2791] [1/8] [__recompiles]     - 1/2: ___check_obj_id(L['self']._modules['attn'].processor, 131625168523232)
V0228 19:39:40.506000 1516 torch/_dynamo/guards.py:2791] [1/8] [__recompiles]     - 1/1: ___check_obj_id(L['self']._modules['attn'].processor, 131625168123856)
V0228 19:39:40.506000 1516 torch/_dynamo/guards.py:2791] [1/8] [__recompiles]     - 1/0: ___check_obj_id(L['self']._modules['attn'].processor, 131625168112624)
W0228 19:39:40.507000 1516 torch/_dynamo/convert_frame.py:906] [1/8] torch._dynamo hit config.cache_size_limit (8)
W0228 19:39:40.507000 1516 torch/_dynamo/convert_frame.py:906] [1/8]    function: 'forward' (/workspace/repos/diffusers/src/diffusers/models/transformers/transformer_hunyuan_video.py:464)
W0228 19:39:40.507000 1516 torch/_dynamo/convert_frame.py:906] [1/8]    last reason: 1/0: ___check_obj_id(L['self']._modules['attn'].processor, 131625168112624)
W0228 19:39:40.507000 1516 torch/_dynamo/convert_frame.py:906] [1/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0228 19:39:40.507000 1516 torch/_dynamo/convert_frame.py:906] [1/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.
V0228 19:39:53.500000 1516 torch/_dynamo/guards.py:2791] [2/1] [__recompiles] Recompiling function forward in /workspace/repos/diffusers/src/diffusers/models/transformers/transformer_hunyuan_video.py:386
V0228 19:39:53.500000 1516 torch/_dynamo/guards.py:2791] [2/1] [__recompiles]     triggered by the following guard failure(s):
V0228 19:39:53.500000 1516 torch/_dynamo/guards.py:2791] [2/1] [__recompiles]     - 2/0: ___check_obj_id(L['self']._modules['attn'].processor, 131625167862528)
V0228 19:39:54.543000 1516 torch/_dynamo/guards.py:2791] [2/2] [__recompiles] Recompiling function forward in /workspace/repos/diffusers/src/diffusers/models/transformers/transformer_hunyuan_video.py:386
V0228 19:39:54.543000 1516 torch/_dynamo/guards.py:2791] [2/2] [__recompiles]     triggered by the following guard failure(s):
V0228 19:39:54.543000 1516 torch/_dynamo/guards.py:2791] [2/2] [__recompiles]     - 2/1: ___check_obj_id(L['self']._modules['attn'].processor, 131625167863296)
V0228 19:39:54.543000 1516 torch/_dynamo/guards.py:2791] [2/2] [__recompiles]     - 2/0: ___check_obj_id(L['self']._modules['attn'].processor, 131625167862528)
V0228 19:39:55.842000 1516 torch/_dynamo/guards.py:2791] [2/3] [__recompiles] Recompiling function forward in /workspace/repos/diffusers/src/diffusers/models/transformers/transformer_hunyuan_video.py:386
V0228 19:39:55.842000 1516 torch/_dynamo/guards.py:2791] [2/3] [__recompiles]     triggered by the following guard failure(s):
V0228 19:39:55.842000 1516 torch/_dynamo/guards.py:2791] [2/3] [__recompiles]     - 2/2: ___check_obj_id(L['self']._modules['attn'].processor, 131625166356800)
V0228 19:39:55.842000 1516 torch/_dynamo/guards.py:2791] [2/3] [__recompiles]     - 2/1: ___check_obj_id(L['self']._modules['attn'].processor, 131625167863296)
V0228 19:39:55.842000 1516 torch/_dynamo/guards.py:2791] [2/3] [__recompiles]     - 2/0: ___check_obj_id(L['self']._modules['attn'].processor, 131625167862528)
V0228 19:39:56.968000 1516 torch/_dynamo/guards.py:2791] [2/4] [__recompiles] Recompiling function forward in /workspace/repos/diffusers/src/diffusers/models/transformers/transformer_hunyuan_video.py:386
V0228 19:39:56.968000 1516 torch/_dynamo/guards.py:2791] [2/4] [__recompiles]     triggered by the following guard failure(s):
V0228 19:39:56.968000 1516 torch/_dynamo/guards.py:2791] [2/4] [__recompiles]     - 2/3: ___check_obj_id(L['self']._modules['attn'].processor, 131625166357568)
V0228 19:39:56.968000 1516 torch/_dynamo/guards.py:2791] [2/4] [__recompiles]     - 2/2: ___check_obj_id(L['self']._modules['attn'].processor, 131625166356800)
V0228 19:39:56.968000 1516 torch/_dynamo/guards.py:2791] [2/4] [__recompiles]     - 2/1: ___check_obj_id(L['self']._modules['attn'].processor, 131625167863296)
V0228 19:39:56.968000 1516 torch/_dynamo/guards.py:2791] [2/4] [__recompiles]     - 2/0: ___check_obj_id(L['self']._modules['attn'].processor, 131625167862528)
V0228 19:39:58.033000 1516 torch/_dynamo/guards.py:2791] [2/5] [__recompiles] Recompiling function forward in /workspace/repos/diffusers/src/diffusers/models/transformers/transformer_hunyuan_video.py:386
V0228 19:39:58.033000 1516 torch/_dynamo/guards.py:2791] [2/5] [__recompiles]     triggered by the following guard failure(s):
V0228 19:39:58.033000 1516 torch/_dynamo/guards.py:2791] [2/5] [__recompiles]     - 2/4: ___check_obj_id(L['self']._modules['attn'].processor, 131625166358336)
V0228 19:39:58.033000 1516 torch/_dynamo/guards.py:2791] [2/5] [__recompiles]     - 2/3: ___check_obj_id(L['self']._modules['attn'].processor, 131625166357568)
V0228 19:39:58.033000 1516 torch/_dynamo/guards.py:2791] [2/5] [__recompiles]     - 2/2: ___check_obj_id(L['self']._modules['attn'].processor, 131625166356800)
V0228 19:39:58.033000 1516 torch/_dynamo/guards.py:2791] [2/5] [__recompiles]     - 2/1: ___check_obj_id(L['self']._modules['attn'].processor, 131625167863296)
V0228 19:39:58.033000 1516 torch/_dynamo/guards.py:2791] [2/5] [__recompiles]     - 2/0: ___check_obj_id(L['self']._modules['attn'].processor, 131625167862528)
V0228 19:39:59.050000 1516 torch/_dynamo/guards.py:2791] [2/6] [__recompiles] Recompiling function forward in /workspace/repos/diffusers/src/diffusers/models/transformers/transformer_hunyuan_video.py:386
V0228 19:39:59.050000 1516 torch/_dynamo/guards.py:2791] [2/6] [__recompiles]     triggered by the following guard failure(s):
V0228 19:39:59.050000 1516 torch/_dynamo/guards.py:2791] [2/6] [__recompiles]     - 2/5: ___check_obj_id(L['self']._modules['attn'].processor, 131625166359152)
V0228 19:39:59.050000 1516 torch/_dynamo/guards.py:2791] [2/6] [__recompiles]     - 2/4: ___check_obj_id(L['self']._modules['attn'].processor, 131625166358336)
V0228 19:39:59.050000 1516 torch/_dynamo/guards.py:2791] [2/6] [__recompiles]     - 2/3: ___check_obj_id(L['self']._modules['attn'].processor, 131625166357568)
V0228 19:39:59.050000 1516 torch/_dynamo/guards.py:2791] [2/6] [__recompiles]     - 2/2: ___check_obj_id(L['self']._modules['attn'].processor, 131625166356800)
V0228 19:39:59.050000 1516 torch/_dynamo/guards.py:2791] [2/6] [__recompiles]     - 2/1: ___check_obj_id(L['self']._modules['attn'].processor, 131625167863296)
V0228 19:39:59.050000 1516 torch/_dynamo/guards.py:2791] [2/6] [__recompiles]     - 2/0: ___check_obj_id(L['self']._modules['attn'].processor, 131625167862528)
V0228 19:40:00.090000 1516 torch/_dynamo/guards.py:2791] [2/7] [__recompiles] Recompiling function forward in /workspace/repos/diffusers/src/diffusers/models/transformers/transformer_hunyuan_video.py:386
V0228 19:40:00.090000 1516 torch/_dynamo/guards.py:2791] [2/7] [__recompiles]     triggered by the following guard failure(s):
V0228 19:40:00.090000 1516 torch/_dynamo/guards.py:2791] [2/7] [__recompiles]     - 2/6: ___check_obj_id(L['self']._modules['attn'].processor, 131625166359968)
V0228 19:40:00.090000 1516 torch/_dynamo/guards.py:2791] [2/7] [__recompiles]     - 2/5: ___check_obj_id(L['self']._modules['attn'].processor, 131625166359152)
V0228 19:40:00.090000 1516 torch/_dynamo/guards.py:2791] [2/7] [__recompiles]     - 2/4: ___check_obj_id(L['self']._modules['attn'].processor, 131625166358336)
V0228 19:40:00.090000 1516 torch/_dynamo/guards.py:2791] [2/7] [__recompiles]     - 2/3: ___check_obj_id(L['self']._modules['attn'].processor, 131625166357568)
V0228 19:40:00.090000 1516 torch/_dynamo/guards.py:2791] [2/7] [__recompiles]     - 2/2: ___check_obj_id(L['self']._modules['attn'].processor, 131625166356800)
V0228 19:40:00.090000 1516 torch/_dynamo/guards.py:2791] [2/7] [__recompiles]     - 2/1: ___check_obj_id(L['self']._modules['attn'].processor, 131625167863296)
V0228 19:40:00.090000 1516 torch/_dynamo/guards.py:2791] [2/7] [__recompiles]     - 2/0: ___check_obj_id(L['self']._modules['attn'].processor, 131625167862528)
V0228 19:40:01.093000 1516 torch/_dynamo/guards.py:2791] [2/8] [__recompiles] Recompiling function forward in /workspace/repos/diffusers/src/diffusers/models/transformers/transformer_hunyuan_video.py:386
V0228 19:40:01.093000 1516 torch/_dynamo/guards.py:2791] [2/8] [__recompiles]     triggered by the following guard failure(s):
V0228 19:40:01.093000 1516 torch/_dynamo/guards.py:2791] [2/8] [__recompiles]     - 2/7: ___check_obj_id(L['self']._modules['attn'].processor, 131625166360736)
V0228 19:40:01.093000 1516 torch/_dynamo/guards.py:2791] [2/8] [__recompiles]     - 2/6: ___check_obj_id(L['self']._modules['attn'].processor, 131625166359968)
V0228 19:40:01.093000 1516 torch/_dynamo/guards.py:2791] [2/8] [__recompiles]     - 2/5: ___check_obj_id(L['self']._modules['attn'].processor, 131625166359152)
V0228 19:40:01.093000 1516 torch/_dynamo/guards.py:2791] [2/8] [__recompiles]     - 2/4: ___check_obj_id(L['self']._modules['attn'].processor, 131625166358336)
V0228 19:40:01.093000 1516 torch/_dynamo/guards.py:2791] [2/8] [__recompiles]     - 2/3: ___check_obj_id(L['self']._modules['attn'].processor, 131625166357568)
V0228 19:40:01.093000 1516 torch/_dynamo/guards.py:2791] [2/8] [__recompiles]     - 2/2: ___check_obj_id(L['self']._modules['attn'].processor, 131625166356800)
V0228 19:40:01.093000 1516 torch/_dynamo/guards.py:2791] [2/8] [__recompiles]     - 2/1: ___check_obj_id(L['self']._modules['attn'].processor, 131625167863296)
V0228 19:40:01.093000 1516 torch/_dynamo/guards.py:2791] [2/8] [__recompiles]     - 2/0: ___check_obj_id(L['self']._modules['attn'].processor, 131625167862528)
W0228 19:40:01.093000 1516 torch/_dynamo/convert_frame.py:906] [2/8] torch._dynamo hit config.cache_size_limit (8)
W0228 19:40:01.093000 1516 torch/_dynamo/convert_frame.py:906] [2/8]    function: 'forward' (/workspace/repos/diffusers/src/diffusers/models/transformers/transformer_hunyuan_video.py:386)
W0228 19:40:01.093000 1516 torch/_dynamo/convert_frame.py:906] [2/8]    last reason: 2/0: ___check_obj_id(L['self']._modules['attn'].processor, 131625167862528)
W0228 19:40:01.093000 1516 torch/_dynamo/convert_frame.py:906] [2/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
  3%|| 1/30 [01:16<36:51, 76.25s/it]V0228 19:40:32.112000 1516 torch/_dynamo/guards.py:2791] [0/1] [__recompiles] Recompiling function forward in /workspace/repos/diffusers/src/diffusers/models/transformers/transformer_hunyuan_video.py:675
V0228 19:40:32.112000 1516 torch/_dynamo/guards.py:2791] [0/1] [__recompiles]     triggered by the following guard failure(s):
V0228 19:40:32.112000 1516 torch/_dynamo/guards.py:2791] [0/1] [__recompiles]     - 0/0: tensor 'L['timestep']' dtype mismatch. expected Float, actual BFloat16

System Info

- 🤗 Diffusers version: 0.33.0.dev0 (+git37a5f1b)
- Platform: Linux-6.5.0-27-generic-x86_64-with-glibc2.39
- Running on Google Colab?: No
- Python version: 3.12.3
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.29.1
- Transformers version: 4.49.0
- Accelerate version: 1.4.0
- PEFT version: not installed
- Bitsandbytes version: not installed
- Safetensors version: 0.5.3
- xFormers version: not installed
- Accelerator: NVIDIA H100 80GB HBM3, 81559 MiB
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: No

Who can help?

@DN6 @a-r-r-o-w

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions