Skip to content

Commit 265840a

Browse files
authored
[LoRA] fix: disabling hooks when loading loras. (#11896)
fix: disabling hooks when loading loras.
1 parent 9f4d997 commit 265840a

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ def _func_optionally_disable_offloading(_pipeline):
470470
for _, component in _pipeline.components.items():
471471
if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"):
472472
continue
473-
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
473+
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
474474

475475
return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)
476476

tests/lora/utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2510,3 +2510,34 @@ def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
25102510
# materializes the test methods on invocation which cannot be overridden.
25112511
return
25122512
self._test_group_offloading_inference_denoiser(offload_type, use_stream)
2513+
2514+
@require_torch_accelerator
2515+
def test_lora_loading_model_cpu_offload(self):
2516+
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
2517+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
2518+
pipe = self.pipeline_class(**components)
2519+
pipe = pipe.to(torch_device)
2520+
pipe.set_progress_bar_config(disable=None)
2521+
2522+
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
2523+
denoiser.add_adapter(denoiser_lora_config)
2524+
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
2525+
2526+
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
2527+
2528+
with tempfile.TemporaryDirectory() as tmpdirname:
2529+
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
2530+
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
2531+
self.pipeline_class.save_lora_weights(
2532+
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
2533+
)
2534+
# reinitialize the pipeline to mimic the inference workflow.
2535+
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
2536+
pipe = self.pipeline_class(**components)
2537+
pipe.enable_model_cpu_offload(device=torch_device)
2538+
pipe.load_lora_weights(tmpdirname)
2539+
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
2540+
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
2541+
2542+
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
2543+
self.assertTrue(np.allclose(output_lora, output_lora_loaded, atol=1e-3, rtol=1e-3))

0 commit comments

Comments
 (0)