Skip to content

Commit f0abdb3

Browse files
merge uapi paddleseg (#3718)
1 parent 243e816 commit f0abdb3

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

paddleseg/core/train.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,15 +345,29 @@ def train(model,
345345
"iter_{}".format(iter))
346346
if not os.path.isdir(current_save_dir):
347347
os.makedirs(current_save_dir)
348+
states_dict = {
349+
'mIoU': mean_iou,
350+
'Acc': acc,
351+
'iter': iter
352+
}
348353
paddle.save(model.state_dict(),
349354
os.path.join(current_save_dir, 'model.pdparams'))
350355
paddle.save(optimizer.state_dict(),
351356
os.path.join(current_save_dir, 'model.pdopt'))
357+
paddle.save(states_dict,
358+
os.path.join(current_save_dir, 'model.pdstates'))
352359

353360
if use_ema:
361+
ema_states_dict = {
362+
'mIoU': ema_mean_iou,
363+
'Acc': ema_acc,
364+
'iter': iter
365+
}
354366
paddle.save(
355367
ema_model.state_dict(),
356368
os.path.join(current_save_dir, 'ema_model.pdparams'))
369+
paddle.save(ema_states_dict,
370+
os.path.join(current_save_dir, 'ema_model.pdstates'))
357371

358372
save_models.append(current_save_dir)
359373
if len(save_models) > keep_checkpoint_max > 0:
@@ -369,6 +383,8 @@ def train(model,
369383
paddle.save(
370384
model.state_dict(),
371385
os.path.join(best_model_dir, 'model.pdparams'))
386+
paddle.save(states_dict,
387+
os.path.join(best_model_dir, 'model.pdstates'))
372388
elif mean_iou < best_mean_iou:
373389
stop_count += 1
374390

@@ -391,6 +407,8 @@ def train(model,
391407
paddle.save(ema_model.state_dict(),
392408
os.path.join(best_ema_model_dir,
393409
'ema_model.pdparams'))
410+
paddle.save(ema_states_dict,
411+
os.path.join(best_ema_model_dir, 'ema_model.pdstates'))
394412
logger.info(
395413
'[EVAL] The EMA model with the best validation mIoU ({:.4f}) was saved at iter {}.'
396414
.format(best_ema_mean_iou, best_ema_model_iter))

0 commit comments

Comments
 (0)