From f0d4e5f6e37dc1f08d08927ad3dcbf60c33e81db Mon Sep 17 00:00:00 2001 From: victolee0 Date: Tue, 18 Mar 2025 20:49:53 +0900 Subject: [PATCH 1/4] feat: Add gradient checkpointing support for AutoencoderKLWan --- .../models/autoencoders/autoencoder_kl_wan.py | 87 +++++++++++++------ 1 file changed, 62 insertions(+), 25 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index fafb1fe867e3..4b8470c79929 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -351,14 +351,24 @@ def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", def forward(self, x, feat_cache=None, feat_idx=[0]): # First residual block - x = self.resnets[0](x, feat_cache, feat_idx) + if torch.is_grad_enabled() and self.gradient_checkpointing: + x = self._gradient_checkpointing_func(self.resnets[0], x, feat_cache, feat_idx) - # Process through attention and residual blocks - for attn, resnet in zip(self.attentions, self.resnets[1:]): - if attn is not None: - x = attn(x) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + x = attn(x) - x = resnet(x, feat_cache, feat_idx) + x = self._gradient_checkpointing_func(resnet, x, feat_cache, feat_idx) + + else: + x = self.resnets[0](x, feat_cache, feat_idx) + + # Process through attention and residual blocks + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + x = attn(x) + + x = resnet(x, feat_cache, feat_idx) return x @@ -443,15 +453,26 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): else: x = self.conv_in(x) - ## downsamples - for layer in self.down_blocks: - if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) + if torch.is_grad_enabled() and self.gradient_checkpointing: + ## downsamples + for layer in self.down_blocks: + if feat_cache is not None: + x = self._gradient_checkpointing_func(layer, x, feat_cache, feat_idx) + else: + x = self._gradient_checkpointing_func(layer, x) + + ## middle + x = self._gradient_checkpointing_func(self.mid_block, x, feat_cache, feat_idx) + else: + ## downsamples + for layer in self.down_blocks: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) - ## middle - x = self.mid_block(x, feat_cache, feat_idx) + ## middle + x = self.mid_block(x, feat_cache, feat_idx) ## head x = self.norm_out(x) @@ -525,11 +546,19 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): Returns: torch.Tensor: Output tensor """ - for resnet in self.resnets: - if feat_cache is not None: - x = resnet(x, feat_cache, feat_idx) - else: - x = resnet(x) + if torch.is_grad_enabled() and self.gradient_checkpointing: + for resnet in self.resnets: + if feat_cache is not None: + x = self._gradient_checkpointing_func(resnet, x, feat_cache, feat_idx) + else: + x = self._gradient_checkpointing_func(resnet, x) + + else: + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache, feat_idx) + else: + x = resnet(x) if self.upsamplers is not None: if feat_cache is not None: @@ -632,12 +661,20 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): else: x = self.conv_in(x) - ## middle - x = self.mid_block(x, feat_cache, feat_idx) + if torch.is_grad_enabled() and self.gradient_checkpointing: + ## middle + x = self._gradient_checkpointing_func(self.mid_block, x, feat_cache, feat_idx) + + ## upsamples + for up_block in self.up_blocks: + x = self._gradient_checkpointing_func(up_block, x, feat_cache, feat_idx) + else: + ## middle + x = self.mid_block(x, feat_cache, feat_idx) - ## upsamples - for up_block in self.up_blocks: - x = up_block(x, feat_cache, feat_idx) + ## upsamples + for up_block in self.up_blocks: + x = up_block(x, feat_cache, feat_idx) ## head x = self.norm_out(x) @@ -665,7 +702,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): for all models (such as downloading or saving). """ - _supports_gradient_checkpointing = False + _supports_gradient_checkpointing = True @register_to_config def __init__( From c88f65bb57992742ab85a3a644531de633c1e537 Mon Sep 17 00:00:00 2001 From: victolee0 Date: Tue, 18 Mar 2025 23:09:35 +0900 Subject: [PATCH 2/4] test: add test for gradient checkpointing --- tests/models/autoencoders/test_models_autoencoder_wan.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/models/autoencoders/test_models_autoencoder_wan.py b/tests/models/autoencoders/test_models_autoencoder_wan.py index ffc474039889..03ed5bdcf073 100644 --- a/tests/models/autoencoders/test_models_autoencoder_wan.py +++ b/tests/models/autoencoders/test_models_autoencoder_wan.py @@ -62,9 +62,14 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict - @unittest.skip("Gradient checkpointing has not been implemented yet") def test_gradient_checkpointing_is_applied(self): - pass + expected_set = { + "WanDecoder3d", + "WanEncoder3d", + "WanMidBlock", + "WanUpBlock", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) @unittest.skip("Test not supported") def test_forward_with_norm_groups(self): From ecddbda9bb1d8145a3c45bc0807415163081fb55 Mon Sep 17 00:00:00 2001 From: victolee0 Date: Mon, 7 Apr 2025 21:05:56 +0900 Subject: [PATCH 3/4] fix indexerror --- .../models/autoencoders/autoencoder_kl_wan.py | 105 +++++++----------- 1 file changed, 40 insertions(+), 65 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 4b8470c79929..a0300fbe3b53 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -351,24 +351,14 @@ def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", def forward(self, x, feat_cache=None, feat_idx=[0]): # First residual block - if torch.is_grad_enabled() and self.gradient_checkpointing: - x = self._gradient_checkpointing_func(self.resnets[0], x, feat_cache, feat_idx) - - for attn, resnet in zip(self.attentions, self.resnets[1:]): - if attn is not None: - x = attn(x) - - x = self._gradient_checkpointing_func(resnet, x, feat_cache, feat_idx) - - else: - x = self.resnets[0](x, feat_cache, feat_idx) + x = self.resnets[0](x, feat_cache, feat_idx) - # Process through attention and residual blocks - for attn, resnet in zip(self.attentions, self.resnets[1:]): - if attn is not None: - x = attn(x) + # Process through attention and residual blocks + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + x = attn(x) - x = resnet(x, feat_cache, feat_idx) + x = resnet(x, feat_cache, feat_idx) return x @@ -453,26 +443,15 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): else: x = self.conv_in(x) - if torch.is_grad_enabled() and self.gradient_checkpointing: - ## downsamples - for layer in self.down_blocks: - if feat_cache is not None: - x = self._gradient_checkpointing_func(layer, x, feat_cache, feat_idx) - else: - x = self._gradient_checkpointing_func(layer, x) - - ## middle - x = self._gradient_checkpointing_func(self.mid_block, x, feat_cache, feat_idx) - else: - ## downsamples - for layer in self.down_blocks: - if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) + ## downsamples + for layer in self.down_blocks: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) - ## middle - x = self.mid_block(x, feat_cache, feat_idx) + ## middle + x = self.mid_block(x, feat_cache, feat_idx) ## head x = self.norm_out(x) @@ -546,19 +525,11 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): Returns: torch.Tensor: Output tensor """ - if torch.is_grad_enabled() and self.gradient_checkpointing: - for resnet in self.resnets: - if feat_cache is not None: - x = self._gradient_checkpointing_func(resnet, x, feat_cache, feat_idx) - else: - x = self._gradient_checkpointing_func(resnet, x) - - else: - for resnet in self.resnets: - if feat_cache is not None: - x = resnet(x, feat_cache, feat_idx) - else: - x = resnet(x) + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache, feat_idx) + else: + x = resnet(x) if self.upsamplers is not None: if feat_cache is not None: @@ -647,7 +618,14 @@ def __init__( self.gradient_checkpointing = False - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=[0], latent=0): + if torch.is_grad_enabled() and self.gradient_checkpointing: + return self._gradient_checkpointing_func(self._decode, x, feat_cache, feat_idx, latent) + else: + return self._decode(x, feat_cache, feat_idx) + + def _decode(self, x, in_cache=None, feat_idx=[0], latent=0): + feat_cache = in_cache.copy() ## conv1 if feat_cache is not None: idx = feat_idx[0] @@ -661,20 +639,12 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): else: x = self.conv_in(x) - if torch.is_grad_enabled() and self.gradient_checkpointing: - ## middle - x = self._gradient_checkpointing_func(self.mid_block, x, feat_cache, feat_idx) - - ## upsamples - for up_block in self.up_blocks: - x = self._gradient_checkpointing_func(up_block, x, feat_cache, feat_idx) - else: - ## middle - x = self.mid_block(x, feat_cache, feat_idx) + ## middle + x = self.mid_block(x, feat_cache, feat_idx) - ## upsamples - for up_block in self.up_blocks: - x = up_block(x, feat_cache, feat_idx) + ## upsamples + for up_block in self.up_blocks: + x = up_block(x, feat_cache, feat_idx) ## head x = self.norm_out(x) @@ -690,7 +660,8 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): feat_idx[0] += 1 else: x = self.conv_out(x) - return x + feat_idx[0] = 0 + return x, feat_cache class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): @@ -836,9 +807,13 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut for i in range(iter_): self._conv_idx = [0] if i == 0: - out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out, self._feat_map = self.decoder( + x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, latent=i + ) else: - out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out_, self._feat_map = self.decoder( + x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, latent=i + ) out = torch.cat([out, out_], 2) out = torch.clamp(out, min=-1.0, max=1.0) From c559b9b9f922f5d7e31fd7f952de2c5e4fc4caba Mon Sep 17 00:00:00 2001 From: victolee0 Date: Tue, 8 Apr 2025 21:36:02 +0900 Subject: [PATCH 4/4] fix: remove latent variable --- .../models/autoencoders/autoencoder_kl_wan.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index a0300fbe3b53..cbfeab09562a 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -618,13 +618,13 @@ def __init__( self.gradient_checkpointing = False - def forward(self, x, feat_cache=None, feat_idx=[0], latent=0): + def forward(self, x, feat_cache=None, feat_idx=[0]): if torch.is_grad_enabled() and self.gradient_checkpointing: - return self._gradient_checkpointing_func(self._decode, x, feat_cache, feat_idx, latent) + return self._gradient_checkpointing_func(self._decode, x, feat_cache, feat_idx) else: return self._decode(x, feat_cache, feat_idx) - def _decode(self, x, in_cache=None, feat_idx=[0], latent=0): + def _decode(self, x, in_cache=None, feat_idx=[0]): feat_cache = in_cache.copy() ## conv1 if feat_cache is not None: @@ -808,11 +808,11 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut self._conv_idx = [0] if i == 0: out, self._feat_map = self.decoder( - x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, latent=i + x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx ) else: out_, self._feat_map = self.decoder( - x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, latent=i + x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx ) out = torch.cat([out, out_], 2)