Skip to content

Commit 05e7a85

Browse files
authored
[lora] fix: lora unloading behvaiour (#11822)
* fix: lora unloading behvaiour * fix * update
1 parent 76ec3d1 commit 05e7a85

File tree

2 files changed

+43
-24
lines changed

2 files changed

+43
-24
lines changed

src/diffusers/loaders/peft.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,8 @@ def unload_lora(self):
693693
recurse_remove_peft_layers(self)
694694
if hasattr(self, "peft_config"):
695695
del self.peft_config
696+
if hasattr(self, "_hf_peft_config_loaded"):
697+
self._hf_peft_config_loaded = None
696698

697699
_maybe_remove_and_reapply_group_offloading(self)
698700

tests/lora/utils.py

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,7 @@ def _get_modules_to_save(self, pipe, has_denoiser=False):
291291

292292
return modules_to_save
293293

294-
def check_if_adapters_added_correctly(
295-
self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"
296-
):
294+
def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
297295
if text_lora_config is not None:
298296
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
299297
pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name)
@@ -345,7 +343,7 @@ def test_simple_inference_with_text_lora(self):
345343
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
346344
self.assertTrue(output_no_lora.shape == self.output_shape)
347345

348-
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
346+
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
349347

350348
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
351349
self.assertTrue(
@@ -428,7 +426,7 @@ def test_low_cpu_mem_usage_with_loading(self):
428426
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
429427
self.assertTrue(output_no_lora.shape == self.output_shape)
430428

431-
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
429+
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
432430

433431
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
434432

@@ -484,7 +482,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
484482
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
485483
self.assertTrue(output_no_lora.shape == self.output_shape)
486484

487-
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
485+
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
488486

489487
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
490488
self.assertTrue(
@@ -522,7 +520,7 @@ def test_simple_inference_with_text_lora_fused(self):
522520
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
523521
self.assertTrue(output_no_lora.shape == self.output_shape)
524522

525-
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
523+
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
526524

527525
pipe.fuse_lora()
528526
# Fusing should still keep the LoRA layers
@@ -554,7 +552,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
554552
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
555553
self.assertTrue(output_no_lora.shape == self.output_shape)
556554

557-
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
555+
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
558556

559557
pipe.unload_lora_weights()
560558
# unloading should remove the LoRA layers
@@ -589,7 +587,7 @@ def test_simple_inference_with_text_lora_save_load(self):
589587
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
590588
self.assertTrue(output_no_lora.shape == self.output_shape)
591589

592-
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
590+
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
593591

594592
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
595593

@@ -640,7 +638,7 @@ def test_simple_inference_with_partial_text_lora(self):
640638
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
641639
self.assertTrue(output_no_lora.shape == self.output_shape)
642640

643-
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
641+
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
644642

645643
state_dict = {}
646644
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -691,7 +689,7 @@ def test_simple_inference_save_pretrained_with_text_lora(self):
691689
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
692690
self.assertTrue(output_no_lora.shape == self.output_shape)
693691

694-
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
692+
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
695693
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
696694

697695
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -734,7 +732,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
734732
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
735733
self.assertTrue(output_no_lora.shape == self.output_shape)
736734

737-
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
735+
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
738736

739737
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
740738

@@ -775,7 +773,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
775773
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
776774
self.assertTrue(output_no_lora.shape == self.output_shape)
777775

778-
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
776+
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
779777

780778
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
781779
self.assertTrue(
@@ -819,7 +817,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
819817
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
820818
self.assertTrue(output_no_lora.shape == self.output_shape)
821819

822-
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
820+
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
823821

824822
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
825823

@@ -857,7 +855,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
857855
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
858856
self.assertTrue(output_no_lora.shape == self.output_shape)
859857

860-
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
858+
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
861859

862860
pipe.unload_lora_weights()
863861
# unloading should remove the LoRA layers
@@ -893,7 +891,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused(
893891
pipe.set_progress_bar_config(disable=None)
894892
_, _, inputs = self.get_dummy_inputs(with_generator=False)
895893

896-
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
894+
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
897895

898896
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
899897
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
@@ -1010,7 +1008,7 @@ def test_wrong_adapter_name_raises_error(self):
10101008
pipe.set_progress_bar_config(disable=None)
10111009
_, _, inputs = self.get_dummy_inputs(with_generator=False)
10121010

1013-
pipe, _ = self.check_if_adapters_added_correctly(
1011+
pipe, _ = self.add_adapters_to_pipeline(
10141012
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
10151013
)
10161014

@@ -1032,7 +1030,7 @@ def test_multiple_wrong_adapter_name_raises_error(self):
10321030
pipe.set_progress_bar_config(disable=None)
10331031
_, _, inputs = self.get_dummy_inputs(with_generator=False)
10341032

1035-
pipe, _ = self.check_if_adapters_added_correctly(
1033+
pipe, _ = self.add_adapters_to_pipeline(
10361034
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
10371035
)
10381036

@@ -1759,7 +1757,7 @@ def test_simple_inference_with_dora(self):
17591757
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
17601758
self.assertTrue(output_no_dora_lora.shape == self.output_shape)
17611759

1762-
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
1760+
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
17631761

17641762
output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
17651763

@@ -1850,7 +1848,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self):
18501848
pipe.set_progress_bar_config(disable=None)
18511849
_, _, inputs = self.get_dummy_inputs(with_generator=False)
18521850

1853-
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
1851+
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
18541852

18551853
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
18561854
pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
@@ -1937,7 +1935,7 @@ def test_set_adapters_match_attention_kwargs(self):
19371935
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
19381936
self.assertTrue(output_no_lora.shape == self.output_shape)
19391937

1940-
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
1938+
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
19411939

19421940
lora_scale = 0.5
19431941
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
@@ -2119,7 +2117,7 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
21192117
pipe = pipe.to(torch_device, dtype=compute_dtype)
21202118
pipe.set_progress_bar_config(disable=None)
21212119

2122-
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
2120+
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
21232121

21242122
if storage_dtype is not None:
21252123
denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
@@ -2237,7 +2235,7 @@ def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha):
22372235
)
22382236
pipe = self.pipeline_class(**components)
22392237

2240-
pipe, _ = self.check_if_adapters_added_correctly(
2238+
pipe, _ = self.add_adapters_to_pipeline(
22412239
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
22422240
)
22432241

@@ -2290,7 +2288,7 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
22902288
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
22912289
self.assertTrue(output_no_lora.shape == self.output_shape)
22922290

2293-
pipe, _ = self.check_if_adapters_added_correctly(
2291+
pipe, _ = self.add_adapters_to_pipeline(
22942292
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
22952293
)
22962294
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -2309,6 +2307,25 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
23092307
np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match."
23102308
)
23112309

2310+
def test_lora_unload_add_adapter(self):
2311+
"""Tests if `unload_lora_weights()` -> `add_adapter()` works."""
2312+
scheduler_cls = self.scheduler_classes[0]
2313+
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
2314+
pipe = self.pipeline_class(**components).to(torch_device)
2315+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
2316+
2317+
pipe, _ = self.add_adapters_to_pipeline(
2318+
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
2319+
)
2320+
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
2321+
2322+
# unload and then add.
2323+
pipe.unload_lora_weights()
2324+
pipe, _ = self.add_adapters_to_pipeline(
2325+
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
2326+
)
2327+
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
2328+
23122329
def test_inference_load_delete_load_adapters(self):
23132330
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
23142331
for scheduler_cls in self.scheduler_classes:

0 commit comments

Comments
 (0)