Skip to content

Commit 4d05f84

Browse files
committed
test=develop, bug fix
1 parent c769431 commit 4d05f84

File tree

4 files changed

+7
-7
lines changed

4 files changed

+7
-7
lines changed

paddle/fluid/operators/distributed/communicator.cc

-3
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,6 @@ DEFINE_bool(communicator_merge_sparse_grad, true,
4545
"merge sparse gradient before sending");
4646
DEFINE_int32(communicator_merge_sparse_bucket, 2000,
4747
"number of threads for sparse var");
48-
DEFINE_bool(communicator_is_sgd_optimizer, true,
49-
"gradient sent to the server is the sum of the gradients "
50-
"calculated by each thread if optimizer is sgd");
5148

5249
namespace paddle {
5350
namespace operators {

paddle/fluid/operators/distributed/communicator.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,10 @@ inline void MergeVars(const std::string& var_name,
141141
auto in = EigenVector<float>::Flatten(in_t);
142142
result.device(*cpu_ctx.eigen_device()) = result + in;
143143
}
144-
result.device(*cpu_ctx.eigen_device()) =
145-
result / static_cast<float>(vars.size());
144+
if (!FLAGS_communicator_is_sgd_optimizer) {
145+
result.device(*cpu_ctx.eigen_device()) =
146+
result / static_cast<float>(vars.size());
147+
}
146148
} else if (var0->IsType<framework::SelectedRows>()) {
147149
auto& slr0 = var0->Get<framework::SelectedRows>();
148150
auto* out_slr = out_var->GetMutable<framework::SelectedRows>();

paddle/fluid/platform/flags.cc

+3-1
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,9 @@ DEFINE_bool(
199199
*/
200200
DEFINE_int32(communicator_max_merge_var_num, 20,
201201
"max var num to merge and send");
202-
202+
DEFINE_bool(communicator_is_sgd_optimizer, true,
203+
"gradient sent to the server is the sum of the gradients "
204+
"calculated by each thread if optimizer is sgd");
203205
/**
204206
* Distributed related FLAG
205207
* Name: FLAGS_communicator_send_queue_size

python/paddle/fluid/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,6 @@ def __bootstrap__():
206206
read_env_flags.append('communicator_fake_rpc')
207207
read_env_flags.append('communicator_send_wait_times')
208208
read_env_flags.append('communicator_merge_sparse_grad')
209-
read_env_flags.append('communicator_is_sgd_optimizer')
210209
if core.is_compiled_with_brpc():
211210
read_env_flags.append('max_body_size')
212211
#set brpc max body size

0 commit comments

Comments
 (0)