From 0d4d9c4e1392d46a1c2c3588bd4d6eb4fdd0c980 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Mon, 12 Feb 2018 17:39:26 +0800 Subject: [PATCH] fix grpc short connection --- paddle/fluid/operators/listen_and_serv_op.cc | 4 ++-- paddle/fluid/operators/recv_op.cc | 4 ++-- paddle/fluid/operators/send_op.cc | 4 ++-- python/paddle/v2/fluid/distribute_transpiler.py | 11 +++-------- 4 files changed, 9 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 426dd0dc0e95b7..8e88a7dcf141dc 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -82,8 +82,8 @@ class ListenAndServOp : public framework::OperatorBase { return string::Sprintf("%s.trainer_%d", varname, grads_counter_[varname]++); } - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(dev_place); framework::Scope &recv_scope = scope.NewScope(); diff --git a/paddle/fluid/operators/recv_op.cc b/paddle/fluid/operators/recv_op.cc index c093f60ceed417..17b57b5d45a3a0 100644 --- a/paddle/fluid/operators/recv_op.cc +++ b/paddle/fluid/operators/recv_op.cc @@ -32,8 +32,8 @@ class RecvOp : public framework::OperatorBase { const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope& scope, - const platform::Place& place) const override { + void RunImpl(const framework::Scope& scope, + const platform::Place& place) const override { auto outs = Outputs("Out"); std::vector epmap = Attr>("epmap"); diff --git a/paddle/fluid/operators/send_op.cc b/paddle/fluid/operators/send_op.cc index b241f738cbf60c..39b6c0e8c515d8 100644 --- a/paddle/fluid/operators/send_op.cc +++ b/paddle/fluid/operators/send_op.cc @@ -48,8 +48,8 @@ class SendOp : public framework::OperatorBase { const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void Run(const framework::Scope& scope, - const platform::Place& place) const override { + void RunImpl(const framework::Scope& scope, + const platform::Place& place) const override { auto ins = Inputs("X"); auto outs = Outputs("Out"); std::vector epmap = Attr>("epmap"); diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index 689920af0c4fb8..bf2e9e88f33947 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -121,7 +121,6 @@ def split_dense_variable(var_list, block_size += dim1 - remains # update split_count after aligning split_count = int(math.ceil(var_numel / float(block_size))) - print("###split var ", var.name, var.shape, block_size, split_count) for block_id in xrange(split_count): curr_block_size = min(block_size, var_numel - ( (block_id) * block_size)) @@ -207,7 +206,7 @@ def transpile(self, rpc_client_var = program.global_block().create_var( name="RPC_CLIENT_VAR", - psersistable=True, + persistable=True, dtype='float32', # dtype and shape is not used in fact shape=[0]) @@ -256,15 +255,13 @@ def _create_vars_from_blocklist(self, program, block_list): splited_shape = [rows] if len(orig_shape) >= 2: splited_shape.extend(orig_shape[1:]) - print("###splited: ", size, rows, splited_shape) var = program.global_block().create_var( name="%s.block%d" % (varname, i), - psersistable=False, + persistable=False, dtype=orig_var.dtype, type=orig_var.type, shape=splited_shape) # flattend splited var var_mapping[varname].append(var) - print("###created split var ", var) return var_mapping def _clone_var(self, block, var): @@ -322,7 +319,7 @@ def _create_var_for_trainers(self, block, var, trainers): for i in xrange(trainers): var_each = block.create_var( name="%s.trainer_%d" % (var.name, i), - psersistable=var.persistable, + persistable=var.persistable, dtype=var.dtype, type=var.type, shape=var.shape) @@ -531,8 +528,6 @@ def get_pserver_program(self, endpoint): """ # step5 pserver_program = Program() - print("param mapping on pserver: #### ", - self.param_grad_ep_mapping[endpoint]["params"]) for v in self.param_grad_ep_mapping[endpoint]["params"]: self._clone_var(pserver_program.global_block(), v) for v in self.param_grad_ep_mapping[endpoint]["grads"]: