|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | import os
|
| 16 | +import gc |
16 | 17 | import time
|
17 | 18 | import yaml
|
18 | 19 | import json
|
@@ -367,13 +368,15 @@ def train(model,
|
367 | 368 | os.path.join(current_save_dir, 'model.pdopt'))
|
368 | 369 | if uniform_output_enabled:
|
369 | 370 | export(cli_args, model, current_save_dir)
|
| 371 | + gc.collect() |
370 | 372 |
|
371 | 373 | if use_ema:
|
372 | 374 | paddle.save(
|
373 | 375 | ema_model.state_dict(),
|
374 | 376 | os.path.join(current_save_dir, 'ema_model.pdparams'))
|
375 | 377 | if uniform_output_enabled:
|
376 | 378 | export(cli_args, ema_model, current_save_dir, use_ema)
|
| 379 | + gc.collect() |
377 | 380 |
|
378 | 381 | save_models.append(current_save_dir)
|
379 | 382 | if len(save_models) > keep_checkpoint_max > 0:
|
@@ -405,6 +408,7 @@ def train(model,
|
405 | 408 | os.path.join(best_model_dir, 'model.pdstates'))
|
406 | 409 | if uniform_output_enabled:
|
407 | 410 | export(cli_args, model, best_model_dir)
|
| 411 | + gc.collect() |
408 | 412 | save_model_info(states_dict, best_model_dir)
|
409 | 413 | update_train_results(cli_args,
|
410 | 414 | "best_model",
|
@@ -450,6 +454,7 @@ def train(model,
|
450 | 454 | if uniform_output_enabled:
|
451 | 455 | export(cli_args, ema_model, best_ema_model_dir,
|
452 | 456 | use_ema)
|
| 457 | + gc.collect() |
453 | 458 | save_model_info(ema_states_dict,
|
454 | 459 | best_ema_model_dir)
|
455 | 460 | update_train_results(cli_args,
|
|
0 commit comments