Skip to content

Commit d72db90

Browse files
committed
Fix dygraph save load problem
test=release/1.4
1 parent f2ae7e3 commit d72db90

File tree

3 files changed

+42
-46
lines changed

3 files changed

+42
-46
lines changed

python/paddle/fluid/dygraph/checkpoint.py

Lines changed: 29 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def save_persistables(vardict, dirname, filename=None):
7575
_save_var_to_file(vardict, dirname, filename)
7676

7777

78-
def load_persistables(vardict, dirname, filename=None):
78+
def load_persistables(dirname):
7979
"""
8080
This function trys to load persistable variables from the folder
8181
`dirname` or the file `filename`.
@@ -86,11 +86,7 @@ def load_persistables(vardict, dirname, filename=None):
8686
the file name.
8787
8888
Args:
89-
vardict(dict of Parameters): The parameters will be loaded.
9089
dirname(str): The directory path.
91-
filename(str|None): The file which saved all variables, this file path should be end with '.npz'. If variables were
92-
saved in differnet files, set it to None.
93-
Default: None
9490
9591
Returns:
9692
dict: The parameter-dict resumed from file
@@ -104,10 +100,7 @@ def load_persistables(vardict, dirname, filename=None):
104100
param_1 = param_dict['PtbModel_0.w_1']
105101
106102
"""
107-
if isinstance(vardict, collections.OrderedDict):
108-
return _load_var_from_file(vardict, dirname, filename)
109-
110-
return {}
103+
return _load_var_from_file(dirname)
111104

112105

113106
def _save_var_to_file(stat_dict, file_dir, file_name):
@@ -139,41 +132,39 @@ def _save_var_to_file(stat_dict, file_dir, file_name):
139132
})
140133

141134

142-
def _load_var_from_file(stat_dict, file_dir, file_name):
135+
def _load_var_from_file(file_dir):
136+
def walk_filename(file_dir):
137+
base_path = os.path.join(file_dir)
138+
var_name_list = []
139+
if os.path.exists(base_path):
140+
for dirpath, dirnames, filenames in os.walk(base_path):
141+
pt = dirpath.replace(base_path, "", 1)
142+
if pt.startswith("/") or pt.startswith("\\"):
143+
pt = pt[1:]
144+
for fth_name in filenames:
145+
if fth_name[0] != '.':
146+
name_path = os.path.join(pt, fth_name)
147+
if "\\" in name_path:
148+
name_path = name_path.replace("\\", "/")
149+
var_name_list.append(name_path)
150+
151+
return var_name_list
152+
143153
load_block = default_main_program().global_block()
144154
load_var_map = {}
145-
146-
for var_key, each_var in stat_dict.items():
147-
assert isinstance(each_var, Variable)
148-
if each_var.type == core.VarDesc.VarType.RAW:
149-
continue
150-
new_var = _clone_var_in_block_(load_block, each_var)
151-
if file_name is None:
152-
load_block.append_op(
153-
type='load',
154-
inputs={},
155-
outputs={'Out': [new_var]},
156-
attrs={
157-
'file_path': os.path.join(file_dir,
158-
os.path.normpath(each_var.name))
159-
})
160-
161-
load_var_map[new_var.name] = new_var
162-
163-
if file_name is not None:
164-
load_var_list = []
165-
for name in sorted(load_var_map.keys()):
166-
load_var_list.append(load_var_map[name])
167-
155+
file_var_list = walk_filename(file_dir)
156+
for var_name in file_var_list:
157+
new_var = Variable(block=load_block, name=var_name)
168158
load_block.append_op(
169-
type='load_combine',
159+
type='load',
170160
inputs={},
171-
outputs={"Out": load_var_list},
161+
outputs={'Out': [new_var]},
172162
attrs={
173-
'file_path': os.path.join(file_dir, os.path.normpath(file_name))
163+
'file_path': os.path.join(file_dir,
164+
os.path.normpath(new_var.name))
174165
})
175-
for res_var in load_var_list:
176-
load_var_map[res_var.name] = res_var
166+
167+
load_var_map[new_var.name] = new_var
177168

178169
return load_var_map
179170

python/paddle/fluid/dygraph/layers.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(self, name_scope, dtype=core.VarDesc.VarType.FP32):
4545
self._dtype = dtype
4646
self._parameters = collections.OrderedDict()
4747
self._sub_layers = collections.OrderedDict()
48+
self._loaddict_holder = collections.OrderedDict()
4849

4950
self._helper = LayerObjectHelper(self._full_name)
5051

@@ -193,6 +194,9 @@ def add_parameter(self, name, parameter):
193194
"""
194195
assert isinstance(parameter, framework.Parameter)
195196
self._parameters[name] = parameter
197+
if parameter.name in self._loaddict_holder:
198+
self._parameters[name] = self._loaddict_holder[parameter.name]
199+
parameter = self._loaddict_holder[parameter.name]
196200
return parameter
197201

198202
def __getattr__(self, name):
@@ -207,7 +211,10 @@ def __setattr__(self, name, value):
207211
if params is None:
208212
raise ValueError(
209213
"super(YourLayer, self).__init__() should be called first")
210-
params[name] = value
214+
if value.name in self._loaddict_holder:
215+
params[name] = self._loaddict_holder[value.name]
216+
else:
217+
params[name] = value
211218
elif isinstance(value, core.Layer):
212219
layers = self.__dict__.get('_sub_layers', None)
213220
if layers is None:
@@ -244,6 +251,7 @@ def state_dict(self, destination=None, prefix='', include_sublayers=True):
244251
return destination
245252

246253
def load_dict(self, stat_dict, include_sublayers=True):
254+
self._loaddict_holder = stat_dict
247255
for name, item in self.__dict__.get('_parameters', None).items():
248256
if item.name in stat_dict:
249257
var = item._ivar.value()

python/paddle/fluid/tests/unittests/test_imperative_checkpoint.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -142,14 +142,11 @@ def test_save_load_persistables(self):
142142
for param in mnist.parameters():
143143
dy_param_init_value[param.name] = param.numpy()
144144

145-
mnist.load_dict(
146-
fluid.dygraph.load_persistables(mnist.state_dict(),
147-
"save_dir"))
148-
149-
restore = mnist.parameters()
145+
restore = fluid.dygraph.load_persistables("save_dir")
146+
mnist.load_dict(restore)
150147

151148
self.assertEqual(len(dy_param_init_value), len(restore))
152-
for value in restore:
149+
for ky, value in restore.items():
153150
self.assertTrue(
154151
np.allclose(value.numpy(), dy_param_init_value[
155152
value.name]))
@@ -158,7 +155,7 @@ def test_save_load_persistables(self):
158155

159156
step += 1
160157

161-
if step > 20:
158+
if step > 10:
162159
break
163160

164161

0 commit comments

Comments
 (0)