@@ -345,15 +345,29 @@ def train(model,
345
345
"iter_{}" .format (iter ))
346
346
if not os .path .isdir (current_save_dir ):
347
347
os .makedirs (current_save_dir )
348
+ states_dict = {
349
+ 'mIoU' : mean_iou ,
350
+ 'Acc' : acc ,
351
+ 'iter' : iter
352
+ }
348
353
paddle .save (model .state_dict (),
349
354
os .path .join (current_save_dir , 'model.pdparams' ))
350
355
paddle .save (optimizer .state_dict (),
351
356
os .path .join (current_save_dir , 'model.pdopt' ))
357
+ paddle .save (states_dict ,
358
+ os .path .join (current_save_dir , 'model.pdstates' ))
352
359
353
360
if use_ema :
361
+ ema_states_dict = {
362
+ 'mIoU' : ema_mean_iou ,
363
+ 'Acc' : ema_acc ,
364
+ 'iter' : iter
365
+ }
354
366
paddle .save (
355
367
ema_model .state_dict (),
356
368
os .path .join (current_save_dir , 'ema_model.pdparams' ))
369
+ paddle .save (ema_states_dict ,
370
+ os .path .join (current_save_dir , 'ema_model.pdstates' ))
357
371
358
372
save_models .append (current_save_dir )
359
373
if len (save_models ) > keep_checkpoint_max > 0 :
@@ -369,6 +383,8 @@ def train(model,
369
383
paddle .save (
370
384
model .state_dict (),
371
385
os .path .join (best_model_dir , 'model.pdparams' ))
386
+ paddle .save (states_dict ,
387
+ os .path .join (best_model_dir , 'model.pdstates' ))
372
388
elif mean_iou < best_mean_iou :
373
389
stop_count += 1
374
390
@@ -391,6 +407,8 @@ def train(model,
391
407
paddle .save (ema_model .state_dict (),
392
408
os .path .join (best_ema_model_dir ,
393
409
'ema_model.pdparams' ))
410
+ paddle .save (ema_states_dict ,
411
+ os .path .join (best_ema_model_dir , 'ema_model.pdstates' ))
394
412
logger .info (
395
413
'[EVAL] The EMA model with the best validation mIoU ({:.4f}) was saved at iter {}.'
396
414
.format (best_ema_mean_iou , best_ema_model_iter ))
0 commit comments