@@ -2510,3 +2510,34 @@ def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
2510
2510
# materializes the test methods on invocation which cannot be overridden.
2511
2511
return
2512
2512
self ._test_group_offloading_inference_denoiser (offload_type , use_stream )
2513
+
2514
+ @require_torch_accelerator
2515
+ def test_lora_loading_model_cpu_offload (self ):
2516
+ components , _ , denoiser_lora_config = self .get_dummy_components (self .scheduler_classes [0 ])
2517
+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
2518
+ pipe = self .pipeline_class (** components )
2519
+ pipe = pipe .to (torch_device )
2520
+ pipe .set_progress_bar_config (disable = None )
2521
+
2522
+ denoiser = pipe .transformer if self .unet_kwargs is None else pipe .unet
2523
+ denoiser .add_adapter (denoiser_lora_config )
2524
+ self .assertTrue (check_if_lora_correctly_set (denoiser ), "Lora not correctly set in denoiser." )
2525
+
2526
+ output_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2527
+
2528
+ with tempfile .TemporaryDirectory () as tmpdirname :
2529
+ modules_to_save = self ._get_modules_to_save (pipe , has_denoiser = True )
2530
+ lora_state_dicts = self ._get_lora_state_dicts (modules_to_save )
2531
+ self .pipeline_class .save_lora_weights (
2532
+ save_directory = tmpdirname , safe_serialization = True , ** lora_state_dicts
2533
+ )
2534
+ # reinitialize the pipeline to mimic the inference workflow.
2535
+ components , _ , denoiser_lora_config = self .get_dummy_components (self .scheduler_classes [0 ])
2536
+ pipe = self .pipeline_class (** components )
2537
+ pipe .enable_model_cpu_offload (device = torch_device )
2538
+ pipe .load_lora_weights (tmpdirname )
2539
+ denoiser = pipe .transformer if self .unet_kwargs is None else pipe .unet
2540
+ self .assertTrue (check_if_lora_correctly_set (denoiser ), "Lora not correctly set in denoiser." )
2541
+
2542
+ output_lora_loaded = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2543
+ self .assertTrue (np .allclose (output_lora , output_lora_loaded , atol = 1e-3 , rtol = 1e-3 ))
0 commit comments