@@ -949,7 +949,9 @@ def forward_core_with_concat(
949
949
c_mod : torch .Tensor ,
950
950
context : Optional [torch .Tensor ] = None ,
951
951
control = None ,
952
+ transformer_options = {},
952
953
) -> torch .Tensor :
954
+ patches_replace = transformer_options .get ("patches_replace" , {})
953
955
if self .register_length > 0 :
954
956
context = torch .cat (
955
957
(
@@ -961,14 +963,25 @@ def forward_core_with_concat(
961
963
962
964
# context is B, L', D
963
965
# x is B, L, D
966
+ blocks_replace = patches_replace .get ("dit" , {})
964
967
blocks = len (self .joint_blocks )
965
968
for i in range (blocks ):
966
- context , x = self .joint_blocks [i ](
967
- context ,
968
- x ,
969
- c = c_mod ,
970
- use_checkpoint = self .use_checkpoint ,
971
- )
969
+ if ("double_block" , i ) in blocks_replace :
970
+ def block_wrap (args ):
971
+ out = {}
972
+ out ["txt" ], out ["img" ] = self .joint_blocks [i ](args ["txt" ], args ["img" ], c = args ["vec" ])
973
+ return out
974
+
975
+ out = blocks_replace [("double_block" , i )]({"img" : x , "txt" : context , "vec" : c_mod }, {"original_block" : block_wrap })
976
+ context = out ["txt" ]
977
+ x = out ["img" ]
978
+ else :
979
+ context , x = self .joint_blocks [i ](
980
+ context ,
981
+ x ,
982
+ c = c_mod ,
983
+ use_checkpoint = self .use_checkpoint ,
984
+ )
972
985
if control is not None :
973
986
control_o = control .get ("output" )
974
987
if i < len (control_o ):
@@ -986,6 +999,7 @@ def forward(
986
999
y : Optional [torch .Tensor ] = None ,
987
1000
context : Optional [torch .Tensor ] = None ,
988
1001
control = None ,
1002
+ transformer_options = {},
989
1003
) -> torch .Tensor :
990
1004
"""
991
1005
Forward pass of DiT.
@@ -1007,7 +1021,7 @@ def forward(
1007
1021
if context is not None :
1008
1022
context = self .context_embedder (context )
1009
1023
1010
- x = self .forward_core_with_concat (x , c , context , control )
1024
+ x = self .forward_core_with_concat (x , c , context , control , transformer_options )
1011
1025
1012
1026
x = self .unpatchify (x , hw = hw ) # (N, out_channels, H, W)
1013
1027
return x [:,:,:hw [- 2 ],:hw [- 1 ]]
@@ -1021,7 +1035,8 @@ def forward(
1021
1035
context : Optional [torch .Tensor ] = None ,
1022
1036
y : Optional [torch .Tensor ] = None ,
1023
1037
control = None ,
1038
+ transformer_options = {},
1024
1039
** kwargs ,
1025
1040
) -> torch .Tensor :
1026
- return super ().forward (x , timesteps , context = context , y = y , control = control )
1041
+ return super ().forward (x , timesteps , context = context , y = y , control = control , transformer_options = transformer_options )
1027
1042
0 commit comments