Skip to content

Commit 975aeee

Browse files
authored
Merge pull request #16904 from junjun315/hot-fix-saveload
fix save and load bug, test=release/1.4
2 parents 24faf9e + d6d18ad commit 975aeee

File tree

4 files changed

+35
-15
lines changed

4 files changed

+35
-15
lines changed

paddle/fluid/imperative/layer.h

+5-1
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,11 @@ class PYBIND11_HIDDEN RuntimeInferVarTypeContext
464464

465465
void SetType(const std::string& name,
466466
framework::proto::VarType::Type type) override {
467-
var_set_[name]->SetType(type);
467+
if (name == "kLookupTablePath") {
468+
VLOG(2) << "SUPER UGLY FIX, remove this when move imperative mode in C++";
469+
} else {
470+
var_set_[name]->SetType(type);
471+
}
468472
}
469473

470474
framework::proto::VarType::Type GetDataType(

python/paddle/fluid/dygraph/checkpoint.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,17 @@ def load_persistables(vardict, dirname, filename=None):
113113
def _save_var_to_file(stat_dict, file_dir, file_name):
114114
save_block = default_main_program().global_block()
115115
save_var_map = {}
116-
for each_var in stat_dict.items():
116+
for var_key, each_var in stat_dict.items():
117117
save_var_map[each_var.name] = each_var
118118
if file_name is None:
119119
save_block.append_op(
120120
type='save',
121121
inputs={'X': [each_var]},
122122
outputs={},
123-
attrs={'file_path': os.path.join(file_dir, each_var.name)})
123+
attrs={
124+
'file_path': os.path.join(file_dir,
125+
os.path.normpath(each_var.name))
126+
})
124127

125128
if file_name is not None:
126129
save_var_list = []
@@ -131,14 +134,16 @@ def _save_var_to_file(stat_dict, file_dir, file_name):
131134
type='save_combine',
132135
inputs={'X': save_var_list},
133136
outputs={},
134-
attrs={'file_path': os.path.join(file_dir, file_name)})
137+
attrs={
138+
'file_path': os.path.join(file_dir, os.path.normpath(file_name))
139+
})
135140

136141

137142
def _load_var_from_file(stat_dict, file_dir, file_name):
138143
load_block = default_main_program().global_block()
139144
load_var_map = {}
140145

141-
for each_var in stat_dict.items():
146+
for var_key, each_var in stat_dict.items():
142147
assert isinstance(each_var, Variable)
143148
if each_var.type == core.VarDesc.VarType.RAW:
144149
continue
@@ -148,7 +153,10 @@ def _load_var_from_file(stat_dict, file_dir, file_name):
148153
type='load',
149154
inputs={},
150155
outputs={'Out': [new_var]},
151-
attrs={'file_path': os.path.join(file_dir, each_var.name)})
156+
attrs={
157+
'file_path': os.path.join(file_dir,
158+
os.path.normpath(each_var.name))
159+
})
152160

153161
load_var_map[new_var.name] = new_var
154162

@@ -161,7 +169,9 @@ def _load_var_from_file(stat_dict, file_dir, file_name):
161169
type='load_combine',
162170
inputs={},
163171
outputs={"Out": load_var_list},
164-
attrs={'file_path': os.path.join(file_dir, file_name)})
172+
attrs={
173+
'file_path': os.path.join(file_dir, os.path.normpath(file_name))
174+
})
165175
for res_var in load_var_list:
166176
load_var_map[res_var.name] = res_var
167177

@@ -175,5 +185,5 @@ def _clone_var_in_block_(block, var):
175185
shape=var.shape,
176186
dtype=var.dtype,
177187
type=var.type,
178-
lod_level=var.lod_level,
188+
lod_level=0,
179189
persistable=True)

python/paddle/fluid/dygraph/layers.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,10 @@ def state_dict(self, destination=None, prefix='', include_sublayers=True):
246246
def load_dict(self, stat_dict, include_sublayers=True):
247247
for name, item in self.__dict__.get('_parameters', None).items():
248248
if item.name in stat_dict:
249-
self.__setattr__(name, stat_dict[item.name])
249+
var = item._ivar.value()
250+
tensor = var.get_tensor()
251+
tensor.set(stat_dict[item.name].numpy(),
252+
framework._current_expected_place())
250253

251254
if include_sublayers:
252255
for layer_name, layer_item in self._sub_layers.items():

python/paddle/fluid/tests/unittests/test_imperative_checkpoint.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def forward(self, inputs):
9999

100100

101101
class TestDygraphCheckpoint(unittest.TestCase):
102-
def save_load_persistables(self):
102+
def test_save_load_persistables(self):
103103
seed = 90
104104
epoch_num = 1
105105

@@ -135,23 +135,26 @@ def save_load_persistables(self):
135135

136136
avg_loss.backward()
137137
sgd.minimize(avg_loss)
138-
fluid.dygraph.save_persistables(mnist, "save_dir")
138+
fluid.dygraph.save_persistables(mnist.state_dict(),
139+
"save_dir")
139140
mnist.clear_gradients()
140141

141142
for param in mnist.parameters():
142143
dy_param_init_value[param.name] = param.numpy()
143144

144145
mnist.load_dict(
145-
fluid.dygraph.load_persistables(mnist, "save_dir"))
146+
fluid.dygraph.load_persistables(mnist.state_dict(),
147+
"save_dir"))
146148

147149
restore = mnist.parameters()
148150

149151
self.assertEqual(len(dy_param_init_value), len(restore))
150152
for value in restore:
151153
self.assertTrue(
152-
np.allclose(value, dy_param_init_value[value.name]))
153-
self.assertTrue(np.isfinite(value.all()))
154-
self.assertFalse(np.isnan(value.any()))
154+
np.allclose(value.numpy(), dy_param_init_value[
155+
value.name]))
156+
self.assertTrue(np.isfinite(value.numpy().all()))
157+
self.assertFalse(np.isnan(value.numpy().any()))
155158

156159
step += 1
157160

0 commit comments

Comments
 (0)