53
53
)
54
54
from diffusers .optimization import get_scheduler
55
55
from diffusers .training_utils import (
56
+ _collate_lora_metadata ,
56
57
_set_state_dict_into_text_encoder ,
57
58
cast_training_params ,
58
59
compute_density_for_timestep_sampling ,
@@ -321,6 +322,12 @@ def parse_args(input_args=None):
321
322
required = False ,
322
323
help = "A folder containing the training data of class images." ,
323
324
)
325
+ parser .add_argument (
326
+ "--lora_alpha" ,
327
+ type = int ,
328
+ default = 4 ,
329
+ help = "LoRA alpha to be used for additional scaling." ,
330
+ )
324
331
parser .add_argument (
325
332
"--instance_prompt" ,
326
333
type = str ,
@@ -1266,7 +1273,7 @@ def main(args):
1266
1273
# now we will add new LoRA weights to the attention layers
1267
1274
transformer_lora_config = LoraConfig (
1268
1275
r = args .rank ,
1269
- lora_alpha = args .rank ,
1276
+ lora_alpha = args .lora_alpha ,
1270
1277
lora_dropout = args .lora_dropout ,
1271
1278
init_lora_weights = "gaussian" ,
1272
1279
target_modules = target_modules ,
@@ -1295,13 +1302,15 @@ def save_model_hook(models, weights, output_dir):
1295
1302
transformer_lora_layers_to_save = None
1296
1303
text_encoder_one_lora_layers_to_save = None
1297
1304
text_encoder_two_lora_layers_to_save = None
1305
+ modules_to_save = {}
1298
1306
1299
1307
for model in models :
1300
1308
if isinstance (unwrap_model (model ), type (unwrap_model (transformer ))):
1301
1309
model = unwrap_model (model )
1302
1310
if args .upcast_before_saving :
1303
1311
model = model .to (torch .float32 )
1304
1312
transformer_lora_layers_to_save = get_peft_model_state_dict (model )
1313
+ modules_to_save ["transformer" ] = model
1305
1314
elif args .train_text_encoder and isinstance (
1306
1315
unwrap_model (model ), type (unwrap_model (text_encoder_one ))
1307
1316
): # or text_encoder_two
@@ -1324,6 +1333,7 @@ def save_model_hook(models, weights, output_dir):
1324
1333
transformer_lora_layers = transformer_lora_layers_to_save ,
1325
1334
text_encoder_lora_layers = text_encoder_one_lora_layers_to_save ,
1326
1335
text_encoder_2_lora_layers = text_encoder_two_lora_layers_to_save ,
1336
+ ** _collate_lora_metadata (modules_to_save ),
1327
1337
)
1328
1338
1329
1339
def load_model_hook (models , input_dir ):
@@ -1925,10 +1935,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1925
1935
accelerator .wait_for_everyone ()
1926
1936
if accelerator .is_main_process :
1927
1937
transformer = unwrap_model (transformer )
1938
+ modules_to_save = {}
1928
1939
if args .upcast_before_saving :
1929
1940
transformer .to (torch .float32 )
1930
1941
else :
1931
1942
transformer = transformer .to (weight_dtype )
1943
+ modules_to_save ["transformer" ] = transformer
1932
1944
transformer_lora_layers = get_peft_model_state_dict (transformer )
1933
1945
1934
1946
if args .train_text_encoder :
@@ -1945,6 +1957,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1945
1957
transformer_lora_layers = transformer_lora_layers ,
1946
1958
text_encoder_lora_layers = text_encoder_lora_layers ,
1947
1959
text_encoder_2_lora_layers = text_encoder_2_lora_layers ,
1960
+ ** _collate_lora_metadata (modules_to_save ),
1948
1961
)
1949
1962
1950
1963
# Final inference
0 commit comments