Skip to content

Commit 6964a6b

Browse files
committed
Add --lora_alpha and metadata handling to train_dreambooth_lora_sd3.py
1 parent 6760300 commit 6964a6b

File tree

2 files changed

+56
-1
lines changed

2 files changed

+56
-1
lines changed

examples/dreambooth/test_dreambooth_lora_sd3.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import json
1617
import logging
1718
import os
1819
import sys
1920
import tempfile
2021

2122
import safetensors
2223

24+
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
2325

2426
sys.path.append("..")
2527
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
@@ -207,6 +209,46 @@ def test_dreambooth_lora_layer(self):
207209
starts_with_transformer = all("attn.to_k" in key for key in lora_state_dict.keys())
208210
self.assertTrue(starts_with_transformer)
209211

212+
def test_dreambooth_lora_sd3_with_metadata(self):
213+
lora_alpha = 8
214+
rank = 4
215+
with tempfile.TemporaryDirectory() as tmpdir:
216+
test_args = f"""
217+
{self.script_path}
218+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
219+
--instance_data_dir={self.instance_data_dir}
220+
--output_dir={tmpdir}
221+
--resolution=32
222+
--train_batch_size=1
223+
--gradient_accumulation_steps=1
224+
--max_train_steps=4
225+
--lora_alpha={lora_alpha}
226+
--rank={rank}
227+
--checkpointing_steps=2
228+
--max_sequence_length 166
229+
""".split()
230+
231+
test_args.extend(["--instance_prompt", ""])
232+
run_command(self._launch_args + test_args)
233+
234+
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
235+
self.assertTrue(os.path.isfile(state_dict_file))
236+
237+
# Check if the metadata was properly serialized.
238+
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
239+
metadata = f.metadata() or {}
240+
241+
metadata.pop("format", None)
242+
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
243+
if raw:
244+
raw = json.loads(raw)
245+
246+
loaded_lora_alpha = raw["transformer.lora_alpha"]
247+
self.assertTrue(loaded_lora_alpha == lora_alpha)
248+
loaded_lora_rank = raw["transformer.r"]
249+
self.assertTrue(loaded_lora_rank == rank)
250+
251+
210252
def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self):
211253
with tempfile.TemporaryDirectory() as tmpdir:
212254
test_args = f"""

examples/dreambooth/train_dreambooth_lora_sd3.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
)
5454
from diffusers.optimization import get_scheduler
5555
from diffusers.training_utils import (
56+
_collate_lora_metadata,
5657
_set_state_dict_into_text_encoder,
5758
cast_training_params,
5859
compute_density_for_timestep_sampling,
@@ -321,6 +322,12 @@ def parse_args(input_args=None):
321322
required=False,
322323
help="A folder containing the training data of class images.",
323324
)
325+
parser.add_argument(
326+
"--lora_alpha",
327+
type=int,
328+
default=4,
329+
help="LoRA alpha to be used for additional scaling.",
330+
)
324331
parser.add_argument(
325332
"--instance_prompt",
326333
type=str,
@@ -1266,7 +1273,7 @@ def main(args):
12661273
# now we will add new LoRA weights to the attention layers
12671274
transformer_lora_config = LoraConfig(
12681275
r=args.rank,
1269-
lora_alpha=args.rank,
1276+
lora_alpha=args.lora_alpha,
12701277
lora_dropout=args.lora_dropout,
12711278
init_lora_weights="gaussian",
12721279
target_modules=target_modules,
@@ -1295,13 +1302,15 @@ def save_model_hook(models, weights, output_dir):
12951302
transformer_lora_layers_to_save = None
12961303
text_encoder_one_lora_layers_to_save = None
12971304
text_encoder_two_lora_layers_to_save = None
1305+
modules_to_save = {}
12981306

12991307
for model in models:
13001308
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
13011309
model = unwrap_model(model)
13021310
if args.upcast_before_saving:
13031311
model = model.to(torch.float32)
13041312
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
1313+
modules_to_save["transformer"] = model
13051314
elif args.train_text_encoder and isinstance(
13061315
unwrap_model(model), type(unwrap_model(text_encoder_one))
13071316
): # or text_encoder_two
@@ -1324,6 +1333,7 @@ def save_model_hook(models, weights, output_dir):
13241333
transformer_lora_layers=transformer_lora_layers_to_save,
13251334
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
13261335
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
1336+
**_collate_lora_metadata(modules_to_save),
13271337
)
13281338

13291339
def load_model_hook(models, input_dir):
@@ -1925,10 +1935,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
19251935
accelerator.wait_for_everyone()
19261936
if accelerator.is_main_process:
19271937
transformer = unwrap_model(transformer)
1938+
modules_to_save = {}
19281939
if args.upcast_before_saving:
19291940
transformer.to(torch.float32)
19301941
else:
19311942
transformer = transformer.to(weight_dtype)
1943+
modules_to_save["transformer"] = transformer
19321944
transformer_lora_layers = get_peft_model_state_dict(transformer)
19331945

19341946
if args.train_text_encoder:
@@ -1945,6 +1957,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
19451957
transformer_lora_layers=transformer_lora_layers,
19461958
text_encoder_lora_layers=text_encoder_lora_layers,
19471959
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
1960+
**_collate_lora_metadata(modules_to_save),
19481961
)
19491962

19501963
# Final inference

0 commit comments

Comments
 (0)