Skip to content

[ppdiffusers] add uvit recompute #347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 12, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions ppdiffusers/ppdiffusers/models/uvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Optional

import einops
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
Expand Down Expand Up @@ -53,7 +54,7 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

self._use_memory_efficient_attention_xformers = False
self._use_memory_efficient_attention_xformers = True
self._attention_op = None

def reshape_heads_to_batch_dim(self, tensor, transpose=True):
Expand All @@ -71,9 +72,8 @@ def reshape_batch_dim_to_heads(self, tensor, transpose=True):
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[str] = None
):
# remove this PR: https://github.com/PaddlePaddle/Paddle/pull/56045
# if self.head_size > 128 and attention_op == "flash":
# attention_op = "cutlass"
if self.head_size > 128 and attention_op == "flash":
attention_op = "cutlass"
if use_memory_efficient_attention_xformers:
if not is_ppxformers_available():
raise NotImplementedError(
Expand Down Expand Up @@ -194,6 +194,7 @@ class UViTModel(ModelMixin, ConfigMixin):
after concatenat-ing a long skip connection, which stabilizes the training of U-ViT in UniDiffuser.

"""
_supports_gradient_checkpointing = True

@register_to_config
def __init__(
Expand Down Expand Up @@ -253,6 +254,7 @@ def __init__(
norm_layer = nn.LayerNorm
self.pos_drop = nn.Dropout(p=pos_drop_rate)

dpr = np.linspace(0, drop_rate, depth + 1)
self.in_blocks = nn.LayerList(
[
Block(
Expand All @@ -261,11 +263,11 @@ def __init__(
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
drop=dpr[i],
attn_drop=attn_drop_rate,
norm_layer=norm_layer,
)
for _ in range(depth // 2)
for i in range(depth // 2)
]
)

Expand All @@ -275,7 +277,7 @@ def __init__(
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
drop=dpr[depth // 2],
attn_drop=attn_drop_rate,
norm_layer=norm_layer,
)
Expand All @@ -288,12 +290,12 @@ def __init__(
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
drop=dpr[i + 1 + depth // 2],
attn_drop=attn_drop_rate,
norm_layer=norm_layer,
skip=True,
)
for _ in range(depth // 2)
for i in range(depth // 2)
]
)

Expand All @@ -305,6 +307,10 @@ def __init__(
self.pos_embed_token = self.create_parameter(
shape=(1, 1, embed_dim), default_initializer=nn.initializer.Constant(0.0)
)
self.gradient_checkpointing = False

def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value

def forward(
self,
Expand Down Expand Up @@ -362,13 +368,22 @@ def forward(

skips = []
for blk in self.in_blocks:
x = blk(x)
if self.gradient_checkpointing:
x = paddle.distributed.fleet.utils.recompute(blk, x)
else:
x = blk(x)
skips.append(x)

x = self.mid_block(x)
if self.gradient_checkpointing:
x = paddle.distributed.fleet.utils.recompute(self.mid_block, x)
else:
x = self.mid_block(x)

for blk in self.out_blocks:
x = blk(x, skips.pop())
if self.gradient_checkpointing:
x = paddle.distributed.fleet.utils.recompute(blk, x, skips.pop())
else:
x = blk(x, skips.pop())

x = self.norm(x)

Expand Down