|
12 | 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 | 13 | # See the License for the specific language governing permissions and
|
14 | 14 | # limitations under the License.
|
| 15 | +import copy |
15 | 16 | import inspect
|
16 | 17 | import os
|
17 | 18 | import re
|
@@ -291,6 +292,20 @@ def _get_modules_to_save(self, pipe, has_denoiser=False):
|
291 | 292 |
|
292 | 293 | return modules_to_save
|
293 | 294 |
|
| 295 | + def _get_exclude_modules(self, pipe): |
| 296 | + from diffusers.utils.peft_utils import _derive_exclude_modules |
| 297 | + |
| 298 | + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) |
| 299 | + denoiser = "unet" if self.unet_kwargs is not None else "transformer" |
| 300 | + modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser} |
| 301 | + denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"] |
| 302 | + pipe.unload_lora_weights() |
| 303 | + denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict() |
| 304 | + exclude_modules = _derive_exclude_modules( |
| 305 | + denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default" |
| 306 | + ) |
| 307 | + return exclude_modules |
| 308 | + |
294 | 309 | def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
|
295 | 310 | if text_lora_config is not None:
|
296 | 311 | if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
@@ -2326,6 +2341,58 @@ def test_lora_unload_add_adapter(self):
|
2326 | 2341 | )
|
2327 | 2342 | _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
2328 | 2343 |
|
| 2344 | + @require_peft_version_greater("0.13.2") |
| 2345 | + def test_lora_exclude_modules(self): |
| 2346 | + """ |
| 2347 | + Test to check if `exclude_modules` works or not. It works in the following way: |
| 2348 | + we first create a pipeline and insert LoRA config into it. We then derive a `set` |
| 2349 | + of modules to exclude by investigating its denoiser state dict and denoiser LoRA |
| 2350 | + state dict. |
| 2351 | +
|
| 2352 | + We then create a new LoRA config to include the `exclude_modules` and perform tests. |
| 2353 | + """ |
| 2354 | + scheduler_cls = self.scheduler_classes[0] |
| 2355 | + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) |
| 2356 | + pipe = self.pipeline_class(**components).to(torch_device) |
| 2357 | + _, _, inputs = self.get_dummy_inputs(with_generator=False) |
| 2358 | + |
| 2359 | + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 2360 | + self.assertTrue(output_no_lora.shape == self.output_shape) |
| 2361 | + |
| 2362 | + # only supported for `denoiser` now |
| 2363 | + pipe_cp = copy.deepcopy(pipe) |
| 2364 | + pipe_cp, _ = self.add_adapters_to_pipeline( |
| 2365 | + pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config |
| 2366 | + ) |
| 2367 | + denoiser_exclude_modules = self._get_exclude_modules(pipe_cp) |
| 2368 | + pipe_cp.to("cpu") |
| 2369 | + del pipe_cp |
| 2370 | + |
| 2371 | + denoiser_lora_config.exclude_modules = denoiser_exclude_modules |
| 2372 | + pipe, _ = self.add_adapters_to_pipeline( |
| 2373 | + pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config |
| 2374 | + ) |
| 2375 | + output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 2376 | + |
| 2377 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 2378 | + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) |
| 2379 | + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) |
| 2380 | + lora_metadatas = self._get_lora_adapter_metadata(modules_to_save) |
| 2381 | + self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas) |
| 2382 | + pipe.unload_lora_weights() |
| 2383 | + pipe.load_lora_weights(tmpdir) |
| 2384 | + |
| 2385 | + output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 2386 | + |
| 2387 | + self.assertTrue( |
| 2388 | + not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), |
| 2389 | + "LoRA should change outputs.", |
| 2390 | + ) |
| 2391 | + self.assertTrue( |
| 2392 | + np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), |
| 2393 | + "Lora outputs should match.", |
| 2394 | + ) |
| 2395 | + |
2329 | 2396 | def test_inference_load_delete_load_adapters(self):
|
2330 | 2397 | "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
|
2331 | 2398 | for scheduler_cls in self.scheduler_classes:
|
|
0 commit comments