Skip to content

Commit cf127dd

Browse files
authored
Merge branch 'main' into main
2 parents 6964a6b + 7392c8f commit cf127dd

File tree

9 files changed

+195
-102
lines changed

9 files changed

+195
-102
lines changed

docs/source/en/api/loaders/lora.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
3737

3838
</Tip>
3939

40+
## LoraBaseMixin
41+
42+
[[autodoc]] loaders.lora_base.LoraBaseMixin
43+
4044
## StableDiffusionLoraLoaderMixin
4145

4246
[[autodoc]] loaders.lora_pipeline.StableDiffusionLoraLoaderMixin
@@ -96,10 +100,6 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
96100

97101
[[autodoc]] loaders.lora_pipeline.HiDreamImageLoraLoaderMixin
98102

99-
## LoraBaseMixin
100-
101-
[[autodoc]] loaders.lora_base.LoraBaseMixin
102-
103103
## WanLoraLoaderMixin
104104

105105
[[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin

src/diffusers/hooks/group_offloading.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,6 @@ def __init__(
9696
else:
9797
self.cpu_param_dict = self._init_cpu_param_dict()
9898

99-
if self.stream is None and self.record_stream:
100-
raise ValueError("`record_stream` cannot be True when `stream` is None.")
101-
10299
def _init_cpu_param_dict(self):
103100
cpu_param_dict = {}
104101
if self.stream is None:
@@ -513,6 +510,9 @@ def apply_group_offloading(
513510
else:
514511
raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
515512

513+
if not use_stream and record_stream:
514+
raise ValueError("`record_stream` cannot be True when `use_stream=False`.")
515+
516516
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
517517

518518
if offload_type == "block_level":

src/diffusers/loaders/lora_base.py

Lines changed: 143 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,17 @@ def _load_lora_into_text_encoder(
424424

425425

426426
def _func_optionally_disable_offloading(_pipeline):
427+
"""
428+
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
429+
430+
Args:
431+
_pipeline (`DiffusionPipeline`):
432+
The pipeline to disable offloading for.
433+
434+
Returns:
435+
tuple:
436+
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
437+
"""
427438
is_model_cpu_offload = False
428439
is_sequential_cpu_offload = False
429440

@@ -442,7 +453,8 @@ def _func_optionally_disable_offloading(_pipeline):
442453
logger.info(
443454
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
444455
)
445-
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
456+
if is_sequential_cpu_offload or is_model_cpu_offload:
457+
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
446458

447459
return (is_model_cpu_offload, is_sequential_cpu_offload)
448460

@@ -453,6 +465,24 @@ class LoraBaseMixin:
453465
_lora_loadable_modules = []
454466
_merged_adapters = set()
455467

468+
@property
469+
def lora_scale(self) -> float:
470+
"""
471+
Returns the lora scale which can be set at run time by the pipeline. # if `_lora_scale` has not been set,
472+
return 1.
473+
"""
474+
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
475+
476+
@property
477+
def num_fused_loras(self):
478+
"""Returns the number of LoRAs that have been fused."""
479+
return len(self._merged_adapters)
480+
481+
@property
482+
def fused_loras(self):
483+
"""Returns names of the LoRAs that have been fused."""
484+
return self._merged_adapters
485+
456486
def load_lora_weights(self, **kwargs):
457487
raise NotImplementedError("`load_lora_weights()` is not implemented.")
458488

@@ -464,33 +494,6 @@ def save_lora_weights(cls, **kwargs):
464494
def lora_state_dict(cls, **kwargs):
465495
raise NotImplementedError("`lora_state_dict()` is not implemented.")
466496

467-
@classmethod
468-
def _optionally_disable_offloading(cls, _pipeline):
469-
"""
470-
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
471-
472-
Args:
473-
_pipeline (`DiffusionPipeline`):
474-
The pipeline to disable offloading for.
475-
476-
Returns:
477-
tuple:
478-
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
479-
"""
480-
return _func_optionally_disable_offloading(_pipeline=_pipeline)
481-
482-
@classmethod
483-
def _fetch_state_dict(cls, *args, **kwargs):
484-
deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
485-
deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
486-
return _fetch_state_dict(*args, **kwargs)
487-
488-
@classmethod
489-
def _best_guess_weight_name(cls, *args, **kwargs):
490-
deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
491-
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
492-
return _best_guess_weight_name(*args, **kwargs)
493-
494497
def unload_lora_weights(self):
495498
"""
496499
Unloads the LoRA parameters.
@@ -661,19 +664,37 @@ def unfuse_lora(self, components: List[str] = [], **kwargs):
661664
self._merged_adapters = self._merged_adapters - {adapter}
662665
module.unmerge()
663666

664-
@property
665-
def num_fused_loras(self):
666-
return len(self._merged_adapters)
667-
668-
@property
669-
def fused_loras(self):
670-
return self._merged_adapters
671-
672667
def set_adapters(
673668
self,
674669
adapter_names: Union[List[str], str],
675670
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
676671
):
672+
"""
673+
Set the currently active adapters for use in the pipeline.
674+
675+
Args:
676+
adapter_names (`List[str]` or `str`):
677+
The names of the adapters to use.
678+
adapter_weights (`Union[List[float], float]`, *optional*):
679+
The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
680+
adapters.
681+
682+
Example:
683+
684+
```py
685+
from diffusers import AutoPipelineForText2Image
686+
import torch
687+
688+
pipeline = AutoPipelineForText2Image.from_pretrained(
689+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
690+
).to("cuda")
691+
pipeline.load_lora_weights(
692+
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
693+
)
694+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
695+
pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
696+
```
697+
"""
677698
if isinstance(adapter_weights, dict):
678699
components_passed = set(adapter_weights.keys())
679700
lora_components = set(self._lora_loadable_modules)
@@ -743,6 +764,24 @@ def set_adapters(
743764
set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])
744765

745766
def disable_lora(self):
767+
"""
768+
Disables the active LoRA layers of the pipeline.
769+
770+
Example:
771+
772+
```py
773+
from diffusers import AutoPipelineForText2Image
774+
import torch
775+
776+
pipeline = AutoPipelineForText2Image.from_pretrained(
777+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
778+
).to("cuda")
779+
pipeline.load_lora_weights(
780+
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
781+
)
782+
pipeline.disable_lora()
783+
```
784+
"""
746785
if not USE_PEFT_BACKEND:
747786
raise ValueError("PEFT backend is required for this method.")
748787

@@ -755,6 +794,24 @@ def disable_lora(self):
755794
disable_lora_for_text_encoder(model)
756795

757796
def enable_lora(self):
797+
"""
798+
Enables the active LoRA layers of the pipeline.
799+
800+
Example:
801+
802+
```py
803+
from diffusers import AutoPipelineForText2Image
804+
import torch
805+
806+
pipeline = AutoPipelineForText2Image.from_pretrained(
807+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
808+
).to("cuda")
809+
pipeline.load_lora_weights(
810+
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
811+
)
812+
pipeline.enable_lora()
813+
```
814+
"""
758815
if not USE_PEFT_BACKEND:
759816
raise ValueError("PEFT backend is required for this method.")
760817

@@ -768,10 +825,26 @@ def enable_lora(self):
768825

769826
def delete_adapters(self, adapter_names: Union[List[str], str]):
770827
"""
828+
Delete an adapter's LoRA layers from the pipeline.
829+
771830
Args:
772-
Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
773831
adapter_names (`Union[List[str], str]`):
774-
The names of the adapter to delete. Can be a single string or a list of strings
832+
The names of the adapters to delete.
833+
834+
Example:
835+
836+
```py
837+
from diffusers import AutoPipelineForText2Image
838+
import torch
839+
840+
pipeline = AutoPipelineForText2Image.from_pretrained(
841+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
842+
).to("cuda")
843+
pipeline.load_lora_weights(
844+
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
845+
)
846+
pipeline.delete_adapters("cinematic")
847+
```
775848
"""
776849
if not USE_PEFT_BACKEND:
777850
raise ValueError("PEFT backend is required for this method.")
@@ -872,6 +945,24 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
872945
adapter_name
873946
].to(device)
874947

948+
def enable_lora_hotswap(self, **kwargs) -> None:
949+
"""
950+
Hotswap adapters without triggering recompilation of a model or if the ranks of the loaded adapters are
951+
different.
952+
953+
Args:
954+
target_rank (`int`):
955+
The highest rank among all the adapters that will be loaded.
956+
check_compiled (`str`, *optional*, defaults to `"error"`):
957+
How to handle a model that is already compiled. The check can return the following messages:
958+
- "error" (default): raise an error
959+
- "warn": issue a warning
960+
- "ignore": do nothing
961+
"""
962+
for key, component in self.components.items():
963+
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
964+
component.enable_lora_hotswap(**kwargs)
965+
875966
@staticmethod
876967
def pack_weights(layers, prefix):
877968
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
@@ -887,6 +978,7 @@ def write_lora_layers(
887978
safe_serialization: bool,
888979
lora_adapter_metadata: Optional[dict] = None,
889980
):
981+
"""Writes the state dict of the LoRA layers (optionally with metadata) to disk."""
890982
if os.path.isfile(save_directory):
891983
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
892984
return
@@ -927,28 +1019,18 @@ def save_function(weights, filename):
9271019
save_function(state_dict, save_path)
9281020
logger.info(f"Model weights saved in {save_path}")
9291021

930-
@property
931-
def lora_scale(self) -> float:
932-
# property function that returns the lora scale which can be set at run time by the pipeline.
933-
# if _lora_scale has not been set, return 1
934-
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
935-
936-
def enable_lora_hotswap(self, **kwargs) -> None:
937-
"""Enables the possibility to hotswap LoRA adapters.
1022+
@classmethod
1023+
def _optionally_disable_offloading(cls, _pipeline):
1024+
return _func_optionally_disable_offloading(_pipeline=_pipeline)
9381025

939-
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
940-
the loaded adapters differ.
1026+
@classmethod
1027+
def _fetch_state_dict(cls, *args, **kwargs):
1028+
deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
1029+
deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
1030+
return _fetch_state_dict(*args, **kwargs)
9411031

942-
Args:
943-
target_rank (`int`):
944-
The highest rank among all the adapters that will be loaded.
945-
check_compiled (`str`, *optional*, defaults to `"error"`):
946-
How to handle the case when the model is already compiled, which should generally be avoided. The
947-
options are:
948-
- "error" (default): raise an error
949-
- "warn": issue a warning
950-
- "ignore": do nothing
951-
"""
952-
for key, component in self.components.items():
953-
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
954-
component.enable_lora_hotswap(**kwargs)
1032+
@classmethod
1033+
def _best_guess_weight_name(cls, *args, **kwargs):
1034+
deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
1035+
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
1036+
return _best_guess_weight_name(*args, **kwargs)

src/diffusers/loaders/peft.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,6 @@ class PeftAdapterMixin:
8585
@classmethod
8686
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
8787
def _optionally_disable_offloading(cls, _pipeline):
88-
"""
89-
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
90-
91-
Args:
92-
_pipeline (`DiffusionPipeline`):
93-
The pipeline to disable offloading for.
94-
95-
Returns:
96-
tuple:
97-
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
98-
"""
9988
return _func_optionally_disable_offloading(_pipeline=_pipeline)
10089

10190
def load_lora_adapter(
@@ -444,7 +433,7 @@ def set_adapters(
444433
weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
445434
):
446435
"""
447-
Set the currently active adapters for use in the UNet.
436+
Set the currently active adapters for use in the diffusion network (e.g. unet, transformer, etc.).
448437
449438
Args:
450439
adapter_names (`List[str]` or `str`):
@@ -466,7 +455,7 @@ def set_adapters(
466455
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
467456
)
468457
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
469-
pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
458+
pipeline.unet.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
470459
```
471460
"""
472461
if not USE_PEFT_BACKEND:
@@ -714,7 +703,7 @@ def disable_lora(self):
714703
pipeline.load_lora_weights(
715704
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
716705
)
717-
pipeline.disable_lora()
706+
pipeline.unet.disable_lora()
718707
```
719708
"""
720709
if not USE_PEFT_BACKEND:
@@ -737,7 +726,7 @@ def enable_lora(self):
737726
pipeline.load_lora_weights(
738727
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
739728
)
740-
pipeline.enable_lora()
729+
pipeline.unet.enable_lora()
741730
```
742731
"""
743732
if not USE_PEFT_BACKEND:
@@ -764,7 +753,7 @@ def delete_adapters(self, adapter_names: Union[List[str], str]):
764753
pipeline.load_lora_weights(
765754
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
766755
)
767-
pipeline.delete_adapters("cinematic")
756+
pipeline.unet.delete_adapters("cinematic")
768757
```
769758
"""
770759
if not USE_PEFT_BACKEND:

0 commit comments

Comments
 (0)