@@ -752,7 +752,7 @@ def forward(
752
752
condition = self .controlnet_cond_embedding (cond )
753
753
feat_seq = torch .mean (condition , dim = (2 , 3 ))
754
754
feat_seq = feat_seq + self .task_embedding [control_idx ]
755
- if from_multi :
755
+ if from_multi or len ( control_type_idx ) == 1 :
756
756
inputs .append (feat_seq .unsqueeze (1 ))
757
757
condition_list .append (condition )
758
758
else :
@@ -772,7 +772,7 @@ def forward(
772
772
for (idx , condition ), scale in zip (enumerate (condition_list [:- 1 ]), conditioning_scale ):
773
773
alpha = self .spatial_ch_projs (x [:, idx ])
774
774
alpha = alpha .unsqueeze (- 1 ).unsqueeze (- 1 )
775
- if from_multi :
775
+ if from_multi or len ( control_type_idx ) == 1 :
776
776
controlnet_cond_fuser += condition + alpha
777
777
else :
778
778
controlnet_cond_fuser += condition + alpha * scale
@@ -819,11 +819,11 @@ def forward(
819
819
# 6. scaling
820
820
if guess_mode and not self .config .global_pool_conditions :
821
821
scales = torch .logspace (- 1 , 0 , len (down_block_res_samples ) + 1 , device = sample .device ) # 0.1 to 1.0
822
- if from_multi :
822
+ if from_multi or len ( control_type_idx ) == 1 :
823
823
scales = scales * conditioning_scale [0 ]
824
824
down_block_res_samples = [sample * scale for sample , scale in zip (down_block_res_samples , scales )]
825
825
mid_block_res_sample = mid_block_res_sample * scales [- 1 ] # last one
826
- elif from_multi :
826
+ elif from_multi or len ( control_type_idx ) == 1 :
827
827
down_block_res_samples = [sample * conditioning_scale [0 ] for sample in down_block_res_samples ]
828
828
mid_block_res_sample = mid_block_res_sample * conditioning_scale [0 ]
829
829
0 commit comments