From 14c42cd31a88e5c5c4466b9774ea405c4d326952 Mon Sep 17 00:00:00 2001 From: "chuzihao.czh" Date: Wed, 13 Aug 2025 18:01:37 +0800 Subject: [PATCH 1/3] enable fp8 cast in wan video --- diffsynth_engine/pipelines/wan_video.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/diffsynth_engine/pipelines/wan_video.py b/diffsynth_engine/pipelines/wan_video.py index 4842d4e..d5e988a 100644 --- a/diffsynth_engine/pipelines/wan_video.py +++ b/diffsynth_engine/pipelines/wan_video.py @@ -574,6 +574,18 @@ def from_pretrained(cls, model_path_or_config: WanPipelineConfig) -> "WanVideoPi if config.offload_mode is not None: pipe.enable_cpu_offload(config.offload_mode) + if config.model_dtype == torch.float8_e4m3fn: + pipe.dtype = torch.bfloat16 # compute dtype + pipe.enable_fp8_autocast( + model_names=["dit"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear + ) + + if config.t5_dtype == torch.float8_e4m3fn: + pipe.dtype = torch.bfloat16 # compute dtype + pipe.enable_fp8_autocast( + model_names=["text_encoder"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear + ) + if config.parallelism > 1: return ParallelWrapper( pipe, From 6858cf1a9e95c22fbc6937c076b9edce60568920 Mon Sep 17 00:00:00 2001 From: "chuzihao.czh" Date: Thu, 14 Aug 2025 11:03:25 +0800 Subject: [PATCH 2/3] support fp8 in wan video --- diffsynth_engine/models/wan/wan_dit.py | 3 ++- diffsynth_engine/pipelines/wan_video.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/diffsynth_engine/models/wan/wan_dit.py b/diffsynth_engine/models/wan/wan_dit.py index 9d67392..45d3072 100644 --- a/diffsynth_engine/models/wan/wan_dit.py +++ b/diffsynth_engine/models/wan/wan_dit.py @@ -349,7 +349,8 @@ def forward( gguf_inference(), cfg_parallel((x, context, timestep, clip_feature, y), use_cfg=use_cfg), ): - t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) # (s, d) + t=sinusoidal_embedding_1d(self.freq_dim, timestep).to(x.dtype) + t = self.time_embedding(t) # (s, d) t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) # (s, 6, d) context = self.text_embedding(context) if self.has_vae_feature: diff --git a/diffsynth_engine/pipelines/wan_video.py b/diffsynth_engine/pipelines/wan_video.py index d5e988a..33e66eb 100644 --- a/diffsynth_engine/pipelines/wan_video.py +++ b/diffsynth_engine/pipelines/wan_video.py @@ -417,7 +417,7 @@ def __call__( cfg_scale_ = cfg_scale if isinstance(cfg_scale, float) else cfg_scale[0] timestep = timestep * mask[:, :, :, ::2, ::2].flatten() # seq_len - timestep = timestep.to(dtype=self.config.model_dtype, device=self.device) + timestep = timestep.to(device=self.device) # Classifier-free guidance noise_pred = self.predict_noise_with_cfg( model=model, From e97d1064a26324a8d66af602486942b1c0657e7b Mon Sep 17 00:00:00 2001 From: "chuzihao.czh" Date: Fri, 15 Aug 2025 18:06:44 +0800 Subject: [PATCH 3/3] update --- diffsynth_engine/models/wan/wan_dit.py | 3 +-- diffsynth_engine/pipelines/wan_video.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/diffsynth_engine/models/wan/wan_dit.py b/diffsynth_engine/models/wan/wan_dit.py index 45d3072..9d67392 100644 --- a/diffsynth_engine/models/wan/wan_dit.py +++ b/diffsynth_engine/models/wan/wan_dit.py @@ -349,8 +349,7 @@ def forward( gguf_inference(), cfg_parallel((x, context, timestep, clip_feature, y), use_cfg=use_cfg), ): - t=sinusoidal_embedding_1d(self.freq_dim, timestep).to(x.dtype) - t = self.time_embedding(t) # (s, d) + t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) # (s, d) t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) # (s, 6, d) context = self.text_embedding(context) if self.has_vae_feature: diff --git a/diffsynth_engine/pipelines/wan_video.py b/diffsynth_engine/pipelines/wan_video.py index 33e66eb..9abb7fc 100644 --- a/diffsynth_engine/pipelines/wan_video.py +++ b/diffsynth_engine/pipelines/wan_video.py @@ -417,7 +417,7 @@ def __call__( cfg_scale_ = cfg_scale if isinstance(cfg_scale, float) else cfg_scale[0] timestep = timestep * mask[:, :, :, ::2, ::2].flatten() # seq_len - timestep = timestep.to(device=self.device) + timestep = timestep.to(dtype=self.dtype, device=self.device) # Classifier-free guidance noise_pred = self.predict_noise_with_cfg( model=model,