@@ -291,9 +291,7 @@ def _get_modules_to_save(self, pipe, has_denoiser=False):
291
291
292
292
return modules_to_save
293
293
294
- def check_if_adapters_added_correctly (
295
- self , pipe , text_lora_config = None , denoiser_lora_config = None , adapter_name = "default"
296
- ):
294
+ def add_adapters_to_pipeline (self , pipe , text_lora_config = None , denoiser_lora_config = None , adapter_name = "default" ):
297
295
if text_lora_config is not None :
298
296
if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
299
297
pipe .text_encoder .add_adapter (text_lora_config , adapter_name = adapter_name )
@@ -345,7 +343,7 @@ def test_simple_inference_with_text_lora(self):
345
343
output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
346
344
self .assertTrue (output_no_lora .shape == self .output_shape )
347
345
348
- pipe , _ = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config = None )
346
+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
349
347
350
348
output_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
351
349
self .assertTrue (
@@ -428,7 +426,7 @@ def test_low_cpu_mem_usage_with_loading(self):
428
426
output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
429
427
self .assertTrue (output_no_lora .shape == self .output_shape )
430
428
431
- pipe , _ = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config )
429
+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
432
430
433
431
images_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
434
432
@@ -484,7 +482,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
484
482
output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
485
483
self .assertTrue (output_no_lora .shape == self .output_shape )
486
484
487
- pipe , _ = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config = None )
485
+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
488
486
489
487
output_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
490
488
self .assertTrue (
@@ -522,7 +520,7 @@ def test_simple_inference_with_text_lora_fused(self):
522
520
output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
523
521
self .assertTrue (output_no_lora .shape == self .output_shape )
524
522
525
- pipe , _ = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config = None )
523
+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
526
524
527
525
pipe .fuse_lora ()
528
526
# Fusing should still keep the LoRA layers
@@ -554,7 +552,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
554
552
output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
555
553
self .assertTrue (output_no_lora .shape == self .output_shape )
556
554
557
- pipe , _ = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config = None )
555
+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
558
556
559
557
pipe .unload_lora_weights ()
560
558
# unloading should remove the LoRA layers
@@ -589,7 +587,7 @@ def test_simple_inference_with_text_lora_save_load(self):
589
587
output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
590
588
self .assertTrue (output_no_lora .shape == self .output_shape )
591
589
592
- pipe , _ = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config = None )
590
+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
593
591
594
592
images_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
595
593
@@ -640,7 +638,7 @@ def test_simple_inference_with_partial_text_lora(self):
640
638
output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
641
639
self .assertTrue (output_no_lora .shape == self .output_shape )
642
640
643
- pipe , _ = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config = None )
641
+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
644
642
645
643
state_dict = {}
646
644
if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
@@ -691,7 +689,7 @@ def test_simple_inference_save_pretrained_with_text_lora(self):
691
689
output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
692
690
self .assertTrue (output_no_lora .shape == self .output_shape )
693
691
694
- pipe , _ = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config = None )
692
+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
695
693
images_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
696
694
697
695
with tempfile .TemporaryDirectory () as tmpdirname :
@@ -734,7 +732,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
734
732
output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
735
733
self .assertTrue (output_no_lora .shape == self .output_shape )
736
734
737
- pipe , _ = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config )
735
+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
738
736
739
737
images_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
740
738
@@ -775,7 +773,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
775
773
output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
776
774
self .assertTrue (output_no_lora .shape == self .output_shape )
777
775
778
- pipe , _ = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config )
776
+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
779
777
780
778
output_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
781
779
self .assertTrue (
@@ -819,7 +817,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
819
817
output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
820
818
self .assertTrue (output_no_lora .shape == self .output_shape )
821
819
822
- pipe , denoiser = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config )
820
+ pipe , denoiser = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
823
821
824
822
pipe .fuse_lora (components = self .pipeline_class ._lora_loadable_modules )
825
823
@@ -857,7 +855,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
857
855
output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
858
856
self .assertTrue (output_no_lora .shape == self .output_shape )
859
857
860
- pipe , denoiser = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config )
858
+ pipe , denoiser = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
861
859
862
860
pipe .unload_lora_weights ()
863
861
# unloading should remove the LoRA layers
@@ -893,7 +891,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused(
893
891
pipe .set_progress_bar_config (disable = None )
894
892
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
895
893
896
- pipe , denoiser = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config )
894
+ pipe , denoiser = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
897
895
898
896
pipe .fuse_lora (components = self .pipeline_class ._lora_loadable_modules )
899
897
self .assertTrue (pipe .num_fused_loras == 1 , f"{ pipe .num_fused_loras = } , { pipe .fused_loras = } " )
@@ -1010,7 +1008,7 @@ def test_wrong_adapter_name_raises_error(self):
1010
1008
pipe .set_progress_bar_config (disable = None )
1011
1009
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
1012
1010
1013
- pipe , _ = self .check_if_adapters_added_correctly (
1011
+ pipe , _ = self .add_adapters_to_pipeline (
1014
1012
pipe , text_lora_config , denoiser_lora_config , adapter_name = adapter_name
1015
1013
)
1016
1014
@@ -1032,7 +1030,7 @@ def test_multiple_wrong_adapter_name_raises_error(self):
1032
1030
pipe .set_progress_bar_config (disable = None )
1033
1031
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
1034
1032
1035
- pipe , _ = self .check_if_adapters_added_correctly (
1033
+ pipe , _ = self .add_adapters_to_pipeline (
1036
1034
pipe , text_lora_config , denoiser_lora_config , adapter_name = adapter_name
1037
1035
)
1038
1036
@@ -1759,7 +1757,7 @@ def test_simple_inference_with_dora(self):
1759
1757
output_no_dora_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
1760
1758
self .assertTrue (output_no_dora_lora .shape == self .output_shape )
1761
1759
1762
- pipe , _ = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config )
1760
+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
1763
1761
1764
1762
output_dora_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
1765
1763
@@ -1850,7 +1848,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self):
1850
1848
pipe .set_progress_bar_config (disable = None )
1851
1849
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
1852
1850
1853
- pipe , _ = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config )
1851
+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
1854
1852
1855
1853
pipe .unet = torch .compile (pipe .unet , mode = "reduce-overhead" , fullgraph = True )
1856
1854
pipe .text_encoder = torch .compile (pipe .text_encoder , mode = "reduce-overhead" , fullgraph = True )
@@ -1937,7 +1935,7 @@ def test_set_adapters_match_attention_kwargs(self):
1937
1935
output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
1938
1936
self .assertTrue (output_no_lora .shape == self .output_shape )
1939
1937
1940
- pipe , _ = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config )
1938
+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
1941
1939
1942
1940
lora_scale = 0.5
1943
1941
attention_kwargs = {attention_kwargs_name : {"scale" : lora_scale }}
@@ -2119,7 +2117,7 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
2119
2117
pipe = pipe .to (torch_device , dtype = compute_dtype )
2120
2118
pipe .set_progress_bar_config (disable = None )
2121
2119
2122
- pipe , denoiser = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config )
2120
+ pipe , denoiser = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
2123
2121
2124
2122
if storage_dtype is not None :
2125
2123
denoiser .enable_layerwise_casting (storage_dtype = storage_dtype , compute_dtype = compute_dtype )
@@ -2237,7 +2235,7 @@ def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha):
2237
2235
)
2238
2236
pipe = self .pipeline_class (** components )
2239
2237
2240
- pipe , _ = self .check_if_adapters_added_correctly (
2238
+ pipe , _ = self .add_adapters_to_pipeline (
2241
2239
pipe , text_lora_config = text_lora_config , denoiser_lora_config = denoiser_lora_config
2242
2240
)
2243
2241
@@ -2290,7 +2288,7 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
2290
2288
output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2291
2289
self .assertTrue (output_no_lora .shape == self .output_shape )
2292
2290
2293
- pipe , _ = self .check_if_adapters_added_correctly (
2291
+ pipe , _ = self .add_adapters_to_pipeline (
2294
2292
pipe , text_lora_config = text_lora_config , denoiser_lora_config = denoiser_lora_config
2295
2293
)
2296
2294
output_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
@@ -2309,6 +2307,25 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
2309
2307
np .allclose (output_lora , output_lora_pretrained , atol = 1e-3 , rtol = 1e-3 ), "Lora outputs should match."
2310
2308
)
2311
2309
2310
+ def test_lora_unload_add_adapter (self ):
2311
+ """Tests if `unload_lora_weights()` -> `add_adapter()` works."""
2312
+ scheduler_cls = self .scheduler_classes [0 ]
2313
+ components , text_lora_config , denoiser_lora_config = self .get_dummy_components (scheduler_cls )
2314
+ pipe = self .pipeline_class (** components ).to (torch_device )
2315
+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
2316
+
2317
+ pipe , _ = self .add_adapters_to_pipeline (
2318
+ pipe , text_lora_config = text_lora_config , denoiser_lora_config = denoiser_lora_config
2319
+ )
2320
+ _ = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2321
+
2322
+ # unload and then add.
2323
+ pipe .unload_lora_weights ()
2324
+ pipe , _ = self .add_adapters_to_pipeline (
2325
+ pipe , text_lora_config = text_lora_config , denoiser_lora_config = denoiser_lora_config
2326
+ )
2327
+ _ = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2328
+
2312
2329
def test_inference_load_delete_load_adapters (self ):
2313
2330
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
2314
2331
for scheduler_cls in self .scheduler_classes :
0 commit comments