Skip to content

Commit 95e90aa

Browse files
authored
test=develop, add communicator_is_sgd_optimizer flag (#20677)
* test=develop, communicator_is_sgd_optimizer flags
1 parent 74a28f5 commit 95e90aa

File tree

7 files changed

+33
-9
lines changed

7 files changed

+33
-9
lines changed

paddle/fluid/operators/distributed/communicator.cc

+2
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
8989
VLOG(0) << "communicator_fake_rpc: " << FLAGS_communicator_fake_rpc;
9090
VLOG(0) << "communicator_merge_sparse_grad: "
9191
<< FLAGS_communicator_merge_sparse_grad;
92+
VLOG(0) << "communicator_is_sgd_optimizer: "
93+
<< FLAGS_communicator_is_sgd_optimizer;
9294

9395
if (send_varname_to_ctx.size() == 0) {
9496
VLOG(0) << "nothing need to be send, will not start send_thread";

paddle/fluid/operators/distributed/communicator.h

+17-5
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License. */
2424
#include <unordered_set>
2525
#include <utility>
2626
#include <vector>
27+
#include "gflags/gflags.h"
2728

2829
#include "paddle/fluid/framework/scope.h"
2930
#include "paddle/fluid/framework/variable.h"
@@ -37,6 +38,8 @@ limitations under the License. */
3738
#include "paddle/fluid/platform/enforce.h"
3839
#include "paddle/fluid/platform/place.h"
3940

41+
DECLARE_bool(communicator_is_sgd_optimizer);
42+
4043
namespace paddle {
4144
namespace operators {
4245
namespace distributed {
@@ -138,8 +141,10 @@ inline void MergeVars(const std::string& var_name,
138141
auto in = EigenVector<float>::Flatten(in_t);
139142
result.device(*cpu_ctx.eigen_device()) = result + in;
140143
}
141-
result.device(*cpu_ctx.eigen_device()) =
142-
result / static_cast<float>(vars.size());
144+
if (!FLAGS_communicator_is_sgd_optimizer) {
145+
result.device(*cpu_ctx.eigen_device()) =
146+
result / static_cast<float>(vars.size());
147+
}
143148
} else if (var0->IsType<framework::SelectedRows>()) {
144149
auto& slr0 = var0->Get<framework::SelectedRows>();
145150
auto* out_slr = out_var->GetMutable<framework::SelectedRows>();
@@ -151,9 +156,16 @@ inline void MergeVars(const std::string& var_name,
151156
inputs.push_back(&var->Get<framework::SelectedRows>());
152157
}
153158
auto dev_ctx = paddle::platform::CPUDeviceContext();
154-
math::scatter::MergeAverage<paddle::platform::CPUDeviceContext, float>
155-
merge_average;
156-
merge_average(dev_ctx, inputs, out_slr);
159+
if (FLAGS_communicator_is_sgd_optimizer) {
160+
math::scatter::MergeAdd<paddle::platform::CPUDeviceContext, float>
161+
merge_add;
162+
merge_add(dev_ctx, inputs, out_slr);
163+
} else {
164+
math::scatter::MergeAverage<paddle::platform::CPUDeviceContext, float>
165+
merge_average;
166+
merge_average(dev_ctx, inputs, out_slr);
167+
}
168+
157169
VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height()
158170
<< " dims: " << slr0.value().dims();
159171
} else {

paddle/fluid/operators/distributed/communicator_test.cc

+1-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ TEST(communicator, merge_lod_tensors) {
4242
}
4343
out_value += static_cast<float>(i);
4444
}
45-
out_value = out_value / 10.0;
4645
const std::string out_name = "Out";
4746
std::unique_ptr<framework::Scope> scope;
4847
scope.reset(new framework::Scope());
@@ -96,7 +95,7 @@ TEST(communicator, merge_selected_rows) {
9695
std::vector<float> out_values;
9796
out_values.reserve(10);
9897
for (auto i = 0; i < 10; ++i) {
99-
out_values.push_back(static_cast<float>((i * (10 - i)) / 10.0));
98+
out_values.push_back(static_cast<float>(i * (10 - i)));
10099
}
101100
for (auto i = 0; i < out_slr.rows().size(); ++i) {
102101
ASSERT_EQ(out_slr.rows()[i], i);

paddle/fluid/operators/distributed/parameter_send.cc

+7
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,13 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
139139
auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections);
140140

141141
auto &send_rows = send_slr.rows();
142+
if (send_rows.size() == 0) {
143+
LOG(WARNING) << "WARNING: The variable sent to pserver is empty, which "
144+
"may cause an unknown error. Please check the state of "
145+
"use_double_buffer in pyreader async mode, you need to "
146+
"turn it false.";
147+
}
148+
142149
std::vector<std::vector<size_t>> outs_rows_idx;
143150
std::vector<std::vector<size_t>> outs_dense_idx;
144151

paddle/fluid/platform/flags.cc

+3-1
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,9 @@ DEFINE_bool(
199199
*/
200200
DEFINE_int32(communicator_max_merge_var_num, 20,
201201
"max var num to merge and send");
202-
202+
DEFINE_bool(communicator_is_sgd_optimizer, true,
203+
"gradient sent to the server is the sum of the gradients "
204+
"calculated by each thread if optimizer is sgd");
203205
/**
204206
* Distributed related FLAG
205207
* Name: FLAGS_communicator_send_queue_size

python/paddle/fluid/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ def __bootstrap__():
206206
read_env_flags.append('communicator_fake_rpc')
207207
read_env_flags.append('communicator_send_wait_times')
208208
read_env_flags.append('communicator_merge_sparse_grad')
209+
read_env_flags.append('communicator_is_sgd_optimizer')
209210
if core.is_compiled_with_brpc():
210211
read_env_flags.append('max_body_size')
211212
#set brpc max body size

python/paddle/fluid/tests/unittests/test_dist_ctr.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ def test_dist_ctr(self):
113113
"FLAGS_communicator_send_queue_size": "2",
114114
"FLAGS_communicator_max_merge_var_num": "2",
115115
"FLAGS_communicator_max_send_grad_num_before_recv": "2",
116-
"FLAGS_communicator_independent_recv_thread": "0"
116+
"FLAGS_communicator_independent_recv_thread": "0",
117+
"FLAGS_communicator_is_sgd_optimizer": "0"
117118
}
118119

119120
self.check_with_place(

0 commit comments

Comments
 (0)