Skip to content

Commit 022dfed

Browse files
authored
Add optimizer save and load (#16986)
* save optimizer related vars in dygraph * test=develop, add optimizer save and load * test=develop, add optimizer save and load * test=develop, merge code and add multi-optimizer save and load * test=develop, fix test_imperative_checkpoint * test=develop, fix include error * test=develop, fix include error * test=develop, renew api spec * test=develop, refine code * test=develop, set default value for checkpoint * test=develop, fix ci error * test=develop, change API.spec and make api more readable * test=develop, refine version and time stamp * test=develop, add example code and refine code * test=develop, refine doc * test=develop, change version
1 parent 453a49b commit 022dfed

File tree

7 files changed

+372
-23
lines changed

7 files changed

+372
-23
lines changed

paddle/fluid/API.spec

+13
Large diffs are not rendered by default.

python/paddle/fluid/dygraph/base.py

+3
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,6 @@ def to_variable(value, block=None, name=None):
9292
return py_var
9393
elif isinstance(value, framework.Variable):
9494
return value
95+
else:
96+
raise TypeError(
97+
"to_variable only accepts 'ndarray' and 'Variable' as value's input")

python/paddle/fluid/dygraph/checkpoint.py

+61-12
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -16,13 +16,18 @@
1616

1717
import os
1818
import collections
19-
from .. import core
2019
from ..framework import Variable, default_main_program
20+
import pickle
21+
from . import learning_rate_scheduler
22+
import warnings
2123

2224
__all__ = ['save_persistables', 'load_persistables']
2325

2426

25-
def save_persistables(vardict, dirname, filename=None):
27+
def save_persistables(model_dict,
28+
optimizer=None,
29+
dirname='save_dir',
30+
filename=None):
2631
"""
2732
This function filters out all variables in layer.parameters from the
2833
give `layer` and then trys to load these variables from the folder
@@ -34,12 +39,12 @@ def save_persistables(vardict, dirname, filename=None):
3439
the file name.
3540
3641
Args:
37-
vardict(dict of Parameters): The parameters will
42+
model_dict(dict of Parameters): The parameters will
3843
be saved. If it is None, nothing
3944
will be deal.
4045
dirname(str): The directory path.
4146
filename(str|None): The file which saved all variables. If variables were
42-
saved in differnet files, set it to None.
47+
saved in different files, set it to None.
4348
Default: None
4449
4550
Returns:
@@ -71,11 +76,11 @@ def save_persistables(vardict, dirname, filename=None):
7176
fluid.dygraph.save_persistables(ptb_model.state_dict(), dirname=param_path,
7277
layer=ptb_model)
7378
"""
74-
if isinstance(vardict, collections.OrderedDict):
75-
_save_var_to_file(vardict, dirname, filename)
79+
if isinstance(model_dict, collections.OrderedDict):
80+
_save_var_to_file(model_dict, optimizer, dirname, filename)
7681

7782

78-
def load_persistables(dirname):
83+
def load_persistables(dirname='save_dir'):
7984
"""
8085
This function trys to load persistable variables from the folder
8186
`dirname` or the file `filename`.
@@ -86,7 +91,8 @@ def load_persistables(dirname):
8691
the file name.
8792
8893
Args:
89-
dirname(str): The directory path.
94+
dirname(str): The directory path. default is save_dir
95+
optimizer(Optimizer): Optimizer to be saved
9096
9197
Returns:
9298
dict: The parameter-dict resumed from file
@@ -103,7 +109,7 @@ def load_persistables(dirname):
103109
return _load_var_from_file(dirname)
104110

105111

106-
def _save_var_to_file(stat_dict, file_dir, file_name):
112+
def _save_var_to_file(stat_dict, optimizers, file_dir, file_name):
107113
save_block = default_main_program().global_block()
108114
save_var_map = {}
109115
for var_key, each_var in stat_dict.items():
@@ -117,6 +123,32 @@ def _save_var_to_file(stat_dict, file_dir, file_name):
117123
'file_path': os.path.join(file_dir,
118124
os.path.normpath(each_var.name))
119125
})
126+
if isinstance(optimizers, (list, tuple)):
127+
optimizers = optimizers
128+
else:
129+
optimizers = [optimizers]
130+
if os.path.exists(os.path.join(file_dir, os.path.normpath("optimizers"))):
131+
pass
132+
else:
133+
os.mkdir(os.path.join(file_dir, os.path.normpath("optimizers")))
134+
for optimizer in optimizers:
135+
if isinstance(optimizer._learning_rate,
136+
learning_rate_scheduler.LearningRateDecay):
137+
try:
138+
f = open(
139+
os.path.join(file_dir, "optimizers",
140+
os.path.normpath(str(optimizer._name))), "wb")
141+
pickle.dump(optimizer._learning_rate, f, 2)
142+
f.close()
143+
except ():
144+
raise IOError("Can't load %s",
145+
os.path.join(
146+
file_dir, "optimizers",
147+
os.path.normpath(str(optimizer._name))))
148+
else:
149+
warnings.warn(
150+
"Optimizer not saved, Only optimizer with 'LearningRateDecay' under DyGraph mode need to be saved"
151+
)
120152

121153
if file_name is not None:
122154
save_var_list = []
@@ -138,6 +170,8 @@ def walk_filename(file_dir):
138170
var_name_list = []
139171
if os.path.exists(base_path):
140172
for dirpath, dirnames, filenames in os.walk(base_path):
173+
if "optimizers" in dirpath:
174+
continue
141175
pt = dirpath.replace(base_path, "", 1)
142176
if pt.startswith("/") or pt.startswith("\\"):
143177
pt = pt[1:]
@@ -152,6 +186,7 @@ def walk_filename(file_dir):
152186

153187
load_block = default_main_program().global_block()
154188
load_var_map = {}
189+
load_optimizer_map = {}
155190
file_var_list = walk_filename(file_dir)
156191
for var_name in file_var_list:
157192
new_var = Variable(block=load_block, name=var_name)
@@ -165,8 +200,22 @@ def walk_filename(file_dir):
165200
})
166201

167202
load_var_map[new_var.name] = new_var
168-
169-
return load_var_map
203+
opt_path = os.path.join(file_dir, "optimizers")
204+
for _, _, optimizers in os.walk(opt_path):
205+
for optimizer in optimizers:
206+
try:
207+
f = open(os.path.join(opt_path, optimizer), "rb")
208+
load_optimizer_map[optimizer] = pickle.load(f)
209+
f.close()
210+
except IOError:
211+
raise IOError("Can't load %s",
212+
os.path.join(
213+
file_dir, "optimizers",
214+
os.path.normpath(str(optimizer._name))))
215+
if len(load_optimizer_map) == 0:
216+
warnings.warn("No optimizer loaded")
217+
218+
return load_var_map, load_optimizer_map
170219

171220

172221
def _clone_var_in_block_(block, var):

python/paddle/fluid/dygraph/learning_rate_scheduler.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,13 @@ def __init__(self, boundaries, values, begin, step=1, dtype='float32'):
6363

6464
self.vars = []
6565
for value in values:
66-
self.vars.append(self.create_lr_var(value))
66+
self.vars.append(value)
6767

6868
def step(self):
6969
for i in range(len(self.boundaries)):
7070
if self.step_num < self.boundaries[i]:
7171
return self.vars[i]
72-
return self.vars[len(self.values) - 1]
72+
return self.create_lr_var(self.vars[len(self.values) - 1])
7373

7474

7575
class NaturalExpDecay(LearningRateDecay):

python/paddle/fluid/optimizer.py

+95-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -16,25 +16,25 @@
1616

1717
import numpy as np
1818
from collections import defaultdict
19-
from functools import reduce
2019

21-
from paddle.fluid import core
2220
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table
2321
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program
24-
from paddle.fluid.layers import tensor
2522

2623
from . import framework
2724
from . import layers
2825
from . import unique_name
2926
from .backward import append_backward
3027
from .clip import append_gradient_clip_ops, error_clip_callback
31-
from .dygraph import base as imperative_base
32-
from .dygraph.learning_rate_scheduler import LearningRateDecay
3328
from .framework import program_guard
3429
from .initializer import Constant
3530
from .layer_helper import LayerHelper
3631
from .layers import ops
3732
from .regularizer import append_regularization_ops
33+
from .dygraph import base as imperative_base
34+
from .dygraph.learning_rate_scheduler import LearningRateDecay
35+
from paddle.fluid import core
36+
from paddle.fluid.layers import tensor
37+
from functools import reduce
3838
from .wrapped_decorator import signature_safe_contextmanager
3939

4040
__all__ = [
@@ -63,14 +63,18 @@ def __init__(self, learning_rate, regularization=None, name=None):
6363
raise TypeError(
6464
"learning rate should be float or LearningRateDecay, got %s here"
6565
% type(learning_rate))
66+
if name is not None:
67+
self._name = unique_name.generate(name)
68+
else:
69+
self._name = unique_name.generate(self.__class__.__name__)
6670
else:
6771
if not isinstance(learning_rate, float) and \
6872
not isinstance(learning_rate, framework.Variable):
6973
raise TypeError(
7074
"learning rate should be float or Variable, got %s here" %
7175
type(learning_rate))
76+
self._name = name
7277

73-
self._name = name
7478
self.regularization = regularization
7579
self._learning_rate = learning_rate
7680
# the learning rate type should be inferenced from loss
@@ -89,6 +93,90 @@ def __init__(self, learning_rate, regularization=None, name=None):
8993
self.helper = None
9094
self._opti_name_list = []
9195

96+
def load(self, stat_dict):
97+
"""
98+
load optimizer with learning rate decay in dygraph mode
99+
:return: None
100+
101+
Args:
102+
stat_dict: the dict load by load_persistable method
103+
104+
Examples:
105+
106+
.. code-block:: python
107+
108+
from __future__ import print_function
109+
import numpy as np
110+
import paddle
111+
import paddle.fluid as fluid
112+
from paddle.fluid.optimizer import SGDOptimizer
113+
from paddle.fluid.dygraph.nn import FC
114+
from paddle.fluid.dygraph.base import to_variable
115+
116+
class MLP(fluid.Layer):
117+
def __init__(self, name_scope):
118+
super(MLP, self).__init__(name_scope)
119+
120+
self._fc1 = FC(self.full_name(), 10)
121+
self._fc2 = FC(self.full_name(), 10)
122+
123+
def forward(self, inputs):
124+
y = self._fc1(inputs)
125+
y = self._fc2(y)
126+
return y
127+
128+
with fluid.dygraph.guard():
129+
mlp = MLP('mlp')
130+
optimizer2 = SGDOptimizer(
131+
learning_rate=fluid.layers.natural_exp_decay(
132+
learning_rate=0.1,
133+
decay_steps=10000,
134+
decay_rate=0.5,
135+
staircase=True))
136+
137+
train_reader = paddle.batch(
138+
paddle.dataset.mnist.train(), batch_size=128, drop_last=True)
139+
140+
for batch_id, data in enumerate(train_reader()):
141+
dy_x_data = np.array(
142+
[x[0].reshape(1, 28, 28) for x in data]).astype('float32')
143+
144+
y_data = np.array([x[1] for x in data]).astype('int64').reshape(
145+
128, 1)
146+
147+
img = to_variable(dy_x_data)
148+
label = to_variable(y_data)
149+
label._stop_gradient = True
150+
cost = mlp(img)
151+
avg_loss = fluid.layers.reduce_mean(cost)
152+
avg_loss.backward()
153+
optimizer.minimize(avg_loss)
154+
mlp.clear_gradients()
155+
fluid.dygraph.save_persistables(
156+
mlp.state_dict(), [optimizer, optimizer2], "save_dir_2")
157+
if batch_id == 2:
158+
break
159+
160+
with fluid.dygraph.guard():
161+
mlp_load = MLP('mlp')
162+
optimizer_load2 = SGDOptimizer(
163+
learning_rate=fluid.layers.natural_exp_decay(
164+
learning_rate=0.1,
165+
decay_steps=10000,
166+
decay_rate=0.5,
167+
staircase=True))
168+
parameters, optimizers = fluid.dygraph.load_persistables(
169+
"save_dir_2")
170+
mlp_load.load_dict(parameters)
171+
optimizer_load2.load(optimizers)
172+
self.assertTrue(optimizer2._learning_rate.__dict__ == optimizer_load2._learning_rate.__dict__)
173+
174+
"""
175+
if framework.in_dygraph_mode():
176+
self._learning_rate = stat_dict[self._name]
177+
else:
178+
raise TypeError("load can only be used under DyGraph mode")
179+
92180
def get_opti_var_name_list(self):
93181
return self._opti_name_list
94182

python/paddle/fluid/tests/unittests/test_imperative_checkpoint.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,14 @@ def test_save_load_persistables(self):
144144

145145
avg_loss.backward()
146146
sgd.minimize(avg_loss)
147-
fluid.dygraph.save_persistables(mnist.state_dict(),
147+
fluid.dygraph.save_persistables(mnist.state_dict(), [sgd],
148148
"save_dir")
149149
mnist.clear_gradients()
150150

151151
for param in mnist.parameters():
152152
dy_param_init_value[param.name] = param.numpy()
153153

154-
restore = fluid.dygraph.load_persistables("save_dir")
154+
restore, _ = fluid.dygraph.load_persistables("save_dir")
155155
mnist.load_dict(restore)
156156

157157
self.assertEqual(len(dy_param_init_value), len(restore))

0 commit comments

Comments
 (0)