16
16
17
17
import os
18
18
import shutil
19
+ import copy
19
20
import platform
20
21
import paddle
21
22
import paddle .distributed as dist
38
39
from ppcls .utils .ema import ExponentialMovingAverage
39
40
from ppcls .utils .save_load import load_dygraph_pretrain
40
41
from ppcls .utils .save_load import init_model
42
+ from ppcls .utils .save_result import update_train_results
41
43
from ppcls .utils import save_load , save_predict_result
42
44
43
45
from ppcls .data .utils .get_image_list import get_image_list
@@ -169,8 +171,8 @@ def __init__(self, config, mode="train"):
169
171
self .config ["DataLoader" ]["Eval" ], "Gallery" ,
170
172
self .device , self .use_dali )
171
173
self .query_dataloader = build_dataloader (
172
- self .config ["DataLoader" ]["Eval" ], "Query" ,
173
- self .device , self . use_dali )
174
+ self .config ["DataLoader" ]["Eval" ], "Query" , self . device ,
175
+ self .use_dali )
174
176
175
177
# build loss
176
178
if self .mode == "train" :
@@ -210,8 +212,8 @@ def __init__(self, config, mode="train"):
210
212
self .config ["Global" ]["eval_during_train" ]):
211
213
if self .eval_mode == "classification" :
212
214
if "Metric" in self .config and "Eval" in self .config ["Metric" ]:
213
- self .eval_metric_func = build_metrics (self .config ["Metric" ]
214
- [ "Eval" ])
215
+ self .eval_metric_func = build_metrics (self .config ["Metric" ][
216
+ "Eval" ])
215
217
else :
216
218
self .eval_metric_func = None
217
219
elif self .eval_mode == "retrieval" :
@@ -266,8 +268,7 @@ def __init__(self, config, mode="train"):
266
268
self .model = paddle .DataParallel (self .model )
267
269
if self .mode == 'train' and len (self .train_loss_func .parameters (
268
270
)) > 0 :
269
- self .train_loss_func = paddle .DataParallel (
270
- self .train_loss_func )
271
+ self .train_loss_func = paddle .DataParallel (self .train_loss_func )
271
272
272
273
# set different seed in different GPU manually in distributed environment
273
274
if seed is None :
@@ -313,6 +314,8 @@ def train(self):
313
314
}
314
315
# global iter counter
315
316
self .global_step = 0
317
+ uniform_output_enabled = self .config ['Global' ].get (
318
+ "uniform_output_enabled" , False )
316
319
317
320
if self .config .Global .checkpoints is not None :
318
321
metric_info = init_model (self .config .Global , self .model ,
@@ -384,41 +387,89 @@ def train(self):
384
387
# save best model from best_acc or best_ema_acc
385
388
if max (acc , acc_ema ) >= max (best_metric ["metric" ],
386
389
best_metric_ema ):
390
+ metric_info = {
391
+ "metric" : max (acc , acc_ema ),
392
+ "epoch" : epoch_id
393
+ }
394
+ prefix = "best_model"
387
395
save_load .save_model (
388
396
self .model ,
389
397
self .optimizer ,
390
- { "metric" : max ( acc , acc_ema ) ,
391
- "epoch" : epoch_id },
392
- self .output_dir ,
398
+ metric_info ,
399
+ os . path . join ( self . output_dir , prefix )
400
+ if uniform_output_enabled else self .output_dir ,
393
401
ema = ema_module ,
394
402
model_name = self .config ["Arch" ]["name" ],
395
- prefix = "best_model" ,
403
+ prefix = prefix ,
396
404
loss = self .train_loss_func ,
397
405
save_student_model = True )
406
+ if uniform_output_enabled :
407
+ save_path = os .path .join (self .output_dir , prefix ,
408
+ "inference" )
409
+ self .export (save_path , uniform_output_enabled )
410
+ if self .ema :
411
+ ema_save_path = os .path .join (
412
+ self .output_dir , prefix , "inference_ema" )
413
+ self .export (ema_save_path , uniform_output_enabled )
414
+ update_train_results (
415
+ self .config , prefix , metric_info , ema = self .ema )
416
+ save_load .save_model_info (metric_info , self .output_dir ,
417
+ prefix )
398
418
399
419
self .model .train ()
400
420
401
421
# save model
402
422
if save_interval > 0 and epoch_id % save_interval == 0 :
423
+ metric_info = {"metric" : acc , "epoch" : epoch_id }
424
+ prefix = "epoch_{}" .format (epoch_id )
403
425
save_load .save_model (
404
426
self .model ,
405
- self .optimizer , {"metric" : acc ,
406
- "epoch" : epoch_id },
407
- self .output_dir ,
427
+ self .optimizer ,
428
+ metric_info ,
429
+ os .path .join (self .output_dir , prefix )
430
+ if uniform_output_enabled else self .output_dir ,
408
431
ema = ema_module ,
409
432
model_name = self .config ["Arch" ]["name" ],
410
- prefix = "epoch_{}" . format ( epoch_id ) ,
433
+ prefix = prefix ,
411
434
loss = self .train_loss_func )
435
+ if uniform_output_enabled :
436
+ save_path = os .path .join (self .output_dir , prefix ,
437
+ "inference" )
438
+ self .export (save_path , uniform_output_enabled )
439
+ if self .ema :
440
+ ema_save_path = os .path .join (self .output_dir , prefix ,
441
+ "inference_ema" )
442
+ self .export (ema_save_path , uniform_output_enabled )
443
+ update_train_results (
444
+ self .config ,
445
+ prefix ,
446
+ metric_info ,
447
+ done_flag = epoch_id == self .config ["Global" ]["epochs" ],
448
+ ema = self .ema )
449
+ save_load .save_model_info (metric_info , self .output_dir ,
450
+ prefix )
412
451
# save the latest model
452
+ metric_info = {"metric" : acc , "epoch" : epoch_id }
453
+ prefix = "latest"
413
454
save_load .save_model (
414
455
self .model ,
415
- self .optimizer , {"metric" : acc ,
416
- "epoch" : epoch_id },
417
- self .output_dir ,
456
+ self .optimizer ,
457
+ metric_info ,
458
+ os .path .join (self .output_dir , prefix )
459
+ if uniform_output_enabled else self .output_dir ,
418
460
ema = ema_module ,
419
461
model_name = self .config ["Arch" ]["name" ],
420
- prefix = "latest" ,
462
+ prefix = prefix ,
421
463
loss = self .train_loss_func )
464
+ if uniform_output_enabled :
465
+ save_path = os .path .join (self .output_dir , prefix , "inference" )
466
+ self .export (save_path , uniform_output_enabled )
467
+ if self .ema :
468
+ ema_save_path = os .path .join (self .output_dir , prefix ,
469
+ "inference_ema" )
470
+ self .export (ema_save_path , uniform_output_enabled )
471
+ save_load .save_model_info (metric_info , self .output_dir , prefix )
472
+ self .model .train ()
422
473
423
474
if self .vdl_writer is not None :
424
475
self .vdl_writer .close ()
@@ -479,33 +530,45 @@ def infer(self):
479
530
image_file_list .clear ()
480
531
except Exception as ex :
481
532
logger .error (
482
- "Exception occured when parse line: {} with msg: {}" .
483
- format ( image_file , ex ))
533
+ "Exception occured when parse line: {} with msg: {}" .format (
534
+ image_file , ex ))
484
535
continue
485
536
if save_path :
486
537
save_predict_result (save_path , results )
487
538
return results
488
539
489
- def export (self ):
490
- assert self .mode == "export"
540
+ def export (self ,
541
+ save_path = None ,
542
+ uniform_output_enabled = False ,
543
+ ema_module = None ):
544
+ assert self .mode == "export" or uniform_output_enabled
545
+ if paddle .distributed .get_rank () != 0 :
546
+ return
491
547
use_multilabel = self .config ["Global" ].get (
492
548
"use_multilabel" ,
493
549
False ) or "ATTRMetric" in self .config ["Metric" ]["Eval" ][0 ]
494
- model = ExportModel (self .config ["Arch" ], self .model , use_multilabel )
495
- if self .config ["Global" ]["pretrained_model" ] is not None :
550
+ model = self .model_ema .module if self .ema else self .model
551
+ if isinstance (self .model , paddle .DataParallel ):
552
+ model = copy .deepcopy (model ._layers )
553
+ else :
554
+ model = copy .deepcopy (model )
555
+ model = ExportModel (self .config ["Arch" ], model
556
+ if not ema_module else ema_module , use_multilabel )
557
+ if self .config ["Global" ][
558
+ "pretrained_model" ] is not None and not uniform_output_enabled :
496
559
load_dygraph_pretrain (model .base_model ,
497
560
self .config ["Global" ]["pretrained_model" ])
498
-
499
561
model .eval ()
500
-
501
562
# for re-parameterization nets
502
- for layer in self . model .sublayers ():
563
+ for layer in model .sublayers ():
503
564
if hasattr (layer , "re_parameterize" ) and not getattr (layer ,
504
565
"is_repped" ):
505
566
layer .re_parameterize ()
506
-
507
- save_path = os .path .join (self .config ["Global" ]["save_inference_dir" ],
508
- "inference" )
567
+ if not save_path :
568
+ save_path = os .path .join (
569
+ self .config ["Global" ]["save_inference_dir" ], "inference" )
570
+ else :
571
+ save_path = os .path .join (save_path , "inference" )
509
572
510
573
model = paddle .jit .to_static (
511
574
model ,
@@ -520,12 +583,12 @@ def export(self):
520
583
save_path + "_int8" )
521
584
else :
522
585
paddle .jit .save (model , save_path )
523
- if self .config ["Global" ].get ("export_for_fd" , False ):
524
- dst_path = os . path . join (
525
- self . config [ "Global" ][ "save_inference_dir" ] , 'inference.yml' )
586
+ if self .config ["Global" ].get ("export_for_fd" ,
587
+ False ) or uniform_output_enabled :
588
+ dst_path = os . path . join ( os . path . dirname ( save_path ) , 'inference.yml' )
526
589
dump_infer_config (self .config , dst_path )
527
590
logger .info (
528
- f"Export succeeded! The inference model exported has been saved in \" { self . config [ 'Global' ][ 'save_inference_dir' ] } \" ."
591
+ f"Export succeeded! The inference model exported has been saved in \" { save_path } \" ."
529
592
)
530
593
531
594
def _init_amp (self ):
0 commit comments