Skip to content

Commit 76a62ac

Browse files
asomozahlky
andauthored
[ControlnetUnion] Multiple Fixes (#11888)
fixes --------- Co-authored-by: hlky <hlky@hlky.ac>
1 parent 1c6ab9e commit 76a62ac

File tree

3 files changed

+223
-100
lines changed

3 files changed

+223
-100
lines changed

src/diffusers/models/controlnets/controlnet_union.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ def forward(
752752
condition = self.controlnet_cond_embedding(cond)
753753
feat_seq = torch.mean(condition, dim=(2, 3))
754754
feat_seq = feat_seq + self.task_embedding[control_idx]
755-
if from_multi:
755+
if from_multi or len(control_type_idx) == 1:
756756
inputs.append(feat_seq.unsqueeze(1))
757757
condition_list.append(condition)
758758
else:
@@ -772,7 +772,7 @@ def forward(
772772
for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
773773
alpha = self.spatial_ch_projs(x[:, idx])
774774
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
775-
if from_multi:
775+
if from_multi or len(control_type_idx) == 1:
776776
controlnet_cond_fuser += condition + alpha
777777
else:
778778
controlnet_cond_fuser += condition + alpha * scale
@@ -819,11 +819,11 @@ def forward(
819819
# 6. scaling
820820
if guess_mode and not self.config.global_pool_conditions:
821821
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
822-
if from_multi:
822+
if from_multi or len(control_type_idx) == 1:
823823
scales = scales * conditioning_scale[0]
824824
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
825825
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
826-
elif from_multi:
826+
elif from_multi or len(control_type_idx) == 1:
827827
down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples]
828828
mid_block_res_sample = mid_block_res_sample * conditioning_scale[0]
829829

0 commit comments

Comments
 (0)