From ba8eabc251dfefcc88873887c64718c3a3dcb3b4 Mon Sep 17 00:00:00 2001 From: "nemonameless@qq.com@github.com" Date: Tue, 12 Dec 2023 07:36:04 +0000 Subject: [PATCH 1/3] [ppdiffusers] add uvit recompute --- ppdiffusers/ppdiffusers/models/uvit.py | 39 ++++++++++++++++++-------- 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/uvit.py b/ppdiffusers/ppdiffusers/models/uvit.py index 31c5a98a4..1e13ee905 100644 --- a/ppdiffusers/ppdiffusers/models/uvit.py +++ b/ppdiffusers/ppdiffusers/models/uvit.py @@ -15,6 +15,7 @@ from typing import Optional import einops +import numpy as np import paddle import paddle.nn as nn import paddle.nn.functional as F @@ -53,7 +54,7 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0. self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) - self._use_memory_efficient_attention_xformers = False + self._use_memory_efficient_attention_xformers = True self._attention_op = None def reshape_heads_to_batch_dim(self, tensor, transpose=True): @@ -71,9 +72,8 @@ def reshape_batch_dim_to_heads(self, tensor, transpose=True): def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[str] = None ): - # remove this PR: https://github.com/PaddlePaddle/Paddle/pull/56045 - # if self.head_size > 128 and attention_op == "flash": - # attention_op = "cutlass" + if self.head_size > 128 and attention_op == "flash": + attention_op = "cutlass" if use_memory_efficient_attention_xformers: if not is_ppxformers_available(): raise NotImplementedError( @@ -194,6 +194,7 @@ class UViTModel(ModelMixin, ConfigMixin): after concatenat-ing a long skip connection, which stabilizes the training of U-ViT in UniDiffuser. """ + _supports_gradient_checkpointing = True @register_to_config def __init__( @@ -253,6 +254,7 @@ def __init__( norm_layer = nn.LayerNorm self.pos_drop = nn.Dropout(p=pos_drop_rate) + dpr = np.linspace(0, drop_rate, depth + 1) self.in_blocks = nn.LayerList( [ Block( @@ -261,11 +263,11 @@ def __init__( mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, + drop=dpr[i], attn_drop=attn_drop_rate, norm_layer=norm_layer, ) - for _ in range(depth // 2) + for i in range(depth // 2) ] ) @@ -275,7 +277,7 @@ def __init__( mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, + drop=dpr[depth // 2], attn_drop=attn_drop_rate, norm_layer=norm_layer, ) @@ -288,12 +290,12 @@ def __init__( mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, + drop=dpr[i + 1 + depth // 2], attn_drop=attn_drop_rate, norm_layer=norm_layer, skip=True, ) - for _ in range(depth // 2) + for i in range(depth // 2) ] ) @@ -306,6 +308,10 @@ def __init__( shape=(1, 1, embed_dim), default_initializer=nn.initializer.Constant(0.0) ) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (Block)): + module.gradient_checkpointing = value + def forward( self, img: paddle.Tensor, @@ -362,13 +368,22 @@ def forward( skips = [] for blk in self.in_blocks: - x = blk(x) + if self._supports_gradient_checkpointing: + x = paddle.distributed.fleet.utils.recompute(blk, x) + else: + x = blk(x) skips.append(x) - x = self.mid_block(x) + if self._supports_gradient_checkpointing: + x = paddle.distributed.fleet.utils.recompute(self.mid_block, x) + else: + x = self.mid_block(x) for blk in self.out_blocks: - x = blk(x, skips.pop()) + if self._supports_gradient_checkpointing: + x = paddle.distributed.fleet.utils.recompute(blk, x, skips.pop()) + else: + x = blk(x, skips.pop()) x = self.norm(x) From ecc5557a991a77efa88f23d571fd294964488e36 Mon Sep 17 00:00:00 2001 From: "nemonameless@qq.com@github.com" Date: Tue, 12 Dec 2023 08:13:40 +0000 Subject: [PATCH 2/3] [ppdiffusers] fix uvit recompute --- ppdiffusers/ppdiffusers/models/uvit.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/uvit.py b/ppdiffusers/ppdiffusers/models/uvit.py index 1e13ee905..7938d7414 100644 --- a/ppdiffusers/ppdiffusers/models/uvit.py +++ b/ppdiffusers/ppdiffusers/models/uvit.py @@ -307,10 +307,10 @@ def __init__( self.pos_embed_token = self.create_parameter( shape=(1, 1, embed_dim), default_initializer=nn.initializer.Constant(0.0) ) + self.gradient_checkpointing = False def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (Block)): - module.gradient_checkpointing = value + self.gradient_checkpointing = value def forward( self, @@ -368,19 +368,19 @@ def forward( skips = [] for blk in self.in_blocks: - if self._supports_gradient_checkpointing: + if self.gradient_checkpointing: x = paddle.distributed.fleet.utils.recompute(blk, x) else: x = blk(x) skips.append(x) - if self._supports_gradient_checkpointing: + if self.gradient_checkpointing: x = paddle.distributed.fleet.utils.recompute(self.mid_block, x) else: x = self.mid_block(x) for blk in self.out_blocks: - if self._supports_gradient_checkpointing: + if self.gradient_checkpointing: x = paddle.distributed.fleet.utils.recompute(blk, x, skips.pop()) else: x = blk(x, skips.pop()) From 37ecf4c2b22ea5c8d31a590e911cce7dc4f4e931 Mon Sep 17 00:00:00 2001 From: "nemonameless@qq.com@github.com" Date: Tue, 12 Dec 2023 08:17:41 +0000 Subject: [PATCH 3/3] [ppdiffusers] fix uvit recompute --- ppdiffusers/ppdiffusers/models/uvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ppdiffusers/ppdiffusers/models/uvit.py b/ppdiffusers/ppdiffusers/models/uvit.py index 7938d7414..d9e295a2f 100644 --- a/ppdiffusers/ppdiffusers/models/uvit.py +++ b/ppdiffusers/ppdiffusers/models/uvit.py @@ -54,7 +54,7 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0. self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) - self._use_memory_efficient_attention_xformers = True + self._use_memory_efficient_attention_xformers = is_ppxformers_available() self._attention_op = None def reshape_heads_to_batch_dim(self, tensor, transpose=True):