diff --git a/diffsynth_engine/pipelines/wan_video.py b/diffsynth_engine/pipelines/wan_video.py index 4842d4e..9abb7fc 100644 --- a/diffsynth_engine/pipelines/wan_video.py +++ b/diffsynth_engine/pipelines/wan_video.py @@ -417,7 +417,7 @@ def __call__( cfg_scale_ = cfg_scale if isinstance(cfg_scale, float) else cfg_scale[0] timestep = timestep * mask[:, :, :, ::2, ::2].flatten() # seq_len - timestep = timestep.to(dtype=self.config.model_dtype, device=self.device) + timestep = timestep.to(dtype=self.dtype, device=self.device) # Classifier-free guidance noise_pred = self.predict_noise_with_cfg( model=model, @@ -574,6 +574,18 @@ def from_pretrained(cls, model_path_or_config: WanPipelineConfig) -> "WanVideoPi if config.offload_mode is not None: pipe.enable_cpu_offload(config.offload_mode) + if config.model_dtype == torch.float8_e4m3fn: + pipe.dtype = torch.bfloat16 # compute dtype + pipe.enable_fp8_autocast( + model_names=["dit"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear + ) + + if config.t5_dtype == torch.float8_e4m3fn: + pipe.dtype = torch.bfloat16 # compute dtype + pipe.enable_fp8_autocast( + model_names=["text_encoder"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear + ) + if config.parallelism > 1: return ParallelWrapper( pipe,