@@ -29,7 +29,6 @@ limitations under the License. */
29
29
defined (PADDLE_WITH_XPU_BKCL)
30
30
#include " paddle/common/flags.h"
31
31
#include " paddle/phi/core/platform/collective_helper.h"
32
- COMMON_DECLARE_bool (dynamic_static_unified_comm);
33
32
#endif
34
33
35
34
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
@@ -136,11 +135,7 @@ class CAllReduceOpXPUKernel : public framework::OpKernel<T> {
136
135
int rid = ctx.Attr <int >(" ring_id" );
137
136
138
137
auto place = ctx.GetPlace ();
139
- BKCLDataType dtype = phi::ToBKCLDataType (in->dtype ());
140
- int64_t numel = in->numel ();
141
- const void * sendbuff = in->data <T>();
142
138
out->Resize (in->dims ());
143
- void * recvbuff = out->mutable_data <T>(place);
144
139
145
140
auto map = phi::distributed::ProcessGroupMapFromGid::getInstance ();
146
141
if (map->has (rid)) {
@@ -180,30 +175,24 @@ class CAllReduceOpXPUKernel : public framework::OpKernel<T> {
180
175
181
176
const auto & comm_context_manager =
182
177
phi::distributed::CommContextManager::GetInstance ();
183
- if (FLAGS_dynamic_static_unified_comm) {
184
- PADDLE_ENFORCE_EQ (comm_context_manager.Has (std::to_string (rid)),
185
- true ,
186
- common::errors::InvalidArgument (
187
- " You choose to use new communication library by "
188
- " setting environment "
189
- " variable FLAGS_dynamic_static_unified_comm True. "
190
- " But ring_id(%d) is "
191
- " not found in comm_context_manager." ,
192
- std::to_string (rid)));
193
- comm_ctx = static_cast <phi::distributed::BKCLCommContext*>(
194
- comm_context_manager.Get (std::to_string (rid)));
195
- PADDLE_ENFORCE_NE (comm_ctx,
196
- nullptr ,
197
- common::errors::Unavailable (
198
- " BKCLCommContext is nullptr, collective op should "
199
- " has ring_id attr." ));
200
- stream = comm_ctx->GetStream ();
201
- VLOG (3 ) << " new comm_context_manager has rid " << rid;
202
- } else {
203
- comm = platform::BKCLCommContext::Instance ().Get (rid, place);
204
- stream = comm->stream ();
205
- VLOG (3 ) << " old BKCLCommContext has rid " << rid;
206
- }
178
+
179
+ PADDLE_ENFORCE_EQ (comm_context_manager.Has (std::to_string (rid)),
180
+ true ,
181
+ common::errors::InvalidArgument (
182
+ " You choose to use new communication library. "
183
+ " But ring_id(%d) is "
184
+ " not found in comm_context_manager." ,
185
+ std::to_string (rid)));
186
+ comm_ctx = static_cast <phi::distributed::BKCLCommContext*>(
187
+ comm_context_manager.Get (std::to_string (rid)));
188
+ PADDLE_ENFORCE_NE (comm_ctx,
189
+ nullptr ,
190
+ common::errors::Unavailable (
191
+ " BKCLCommContext is nullptr, collective op should "
192
+ " has ring_id attr." ));
193
+ stream = comm_ctx->GetStream ();
194
+ VLOG (3 ) << " new comm_context_manager has rid " << rid;
195
+
207
196
if (ctx.Attr <bool >(" use_calc_stream" )) {
208
197
auto dev_ctx = phi::DeviceContextPool::Instance ().Get (place);
209
198
stream = static_cast <phi::XPUContext*>(dev_ctx)->x_context ()->xpu_stream ;
@@ -232,17 +221,7 @@ class CAllReduceOpXPUKernel : public framework::OpKernel<T> {
232
221
red_type));
233
222
}
234
223
235
- if (comm_ctx) {
236
- comm_ctx->AllReduce (out, *in, bkcl_red_type, stream);
237
- } else {
238
- PADDLE_ENFORCE_XPU_SUCCESS (bkcl_all_reduce (comm->comm (),
239
- sendbuff,
240
- recvbuff,
241
- numel,
242
- dtype,
243
- bkcl_red_type,
244
- stream));
245
- }
224
+ comm_ctx->AllReduce (out, *in, bkcl_red_type, stream);
246
225
#else
247
226
PADDLE_THROW (common::errors::PreconditionNotMet (
248
227
" PaddlePaddle should be compiled with XPU." ));
@@ -280,12 +259,10 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
280
259
auto out = ctx.Output <phi::DenseTensor>(" Out" );
281
260
int rid = ctx.Attr <int >(" ring_id" );
282
261
283
- auto place = ctx.GetPlace ();
284
262
ncclDataType_t dtype = phi::ToNCCLDataType (in->dtype ());
285
263
int64_t numel = in->numel ();
286
264
const void * sendbuff = in->data <T>();
287
265
out->Resize (in->dims ());
288
- void * recvbuff = out->mutable_data <T>(place);
289
266
290
267
auto map = phi::distributed::ProcessGroupMapFromGid::getInstance ();
291
268
if (map->has (rid)) {
@@ -325,30 +302,24 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
325
302
326
303
const auto & comm_context_manager =
327
304
phi::distributed::CommContextManager::GetInstance ();
328
- if (FLAGS_dynamic_static_unified_comm) {
329
- PADDLE_ENFORCE_EQ (comm_context_manager.Has (std::to_string (rid)),
330
- true ,
331
- common::errors::InvalidArgument (
332
- " You choose to use new communication library by "
333
- " setting environment "
334
- " variable FLAGS_dynamic_static_unified_comm True. "
335
- " But ring_id(%d) is "
336
- " not found in comm_context_manager." ,
337
- std::to_string (rid)));
338
- comm_ctx = static_cast <phi::distributed::NCCLCommContext*>(
339
- comm_context_manager.Get (std::to_string (rid)));
340
- PADDLE_ENFORCE_NE (comm_ctx,
341
- nullptr ,
342
- common::errors::Unavailable (
343
- " NCCLCommContext is nullptr, collective op should "
344
- " has ring_id attr." ));
345
- stream = comm_ctx->GetStream ();
346
- VLOG (3 ) << " new comm_context_manager has rid " << rid;
347
- } else {
348
- comm = platform::NCCLCommContext::Instance ().Get (rid, place);
349
- stream = comm->stream ();
350
- VLOG (3 ) << " old NCCLCommContext has rid " << rid;
351
- }
305
+
306
+ PADDLE_ENFORCE_EQ (comm_context_manager.Has (std::to_string (rid)),
307
+ true ,
308
+ common::errors::InvalidArgument (
309
+ " You choose to use new communication library. "
310
+ " But ring_id(%d) is "
311
+ " not found in comm_context_manager." ,
312
+ std::to_string (rid)));
313
+ comm_ctx = static_cast <phi::distributed::NCCLCommContext*>(
314
+ comm_context_manager.Get (std::to_string (rid)));
315
+ PADDLE_ENFORCE_NE (comm_ctx,
316
+ nullptr ,
317
+ common::errors::Unavailable (
318
+ " NCCLCommContext is nullptr, collective op should "
319
+ " has ring_id attr." ));
320
+ stream = comm_ctx->GetStream ();
321
+ VLOG (3 ) << " new comm_context_manager has rid " << rid;
322
+
352
323
if (ctx.Attr <bool >(" use_calc_stream" )) {
353
324
// should not use global ctx for calc stream.
354
325
// auto dev_ctx = phi::DeviceContextPool::Instance().Get(place);
@@ -390,17 +361,7 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
390
361
red_type));
391
362
}
392
363
393
- if (comm_ctx) {
394
- comm_ctx->AllReduce (out, *in, nccl_red_type, stream);
395
- } else {
396
- PADDLE_ENFORCE_GPU_SUCCESS (phi::dynload::ncclAllReduce (sendbuff,
397
- recvbuff,
398
- numel,
399
- dtype,
400
- nccl_red_type,
401
- comm->comm (),
402
- stream));
403
- }
364
+ comm_ctx->AllReduce (out, *in, nccl_red_type, stream);
404
365
#else
405
366
PADDLE_THROW (common::errors::PreconditionNotMet (
406
367
" PaddlePaddle should compile with GPU." ));
0 commit comments