diff --git a/paddleseg/core/export.py b/paddleseg/core/export.py index 839f22980..f2f18a61b 100644 --- a/paddleseg/core/export.py +++ b/paddleseg/core/export.py @@ -54,15 +54,15 @@ def export(args, model=None, save_dir=None, use_ema=False): input_spec = [paddle.static.InputSpec(shape=shape, dtype='float32')] model.eval() model = paddle.jit.to_static(model, input_spec=input_spec) - uniform_output_enabled = cfg.dic.get('uniform_output_enabled', False) - if args.for_fd or uniform_output_enabled: + export_during_train = cfg.dic.get('export_during_train', False) + if args.for_fd or export_during_train: save_name = 'inference' yaml_name = 'inference.yml' else: save_name = 'model' yaml_name = 'deploy.yaml' - if uniform_output_enabled: + if export_during_train: inference_model_path = os.path.join(save_dir, "inference", save_name) yml_file = os.path.join(save_dir, "inference", yaml_name) if use_ema: diff --git a/paddleseg/core/train.py b/paddleseg/core/train.py index 4c35a0e78..ca608c827 100644 --- a/paddleseg/core/train.py +++ b/paddleseg/core/train.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import gc import time import yaml import json @@ -124,6 +125,7 @@ def train(model, param.stop_gradient = True uniform_output_enabled = kwargs.pop("uniform_output_enabled", False) + export_during_train = kwargs.pop("export_during_train", False) cli_args = kwargs.pop("cli_args", None) model.train() nranks = paddle.distributed.ParallelEnv().nranks @@ -365,15 +367,17 @@ def train(model, os.path.join(current_save_dir, 'model.pdparams')) paddle.save(optimizer.state_dict(), os.path.join(current_save_dir, 'model.pdopt')) - if uniform_output_enabled: + if export_during_train: export(cli_args, model, current_save_dir) + gc.collect() if use_ema: paddle.save( ema_model.state_dict(), os.path.join(current_save_dir, 'ema_model.pdparams')) - if uniform_output_enabled: + if export_during_train: export(cli_args, ema_model, current_save_dir, use_ema) + gc.collect() save_models.append(current_save_dir) if len(save_models) > keep_checkpoint_max > 0: @@ -403,8 +407,10 @@ def train(model, paddle.save( states_dict, os.path.join(best_model_dir, 'model.pdstates')) - if uniform_output_enabled: + if export_during_train: export(cli_args, model, best_model_dir) + gc.collect() + if uniform_output_enabled: save_model_info(states_dict, best_model_dir) update_train_results(cli_args, "best_model", @@ -447,9 +453,11 @@ def train(model, ema_states_dict, os.path.join(best_ema_model_dir, 'ema_model.pdstates')) - if uniform_output_enabled: + if export_during_train: export(cli_args, ema_model, best_ema_model_dir, use_ema) + gc.collect() + if uniform_output_enabled: save_model_info(ema_states_dict, best_ema_model_dir) update_train_results(cli_args, diff --git a/tools/train.py b/tools/train.py index b0f26840a..d5b67674b 100644 --- a/tools/train.py +++ b/tools/train.py @@ -180,6 +180,7 @@ def main(args): utils.set_device(args.device) utils.set_cv2_num_threads(args.num_workers) uniform_output_enabled = cfg.dic.get("uniform_output_enabled", False) + export_during_train = cfg.dic.get("export_during_train", False) if uniform_output_enabled: if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) @@ -244,7 +245,9 @@ def main(args): print_mem_info=print_mem_info, shuffle=shuffle, uniform_output_enabled=uniform_output_enabled, - cli_args=None if not uniform_output_enabled else args) + export_during_train=export_during_train, + cli_args=args + if export_during_train or uniform_output_enabled else None) if __name__ == '__main__':