@@ -424,6 +424,17 @@ def _load_lora_into_text_encoder(
424
424
425
425
426
426
def _func_optionally_disable_offloading (_pipeline ):
427
+ """
428
+ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
429
+
430
+ Args:
431
+ _pipeline (`DiffusionPipeline`):
432
+ The pipeline to disable offloading for.
433
+
434
+ Returns:
435
+ tuple:
436
+ A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
437
+ """
427
438
is_model_cpu_offload = False
428
439
is_sequential_cpu_offload = False
429
440
@@ -442,7 +453,8 @@ def _func_optionally_disable_offloading(_pipeline):
442
453
logger .info (
443
454
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
444
455
)
445
- remove_hook_from_module (component , recurse = is_sequential_cpu_offload )
456
+ if is_sequential_cpu_offload or is_model_cpu_offload :
457
+ remove_hook_from_module (component , recurse = is_sequential_cpu_offload )
446
458
447
459
return (is_model_cpu_offload , is_sequential_cpu_offload )
448
460
@@ -453,6 +465,24 @@ class LoraBaseMixin:
453
465
_lora_loadable_modules = []
454
466
_merged_adapters = set ()
455
467
468
+ @property
469
+ def lora_scale (self ) -> float :
470
+ """
471
+ Returns the lora scale which can be set at run time by the pipeline. # if `_lora_scale` has not been set,
472
+ return 1.
473
+ """
474
+ return self ._lora_scale if hasattr (self , "_lora_scale" ) else 1.0
475
+
476
+ @property
477
+ def num_fused_loras (self ):
478
+ """Returns the number of LoRAs that have been fused."""
479
+ return len (self ._merged_adapters )
480
+
481
+ @property
482
+ def fused_loras (self ):
483
+ """Returns names of the LoRAs that have been fused."""
484
+ return self ._merged_adapters
485
+
456
486
def load_lora_weights (self , ** kwargs ):
457
487
raise NotImplementedError ("`load_lora_weights()` is not implemented." )
458
488
@@ -464,33 +494,6 @@ def save_lora_weights(cls, **kwargs):
464
494
def lora_state_dict (cls , ** kwargs ):
465
495
raise NotImplementedError ("`lora_state_dict()` is not implemented." )
466
496
467
- @classmethod
468
- def _optionally_disable_offloading (cls , _pipeline ):
469
- """
470
- Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
471
-
472
- Args:
473
- _pipeline (`DiffusionPipeline`):
474
- The pipeline to disable offloading for.
475
-
476
- Returns:
477
- tuple:
478
- A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
479
- """
480
- return _func_optionally_disable_offloading (_pipeline = _pipeline )
481
-
482
- @classmethod
483
- def _fetch_state_dict (cls , * args , ** kwargs ):
484
- deprecation_message = f"Using the `_fetch_state_dict()` method from { cls } has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
485
- deprecate ("_fetch_state_dict" , "0.35.0" , deprecation_message )
486
- return _fetch_state_dict (* args , ** kwargs )
487
-
488
- @classmethod
489
- def _best_guess_weight_name (cls , * args , ** kwargs ):
490
- deprecation_message = f"Using the `_best_guess_weight_name()` method from { cls } has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
491
- deprecate ("_best_guess_weight_name" , "0.35.0" , deprecation_message )
492
- return _best_guess_weight_name (* args , ** kwargs )
493
-
494
497
def unload_lora_weights (self ):
495
498
"""
496
499
Unloads the LoRA parameters.
@@ -661,19 +664,37 @@ def unfuse_lora(self, components: List[str] = [], **kwargs):
661
664
self ._merged_adapters = self ._merged_adapters - {adapter }
662
665
module .unmerge ()
663
666
664
- @property
665
- def num_fused_loras (self ):
666
- return len (self ._merged_adapters )
667
-
668
- @property
669
- def fused_loras (self ):
670
- return self ._merged_adapters
671
-
672
667
def set_adapters (
673
668
self ,
674
669
adapter_names : Union [List [str ], str ],
675
670
adapter_weights : Optional [Union [float , Dict , List [float ], List [Dict ]]] = None ,
676
671
):
672
+ """
673
+ Set the currently active adapters for use in the pipeline.
674
+
675
+ Args:
676
+ adapter_names (`List[str]` or `str`):
677
+ The names of the adapters to use.
678
+ adapter_weights (`Union[List[float], float]`, *optional*):
679
+ The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
680
+ adapters.
681
+
682
+ Example:
683
+
684
+ ```py
685
+ from diffusers import AutoPipelineForText2Image
686
+ import torch
687
+
688
+ pipeline = AutoPipelineForText2Image.from_pretrained(
689
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
690
+ ).to("cuda")
691
+ pipeline.load_lora_weights(
692
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
693
+ )
694
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
695
+ pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
696
+ ```
697
+ """
677
698
if isinstance (adapter_weights , dict ):
678
699
components_passed = set (adapter_weights .keys ())
679
700
lora_components = set (self ._lora_loadable_modules )
@@ -743,6 +764,24 @@ def set_adapters(
743
764
set_adapters_for_text_encoder (adapter_names , model , _component_adapter_weights [component ])
744
765
745
766
def disable_lora (self ):
767
+ """
768
+ Disables the active LoRA layers of the pipeline.
769
+
770
+ Example:
771
+
772
+ ```py
773
+ from diffusers import AutoPipelineForText2Image
774
+ import torch
775
+
776
+ pipeline = AutoPipelineForText2Image.from_pretrained(
777
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
778
+ ).to("cuda")
779
+ pipeline.load_lora_weights(
780
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
781
+ )
782
+ pipeline.disable_lora()
783
+ ```
784
+ """
746
785
if not USE_PEFT_BACKEND :
747
786
raise ValueError ("PEFT backend is required for this method." )
748
787
@@ -755,6 +794,24 @@ def disable_lora(self):
755
794
disable_lora_for_text_encoder (model )
756
795
757
796
def enable_lora (self ):
797
+ """
798
+ Enables the active LoRA layers of the pipeline.
799
+
800
+ Example:
801
+
802
+ ```py
803
+ from diffusers import AutoPipelineForText2Image
804
+ import torch
805
+
806
+ pipeline = AutoPipelineForText2Image.from_pretrained(
807
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
808
+ ).to("cuda")
809
+ pipeline.load_lora_weights(
810
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
811
+ )
812
+ pipeline.enable_lora()
813
+ ```
814
+ """
758
815
if not USE_PEFT_BACKEND :
759
816
raise ValueError ("PEFT backend is required for this method." )
760
817
@@ -768,10 +825,26 @@ def enable_lora(self):
768
825
769
826
def delete_adapters (self , adapter_names : Union [List [str ], str ]):
770
827
"""
828
+ Delete an adapter's LoRA layers from the pipeline.
829
+
771
830
Args:
772
- Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
773
831
adapter_names (`Union[List[str], str]`):
774
- The names of the adapter to delete. Can be a single string or a list of strings
832
+ The names of the adapters to delete.
833
+
834
+ Example:
835
+
836
+ ```py
837
+ from diffusers import AutoPipelineForText2Image
838
+ import torch
839
+
840
+ pipeline = AutoPipelineForText2Image.from_pretrained(
841
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
842
+ ).to("cuda")
843
+ pipeline.load_lora_weights(
844
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
845
+ )
846
+ pipeline.delete_adapters("cinematic")
847
+ ```
775
848
"""
776
849
if not USE_PEFT_BACKEND :
777
850
raise ValueError ("PEFT backend is required for this method." )
@@ -872,6 +945,24 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
872
945
adapter_name
873
946
].to (device )
874
947
948
+ def enable_lora_hotswap (self , ** kwargs ) -> None :
949
+ """
950
+ Hotswap adapters without triggering recompilation of a model or if the ranks of the loaded adapters are
951
+ different.
952
+
953
+ Args:
954
+ target_rank (`int`):
955
+ The highest rank among all the adapters that will be loaded.
956
+ check_compiled (`str`, *optional*, defaults to `"error"`):
957
+ How to handle a model that is already compiled. The check can return the following messages:
958
+ - "error" (default): raise an error
959
+ - "warn": issue a warning
960
+ - "ignore": do nothing
961
+ """
962
+ for key , component in self .components .items ():
963
+ if hasattr (component , "enable_lora_hotswap" ) and (key in self ._lora_loadable_modules ):
964
+ component .enable_lora_hotswap (** kwargs )
965
+
875
966
@staticmethod
876
967
def pack_weights (layers , prefix ):
877
968
layers_weights = layers .state_dict () if isinstance (layers , torch .nn .Module ) else layers
@@ -887,6 +978,7 @@ def write_lora_layers(
887
978
safe_serialization : bool ,
888
979
lora_adapter_metadata : Optional [dict ] = None ,
889
980
):
981
+ """Writes the state dict of the LoRA layers (optionally with metadata) to disk."""
890
982
if os .path .isfile (save_directory ):
891
983
logger .error (f"Provided path ({ save_directory } ) should be a directory, not a file" )
892
984
return
@@ -927,28 +1019,18 @@ def save_function(weights, filename):
927
1019
save_function (state_dict , save_path )
928
1020
logger .info (f"Model weights saved in { save_path } " )
929
1021
930
- @property
931
- def lora_scale (self ) -> float :
932
- # property function that returns the lora scale which can be set at run time by the pipeline.
933
- # if _lora_scale has not been set, return 1
934
- return self ._lora_scale if hasattr (self , "_lora_scale" ) else 1.0
935
-
936
- def enable_lora_hotswap (self , ** kwargs ) -> None :
937
- """Enables the possibility to hotswap LoRA adapters.
1022
+ @classmethod
1023
+ def _optionally_disable_offloading (cls , _pipeline ):
1024
+ return _func_optionally_disable_offloading (_pipeline = _pipeline )
938
1025
939
- Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
940
- the loaded adapters differ.
1026
+ @classmethod
1027
+ def _fetch_state_dict (cls , * args , ** kwargs ):
1028
+ deprecation_message = f"Using the `_fetch_state_dict()` method from { cls } has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
1029
+ deprecate ("_fetch_state_dict" , "0.35.0" , deprecation_message )
1030
+ return _fetch_state_dict (* args , ** kwargs )
941
1031
942
- Args:
943
- target_rank (`int`):
944
- The highest rank among all the adapters that will be loaded.
945
- check_compiled (`str`, *optional*, defaults to `"error"`):
946
- How to handle the case when the model is already compiled, which should generally be avoided. The
947
- options are:
948
- - "error" (default): raise an error
949
- - "warn": issue a warning
950
- - "ignore": do nothing
951
- """
952
- for key , component in self .components .items ():
953
- if hasattr (component , "enable_lora_hotswap" ) and (key in self ._lora_loadable_modules ):
954
- component .enable_lora_hotswap (** kwargs )
1032
+ @classmethod
1033
+ def _best_guess_weight_name (cls , * args , ** kwargs ):
1034
+ deprecation_message = f"Using the `_best_guess_weight_name()` method from { cls } has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
1035
+ deprecate ("_best_guess_weight_name" , "0.35.0" , deprecation_message )
1036
+ return _best_guess_weight_name (* args , ** kwargs )
0 commit comments