Skip to content

Commit bc34fa8

Browse files
authored
[lora]feat: use exclude modules to loraconfig. (#11806)
* feat: use exclude modules to loraconfig. * version-guard. * tests and version guard. * remove print. * describe the test * more detailed warning message + shift to debug * update * update * update * remove test
1 parent 05e7a85 commit bc34fa8

File tree

4 files changed

+131
-12
lines changed

4 files changed

+131
-12
lines changed

src/diffusers/loaders/peft.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,13 +244,20 @@ def load_lora_adapter(
244244
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
245245
}
246246

247-
# create LoraConfig
248-
lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank)
249-
250247
# adapter_name
251248
if adapter_name is None:
252249
adapter_name = get_adapter_name(self)
253250

251+
# create LoraConfig
252+
lora_config = _create_lora_config(
253+
state_dict,
254+
network_alphas,
255+
metadata,
256+
rank,
257+
model_state_dict=self.state_dict(),
258+
adapter_name=adapter_name,
259+
)
260+
254261
# <Unsafe code
255262
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
256263
# Now we remove any existing hooks to `_pipeline`.

src/diffusers/utils/peft_utils.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
150150
module.set_scale(adapter_name, 1.0)
151151

152152

153-
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
153+
def get_peft_kwargs(
154+
rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None
155+
):
154156
rank_pattern = {}
155157
alpha_pattern = {}
156158
r = lora_alpha = list(rank_dict.values())[0]
@@ -180,7 +182,6 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
180182
else:
181183
lora_alpha = set(network_alpha_dict.values()).pop()
182184

183-
# layer names without the Diffusers specific
184185
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
185186
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
186187
# for now we know that the "bias" keys are only associated with `lora_B`.
@@ -195,6 +196,21 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
195196
"use_dora": use_dora,
196197
"lora_bias": lora_bias,
197198
}
199+
200+
# Example: try load FusionX LoRA into Wan VACE
201+
exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name)
202+
if exclude_modules:
203+
if not is_peft_version(">=", "0.14.0"):
204+
msg = """
205+
It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft`
206+
version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U
207+
peft`. For most cases, this can be completely ignored. But if it seems unexpected, please file an issue -
208+
https://github.com/huggingface/diffusers/issues/new
209+
"""
210+
logger.debug(msg)
211+
else:
212+
lora_config_kwargs.update({"exclude_modules": exclude_modules})
213+
198214
return lora_config_kwargs
199215

200216

@@ -294,19 +310,20 @@ def check_peft_version(min_version: str) -> None:
294310

295311

296312
def _create_lora_config(
297-
state_dict,
298-
network_alphas,
299-
metadata,
300-
rank_pattern_dict,
301-
is_unet: bool = True,
313+
state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None
302314
):
303315
from peft import LoraConfig
304316

305317
if metadata is not None:
306318
lora_config_kwargs = metadata
307319
else:
308320
lora_config_kwargs = get_peft_kwargs(
309-
rank_pattern_dict, network_alpha_dict=network_alphas, peft_state_dict=state_dict, is_unet=is_unet
321+
rank_pattern_dict,
322+
network_alpha_dict=network_alphas,
323+
peft_state_dict=state_dict,
324+
is_unet=is_unet,
325+
model_state_dict=model_state_dict,
326+
adapter_name=adapter_name,
310327
)
311328

312329
_maybe_raise_error_for_ambiguous_keys(lora_config_kwargs)
@@ -371,3 +388,27 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
371388

372389
if warn_msg:
373390
logger.warning(warn_msg)
391+
392+
393+
def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None):
394+
"""
395+
Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the
396+
`model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it
397+
doesn't exist in `peft_state_dict`.
398+
"""
399+
if model_state_dict is None:
400+
return
401+
all_modules = set()
402+
string_to_replace = f"{adapter_name}." if adapter_name else ""
403+
404+
for name in model_state_dict.keys():
405+
if string_to_replace:
406+
name = name.replace(string_to_replace, "")
407+
if "." in name:
408+
module_name = name.rsplit(".", 1)[0]
409+
all_modules.add(module_name)
410+
411+
target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()}
412+
exclude_modules = list(all_modules - target_modules_set)
413+
414+
return exclude_modules

tests/lora/test_lora_layers_wan.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424
WanPipeline,
2525
WanTransformer3DModel,
2626
)
27-
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps
27+
from diffusers.utils.testing_utils import (
28+
floats_tensor,
29+
require_peft_backend,
30+
skip_mps,
31+
)
2832

2933

3034
sys.path.append(".")

tests/lora/utils.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import copy
1516
import inspect
1617
import os
1718
import re
@@ -291,6 +292,20 @@ def _get_modules_to_save(self, pipe, has_denoiser=False):
291292

292293
return modules_to_save
293294

295+
def _get_exclude_modules(self, pipe):
296+
from diffusers.utils.peft_utils import _derive_exclude_modules
297+
298+
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
299+
denoiser = "unet" if self.unet_kwargs is not None else "transformer"
300+
modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser}
301+
denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"]
302+
pipe.unload_lora_weights()
303+
denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict()
304+
exclude_modules = _derive_exclude_modules(
305+
denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default"
306+
)
307+
return exclude_modules
308+
294309
def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
295310
if text_lora_config is not None:
296311
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -2326,6 +2341,58 @@ def test_lora_unload_add_adapter(self):
23262341
)
23272342
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
23282343

2344+
@require_peft_version_greater("0.13.2")
2345+
def test_lora_exclude_modules(self):
2346+
"""
2347+
Test to check if `exclude_modules` works or not. It works in the following way:
2348+
we first create a pipeline and insert LoRA config into it. We then derive a `set`
2349+
of modules to exclude by investigating its denoiser state dict and denoiser LoRA
2350+
state dict.
2351+
2352+
We then create a new LoRA config to include the `exclude_modules` and perform tests.
2353+
"""
2354+
scheduler_cls = self.scheduler_classes[0]
2355+
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
2356+
pipe = self.pipeline_class(**components).to(torch_device)
2357+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
2358+
2359+
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
2360+
self.assertTrue(output_no_lora.shape == self.output_shape)
2361+
2362+
# only supported for `denoiser` now
2363+
pipe_cp = copy.deepcopy(pipe)
2364+
pipe_cp, _ = self.add_adapters_to_pipeline(
2365+
pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
2366+
)
2367+
denoiser_exclude_modules = self._get_exclude_modules(pipe_cp)
2368+
pipe_cp.to("cpu")
2369+
del pipe_cp
2370+
2371+
denoiser_lora_config.exclude_modules = denoiser_exclude_modules
2372+
pipe, _ = self.add_adapters_to_pipeline(
2373+
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
2374+
)
2375+
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
2376+
2377+
with tempfile.TemporaryDirectory() as tmpdir:
2378+
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
2379+
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
2380+
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
2381+
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
2382+
pipe.unload_lora_weights()
2383+
pipe.load_lora_weights(tmpdir)
2384+
2385+
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
2386+
2387+
self.assertTrue(
2388+
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
2389+
"LoRA should change outputs.",
2390+
)
2391+
self.assertTrue(
2392+
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
2393+
"Lora outputs should match.",
2394+
)
2395+
23292396
def test_inference_load_delete_load_adapters(self):
23302397
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
23312398
for scheduler_cls in self.scheduler_classes:

0 commit comments

Comments
 (0)