Skip to content

Commit 9254271

Browse files
sayakpaulstevhliu
andauthored
[docs] minor cleanups in the lora docs. (#11770)
* minor cleanups in the lora docs. * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * format docs * fix copies --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
1 parent 6760300 commit 9254271

File tree

4 files changed

+150
-91
lines changed

4 files changed

+150
-91
lines changed

docs/source/en/api/loaders/lora.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
3737

3838
</Tip>
3939

40+
## LoraBaseMixin
41+
42+
[[autodoc]] loaders.lora_base.LoraBaseMixin
43+
4044
## StableDiffusionLoraLoaderMixin
4145

4246
[[autodoc]] loaders.lora_pipeline.StableDiffusionLoraLoaderMixin
@@ -96,10 +100,6 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
96100

97101
[[autodoc]] loaders.lora_pipeline.HiDreamImageLoraLoaderMixin
98102

99-
## LoraBaseMixin
100-
101-
[[autodoc]] loaders.lora_base.LoraBaseMixin
102-
103103
## WanLoraLoaderMixin
104104

105105
[[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin

src/diffusers/loaders/lora_base.py

Lines changed: 141 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,17 @@ def _load_lora_into_text_encoder(
424424

425425

426426
def _func_optionally_disable_offloading(_pipeline):
427+
"""
428+
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
429+
430+
Args:
431+
_pipeline (`DiffusionPipeline`):
432+
The pipeline to disable offloading for.
433+
434+
Returns:
435+
tuple:
436+
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
437+
"""
427438
is_model_cpu_offload = False
428439
is_sequential_cpu_offload = False
429440

@@ -453,6 +464,24 @@ class LoraBaseMixin:
453464
_lora_loadable_modules = []
454465
_merged_adapters = set()
455466

467+
@property
468+
def lora_scale(self) -> float:
469+
"""
470+
Returns the lora scale which can be set at run time by the pipeline. # if `_lora_scale` has not been set,
471+
return 1.
472+
"""
473+
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
474+
475+
@property
476+
def num_fused_loras(self):
477+
"""Returns the number of LoRAs that have been fused."""
478+
return len(self._merged_adapters)
479+
480+
@property
481+
def fused_loras(self):
482+
"""Returns names of the LoRAs that have been fused."""
483+
return self._merged_adapters
484+
456485
def load_lora_weights(self, **kwargs):
457486
raise NotImplementedError("`load_lora_weights()` is not implemented.")
458487

@@ -464,33 +493,6 @@ def save_lora_weights(cls, **kwargs):
464493
def lora_state_dict(cls, **kwargs):
465494
raise NotImplementedError("`lora_state_dict()` is not implemented.")
466495

467-
@classmethod
468-
def _optionally_disable_offloading(cls, _pipeline):
469-
"""
470-
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
471-
472-
Args:
473-
_pipeline (`DiffusionPipeline`):
474-
The pipeline to disable offloading for.
475-
476-
Returns:
477-
tuple:
478-
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
479-
"""
480-
return _func_optionally_disable_offloading(_pipeline=_pipeline)
481-
482-
@classmethod
483-
def _fetch_state_dict(cls, *args, **kwargs):
484-
deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
485-
deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
486-
return _fetch_state_dict(*args, **kwargs)
487-
488-
@classmethod
489-
def _best_guess_weight_name(cls, *args, **kwargs):
490-
deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
491-
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
492-
return _best_guess_weight_name(*args, **kwargs)
493-
494496
def unload_lora_weights(self):
495497
"""
496498
Unloads the LoRA parameters.
@@ -661,19 +663,37 @@ def unfuse_lora(self, components: List[str] = [], **kwargs):
661663
self._merged_adapters = self._merged_adapters - {adapter}
662664
module.unmerge()
663665

664-
@property
665-
def num_fused_loras(self):
666-
return len(self._merged_adapters)
667-
668-
@property
669-
def fused_loras(self):
670-
return self._merged_adapters
671-
672666
def set_adapters(
673667
self,
674668
adapter_names: Union[List[str], str],
675669
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
676670
):
671+
"""
672+
Set the currently active adapters for use in the pipeline.
673+
674+
Args:
675+
adapter_names (`List[str]` or `str`):
676+
The names of the adapters to use.
677+
adapter_weights (`Union[List[float], float]`, *optional*):
678+
The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
679+
adapters.
680+
681+
Example:
682+
683+
```py
684+
from diffusers import AutoPipelineForText2Image
685+
import torch
686+
687+
pipeline = AutoPipelineForText2Image.from_pretrained(
688+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
689+
).to("cuda")
690+
pipeline.load_lora_weights(
691+
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
692+
)
693+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
694+
pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
695+
```
696+
"""
677697
if isinstance(adapter_weights, dict):
678698
components_passed = set(adapter_weights.keys())
679699
lora_components = set(self._lora_loadable_modules)
@@ -743,6 +763,24 @@ def set_adapters(
743763
set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])
744764

745765
def disable_lora(self):
766+
"""
767+
Disables the active LoRA layers of the pipeline.
768+
769+
Example:
770+
771+
```py
772+
from diffusers import AutoPipelineForText2Image
773+
import torch
774+
775+
pipeline = AutoPipelineForText2Image.from_pretrained(
776+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
777+
).to("cuda")
778+
pipeline.load_lora_weights(
779+
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
780+
)
781+
pipeline.disable_lora()
782+
```
783+
"""
746784
if not USE_PEFT_BACKEND:
747785
raise ValueError("PEFT backend is required for this method.")
748786

@@ -755,6 +793,24 @@ def disable_lora(self):
755793
disable_lora_for_text_encoder(model)
756794

757795
def enable_lora(self):
796+
"""
797+
Enables the active LoRA layers of the pipeline.
798+
799+
Example:
800+
801+
```py
802+
from diffusers import AutoPipelineForText2Image
803+
import torch
804+
805+
pipeline = AutoPipelineForText2Image.from_pretrained(
806+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
807+
).to("cuda")
808+
pipeline.load_lora_weights(
809+
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
810+
)
811+
pipeline.enable_lora()
812+
```
813+
"""
758814
if not USE_PEFT_BACKEND:
759815
raise ValueError("PEFT backend is required for this method.")
760816

@@ -768,10 +824,26 @@ def enable_lora(self):
768824

769825
def delete_adapters(self, adapter_names: Union[List[str], str]):
770826
"""
827+
Delete an adapter's LoRA layers from the pipeline.
828+
771829
Args:
772-
Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
773830
adapter_names (`Union[List[str], str]`):
774-
The names of the adapter to delete. Can be a single string or a list of strings
831+
The names of the adapters to delete.
832+
833+
Example:
834+
835+
```py
836+
from diffusers import AutoPipelineForText2Image
837+
import torch
838+
839+
pipeline = AutoPipelineForText2Image.from_pretrained(
840+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
841+
).to("cuda")
842+
pipeline.load_lora_weights(
843+
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
844+
)
845+
pipeline.delete_adapters("cinematic")
846+
```
775847
"""
776848
if not USE_PEFT_BACKEND:
777849
raise ValueError("PEFT backend is required for this method.")
@@ -872,6 +944,24 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
872944
adapter_name
873945
].to(device)
874946

947+
def enable_lora_hotswap(self, **kwargs) -> None:
948+
"""
949+
Hotswap adapters without triggering recompilation of a model or if the ranks of the loaded adapters are
950+
different.
951+
952+
Args:
953+
target_rank (`int`):
954+
The highest rank among all the adapters that will be loaded.
955+
check_compiled (`str`, *optional*, defaults to `"error"`):
956+
How to handle a model that is already compiled. The check can return the following messages:
957+
- "error" (default): raise an error
958+
- "warn": issue a warning
959+
- "ignore": do nothing
960+
"""
961+
for key, component in self.components.items():
962+
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
963+
component.enable_lora_hotswap(**kwargs)
964+
875965
@staticmethod
876966
def pack_weights(layers, prefix):
877967
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
@@ -887,6 +977,7 @@ def write_lora_layers(
887977
safe_serialization: bool,
888978
lora_adapter_metadata: Optional[dict] = None,
889979
):
980+
"""Writes the state dict of the LoRA layers (optionally with metadata) to disk."""
890981
if os.path.isfile(save_directory):
891982
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
892983
return
@@ -927,28 +1018,18 @@ def save_function(weights, filename):
9271018
save_function(state_dict, save_path)
9281019
logger.info(f"Model weights saved in {save_path}")
9291020

930-
@property
931-
def lora_scale(self) -> float:
932-
# property function that returns the lora scale which can be set at run time by the pipeline.
933-
# if _lora_scale has not been set, return 1
934-
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
935-
936-
def enable_lora_hotswap(self, **kwargs) -> None:
937-
"""Enables the possibility to hotswap LoRA adapters.
1021+
@classmethod
1022+
def _optionally_disable_offloading(cls, _pipeline):
1023+
return _func_optionally_disable_offloading(_pipeline=_pipeline)
9381024

939-
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
940-
the loaded adapters differ.
1025+
@classmethod
1026+
def _fetch_state_dict(cls, *args, **kwargs):
1027+
deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
1028+
deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
1029+
return _fetch_state_dict(*args, **kwargs)
9411030

942-
Args:
943-
target_rank (`int`):
944-
The highest rank among all the adapters that will be loaded.
945-
check_compiled (`str`, *optional*, defaults to `"error"`):
946-
How to handle the case when the model is already compiled, which should generally be avoided. The
947-
options are:
948-
- "error" (default): raise an error
949-
- "warn": issue a warning
950-
- "ignore": do nothing
951-
"""
952-
for key, component in self.components.items():
953-
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
954-
component.enable_lora_hotswap(**kwargs)
1031+
@classmethod
1032+
def _best_guess_weight_name(cls, *args, **kwargs):
1033+
deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
1034+
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
1035+
return _best_guess_weight_name(*args, **kwargs)

src/diffusers/loaders/peft.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,6 @@ class PeftAdapterMixin:
8585
@classmethod
8686
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
8787
def _optionally_disable_offloading(cls, _pipeline):
88-
"""
89-
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
90-
91-
Args:
92-
_pipeline (`DiffusionPipeline`):
93-
The pipeline to disable offloading for.
94-
95-
Returns:
96-
tuple:
97-
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
98-
"""
9988
return _func_optionally_disable_offloading(_pipeline=_pipeline)
10089

10190
def load_lora_adapter(
@@ -444,7 +433,7 @@ def set_adapters(
444433
weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
445434
):
446435
"""
447-
Set the currently active adapters for use in the UNet.
436+
Set the currently active adapters for use in the diffusion network (e.g. unet, transformer, etc.).
448437
449438
Args:
450439
adapter_names (`List[str]` or `str`):
@@ -466,7 +455,7 @@ def set_adapters(
466455
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
467456
)
468457
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
469-
pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
458+
pipeline.unet.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
470459
```
471460
"""
472461
if not USE_PEFT_BACKEND:
@@ -714,7 +703,7 @@ def disable_lora(self):
714703
pipeline.load_lora_weights(
715704
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
716705
)
717-
pipeline.disable_lora()
706+
pipeline.unet.disable_lora()
718707
```
719708
"""
720709
if not USE_PEFT_BACKEND:
@@ -737,7 +726,7 @@ def enable_lora(self):
737726
pipeline.load_lora_weights(
738727
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
739728
)
740-
pipeline.enable_lora()
729+
pipeline.unet.enable_lora()
741730
```
742731
"""
743732
if not USE_PEFT_BACKEND:
@@ -764,7 +753,7 @@ def delete_adapters(self, adapter_names: Union[List[str], str]):
764753
pipeline.load_lora_weights(
765754
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
766755
)
767-
pipeline.delete_adapters("cinematic")
756+
pipeline.unet.delete_adapters("cinematic")
768757
```
769758
"""
770759
if not USE_PEFT_BACKEND:

src/diffusers/loaders/unet.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -394,17 +394,6 @@ def _process_lora(
394394
@classmethod
395395
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
396396
def _optionally_disable_offloading(cls, _pipeline):
397-
"""
398-
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
399-
400-
Args:
401-
_pipeline (`DiffusionPipeline`):
402-
The pipeline to disable offloading for.
403-
404-
Returns:
405-
tuple:
406-
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
407-
"""
408397
return _func_optionally_disable_offloading(_pipeline=_pipeline)
409398

410399
def save_attn_procs(

0 commit comments

Comments
 (0)