@@ -24,6 +24,7 @@ limitations under the License. */
24
24
#include < unordered_set>
25
25
#include < utility>
26
26
#include < vector>
27
+ #include " gflags/gflags.h"
27
28
28
29
#include " paddle/fluid/framework/scope.h"
29
30
#include " paddle/fluid/framework/variable.h"
@@ -37,6 +38,8 @@ limitations under the License. */
37
38
#include " paddle/fluid/platform/enforce.h"
38
39
#include " paddle/fluid/platform/place.h"
39
40
41
+ DECLARE_bool (communicator_is_sgd_optimizer);
42
+
40
43
namespace paddle {
41
44
namespace operators {
42
45
namespace distributed {
@@ -138,8 +141,10 @@ inline void MergeVars(const std::string& var_name,
138
141
auto in = EigenVector<float >::Flatten (in_t );
139
142
result.device (*cpu_ctx.eigen_device ()) = result + in;
140
143
}
141
- result.device (*cpu_ctx.eigen_device ()) =
142
- result / static_cast <float >(vars.size ());
144
+ if (!FLAGS_communicator_is_sgd_optimizer) {
145
+ result.device (*cpu_ctx.eigen_device ()) =
146
+ result / static_cast <float >(vars.size ());
147
+ }
143
148
} else if (var0->IsType <framework::SelectedRows>()) {
144
149
auto & slr0 = var0->Get <framework::SelectedRows>();
145
150
auto * out_slr = out_var->GetMutable <framework::SelectedRows>();
@@ -151,9 +156,16 @@ inline void MergeVars(const std::string& var_name,
151
156
inputs.push_back (&var->Get <framework::SelectedRows>());
152
157
}
153
158
auto dev_ctx = paddle::platform::CPUDeviceContext ();
154
- math::scatter::MergeAverage<paddle::platform::CPUDeviceContext, float >
155
- merge_average;
156
- merge_average (dev_ctx, inputs, out_slr);
159
+ if (FLAGS_communicator_is_sgd_optimizer) {
160
+ math::scatter::MergeAdd<paddle::platform::CPUDeviceContext, float >
161
+ merge_add;
162
+ merge_add (dev_ctx, inputs, out_slr);
163
+ } else {
164
+ math::scatter::MergeAverage<paddle::platform::CPUDeviceContext, float >
165
+ merge_average;
166
+ merge_average (dev_ctx, inputs, out_slr);
167
+ }
168
+
157
169
VLOG (3 ) << " merge " << var_name << " SelectedRows height: " << slr0.height ()
158
170
<< " dims: " << slr0.value ().dims ();
159
171
} else {
0 commit comments