12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import inspect
15
16
from dataclasses import dataclass
16
- from typing import Any , Callable , Type
17
-
18
- from ..models .attention import BasicTransformerBlock
19
- from ..models .attention_processor import AttnProcessor2_0
20
- from ..models .transformers .cogvideox_transformer_3d import CogVideoXBlock
21
- from ..models .transformers .transformer_cogview4 import CogView4AttnProcessor , CogView4TransformerBlock
22
- from ..models .transformers .transformer_flux import FluxSingleTransformerBlock , FluxTransformerBlock
23
- from ..models .transformers .transformer_hunyuan_video import (
24
- HunyuanVideoSingleTransformerBlock ,
25
- HunyuanVideoTokenReplaceSingleTransformerBlock ,
26
- HunyuanVideoTokenReplaceTransformerBlock ,
27
- HunyuanVideoTransformerBlock ,
28
- )
29
- from ..models .transformers .transformer_ltx import LTXVideoTransformerBlock
30
- from ..models .transformers .transformer_mochi import MochiTransformerBlock
31
- from ..models .transformers .transformer_wan import WanTransformerBlock
17
+ from typing import Any , Callable , Dict , Type
32
18
33
19
34
20
@dataclass
@@ -38,40 +24,90 @@ class AttentionProcessorMetadata:
38
24
39
25
@dataclass
40
26
class TransformerBlockMetadata :
41
- skip_block_output_fn : Callable [[Any ], Any ]
42
27
return_hidden_states_index : int = None
43
28
return_encoder_hidden_states_index : int = None
44
29
30
+ _cls : Type = None
31
+ _cached_parameter_indices : Dict [str , int ] = None
32
+
33
+ def _get_parameter_from_args_kwargs (self , identifier : str , args = (), kwargs = None ):
34
+ kwargs = kwargs or {}
35
+ if identifier in kwargs :
36
+ return kwargs [identifier ]
37
+ if self ._cached_parameter_indices is not None :
38
+ return args [self ._cached_parameter_indices [identifier ]]
39
+ if self ._cls is None :
40
+ raise ValueError ("Model class is not set for metadata." )
41
+ parameters = list (inspect .signature (self ._cls .forward ).parameters .keys ())
42
+ parameters = parameters [1 :] # skip `self`
43
+ self ._cached_parameter_indices = {param : i for i , param in enumerate (parameters )}
44
+ if identifier not in self ._cached_parameter_indices :
45
+ raise ValueError (f"Parameter '{ identifier } ' not found in function signature but was requested." )
46
+ index = self ._cached_parameter_indices [identifier ]
47
+ if index >= len (args ):
48
+ raise ValueError (f"Expected { index } arguments but got { len (args )} ." )
49
+ return args [index ]
50
+
45
51
46
52
class AttentionProcessorRegistry :
47
53
_registry = {}
54
+ # TODO(aryan): this is only required for the time being because we need to do the registrations
55
+ # for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
56
+ # import errors because of the models imported in this file.
57
+ _is_registered = False
48
58
49
59
@classmethod
50
60
def register (cls , model_class : Type , metadata : AttentionProcessorMetadata ):
61
+ cls ._register ()
51
62
cls ._registry [model_class ] = metadata
52
63
53
64
@classmethod
54
65
def get (cls , model_class : Type ) -> AttentionProcessorMetadata :
66
+ cls ._register ()
55
67
if model_class not in cls ._registry :
56
68
raise ValueError (f"Model class { model_class } not registered." )
57
69
return cls ._registry [model_class ]
58
70
71
+ @classmethod
72
+ def _register (cls ):
73
+ if cls ._is_registered :
74
+ return
75
+ cls ._is_registered = True
76
+ _register_attention_processors_metadata ()
77
+
59
78
60
79
class TransformerBlockRegistry :
61
80
_registry = {}
81
+ # TODO(aryan): this is only required for the time being because we need to do the registrations
82
+ # for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
83
+ # import errors because of the models imported in this file.
84
+ _is_registered = False
62
85
63
86
@classmethod
64
87
def register (cls , model_class : Type , metadata : TransformerBlockMetadata ):
88
+ cls ._register ()
89
+ metadata ._cls = model_class
65
90
cls ._registry [model_class ] = metadata
66
91
67
92
@classmethod
68
93
def get (cls , model_class : Type ) -> TransformerBlockMetadata :
94
+ cls ._register ()
69
95
if model_class not in cls ._registry :
70
96
raise ValueError (f"Model class { model_class } not registered." )
71
97
return cls ._registry [model_class ]
72
98
99
+ @classmethod
100
+ def _register (cls ):
101
+ if cls ._is_registered :
102
+ return
103
+ cls ._is_registered = True
104
+ _register_transformer_blocks_metadata ()
105
+
73
106
74
107
def _register_attention_processors_metadata ():
108
+ from ..models .attention_processor import AttnProcessor2_0
109
+ from ..models .transformers .transformer_cogview4 import CogView4AttnProcessor
110
+
75
111
# AttnProcessor2_0
76
112
AttentionProcessorRegistry .register (
77
113
model_class = AttnProcessor2_0 ,
@@ -90,11 +126,24 @@ def _register_attention_processors_metadata():
90
126
91
127
92
128
def _register_transformer_blocks_metadata ():
129
+ from ..models .attention import BasicTransformerBlock
130
+ from ..models .transformers .cogvideox_transformer_3d import CogVideoXBlock
131
+ from ..models .transformers .transformer_cogview4 import CogView4TransformerBlock
132
+ from ..models .transformers .transformer_flux import FluxSingleTransformerBlock , FluxTransformerBlock
133
+ from ..models .transformers .transformer_hunyuan_video import (
134
+ HunyuanVideoSingleTransformerBlock ,
135
+ HunyuanVideoTokenReplaceSingleTransformerBlock ,
136
+ HunyuanVideoTokenReplaceTransformerBlock ,
137
+ HunyuanVideoTransformerBlock ,
138
+ )
139
+ from ..models .transformers .transformer_ltx import LTXVideoTransformerBlock
140
+ from ..models .transformers .transformer_mochi import MochiTransformerBlock
141
+ from ..models .transformers .transformer_wan import WanTransformerBlock
142
+
93
143
# BasicTransformerBlock
94
144
TransformerBlockRegistry .register (
95
145
model_class = BasicTransformerBlock ,
96
146
metadata = TransformerBlockMetadata (
97
- skip_block_output_fn = _skip_block_output_fn_BasicTransformerBlock ,
98
147
return_hidden_states_index = 0 ,
99
148
return_encoder_hidden_states_index = None ,
100
149
),
@@ -104,7 +153,6 @@ def _register_transformer_blocks_metadata():
104
153
TransformerBlockRegistry .register (
105
154
model_class = CogVideoXBlock ,
106
155
metadata = TransformerBlockMetadata (
107
- skip_block_output_fn = _skip_block_output_fn_CogVideoXBlock ,
108
156
return_hidden_states_index = 0 ,
109
157
return_encoder_hidden_states_index = 1 ,
110
158
),
@@ -114,7 +162,6 @@ def _register_transformer_blocks_metadata():
114
162
TransformerBlockRegistry .register (
115
163
model_class = CogView4TransformerBlock ,
116
164
metadata = TransformerBlockMetadata (
117
- skip_block_output_fn = _skip_block_output_fn_CogView4TransformerBlock ,
118
165
return_hidden_states_index = 0 ,
119
166
return_encoder_hidden_states_index = 1 ,
120
167
),
@@ -124,15 +171,13 @@ def _register_transformer_blocks_metadata():
124
171
TransformerBlockRegistry .register (
125
172
model_class = FluxTransformerBlock ,
126
173
metadata = TransformerBlockMetadata (
127
- skip_block_output_fn = _skip_block_output_fn_FluxTransformerBlock ,
128
174
return_hidden_states_index = 1 ,
129
175
return_encoder_hidden_states_index = 0 ,
130
176
),
131
177
)
132
178
TransformerBlockRegistry .register (
133
179
model_class = FluxSingleTransformerBlock ,
134
180
metadata = TransformerBlockMetadata (
135
- skip_block_output_fn = _skip_block_output_fn_FluxSingleTransformerBlock ,
136
181
return_hidden_states_index = 1 ,
137
182
return_encoder_hidden_states_index = 0 ,
138
183
),
@@ -142,31 +187,27 @@ def _register_transformer_blocks_metadata():
142
187
TransformerBlockRegistry .register (
143
188
model_class = HunyuanVideoTransformerBlock ,
144
189
metadata = TransformerBlockMetadata (
145
- skip_block_output_fn = _skip_block_output_fn_HunyuanVideoTransformerBlock ,
146
190
return_hidden_states_index = 0 ,
147
191
return_encoder_hidden_states_index = 1 ,
148
192
),
149
193
)
150
194
TransformerBlockRegistry .register (
151
195
model_class = HunyuanVideoSingleTransformerBlock ,
152
196
metadata = TransformerBlockMetadata (
153
- skip_block_output_fn = _skip_block_output_fn_HunyuanVideoSingleTransformerBlock ,
154
197
return_hidden_states_index = 0 ,
155
198
return_encoder_hidden_states_index = 1 ,
156
199
),
157
200
)
158
201
TransformerBlockRegistry .register (
159
202
model_class = HunyuanVideoTokenReplaceTransformerBlock ,
160
203
metadata = TransformerBlockMetadata (
161
- skip_block_output_fn = _skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock ,
162
204
return_hidden_states_index = 0 ,
163
205
return_encoder_hidden_states_index = 1 ,
164
206
),
165
207
)
166
208
TransformerBlockRegistry .register (
167
209
model_class = HunyuanVideoTokenReplaceSingleTransformerBlock ,
168
210
metadata = TransformerBlockMetadata (
169
- skip_block_output_fn = _skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock ,
170
211
return_hidden_states_index = 0 ,
171
212
return_encoder_hidden_states_index = 1 ,
172
213
),
@@ -176,7 +217,6 @@ def _register_transformer_blocks_metadata():
176
217
TransformerBlockRegistry .register (
177
218
model_class = LTXVideoTransformerBlock ,
178
219
metadata = TransformerBlockMetadata (
179
- skip_block_output_fn = _skip_block_output_fn_LTXVideoTransformerBlock ,
180
220
return_hidden_states_index = 0 ,
181
221
return_encoder_hidden_states_index = None ,
182
222
),
@@ -186,7 +226,6 @@ def _register_transformer_blocks_metadata():
186
226
TransformerBlockRegistry .register (
187
227
model_class = MochiTransformerBlock ,
188
228
metadata = TransformerBlockMetadata (
189
- skip_block_output_fn = _skip_block_output_fn_MochiTransformerBlock ,
190
229
return_hidden_states_index = 0 ,
191
230
return_encoder_hidden_states_index = 1 ,
192
231
),
@@ -196,7 +235,6 @@ def _register_transformer_blocks_metadata():
196
235
TransformerBlockRegistry .register (
197
236
model_class = WanTransformerBlock ,
198
237
metadata = TransformerBlockMetadata (
199
- skip_block_output_fn = _skip_block_output_fn_WanTransformerBlock ,
200
238
return_hidden_states_index = 0 ,
201
239
return_encoder_hidden_states_index = None ,
202
240
),
@@ -223,49 +261,4 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, *
223
261
224
262
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
225
263
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
226
-
227
-
228
- def _skip_block_output_fn___hidden_states_0___ret___hidden_states (self , * args , ** kwargs ):
229
- hidden_states = kwargs .get ("hidden_states" , None )
230
- if hidden_states is None and len (args ) > 0 :
231
- hidden_states = args [0 ]
232
- return hidden_states
233
-
234
-
235
- def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states (self , * args , ** kwargs ):
236
- hidden_states = kwargs .get ("hidden_states" , None )
237
- encoder_hidden_states = kwargs .get ("encoder_hidden_states" , None )
238
- if hidden_states is None and len (args ) > 0 :
239
- hidden_states = args [0 ]
240
- if encoder_hidden_states is None and len (args ) > 1 :
241
- encoder_hidden_states = args [1 ]
242
- return hidden_states , encoder_hidden_states
243
-
244
-
245
- def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states (self , * args , ** kwargs ):
246
- hidden_states = kwargs .get ("hidden_states" , None )
247
- encoder_hidden_states = kwargs .get ("encoder_hidden_states" , None )
248
- if hidden_states is None and len (args ) > 0 :
249
- hidden_states = args [0 ]
250
- if encoder_hidden_states is None and len (args ) > 1 :
251
- encoder_hidden_states = args [1 ]
252
- return encoder_hidden_states , hidden_states
253
-
254
-
255
- _skip_block_output_fn_BasicTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
256
- _skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
257
- _skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
258
- _skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
259
- _skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
260
- _skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
261
- _skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
262
- _skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
263
- _skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
264
- _skip_block_output_fn_LTXVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
265
- _skip_block_output_fn_MochiTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
266
- _skip_block_output_fn_WanTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
267
264
# fmt: on
268
-
269
-
270
- _register_attention_processors_metadata ()
271
- _register_transformer_blocks_metadata ()
0 commit comments