Skip to content

Commit 6e82ecb

Browse files
authored
support Wan/fp8 (#145)
* enable fp8 cast in wan video
1 parent d46d93c commit 6e82ecb

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

diffsynth_engine/pipelines/wan_video.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def __call__(
417417
cfg_scale_ = cfg_scale if isinstance(cfg_scale, float) else cfg_scale[0]
418418

419419
timestep = timestep * mask[:, :, :, ::2, ::2].flatten() # seq_len
420-
timestep = timestep.to(dtype=self.config.model_dtype, device=self.device)
420+
timestep = timestep.to(dtype=self.dtype, device=self.device)
421421
# Classifier-free guidance
422422
noise_pred = self.predict_noise_with_cfg(
423423
model=model,
@@ -574,6 +574,18 @@ def from_pretrained(cls, model_path_or_config: WanPipelineConfig) -> "WanVideoPi
574574
if config.offload_mode is not None:
575575
pipe.enable_cpu_offload(config.offload_mode)
576576

577+
if config.model_dtype == torch.float8_e4m3fn:
578+
pipe.dtype = torch.bfloat16 # compute dtype
579+
pipe.enable_fp8_autocast(
580+
model_names=["dit"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear
581+
)
582+
583+
if config.t5_dtype == torch.float8_e4m3fn:
584+
pipe.dtype = torch.bfloat16 # compute dtype
585+
pipe.enable_fp8_autocast(
586+
model_names=["text_encoder"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear
587+
)
588+
577589
if config.parallelism > 1:
578590
return ParallelWrapper(
579591
pipe,

0 commit comments

Comments
 (0)