@@ -417,7 +417,7 @@ def __call__(
417
417
cfg_scale_ = cfg_scale if isinstance (cfg_scale , float ) else cfg_scale [0 ]
418
418
419
419
timestep = timestep * mask [:, :, :, ::2 , ::2 ].flatten () # seq_len
420
- timestep = timestep .to (dtype = self .config . model_dtype , device = self .device )
420
+ timestep = timestep .to (dtype = self .dtype , device = self .device )
421
421
# Classifier-free guidance
422
422
noise_pred = self .predict_noise_with_cfg (
423
423
model = model ,
@@ -574,6 +574,18 @@ def from_pretrained(cls, model_path_or_config: WanPipelineConfig) -> "WanVideoPi
574
574
if config .offload_mode is not None :
575
575
pipe .enable_cpu_offload (config .offload_mode )
576
576
577
+ if config .model_dtype == torch .float8_e4m3fn :
578
+ pipe .dtype = torch .bfloat16 # compute dtype
579
+ pipe .enable_fp8_autocast (
580
+ model_names = ["dit" ], compute_dtype = pipe .dtype , use_fp8_linear = config .use_fp8_linear
581
+ )
582
+
583
+ if config .t5_dtype == torch .float8_e4m3fn :
584
+ pipe .dtype = torch .bfloat16 # compute dtype
585
+ pipe .enable_fp8_autocast (
586
+ model_names = ["text_encoder" ], compute_dtype = pipe .dtype , use_fp8_linear = config .use_fp8_linear
587
+ )
588
+
577
589
if config .parallelism > 1 :
578
590
return ParallelWrapper (
579
591
pipe ,
0 commit comments