1
- # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
1
+ # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2
2
#
3
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
4
# you may not use this file except in compliance with the License.
16
16
17
17
import os
18
18
import collections
19
- from .. import core
20
19
from ..framework import Variable , default_main_program
20
+ import pickle
21
+ from . import learning_rate_scheduler
22
+ import warnings
21
23
22
24
__all__ = ['save_persistables' , 'load_persistables' ]
23
25
24
26
25
- def save_persistables (vardict , dirname , filename = None ):
27
+ def save_persistables (model_dict ,
28
+ optimizer = None ,
29
+ dirname = 'save_dir' ,
30
+ filename = None ):
26
31
"""
27
32
This function filters out all variables in layer.parameters from the
28
33
give `layer` and then trys to load these variables from the folder
@@ -34,12 +39,12 @@ def save_persistables(vardict, dirname, filename=None):
34
39
the file name.
35
40
36
41
Args:
37
- vardict (dict of Parameters): The parameters will
42
+ model_dict (dict of Parameters): The parameters will
38
43
be saved. If it is None, nothing
39
44
will be deal.
40
45
dirname(str): The directory path.
41
46
filename(str|None): The file which saved all variables. If variables were
42
- saved in differnet files, set it to None.
47
+ saved in different files, set it to None.
43
48
Default: None
44
49
45
50
Returns:
@@ -71,11 +76,11 @@ def save_persistables(vardict, dirname, filename=None):
71
76
fluid.dygraph.save_persistables(ptb_model.state_dict(), dirname=param_path,
72
77
layer=ptb_model)
73
78
"""
74
- if isinstance (vardict , collections .OrderedDict ):
75
- _save_var_to_file (vardict , dirname , filename )
79
+ if isinstance (model_dict , collections .OrderedDict ):
80
+ _save_var_to_file (model_dict , optimizer , dirname , filename )
76
81
77
82
78
- def load_persistables (dirname ):
83
+ def load_persistables (dirname = 'save_dir' ):
79
84
"""
80
85
This function trys to load persistable variables from the folder
81
86
`dirname` or the file `filename`.
@@ -86,7 +91,8 @@ def load_persistables(dirname):
86
91
the file name.
87
92
88
93
Args:
89
- dirname(str): The directory path.
94
+ dirname(str): The directory path. default is save_dir
95
+ optimizer(Optimizer): Optimizer to be saved
90
96
91
97
Returns:
92
98
dict: The parameter-dict resumed from file
@@ -103,7 +109,7 @@ def load_persistables(dirname):
103
109
return _load_var_from_file (dirname )
104
110
105
111
106
- def _save_var_to_file (stat_dict , file_dir , file_name ):
112
+ def _save_var_to_file (stat_dict , optimizers , file_dir , file_name ):
107
113
save_block = default_main_program ().global_block ()
108
114
save_var_map = {}
109
115
for var_key , each_var in stat_dict .items ():
@@ -117,6 +123,32 @@ def _save_var_to_file(stat_dict, file_dir, file_name):
117
123
'file_path' : os .path .join (file_dir ,
118
124
os .path .normpath (each_var .name ))
119
125
})
126
+ if isinstance (optimizers , (list , tuple )):
127
+ optimizers = optimizers
128
+ else :
129
+ optimizers = [optimizers ]
130
+ if os .path .exists (os .path .join (file_dir , os .path .normpath ("optimizers" ))):
131
+ pass
132
+ else :
133
+ os .mkdir (os .path .join (file_dir , os .path .normpath ("optimizers" )))
134
+ for optimizer in optimizers :
135
+ if isinstance (optimizer ._learning_rate ,
136
+ learning_rate_scheduler .LearningRateDecay ):
137
+ try :
138
+ f = open (
139
+ os .path .join (file_dir , "optimizers" ,
140
+ os .path .normpath (str (optimizer ._name ))), "wb" )
141
+ pickle .dump (optimizer ._learning_rate , f , 2 )
142
+ f .close ()
143
+ except ():
144
+ raise IOError ("Can't load %s" ,
145
+ os .path .join (
146
+ file_dir , "optimizers" ,
147
+ os .path .normpath (str (optimizer ._name ))))
148
+ else :
149
+ warnings .warn (
150
+ "Optimizer not saved, Only optimizer with 'LearningRateDecay' under DyGraph mode need to be saved"
151
+ )
120
152
121
153
if file_name is not None :
122
154
save_var_list = []
@@ -138,6 +170,8 @@ def walk_filename(file_dir):
138
170
var_name_list = []
139
171
if os .path .exists (base_path ):
140
172
for dirpath , dirnames , filenames in os .walk (base_path ):
173
+ if "optimizers" in dirpath :
174
+ continue
141
175
pt = dirpath .replace (base_path , "" , 1 )
142
176
if pt .startswith ("/" ) or pt .startswith ("\\ " ):
143
177
pt = pt [1 :]
@@ -152,6 +186,7 @@ def walk_filename(file_dir):
152
186
153
187
load_block = default_main_program ().global_block ()
154
188
load_var_map = {}
189
+ load_optimizer_map = {}
155
190
file_var_list = walk_filename (file_dir )
156
191
for var_name in file_var_list :
157
192
new_var = Variable (block = load_block , name = var_name )
@@ -165,8 +200,22 @@ def walk_filename(file_dir):
165
200
})
166
201
167
202
load_var_map [new_var .name ] = new_var
168
-
169
- return load_var_map
203
+ opt_path = os .path .join (file_dir , "optimizers" )
204
+ for _ , _ , optimizers in os .walk (opt_path ):
205
+ for optimizer in optimizers :
206
+ try :
207
+ f = open (os .path .join (opt_path , optimizer ), "rb" )
208
+ load_optimizer_map [optimizer ] = pickle .load (f )
209
+ f .close ()
210
+ except IOError :
211
+ raise IOError ("Can't load %s" ,
212
+ os .path .join (
213
+ file_dir , "optimizers" ,
214
+ os .path .normpath (str (optimizer ._name ))))
215
+ if len (load_optimizer_map ) == 0 :
216
+ warnings .warn ("No optimizer loaded" )
217
+
218
+ return load_var_map , load_optimizer_map
170
219
171
220
172
221
def _clone_var_in_block_ (block , var ):
0 commit comments