Skip to content

Commit c1d87fc

Browse files
committed
fix save/load in fleet (#17675)
* fix save/load in Fleet * add UT framework of Fleet
1 parent 1810bfb commit c1d87fc

File tree

15 files changed

+664
-37
lines changed

15 files changed

+664
-37
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ paddle.fluid.io.save_persistables (ArgSpec(args=['executor', 'dirname', 'main_pr
5353
paddle.fluid.io.load_vars (ArgSpec(args=['executor', 'dirname', 'main_program', 'vars', 'predicate', 'filename'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', '1bb9454cf09d71f190bb51550c5a3ac9'))
5454
paddle.fluid.io.load_params (ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None)), ('document', '944291120d37bdb037a689d2c86d0a6e'))
5555
paddle.fluid.io.load_persistables (ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None)), ('document', '28df5bfe26ca7a077f91156abb0fe6d2'))
56-
paddle.fluid.io.save_inference_model (ArgSpec(args=['dirname', 'feeded_var_names', 'target_vars', 'executor', 'main_program', 'model_filename', 'params_filename', 'export_for_deployment'], varargs=None, keywords=None, defaults=(None, None, None, True)), ('document', '89539e459eb959145f15c9c3e38fa97c'))
56+
paddle.fluid.io.save_inference_model (ArgSpec(args=['dirname', 'feeded_var_names', 'target_vars', 'executor', 'main_program', 'model_filename', 'params_filename', 'export_for_deployment', 'program_only'], varargs=None, keywords=None, defaults=(None, None, None, True, False)), ('document', 'fc82bfd137a9b1ab8ebd1651bd35b6e5'))
5757
paddle.fluid.io.load_inference_model (ArgSpec(args=['dirname', 'executor', 'model_filename', 'params_filename', 'pserver_endpoints'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '2f54d7c206b62f8c10f4f9d78c731cfd'))
5858
paddle.fluid.io.PyReader.__init__ (ArgSpec(args=['self', 'feed_list', 'capacity', 'use_double_buffer', 'iterable', 'return_list'], varargs=None, keywords=None, defaults=(None, None, True, True, False)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
5959
paddle.fluid.io.PyReader.decorate_batch_generator (ArgSpec(args=['self', 'reader', 'places'], varargs=None, keywords=None, defaults=(None,)), ('document', '4a072de39998ee4e0de33fcec11325a6'))

paddle/fluid/operators/distributed/request_handler_impl.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ bool RequestGetHandler::Handle(const std::string& varname,
104104
} else {
105105
if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) {
106106
if (enable_dc_asgd_) {
107-
// NOTE: the format is determined by distributed_transpiler.py
107+
// NOTE: the format is determined by distribute_transpiler.py
108108
std::string param_bak_name =
109109
string::Sprintf("%s.trainer_%d_bak", varname, trainer_id);
110110
VLOG(3) << "getting " << param_bak_name << " trainer_id " << trainer_id;

python/paddle/fluid/incubate/fleet/base/fleet_base.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,22 @@
1515
from __future__ import print_function
1616

1717
import abc
18-
from enum import Enum
1918

2019
import paddle.fluid as fluid
2120
from paddle.fluid.executor import Executor
2221
from paddle.fluid.optimizer import SGD
2322

24-
from role_maker import MPISymetricRoleMaker
25-
from role_maker import RoleMakerBase
26-
from role_maker import UserDefinedRoleMaker
23+
from paddle.fluid.incubate.fleet.base.role_maker import MPISymetricRoleMaker
24+
from paddle.fluid.incubate.fleet.base.role_maker import RoleMakerBase
25+
from paddle.fluid.incubate.fleet.base.role_maker import UserDefinedRoleMaker
2726

2827

29-
class Mode(Enum):
28+
class Mode:
3029
"""
3130
There are various mode for fleet, each of them is designed for different model.
3231
"""
33-
TRANSPILER = 1,
34-
PSLIB = 2,
32+
TRANSPILER = 1
33+
PSLIB = 2
3534
COLLECTIVE = 3
3635

3736

python/paddle/fluid/incubate/fleet/base/role_maker.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,15 @@
1313
# limitations under the License.
1414

1515
from __future__ import print_function
16-
from enum import Enum
1716

1817
__all__ = [
1918
'Role', 'RoleMakerBase', 'MPISymetricRoleMaker', 'UserDefinedRoleMaker',
2019
'UserDefinedCollectiveRoleMaker'
2120
]
2221

2322

24-
class Role(Enum):
25-
WORKER = 1,
23+
class Role:
24+
WORKER = 1
2625
SERVER = 2
2726

2827

@@ -313,7 +312,7 @@ def __init__(self,
313312
raise ValueError("current_id must be gather or equal 0")
314313
self._current_id = current_id
315314

316-
if not isinstance(role, Role):
315+
if role != Role.WORKER and role != Role.SERVER:
317316
raise TypeError("role must be as Role")
318317
else:
319318
self._role = role

python/paddle/fluid/incubate/fleet/collective/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
import paddle.fluid.io as io
1818
import paddle.fluid.transpiler.distribute_transpiler as dist_transpiler
1919

20-
from ..base.fleet_base import Fleet
21-
from ..base.fleet_base import Mode
22-
from ..base.fleet_base import DistributedOptimizer
20+
from paddle.fluid.incubate.fleet.base.fleet_base import Fleet
21+
from paddle.fluid.incubate.fleet.base.fleet_base import Mode
22+
from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer
2323

2424

2525
class Collective(Fleet):

python/paddle/fluid/incubate/fleet/parameter_server/distributed_transpiler/__init__.py renamed to python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,16 @@
1515

1616
import paddle.fluid.io as io
1717
from paddle.fluid.communicator import Communicator
18+
from paddle.fluid.framework import default_main_program
1819
from paddle.fluid.framework import default_startup_program
20+
from paddle.fluid.framework import Program
1921
from paddle.fluid.optimizer import Optimizer
2022
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspiler as OriginTranspiler
2123
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig
2224

23-
from ...base.fleet_base import DistributedOptimizer
24-
from ...base.fleet_base import Fleet
25-
from ...base.fleet_base import Mode
25+
from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer
26+
from paddle.fluid.incubate.fleet.base.fleet_base import Fleet
27+
from paddle.fluid.incubate.fleet.base.fleet_base import Mode
2628

2729

2830
class DistributedTranspiler(Fleet):
@@ -34,6 +36,7 @@ def __init__(self):
3436
super(DistributedTranspiler, self).__init__(Mode.TRANSPILER)
3537
self._transpile_config = None
3638
self._transpiler = None
39+
self._origin_program = None
3740
self.startup_program = None
3841
self.main_program = None
3942
self._communicator = None
@@ -75,8 +78,7 @@ def init_server(self, model_dir=None):
7578
if not os.path.isdir(model_dir):
7679
raise ValueError("There is no directory named '%s'", model_dir)
7780

78-
io.load_persistables(self._executor, model_dir,
79-
self.startup_program)
81+
io.load_persistables(self._executor, model_dir, self.main_program)
8082

8183
def run_server(self):
8284
"""
@@ -137,9 +139,31 @@ def save_inference_model(self,
137139
Prune the given `main_program` to build a new program especially for inference,
138140
and then save it and all related parameters to given `dirname` by the `executor`.
139141
"""
140-
io.save_inference_model(dirname, feeded_var_names, target_vars,
141-
executor, main_program, None, None,
142-
export_for_deployment)
142+
if main_program is not None:
143+
io.save_inference_model(dirname, feeded_var_names, target_vars,
144+
executor, main_program, None, None,
145+
export_for_deployment)
146+
else:
147+
io.save_inference_model(
148+
dirname,
149+
feeded_var_names,
150+
target_vars,
151+
executor,
152+
self._origin_program,
153+
None,
154+
None,
155+
export_for_deployment,
156+
model_only=True)
157+
158+
model_basename = "__model__"
159+
model_filename = os.path.join(dirname, model_basename)
160+
161+
with open(model_filename, "rb") as f:
162+
program_desc_str = f.read()
163+
164+
program = Program.parse_from_string(program_desc_str)
165+
program._copy_dist_param_info_from(self.main_program)
166+
self.save_persistables(executor, dirname, program)
143167

144168
def save_persistables(self, executor, dirname, main_program=None):
145169
"""
@@ -152,6 +176,14 @@ def save_persistables(self, executor, dirname, main_program=None):
152176
files, set `filename` None; if you would like to save all variables in a
153177
single file, use `filename` to specify the file name.
154178
"""
179+
180+
if main_program is None:
181+
main_program = self.main_program
182+
183+
if not main_program._is_distributed:
184+
raise ValueError(
185+
"main_program is for local, may not use fleet.save_persistables")
186+
155187
io.save_persistables(executor, dirname, main_program, None)
156188

157189
def _transpile(self, config):
@@ -162,18 +194,27 @@ def _transpile(self, config):
162194
if not config.sync_mode:
163195
config.runtime_split_send_recv = True
164196

197+
# _origin_program is a deep copy for default_main_program, for inference
198+
self._origin_program = default_main_program().clone(for_test=False)
199+
165200
self._transpile_config = config
166201
self._transpiler = OriginTranspiler(config)
167-
self._transpiler.transpile(
168-
trainer_id=fleet.worker_index(),
169-
pservers=fleet.server_endpoints(to_string=True),
170-
trainers=fleet.worker_num(),
171-
sync_mode=config.sync_mode)
172202

173203
if self.is_worker():
204+
self._transpiler.transpile(
205+
trainer_id=fleet.worker_index(),
206+
pservers=fleet.server_endpoints(to_string=True),
207+
trainers=fleet.worker_num(),
208+
sync_mode=config.sync_mode)
174209
self.main_program = self._transpiler.get_trainer_program()
175210
self.startup_program = default_startup_program()
176211
else:
212+
self._transpiler.transpile(
213+
trainer_id=fleet.worker_index(),
214+
pservers=fleet.server_endpoints(to_string=True),
215+
trainers=fleet.worker_num(),
216+
sync_mode=config.sync_mode,
217+
current_endpoint=self.server_endpoints()[self.server_index()])
177218
self.main_program, self.startup_program = \
178219
self._transpiler.get_pserver_programs(self.server_endpoints()[self.server_index()])
179220

python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313

1414
import sys
15-
from .optimizer_factory import *
15+
from optimizer_factory import *
1616
from google.protobuf import text_format
1717

1818
import paddle.fluid as fluid
1919
from paddle.fluid.framework import Program
2020

21-
from ...base.fleet_base import Fleet
22-
from ...base.fleet_base import Mode
23-
from ...base.role_maker import MPISymetricRoleMaker
24-
from ...base.fleet_base import DistributedOptimizer
21+
from paddle.fluid.incubate.fleet.base.fleet_base import Fleet
22+
from paddle.fluid.incubate.fleet.base.fleet_base import Mode
23+
from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer
24+
from paddle.fluid.incubate.fleet.base.role_maker import MPISymetricRoleMaker
2525

2626

2727
class PSLib(Fleet):

python/paddle/fluid/incubate/fleet/tests/fleet_deep_ctr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import paddle.fluid as fluid
2020
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
21-
from paddle.fluid.incubate.fleet.parameter_server.distributed_transpiler import fleet
21+
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
2222
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig
2323

2424
import ctr_dataset_reader

python/paddle/fluid/io.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,8 @@ def save_inference_model(dirname,
907907
main_program=None,
908908
model_filename=None,
909909
params_filename=None,
910-
export_for_deployment=True):
910+
export_for_deployment=True,
911+
program_only=False):
911912
"""
912913
Prune the given `main_program` to build a new program especially for inference,
913914
and then save it and all related parameters to given `dirname` by the `executor`.
@@ -938,6 +939,7 @@ def save_inference_model(dirname,
938939
more information will be stored for flexible
939940
optimization and re-training. Currently, only
940941
True is supported.
942+
program_only(bool): If True, It will save inference program only, and do not save params of Program.
941943
942944
Returns:
943945
target_var_name_list(list): The fetch variables' name list
@@ -1071,6 +1073,12 @@ def save_inference_model(dirname,
10711073
with open(model_basename + ".main_program", "wb") as f:
10721074
f.write(main_program.desc.serialize_to_string())
10731075

1076+
if program_only:
1077+
warnings.warn(
1078+
"save_inference_model specified the param `program_only` to True, It will not save params of Program."
1079+
)
1080+
return target_var_name_list
1081+
10741082
main_program._copy_dist_param_info_from(origin_program)
10751083

10761084
if params_filename is not None:

python/paddle/fluid/tests/unittests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ if(NOT WITH_DISTRIBUTE)
1717
LIST(REMOVE_ITEM TEST_OPS test_dist_text_classification)
1818
LIST(REMOVE_ITEM TEST_OPS test_nce_remote_table_op)
1919
LIST(REMOVE_ITEM TEST_OPS test_hsigmoid_remote_table_op)
20+
LIST(REMOVE_ITEM TEST_OPS test_dist_fleet_ctr)
2021
endif(NOT WITH_DISTRIBUTE)
2122

2223
LIST(REMOVE_ITEM TEST_OPS test_launch)
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import logging
18+
import tarfile
19+
import os
20+
21+
import paddle
22+
import paddle.fluid.incubate.data_generator as data_generator
23+
24+
logging.basicConfig()
25+
logger = logging.getLogger("paddle")
26+
logger.setLevel(logging.INFO)
27+
28+
DATA_URL = "http://paddle-ctr-data.bj.bcebos.com/avazu_ctr_data.tgz"
29+
DATA_MD5 = "c11df99fbd14e53cd4bfa6567344b26e"
30+
"""
31+
avazu_ctr_data/train.txt
32+
avazu_ctr_data/infer.txt
33+
avazu_ctr_data/test.txt
34+
avazu_ctr_data/data.meta.txt
35+
"""
36+
37+
38+
def download_file():
39+
file_name = "avazu_ctr_data"
40+
path = paddle.dataset.common.download(DATA_URL, file_name, DATA_MD5)
41+
42+
dir_name = os.path.dirname(path)
43+
text_file_dir_name = os.path.join(dir_name, file_name)
44+
45+
if not os.path.exists(text_file_dir_name):
46+
tar = tarfile.open(path, "r:gz")
47+
tar.extractall(dir_name)
48+
return text_file_dir_name
49+
50+
51+
def load_dnn_input_record(sent):
52+
return list(map(int, sent.split()))
53+
54+
55+
def load_lr_input_record(sent):
56+
res = []
57+
for _ in [x.split(':') for x in sent.split()]:
58+
res.append(int(_[0]))
59+
return res
60+
61+
62+
class DatasetCtrReader(data_generator.MultiSlotDataGenerator):
63+
def generate_sample(self, line):
64+
def iter():
65+
fs = line.strip().split('\t')
66+
dnn_input = load_dnn_input_record(fs[0])
67+
lr_input = load_lr_input_record(fs[1])
68+
click = [int(fs[2])]
69+
yield ("dnn_data", dnn_input), \
70+
("lr_data", lr_input), \
71+
("click", click)
72+
73+
return iter
74+
75+
76+
def prepare_data():
77+
"""
78+
load data meta info from path, return (dnn_input_dim, lr_input_dim)
79+
"""
80+
file_dir_name = download_file()
81+
meta_file_path = os.path.join(file_dir_name, 'data.meta.txt')
82+
train_file_path = os.path.join(file_dir_name, 'train.txt')
83+
with open(meta_file_path, "r") as f:
84+
lines = f.readlines()
85+
err_info = "wrong meta format"
86+
assert len(lines) == 2, err_info
87+
assert 'dnn_input_dim:' in lines[0] and 'lr_input_dim:' in lines[
88+
1], err_info
89+
res = map(int, [_.split(':')[1] for _ in lines])
90+
res = list(res)
91+
dnn_input_dim = res[0]
92+
lr_input_dim = res[1]
93+
logger.info('dnn input dim: %d' % dnn_input_dim)
94+
logger.info('lr input dim: %d' % lr_input_dim)
95+
return dnn_input_dim, lr_input_dim, train_file_path
96+
97+
98+
if __name__ == "__main__":
99+
pairwise_reader = DatasetCtrReader()
100+
pairwise_reader.run_from_stdin()

0 commit comments

Comments
 (0)