Skip to content

Commit 30c0c81

Browse files
Add a way to patch blocks in SD3.
1 parent 13b0ff8 commit 30c0c81

File tree

1 file changed

+23
-8
lines changed
  • comfy/ldm/modules/diffusionmodules

1 file changed

+23
-8
lines changed

comfy/ldm/modules/diffusionmodules/mmdit.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -949,7 +949,9 @@ def forward_core_with_concat(
949949
c_mod: torch.Tensor,
950950
context: Optional[torch.Tensor] = None,
951951
control = None,
952+
transformer_options = {},
952953
) -> torch.Tensor:
954+
patches_replace = transformer_options.get("patches_replace", {})
953955
if self.register_length > 0:
954956
context = torch.cat(
955957
(
@@ -961,14 +963,25 @@ def forward_core_with_concat(
961963

962964
# context is B, L', D
963965
# x is B, L, D
966+
blocks_replace = patches_replace.get("dit", {})
964967
blocks = len(self.joint_blocks)
965968
for i in range(blocks):
966-
context, x = self.joint_blocks[i](
967-
context,
968-
x,
969-
c=c_mod,
970-
use_checkpoint=self.use_checkpoint,
971-
)
969+
if ("double_block", i) in blocks_replace:
970+
def block_wrap(args):
971+
out = {}
972+
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
973+
return out
974+
975+
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
976+
context = out["txt"]
977+
x = out["img"]
978+
else:
979+
context, x = self.joint_blocks[i](
980+
context,
981+
x,
982+
c=c_mod,
983+
use_checkpoint=self.use_checkpoint,
984+
)
972985
if control is not None:
973986
control_o = control.get("output")
974987
if i < len(control_o):
@@ -986,6 +999,7 @@ def forward(
986999
y: Optional[torch.Tensor] = None,
9871000
context: Optional[torch.Tensor] = None,
9881001
control = None,
1002+
transformer_options = {},
9891003
) -> torch.Tensor:
9901004
"""
9911005
Forward pass of DiT.
@@ -1007,7 +1021,7 @@ def forward(
10071021
if context is not None:
10081022
context = self.context_embedder(context)
10091023

1010-
x = self.forward_core_with_concat(x, c, context, control)
1024+
x = self.forward_core_with_concat(x, c, context, control, transformer_options)
10111025

10121026
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
10131027
return x[:,:,:hw[-2],:hw[-1]]
@@ -1021,7 +1035,8 @@ def forward(
10211035
context: Optional[torch.Tensor] = None,
10221036
y: Optional[torch.Tensor] = None,
10231037
control = None,
1038+
transformer_options = {},
10241039
**kwargs,
10251040
) -> torch.Tensor:
1026-
return super().forward(x, timesteps, context=context, y=y, control=control)
1041+
return super().forward(x, timesteps, context=context, y=y, control=control, transformer_options=transformer_options)
10271042

0 commit comments

Comments
 (0)