-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Add optimizer save and load #16986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add optimizer save and load #16986
Changes from 23 commits
222700d
31fa880
8f3b7a3
f97cd9d
dda6302
9d3cd10
b887881
b5f971a
a8e8325
bab9583
495b3bc
3a1a16c
d69205c
a4f3fee
6c34937
4700a29
e5ecc65
6c89fd1
1cf0bde
eb40601
35bb46e
c19458a
335540b
53042d9
84d4322
f91f443
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,13 +16,18 @@ | |
|
||
import os | ||
import collections | ||
from .. import core | ||
from ..framework import Variable, default_main_program | ||
import pickle | ||
from . import learning_rate_scheduler | ||
import warnings | ||
|
||
__all__ = ['save_persistables', 'load_persistables'] | ||
|
||
|
||
def save_persistables(vardict, dirname, filename=None): | ||
def save_persistables(model_dict, | ||
optimizer=None, | ||
dirname='save_dir', | ||
filename=None): | ||
""" | ||
This function filters out all variables in layer.parameters from the | ||
give `layer` and then trys to load these variables from the folder | ||
|
@@ -34,12 +39,12 @@ def save_persistables(vardict, dirname, filename=None): | |
the file name. | ||
|
||
Args: | ||
vardict(dict of Parameters): The parameters will | ||
model_dict(dict of Parameters): The parameters will | ||
be saved. If it is None, nothing | ||
will be deal. | ||
dirname(str): The directory path. | ||
filename(str|None): The file which saved all variables. If variables were | ||
saved in differnet files, set it to None. | ||
saved in different files, set it to None. | ||
Default: None | ||
|
||
Returns: | ||
|
@@ -71,11 +76,11 @@ def save_persistables(vardict, dirname, filename=None): | |
fluid.dygraph.save_persistables(ptb_model.state_dict(), dirname=param_path, | ||
layer=ptb_model) | ||
""" | ||
if isinstance(vardict, collections.OrderedDict): | ||
_save_var_to_file(vardict, dirname, filename) | ||
if isinstance(model_dict, collections.OrderedDict): | ||
_save_var_to_file(model_dict, optimizer, dirname, filename) | ||
|
||
|
||
def load_persistables(dirname): | ||
def load_persistables(dirname='save_dir'): | ||
""" | ||
This function trys to load persistable variables from the folder | ||
`dirname` or the file `filename`. | ||
|
@@ -86,7 +91,8 @@ def load_persistables(dirname): | |
the file name. | ||
|
||
Args: | ||
dirname(str): The directory path. | ||
dirname(str): The directory path. default is save_dir | ||
optimizer(Optimizer): Optimizer to be save | ||
|
||
Returns: | ||
dict: The parameter-dict resumed from file | ||
|
@@ -103,7 +109,7 @@ def load_persistables(dirname): | |
return _load_var_from_file(dirname) | ||
|
||
|
||
def _save_var_to_file(stat_dict, file_dir, file_name): | ||
def _save_var_to_file(stat_dict, optimizers, file_dir, file_name): | ||
save_block = default_main_program().global_block() | ||
save_var_map = {} | ||
for var_key, each_var in stat_dict.items(): | ||
|
@@ -117,6 +123,32 @@ def _save_var_to_file(stat_dict, file_dir, file_name): | |
'file_path': os.path.join(file_dir, | ||
os.path.normpath(each_var.name)) | ||
}) | ||
if isinstance(optimizers, (list, tuple)): | ||
optimizers = optimizers | ||
else: | ||
optimizers = [optimizers] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use list(obj) instead? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the benefit for this? |
||
if os.path.exists(os.path.join(file_dir, os.path.normpath("optimizers"))): | ||
pass | ||
else: | ||
os.mkdir(os.path.join(file_dir, os.path.normpath("optimizers"))) | ||
for optimizer in optimizers: | ||
if isinstance(optimizer._learning_rate, | ||
learning_rate_scheduler.LearningRateDecay): | ||
try: | ||
f = open( | ||
os.path.join(file_dir, "optimizers", | ||
os.path.normpath(str(optimizer._name))), "wb") | ||
pickle.dump(optimizer._learning_rate, f, 2) | ||
f.close() | ||
except (): | ||
raise IOError("Can't load %s", | ||
os.path.join( | ||
file_dir, "optimizers", | ||
os.path.normpath(str(optimizer._name)))) | ||
else: | ||
warnings.warn( | ||
"Optimizer not saved, Only optimizer with 'LearningRateDecay' under DyGraph mode need to be saved" | ||
) | ||
|
||
if file_name is not None: | ||
save_var_list = [] | ||
|
@@ -138,6 +170,8 @@ def walk_filename(file_dir): | |
var_name_list = [] | ||
if os.path.exists(base_path): | ||
for dirpath, dirnames, filenames in os.walk(base_path): | ||
if "optimizers" in dirpath: | ||
continue | ||
pt = dirpath.replace(base_path, "", 1) | ||
if pt.startswith("/") or pt.startswith("\\"): | ||
pt = pt[1:] | ||
|
@@ -152,6 +186,7 @@ def walk_filename(file_dir): | |
|
||
load_block = default_main_program().global_block() | ||
load_var_map = {} | ||
load_optimizer_map = {} | ||
file_var_list = walk_filename(file_dir) | ||
for var_name in file_var_list: | ||
new_var = Variable(block=load_block, name=var_name) | ||
|
@@ -165,8 +200,22 @@ def walk_filename(file_dir): | |
}) | ||
|
||
load_var_map[new_var.name] = new_var | ||
|
||
return load_var_map | ||
opt_path = os.path.join(file_dir, "optimizers") | ||
for _, _, optimizers in os.walk(opt_path): | ||
for optimizer in optimizers: | ||
try: | ||
f = open(os.path.join(opt_path, optimizer), "rb") | ||
load_optimizer_map[optimizer] = pickle.load(f) | ||
f.close() | ||
except IOError: | ||
raise IOError("Can't load %s", | ||
os.path.join( | ||
file_dir, "optimizers", | ||
os.path.normpath(str(optimizer._name)))) | ||
if len(load_optimizer_map) == 0: | ||
warnings.warn("No optimizer loaded") | ||
|
||
return load_var_map, load_optimizer_map | ||
|
||
|
||
def _clone_var_in_block_(block, var): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
|
@@ -16,25 +16,25 @@ | |
|
||
import numpy as np | ||
from collections import defaultdict | ||
from functools import reduce | ||
|
||
from paddle.fluid import core | ||
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table | ||
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program | ||
from paddle.fluid.layers import tensor | ||
|
||
from . import framework | ||
from . import layers | ||
from . import unique_name | ||
from .backward import append_backward | ||
from .clip import append_gradient_clip_ops, error_clip_callback | ||
from .dygraph import base as imperative_base | ||
from .dygraph.learning_rate_scheduler import LearningRateDecay | ||
from .framework import program_guard | ||
from .initializer import Constant | ||
from .layer_helper import LayerHelper | ||
from .layers import ops | ||
from .regularizer import append_regularization_ops | ||
from .dygraph import base as imperative_base | ||
from .dygraph.learning_rate_scheduler import LearningRateDecay | ||
from paddle.fluid import core | ||
from paddle.fluid.layers import tensor | ||
from functools import reduce | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why import those? |
||
from .wrapped_decorator import signature_safe_contextmanager | ||
|
||
__all__ = [ | ||
|
@@ -63,14 +63,18 @@ def __init__(self, learning_rate, regularization=None, name=None): | |
raise TypeError( | ||
"learning rate should be float or LearningRateDecay, got %s here" | ||
% type(learning_rate)) | ||
if name is not None: | ||
self._name = unique_name.generate(name) | ||
else: | ||
self._name = unique_name.generate(self.__class__.__name__) | ||
else: | ||
if not isinstance(learning_rate, float) and \ | ||
not isinstance(learning_rate, framework.Variable): | ||
raise TypeError( | ||
"learning rate should be float or Variable, got %s here" % | ||
type(learning_rate)) | ||
self._name = name | ||
|
||
self._name = name | ||
self.regularization = regularization | ||
self._learning_rate = learning_rate | ||
# the learning rate type should be inferenced from loss | ||
|
@@ -89,6 +93,19 @@ def __init__(self, learning_rate, regularization=None, name=None): | |
self.helper = None | ||
self._opti_name_list = [] | ||
|
||
def load(self, stat_dict): | ||
""" | ||
load optimizer with learning rate decay in dygraph mode | ||
:return: None | ||
|
||
Args: | ||
stat_dict: the dict load by load_persistable method | ||
""" | ||
if framework.in_dygraph_mode(): | ||
self._learning_rate = stat_dict[self._name] | ||
else: | ||
raise TypeError("load_dict can only be used under DyGraph mode") | ||
|
||
def get_opti_var_name_list(self): | ||
return self._opti_name_list | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to be saved?