@@ -424,6 +424,17 @@ def _load_lora_into_text_encoder(
424
424
425
425
426
426
def _func_optionally_disable_offloading (_pipeline ):
427
+ """
428
+ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
429
+
430
+ Args:
431
+ _pipeline (`DiffusionPipeline`):
432
+ The pipeline to disable offloading for.
433
+
434
+ Returns:
435
+ tuple:
436
+ A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
437
+ """
427
438
is_model_cpu_offload = False
428
439
is_sequential_cpu_offload = False
429
440
@@ -453,6 +464,24 @@ class LoraBaseMixin:
453
464
_lora_loadable_modules = []
454
465
_merged_adapters = set ()
455
466
467
+ @property
468
+ def lora_scale (self ) -> float :
469
+ """
470
+ Returns the lora scale which can be set at run time by the pipeline. # if `_lora_scale` has not been set,
471
+ return 1.
472
+ """
473
+ return self ._lora_scale if hasattr (self , "_lora_scale" ) else 1.0
474
+
475
+ @property
476
+ def num_fused_loras (self ):
477
+ """Returns the number of LoRAs that have been fused."""
478
+ return len (self ._merged_adapters )
479
+
480
+ @property
481
+ def fused_loras (self ):
482
+ """Returns names of the LoRAs that have been fused."""
483
+ return self ._merged_adapters
484
+
456
485
def load_lora_weights (self , ** kwargs ):
457
486
raise NotImplementedError ("`load_lora_weights()` is not implemented." )
458
487
@@ -464,33 +493,6 @@ def save_lora_weights(cls, **kwargs):
464
493
def lora_state_dict (cls , ** kwargs ):
465
494
raise NotImplementedError ("`lora_state_dict()` is not implemented." )
466
495
467
- @classmethod
468
- def _optionally_disable_offloading (cls , _pipeline ):
469
- """
470
- Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
471
-
472
- Args:
473
- _pipeline (`DiffusionPipeline`):
474
- The pipeline to disable offloading for.
475
-
476
- Returns:
477
- tuple:
478
- A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
479
- """
480
- return _func_optionally_disable_offloading (_pipeline = _pipeline )
481
-
482
- @classmethod
483
- def _fetch_state_dict (cls , * args , ** kwargs ):
484
- deprecation_message = f"Using the `_fetch_state_dict()` method from { cls } has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
485
- deprecate ("_fetch_state_dict" , "0.35.0" , deprecation_message )
486
- return _fetch_state_dict (* args , ** kwargs )
487
-
488
- @classmethod
489
- def _best_guess_weight_name (cls , * args , ** kwargs ):
490
- deprecation_message = f"Using the `_best_guess_weight_name()` method from { cls } has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
491
- deprecate ("_best_guess_weight_name" , "0.35.0" , deprecation_message )
492
- return _best_guess_weight_name (* args , ** kwargs )
493
-
494
496
def unload_lora_weights (self ):
495
497
"""
496
498
Unloads the LoRA parameters.
@@ -661,19 +663,37 @@ def unfuse_lora(self, components: List[str] = [], **kwargs):
661
663
self ._merged_adapters = self ._merged_adapters - {adapter }
662
664
module .unmerge ()
663
665
664
- @property
665
- def num_fused_loras (self ):
666
- return len (self ._merged_adapters )
667
-
668
- @property
669
- def fused_loras (self ):
670
- return self ._merged_adapters
671
-
672
666
def set_adapters (
673
667
self ,
674
668
adapter_names : Union [List [str ], str ],
675
669
adapter_weights : Optional [Union [float , Dict , List [float ], List [Dict ]]] = None ,
676
670
):
671
+ """
672
+ Set the currently active adapters for use in the pipeline.
673
+
674
+ Args:
675
+ adapter_names (`List[str]` or `str`):
676
+ The names of the adapters to use.
677
+ adapter_weights (`Union[List[float], float]`, *optional*):
678
+ The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
679
+ adapters.
680
+
681
+ Example:
682
+
683
+ ```py
684
+ from diffusers import AutoPipelineForText2Image
685
+ import torch
686
+
687
+ pipeline = AutoPipelineForText2Image.from_pretrained(
688
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
689
+ ).to("cuda")
690
+ pipeline.load_lora_weights(
691
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
692
+ )
693
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
694
+ pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
695
+ ```
696
+ """
677
697
if isinstance (adapter_weights , dict ):
678
698
components_passed = set (adapter_weights .keys ())
679
699
lora_components = set (self ._lora_loadable_modules )
@@ -743,6 +763,24 @@ def set_adapters(
743
763
set_adapters_for_text_encoder (adapter_names , model , _component_adapter_weights [component ])
744
764
745
765
def disable_lora (self ):
766
+ """
767
+ Disables the active LoRA layers of the pipeline.
768
+
769
+ Example:
770
+
771
+ ```py
772
+ from diffusers import AutoPipelineForText2Image
773
+ import torch
774
+
775
+ pipeline = AutoPipelineForText2Image.from_pretrained(
776
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
777
+ ).to("cuda")
778
+ pipeline.load_lora_weights(
779
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
780
+ )
781
+ pipeline.disable_lora()
782
+ ```
783
+ """
746
784
if not USE_PEFT_BACKEND :
747
785
raise ValueError ("PEFT backend is required for this method." )
748
786
@@ -755,6 +793,24 @@ def disable_lora(self):
755
793
disable_lora_for_text_encoder (model )
756
794
757
795
def enable_lora (self ):
796
+ """
797
+ Enables the active LoRA layers of the pipeline.
798
+
799
+ Example:
800
+
801
+ ```py
802
+ from diffusers import AutoPipelineForText2Image
803
+ import torch
804
+
805
+ pipeline = AutoPipelineForText2Image.from_pretrained(
806
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
807
+ ).to("cuda")
808
+ pipeline.load_lora_weights(
809
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
810
+ )
811
+ pipeline.enable_lora()
812
+ ```
813
+ """
758
814
if not USE_PEFT_BACKEND :
759
815
raise ValueError ("PEFT backend is required for this method." )
760
816
@@ -768,10 +824,26 @@ def enable_lora(self):
768
824
769
825
def delete_adapters (self , adapter_names : Union [List [str ], str ]):
770
826
"""
827
+ Delete an adapter's LoRA layers from the pipeline.
828
+
771
829
Args:
772
- Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
773
830
adapter_names (`Union[List[str], str]`):
774
- The names of the adapter to delete. Can be a single string or a list of strings
831
+ The names of the adapters to delete.
832
+
833
+ Example:
834
+
835
+ ```py
836
+ from diffusers import AutoPipelineForText2Image
837
+ import torch
838
+
839
+ pipeline = AutoPipelineForText2Image.from_pretrained(
840
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
841
+ ).to("cuda")
842
+ pipeline.load_lora_weights(
843
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
844
+ )
845
+ pipeline.delete_adapters("cinematic")
846
+ ```
775
847
"""
776
848
if not USE_PEFT_BACKEND :
777
849
raise ValueError ("PEFT backend is required for this method." )
@@ -872,6 +944,24 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
872
944
adapter_name
873
945
].to (device )
874
946
947
+ def enable_lora_hotswap (self , ** kwargs ) -> None :
948
+ """
949
+ Hotswap adapters without triggering recompilation of a model or if the ranks of the loaded adapters are
950
+ different.
951
+
952
+ Args:
953
+ target_rank (`int`):
954
+ The highest rank among all the adapters that will be loaded.
955
+ check_compiled (`str`, *optional*, defaults to `"error"`):
956
+ How to handle a model that is already compiled. The check can return the following messages:
957
+ - "error" (default): raise an error
958
+ - "warn": issue a warning
959
+ - "ignore": do nothing
960
+ """
961
+ for key , component in self .components .items ():
962
+ if hasattr (component , "enable_lora_hotswap" ) and (key in self ._lora_loadable_modules ):
963
+ component .enable_lora_hotswap (** kwargs )
964
+
875
965
@staticmethod
876
966
def pack_weights (layers , prefix ):
877
967
layers_weights = layers .state_dict () if isinstance (layers , torch .nn .Module ) else layers
@@ -887,6 +977,7 @@ def write_lora_layers(
887
977
safe_serialization : bool ,
888
978
lora_adapter_metadata : Optional [dict ] = None ,
889
979
):
980
+ """Writes the state dict of the LoRA layers (optionally with metadata) to disk."""
890
981
if os .path .isfile (save_directory ):
891
982
logger .error (f"Provided path ({ save_directory } ) should be a directory, not a file" )
892
983
return
@@ -927,28 +1018,18 @@ def save_function(weights, filename):
927
1018
save_function (state_dict , save_path )
928
1019
logger .info (f"Model weights saved in { save_path } " )
929
1020
930
- @property
931
- def lora_scale (self ) -> float :
932
- # property function that returns the lora scale which can be set at run time by the pipeline.
933
- # if _lora_scale has not been set, return 1
934
- return self ._lora_scale if hasattr (self , "_lora_scale" ) else 1.0
935
-
936
- def enable_lora_hotswap (self , ** kwargs ) -> None :
937
- """Enables the possibility to hotswap LoRA adapters.
1021
+ @classmethod
1022
+ def _optionally_disable_offloading (cls , _pipeline ):
1023
+ return _func_optionally_disable_offloading (_pipeline = _pipeline )
938
1024
939
- Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
940
- the loaded adapters differ.
1025
+ @classmethod
1026
+ def _fetch_state_dict (cls , * args , ** kwargs ):
1027
+ deprecation_message = f"Using the `_fetch_state_dict()` method from { cls } has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
1028
+ deprecate ("_fetch_state_dict" , "0.35.0" , deprecation_message )
1029
+ return _fetch_state_dict (* args , ** kwargs )
941
1030
942
- Args:
943
- target_rank (`int`):
944
- The highest rank among all the adapters that will be loaded.
945
- check_compiled (`str`, *optional*, defaults to `"error"`):
946
- How to handle the case when the model is already compiled, which should generally be avoided. The
947
- options are:
948
- - "error" (default): raise an error
949
- - "warn": issue a warning
950
- - "ignore": do nothing
951
- """
952
- for key , component in self .components .items ():
953
- if hasattr (component , "enable_lora_hotswap" ) and (key in self ._lora_loadable_modules ):
954
- component .enable_lora_hotswap (** kwargs )
1031
+ @classmethod
1032
+ def _best_guess_weight_name (cls , * args , ** kwargs ):
1033
+ deprecation_message = f"Using the `_best_guess_weight_name()` method from { cls } has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
1034
+ deprecate ("_best_guess_weight_name" , "0.35.0" , deprecation_message )
1035
+ return _best_guess_weight_name (* args , ** kwargs )
0 commit comments