Skip to content

Commit c0fc50d

Browse files
authored
Merge pull request PaddlePaddle#8409 from typhoonzero/fix_grpc_short_conn
Fix grpc short connection
2 parents 07923ba + 0d4d9c4 commit c0fc50d

File tree

4 files changed

+9
-14
lines changed

4 files changed

+9
-14
lines changed

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ class ListenAndServOp : public framework::OperatorBase {
8282
return string::Sprintf("%s.trainer_%d", varname, grads_counter_[varname]++);
8383
}
8484

85-
void Run(const framework::Scope &scope,
86-
const platform::Place &dev_place) const override {
85+
void RunImpl(const framework::Scope &scope,
86+
const platform::Place &dev_place) const override {
8787
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
8888
auto &dev_ctx = *pool.Get(dev_place);
8989
framework::Scope &recv_scope = scope.NewScope();

paddle/fluid/operators/recv_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ class RecvOp : public framework::OperatorBase {
3232
const framework::AttributeMap& attrs)
3333
: OperatorBase(type, inputs, outputs, attrs) {}
3434

35-
void Run(const framework::Scope& scope,
36-
const platform::Place& place) const override {
35+
void RunImpl(const framework::Scope& scope,
36+
const platform::Place& place) const override {
3737
auto outs = Outputs("Out");
3838
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
3939

paddle/fluid/operators/send_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ class SendOp : public framework::OperatorBase {
4848
const framework::AttributeMap& attrs)
4949
: OperatorBase(type, inputs, outputs, attrs) {}
5050

51-
void Run(const framework::Scope& scope,
52-
const platform::Place& place) const override {
51+
void RunImpl(const framework::Scope& scope,
52+
const platform::Place& place) const override {
5353
auto ins = Inputs("X");
5454
auto outs = Outputs("Out");
5555
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");

python/paddle/v2/fluid/distribute_transpiler.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ def split_dense_variable(var_list,
121121
block_size += dim1 - remains
122122
# update split_count after aligning
123123
split_count = int(math.ceil(var_numel / float(block_size)))
124-
print("###split var ", var.name, var.shape, block_size, split_count)
125124
for block_id in xrange(split_count):
126125
curr_block_size = min(block_size, var_numel - (
127126
(block_id) * block_size))
@@ -207,7 +206,7 @@ def transpile(self,
207206

208207
rpc_client_var = program.global_block().create_var(
209208
name="RPC_CLIENT_VAR",
210-
psersistable=True,
209+
persistable=True,
211210
dtype='float32', # dtype and shape is not used in fact
212211
shape=[0])
213212

@@ -256,15 +255,13 @@ def _create_vars_from_blocklist(self, program, block_list):
256255
splited_shape = [rows]
257256
if len(orig_shape) >= 2:
258257
splited_shape.extend(orig_shape[1:])
259-
print("###splited: ", size, rows, splited_shape)
260258
var = program.global_block().create_var(
261259
name="%s.block%d" % (varname, i),
262-
psersistable=False,
260+
persistable=False,
263261
dtype=orig_var.dtype,
264262
type=orig_var.type,
265263
shape=splited_shape) # flattend splited var
266264
var_mapping[varname].append(var)
267-
print("###created split var ", var)
268265
return var_mapping
269266

270267
def _clone_var(self, block, var):
@@ -322,7 +319,7 @@ def _create_var_for_trainers(self, block, var, trainers):
322319
for i in xrange(trainers):
323320
var_each = block.create_var(
324321
name="%s.trainer_%d" % (var.name, i),
325-
psersistable=var.persistable,
322+
persistable=var.persistable,
326323
dtype=var.dtype,
327324
type=var.type,
328325
shape=var.shape)
@@ -531,8 +528,6 @@ def get_pserver_program(self, endpoint):
531528
"""
532529
# step5
533530
pserver_program = Program()
534-
print("param mapping on pserver: #### ",
535-
self.param_grad_ep_mapping[endpoint]["params"])
536531
for v in self.param_grad_ep_mapping[endpoint]["params"]:
537532
self._clone_var(pserver_program.global_block(), v)
538533
for v in self.param_grad_ep_mapping[endpoint]["grads"]:

0 commit comments

Comments
 (0)