Skip to content

Commit de7cdf6

Browse files
a-r-r-o-wDN6
andauthored
Merge modular diffusers with main (#11893)
* [CI] Fix big GPU test marker (#11786) * update * update * First Block Cache (#11180) * update * modify flux single blocks to make compatible with cache techniques (without too much model-specific intrusion code) * remove debug logs * update * cache context for different batches of data * fix hs residual bug for single return outputs; support ltx * fix controlnet flux * support flux, ltx i2v, ltx condition * update * update * Update docs/source/en/api/cache.md * Update src/diffusers/hooks/hooks.py Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> * address review comments pt. 1 * address review comments pt. 2 * cache context refacotr; address review pt. 3 * address review comments * metadata registration with decorators instead of centralized * support cogvideox * support mochi * fix * remove unused function * remove central registry based on review * update --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> * fix --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
1 parent 73c5fe8 commit de7cdf6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+687
-268
lines changed

.github/workflows/nightly_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ jobs:
248248
BIG_GPU_MEMORY: 40
249249
run: |
250250
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
251-
-m "big_gpu_with_torch_cuda" \
251+
-m "big_accelerator" \
252252
--make-reports=tests_big_gpu_torch_cuda \
253253
--report-log=tests_big_gpu_torch_cuda.log \
254254
tests/

docs/source/en/api/cache.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,9 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
2828
[[autodoc]] FasterCacheConfig
2929

3030
[[autodoc]] apply_faster_cache
31+
32+
### FirstBlockCacheConfig
33+
34+
[[autodoc]] FirstBlockCacheConfig
35+
36+
[[autodoc]] apply_first_block_cache

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,13 @@
147147
_import_structure["hooks"].extend(
148148
[
149149
"FasterCacheConfig",
150+
"FirstBlockCacheConfig",
150151
"HookRegistry",
151152
"LayerSkipConfig",
152153
"PyramidAttentionBroadcastConfig",
153154
"SmoothedEnergyGuidanceConfig",
154155
"apply_faster_cache",
156+
"apply_first_block_cache",
155157
"apply_layer_skip",
156158
"apply_pyramid_attention_broadcast",
157159
]
@@ -793,11 +795,13 @@
793795
)
794796
from .hooks import (
795797
FasterCacheConfig,
798+
FirstBlockCacheConfig,
796799
HookRegistry,
797800
LayerSkipConfig,
798801
PyramidAttentionBroadcastConfig,
799802
SmoothedEnergyGuidanceConfig,
800803
apply_faster_cache,
804+
apply_first_block_cache,
801805
apply_layer_skip,
802806
apply_pyramid_attention_broadcast,
803807
)

src/diffusers/hooks/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,23 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
from ..utils import is_torch_available
216

317

418
if is_torch_available():
519
from .faster_cache import FasterCacheConfig, apply_faster_cache
20+
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
621
from .group_offloading import apply_group_offloading
722
from .hooks import HookRegistry, ModelHook
823
from .layer_skip import LayerSkipConfig, apply_layer_skip

src/diffusers/hooks/_helpers.py

Lines changed: 67 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import inspect
1516
from dataclasses import dataclass
16-
from typing import Any, Callable, Type
17-
18-
from ..models.attention import BasicTransformerBlock
19-
from ..models.attention_processor import AttnProcessor2_0
20-
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
21-
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor, CogView4TransformerBlock
22-
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
23-
from ..models.transformers.transformer_hunyuan_video import (
24-
HunyuanVideoSingleTransformerBlock,
25-
HunyuanVideoTokenReplaceSingleTransformerBlock,
26-
HunyuanVideoTokenReplaceTransformerBlock,
27-
HunyuanVideoTransformerBlock,
28-
)
29-
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
30-
from ..models.transformers.transformer_mochi import MochiTransformerBlock
31-
from ..models.transformers.transformer_wan import WanTransformerBlock
17+
from typing import Any, Callable, Dict, Type
3218

3319

3420
@dataclass
@@ -38,40 +24,90 @@ class AttentionProcessorMetadata:
3824

3925
@dataclass
4026
class TransformerBlockMetadata:
41-
skip_block_output_fn: Callable[[Any], Any]
4227
return_hidden_states_index: int = None
4328
return_encoder_hidden_states_index: int = None
4429

30+
_cls: Type = None
31+
_cached_parameter_indices: Dict[str, int] = None
32+
33+
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
34+
kwargs = kwargs or {}
35+
if identifier in kwargs:
36+
return kwargs[identifier]
37+
if self._cached_parameter_indices is not None:
38+
return args[self._cached_parameter_indices[identifier]]
39+
if self._cls is None:
40+
raise ValueError("Model class is not set for metadata.")
41+
parameters = list(inspect.signature(self._cls.forward).parameters.keys())
42+
parameters = parameters[1:] # skip `self`
43+
self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
44+
if identifier not in self._cached_parameter_indices:
45+
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
46+
index = self._cached_parameter_indices[identifier]
47+
if index >= len(args):
48+
raise ValueError(f"Expected {index} arguments but got {len(args)}.")
49+
return args[index]
50+
4551

4652
class AttentionProcessorRegistry:
4753
_registry = {}
54+
# TODO(aryan): this is only required for the time being because we need to do the registrations
55+
# for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
56+
# import errors because of the models imported in this file.
57+
_is_registered = False
4858

4959
@classmethod
5060
def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
61+
cls._register()
5162
cls._registry[model_class] = metadata
5263

5364
@classmethod
5465
def get(cls, model_class: Type) -> AttentionProcessorMetadata:
66+
cls._register()
5567
if model_class not in cls._registry:
5668
raise ValueError(f"Model class {model_class} not registered.")
5769
return cls._registry[model_class]
5870

71+
@classmethod
72+
def _register(cls):
73+
if cls._is_registered:
74+
return
75+
cls._is_registered = True
76+
_register_attention_processors_metadata()
77+
5978

6079
class TransformerBlockRegistry:
6180
_registry = {}
81+
# TODO(aryan): this is only required for the time being because we need to do the registrations
82+
# for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
83+
# import errors because of the models imported in this file.
84+
_is_registered = False
6285

6386
@classmethod
6487
def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
88+
cls._register()
89+
metadata._cls = model_class
6590
cls._registry[model_class] = metadata
6691

6792
@classmethod
6893
def get(cls, model_class: Type) -> TransformerBlockMetadata:
94+
cls._register()
6995
if model_class not in cls._registry:
7096
raise ValueError(f"Model class {model_class} not registered.")
7197
return cls._registry[model_class]
7298

99+
@classmethod
100+
def _register(cls):
101+
if cls._is_registered:
102+
return
103+
cls._is_registered = True
104+
_register_transformer_blocks_metadata()
105+
73106

74107
def _register_attention_processors_metadata():
108+
from ..models.attention_processor import AttnProcessor2_0
109+
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
110+
75111
# AttnProcessor2_0
76112
AttentionProcessorRegistry.register(
77113
model_class=AttnProcessor2_0,
@@ -90,11 +126,24 @@ def _register_attention_processors_metadata():
90126

91127

92128
def _register_transformer_blocks_metadata():
129+
from ..models.attention import BasicTransformerBlock
130+
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
131+
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
132+
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
133+
from ..models.transformers.transformer_hunyuan_video import (
134+
HunyuanVideoSingleTransformerBlock,
135+
HunyuanVideoTokenReplaceSingleTransformerBlock,
136+
HunyuanVideoTokenReplaceTransformerBlock,
137+
HunyuanVideoTransformerBlock,
138+
)
139+
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
140+
from ..models.transformers.transformer_mochi import MochiTransformerBlock
141+
from ..models.transformers.transformer_wan import WanTransformerBlock
142+
93143
# BasicTransformerBlock
94144
TransformerBlockRegistry.register(
95145
model_class=BasicTransformerBlock,
96146
metadata=TransformerBlockMetadata(
97-
skip_block_output_fn=_skip_block_output_fn_BasicTransformerBlock,
98147
return_hidden_states_index=0,
99148
return_encoder_hidden_states_index=None,
100149
),
@@ -104,7 +153,6 @@ def _register_transformer_blocks_metadata():
104153
TransformerBlockRegistry.register(
105154
model_class=CogVideoXBlock,
106155
metadata=TransformerBlockMetadata(
107-
skip_block_output_fn=_skip_block_output_fn_CogVideoXBlock,
108156
return_hidden_states_index=0,
109157
return_encoder_hidden_states_index=1,
110158
),
@@ -114,7 +162,6 @@ def _register_transformer_blocks_metadata():
114162
TransformerBlockRegistry.register(
115163
model_class=CogView4TransformerBlock,
116164
metadata=TransformerBlockMetadata(
117-
skip_block_output_fn=_skip_block_output_fn_CogView4TransformerBlock,
118165
return_hidden_states_index=0,
119166
return_encoder_hidden_states_index=1,
120167
),
@@ -124,15 +171,13 @@ def _register_transformer_blocks_metadata():
124171
TransformerBlockRegistry.register(
125172
model_class=FluxTransformerBlock,
126173
metadata=TransformerBlockMetadata(
127-
skip_block_output_fn=_skip_block_output_fn_FluxTransformerBlock,
128174
return_hidden_states_index=1,
129175
return_encoder_hidden_states_index=0,
130176
),
131177
)
132178
TransformerBlockRegistry.register(
133179
model_class=FluxSingleTransformerBlock,
134180
metadata=TransformerBlockMetadata(
135-
skip_block_output_fn=_skip_block_output_fn_FluxSingleTransformerBlock,
136181
return_hidden_states_index=1,
137182
return_encoder_hidden_states_index=0,
138183
),
@@ -142,31 +187,27 @@ def _register_transformer_blocks_metadata():
142187
TransformerBlockRegistry.register(
143188
model_class=HunyuanVideoTransformerBlock,
144189
metadata=TransformerBlockMetadata(
145-
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTransformerBlock,
146190
return_hidden_states_index=0,
147191
return_encoder_hidden_states_index=1,
148192
),
149193
)
150194
TransformerBlockRegistry.register(
151195
model_class=HunyuanVideoSingleTransformerBlock,
152196
metadata=TransformerBlockMetadata(
153-
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoSingleTransformerBlock,
154197
return_hidden_states_index=0,
155198
return_encoder_hidden_states_index=1,
156199
),
157200
)
158201
TransformerBlockRegistry.register(
159202
model_class=HunyuanVideoTokenReplaceTransformerBlock,
160203
metadata=TransformerBlockMetadata(
161-
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock,
162204
return_hidden_states_index=0,
163205
return_encoder_hidden_states_index=1,
164206
),
165207
)
166208
TransformerBlockRegistry.register(
167209
model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
168210
metadata=TransformerBlockMetadata(
169-
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock,
170211
return_hidden_states_index=0,
171212
return_encoder_hidden_states_index=1,
172213
),
@@ -176,7 +217,6 @@ def _register_transformer_blocks_metadata():
176217
TransformerBlockRegistry.register(
177218
model_class=LTXVideoTransformerBlock,
178219
metadata=TransformerBlockMetadata(
179-
skip_block_output_fn=_skip_block_output_fn_LTXVideoTransformerBlock,
180220
return_hidden_states_index=0,
181221
return_encoder_hidden_states_index=None,
182222
),
@@ -186,7 +226,6 @@ def _register_transformer_blocks_metadata():
186226
TransformerBlockRegistry.register(
187227
model_class=MochiTransformerBlock,
188228
metadata=TransformerBlockMetadata(
189-
skip_block_output_fn=_skip_block_output_fn_MochiTransformerBlock,
190229
return_hidden_states_index=0,
191230
return_encoder_hidden_states_index=1,
192231
),
@@ -196,7 +235,6 @@ def _register_transformer_blocks_metadata():
196235
TransformerBlockRegistry.register(
197236
model_class=WanTransformerBlock,
198237
metadata=TransformerBlockMetadata(
199-
skip_block_output_fn=_skip_block_output_fn_WanTransformerBlock,
200238
return_hidden_states_index=0,
201239
return_encoder_hidden_states_index=None,
202240
),
@@ -223,49 +261,4 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, *
223261

224262
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
225263
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
226-
227-
228-
def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs):
229-
hidden_states = kwargs.get("hidden_states", None)
230-
if hidden_states is None and len(args) > 0:
231-
hidden_states = args[0]
232-
return hidden_states
233-
234-
235-
def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
236-
hidden_states = kwargs.get("hidden_states", None)
237-
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
238-
if hidden_states is None and len(args) > 0:
239-
hidden_states = args[0]
240-
if encoder_hidden_states is None and len(args) > 1:
241-
encoder_hidden_states = args[1]
242-
return hidden_states, encoder_hidden_states
243-
244-
245-
def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states(self, *args, **kwargs):
246-
hidden_states = kwargs.get("hidden_states", None)
247-
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
248-
if hidden_states is None and len(args) > 0:
249-
hidden_states = args[0]
250-
if encoder_hidden_states is None and len(args) > 1:
251-
encoder_hidden_states = args[1]
252-
return encoder_hidden_states, hidden_states
253-
254-
255-
_skip_block_output_fn_BasicTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
256-
_skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
257-
_skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
258-
_skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
259-
_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
260-
_skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
261-
_skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
262-
_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
263-
_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
264-
_skip_block_output_fn_LTXVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
265-
_skip_block_output_fn_MochiTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
266-
_skip_block_output_fn_WanTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
267264
# fmt: on
268-
269-
270-
_register_attention_processors_metadata()
271-
_register_transformer_blocks_metadata()

0 commit comments

Comments
 (0)