@@ -75,7 +75,7 @@ def save_persistables(vardict, dirname, filename=None):
75
75
_save_var_to_file (vardict , dirname , filename )
76
76
77
77
78
- def load_persistables (vardict , dirname , filename = None ):
78
+ def load_persistables (dirname ):
79
79
"""
80
80
This function trys to load persistable variables from the folder
81
81
`dirname` or the file `filename`.
@@ -86,11 +86,7 @@ def load_persistables(vardict, dirname, filename=None):
86
86
the file name.
87
87
88
88
Args:
89
- vardict(dict of Parameters): The parameters will be loaded.
90
89
dirname(str): The directory path.
91
- filename(str|None): The file which saved all variables, this file path should be end with '.npz'. If variables were
92
- saved in differnet files, set it to None.
93
- Default: None
94
90
95
91
Returns:
96
92
dict: The parameter-dict resumed from file
@@ -104,10 +100,7 @@ def load_persistables(vardict, dirname, filename=None):
104
100
param_1 = param_dict['PtbModel_0.w_1']
105
101
106
102
"""
107
- if isinstance (vardict , collections .OrderedDict ):
108
- return _load_var_from_file (vardict , dirname , filename )
109
-
110
- return {}
103
+ return _load_var_from_file (dirname )
111
104
112
105
113
106
def _save_var_to_file (stat_dict , file_dir , file_name ):
@@ -139,41 +132,39 @@ def _save_var_to_file(stat_dict, file_dir, file_name):
139
132
})
140
133
141
134
142
- def _load_var_from_file (stat_dict , file_dir , file_name ):
135
+ def _load_var_from_file (file_dir ):
136
+ def walk_filename (file_dir ):
137
+ base_path = os .path .join (file_dir )
138
+ var_name_list = []
139
+ if os .path .exists (base_path ):
140
+ for dirpath , dirnames , filenames in os .walk (base_path ):
141
+ pt = dirpath .replace (base_path , "" , 1 )
142
+ if pt .startswith ("/" ) or pt .startswith ("\\ " ):
143
+ pt = pt [1 :]
144
+ for fth_name in filenames :
145
+ if fth_name [0 ] != '.' :
146
+ name_path = os .path .join (pt , fth_name )
147
+ if "\\ " in name_path :
148
+ name_path = name_path .replace ("\\ " , "/" )
149
+ var_name_list .append (name_path )
150
+
151
+ return var_name_list
152
+
143
153
load_block = default_main_program ().global_block ()
144
154
load_var_map = {}
145
-
146
- for var_key , each_var in stat_dict .items ():
147
- assert isinstance (each_var , Variable )
148
- if each_var .type == core .VarDesc .VarType .RAW :
149
- continue
150
- new_var = _clone_var_in_block_ (load_block , each_var )
151
- if file_name is None :
152
- load_block .append_op (
153
- type = 'load' ,
154
- inputs = {},
155
- outputs = {'Out' : [new_var ]},
156
- attrs = {
157
- 'file_path' : os .path .join (file_dir ,
158
- os .path .normpath (each_var .name ))
159
- })
160
-
161
- load_var_map [new_var .name ] = new_var
162
-
163
- if file_name is not None :
164
- load_var_list = []
165
- for name in sorted (load_var_map .keys ()):
166
- load_var_list .append (load_var_map [name ])
167
-
155
+ file_var_list = walk_filename (file_dir )
156
+ for var_name in file_var_list :
157
+ new_var = Variable (block = load_block , name = var_name )
168
158
load_block .append_op (
169
- type = 'load_combine ' ,
159
+ type = 'load ' ,
170
160
inputs = {},
171
- outputs = {" Out" : load_var_list },
161
+ outputs = {' Out' : [ new_var ] },
172
162
attrs = {
173
- 'file_path' : os .path .join (file_dir , os .path .normpath (file_name ))
163
+ 'file_path' : os .path .join (file_dir ,
164
+ os .path .normpath (new_var .name ))
174
165
})
175
- for res_var in load_var_list :
176
- load_var_map [res_var .name ] = res_var
166
+
167
+ load_var_map [new_var .name ] = new_var
177
168
178
169
return load_var_map
179
170
0 commit comments