Skip to content

Commit e92c3f5

Browse files
authored
[clean old comm]Revert "Revert old comm" (PaddlePaddle#71791)
* Revert "Revert "[clean old comm][fluid_ops] c_allreduce_op.h" (PaddlePaddle#70929)" This reverts commit feb941a.
1 parent 678771c commit e92c3f5

File tree

7 files changed

+108
-315
lines changed

7 files changed

+108
-315
lines changed

paddle/fluid/operators/collective/c_allreduce_op.h

Lines changed: 38 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ limitations under the License. */
2929
defined(PADDLE_WITH_XPU_BKCL)
3030
#include "paddle/common/flags.h"
3131
#include "paddle/phi/core/platform/collective_helper.h"
32-
COMMON_DECLARE_bool(dynamic_static_unified_comm);
3332
#endif
3433

3534
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
@@ -136,11 +135,7 @@ class CAllReduceOpXPUKernel : public framework::OpKernel<T> {
136135
int rid = ctx.Attr<int>("ring_id");
137136

138137
auto place = ctx.GetPlace();
139-
BKCLDataType dtype = phi::ToBKCLDataType(in->dtype());
140-
int64_t numel = in->numel();
141-
const void* sendbuff = in->data<T>();
142138
out->Resize(in->dims());
143-
void* recvbuff = out->mutable_data<T>(place);
144139

145140
auto map = phi::distributed::ProcessGroupMapFromGid::getInstance();
146141
if (map->has(rid)) {
@@ -180,30 +175,24 @@ class CAllReduceOpXPUKernel : public framework::OpKernel<T> {
180175

181176
const auto& comm_context_manager =
182177
phi::distributed::CommContextManager::GetInstance();
183-
if (FLAGS_dynamic_static_unified_comm) {
184-
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
185-
true,
186-
common::errors::InvalidArgument(
187-
"You choose to use new communication library by "
188-
"setting environment "
189-
"variable FLAGS_dynamic_static_unified_comm True. "
190-
"But ring_id(%d) is "
191-
"not found in comm_context_manager.",
192-
std::to_string(rid)));
193-
comm_ctx = static_cast<phi::distributed::BKCLCommContext*>(
194-
comm_context_manager.Get(std::to_string(rid)));
195-
PADDLE_ENFORCE_NE(comm_ctx,
196-
nullptr,
197-
common::errors::Unavailable(
198-
"BKCLCommContext is nullptr, collective op should "
199-
"has ring_id attr."));
200-
stream = comm_ctx->GetStream();
201-
VLOG(3) << "new comm_context_manager has rid " << rid;
202-
} else {
203-
comm = platform::BKCLCommContext::Instance().Get(rid, place);
204-
stream = comm->stream();
205-
VLOG(3) << "old BKCLCommContext has rid " << rid;
206-
}
178+
179+
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
180+
true,
181+
common::errors::InvalidArgument(
182+
"You choose to use new communication library. "
183+
"But ring_id(%d) is "
184+
"not found in comm_context_manager.",
185+
std::to_string(rid)));
186+
comm_ctx = static_cast<phi::distributed::BKCLCommContext*>(
187+
comm_context_manager.Get(std::to_string(rid)));
188+
PADDLE_ENFORCE_NE(comm_ctx,
189+
nullptr,
190+
common::errors::Unavailable(
191+
"BKCLCommContext is nullptr, collective op should "
192+
"has ring_id attr."));
193+
stream = comm_ctx->GetStream();
194+
VLOG(3) << "new comm_context_manager has rid " << rid;
195+
207196
if (ctx.Attr<bool>("use_calc_stream")) {
208197
auto dev_ctx = phi::DeviceContextPool::Instance().Get(place);
209198
stream = static_cast<phi::XPUContext*>(dev_ctx)->x_context()->xpu_stream;
@@ -232,17 +221,7 @@ class CAllReduceOpXPUKernel : public framework::OpKernel<T> {
232221
red_type));
233222
}
234223

235-
if (comm_ctx) {
236-
comm_ctx->AllReduce(out, *in, bkcl_red_type, stream);
237-
} else {
238-
PADDLE_ENFORCE_XPU_SUCCESS(bkcl_all_reduce(comm->comm(),
239-
sendbuff,
240-
recvbuff,
241-
numel,
242-
dtype,
243-
bkcl_red_type,
244-
stream));
245-
}
224+
comm_ctx->AllReduce(out, *in, bkcl_red_type, stream);
246225
#else
247226
PADDLE_THROW(common::errors::PreconditionNotMet(
248227
"PaddlePaddle should be compiled with XPU."));
@@ -280,12 +259,10 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
280259
auto out = ctx.Output<phi::DenseTensor>("Out");
281260
int rid = ctx.Attr<int>("ring_id");
282261

283-
auto place = ctx.GetPlace();
284262
ncclDataType_t dtype = phi::ToNCCLDataType(in->dtype());
285263
int64_t numel = in->numel();
286264
const void* sendbuff = in->data<T>();
287265
out->Resize(in->dims());
288-
void* recvbuff = out->mutable_data<T>(place);
289266

290267
auto map = phi::distributed::ProcessGroupMapFromGid::getInstance();
291268
if (map->has(rid)) {
@@ -325,30 +302,24 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
325302

326303
const auto& comm_context_manager =
327304
phi::distributed::CommContextManager::GetInstance();
328-
if (FLAGS_dynamic_static_unified_comm) {
329-
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
330-
true,
331-
common::errors::InvalidArgument(
332-
"You choose to use new communication library by "
333-
"setting environment "
334-
"variable FLAGS_dynamic_static_unified_comm True. "
335-
"But ring_id(%d) is "
336-
"not found in comm_context_manager.",
337-
std::to_string(rid)));
338-
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
339-
comm_context_manager.Get(std::to_string(rid)));
340-
PADDLE_ENFORCE_NE(comm_ctx,
341-
nullptr,
342-
common::errors::Unavailable(
343-
"NCCLCommContext is nullptr, collective op should "
344-
"has ring_id attr."));
345-
stream = comm_ctx->GetStream();
346-
VLOG(3) << "new comm_context_manager has rid " << rid;
347-
} else {
348-
comm = platform::NCCLCommContext::Instance().Get(rid, place);
349-
stream = comm->stream();
350-
VLOG(3) << "old NCCLCommContext has rid " << rid;
351-
}
305+
306+
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
307+
true,
308+
common::errors::InvalidArgument(
309+
"You choose to use new communication library. "
310+
"But ring_id(%d) is "
311+
"not found in comm_context_manager.",
312+
std::to_string(rid)));
313+
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
314+
comm_context_manager.Get(std::to_string(rid)));
315+
PADDLE_ENFORCE_NE(comm_ctx,
316+
nullptr,
317+
common::errors::Unavailable(
318+
"NCCLCommContext is nullptr, collective op should "
319+
"has ring_id attr."));
320+
stream = comm_ctx->GetStream();
321+
VLOG(3) << "new comm_context_manager has rid " << rid;
322+
352323
if (ctx.Attr<bool>("use_calc_stream")) {
353324
// should not use global ctx for calc stream.
354325
// auto dev_ctx = phi::DeviceContextPool::Instance().Get(place);
@@ -390,17 +361,7 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
390361
red_type));
391362
}
392363

393-
if (comm_ctx) {
394-
comm_ctx->AllReduce(out, *in, nccl_red_type, stream);
395-
} else {
396-
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce(sendbuff,
397-
recvbuff,
398-
numel,
399-
dtype,
400-
nccl_red_type,
401-
comm->comm(),
402-
stream));
403-
}
364+
comm_ctx->AllReduce(out, *in, nccl_red_type, stream);
404365
#else
405366
PADDLE_THROW(common::errors::PreconditionNotMet(
406367
"PaddlePaddle should compile with GPU."));

paddle/fluid/operators/collective/c_gen_bkcl_id_op.cc

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ limitations under the License. */
2424
#include "paddle/phi/core/platform/device_context.h"
2525
#include "paddle/phi/core/platform/gen_comm_id_helper.h"
2626

27-
COMMON_DECLARE_bool(dynamic_static_unified_comm);
2827
namespace paddle {
2928
namespace operators {
3029

@@ -63,30 +62,13 @@ class CGenBKCLIdOp : public framework::OperatorBase {
6362

6463
void RunImpl(const framework::Scope& scope,
6564
const phi::Place& dev_place) const override {
66-
int rank = Attr<int>("rank");
67-
int ring_id = Attr<int>("ring_id");
68-
6965
std::function<std::string(size_t)> func = [&](size_t i) -> std::string {
7066
return Output("Out");
7167
};
7268

73-
std::string endpoint = Attr<std::string>("endpoint");
74-
7569
std::vector<BKCLUniqueId> bkcl_ids;
7670
bkcl_ids.resize(1);
7771

78-
if (!FLAGS_dynamic_static_unified_comm) {
79-
int server_fd = platform::SocketServer::GetInstance(endpoint).socket();
80-
if (rank == 0) {
81-
GenBKCLID(&bkcl_ids);
82-
std::vector<std::string> endpoint_list =
83-
Attr<std::vector<std::string>>("other_endpoints");
84-
platform::SendBroadCastCommID(endpoint_list, &bkcl_ids, ring_id);
85-
} else {
86-
platform::RecvBroadCastCommID(server_fd, endpoint, &bkcl_ids, ring_id);
87-
}
88-
}
89-
9072
CopyBKCLIDToVar(bkcl_ids, func, scope);
9173
}
9274
};

paddle/fluid/operators/collective/c_gen_nccl_id_op.cc

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ limitations under the License. */
2323
#include "paddle/phi/core/platform/device_context.h"
2424
#include "paddle/phi/core/platform/gen_comm_id_helper.h"
2525

26-
COMMON_DECLARE_bool(dynamic_static_unified_comm);
2726
namespace paddle::operators {
2827

2928
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
@@ -58,30 +57,13 @@ class CGenNCCLIdOp : public framework::OperatorBase {
5857

5958
void RunImpl(const framework::Scope& scope,
6059
const phi::Place& dev_place) const override {
61-
int rank = Attr<int>("rank");
62-
int ring_id = Attr<int>("ring_id");
63-
6460
std::function<std::string(size_t)> func = [&](size_t i) -> std::string {
6561
return Output("Out");
6662
};
6763

68-
std::string endpoint = Attr<std::string>("endpoint");
69-
7064
std::vector<ncclUniqueId> nccl_ids;
7165
nccl_ids.resize(1);
7266

73-
if (!FLAGS_dynamic_static_unified_comm) {
74-
int server_fd = platform::SocketServer::GetInstance(endpoint).socket();
75-
if (rank == 0) {
76-
GenNCCLID(&nccl_ids);
77-
std::vector<std::string> endpoint_list =
78-
Attr<std::vector<std::string>>("other_endpoints");
79-
platform::SendBroadCastCommID(endpoint_list, &nccl_ids, ring_id);
80-
} else {
81-
platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids, ring_id);
82-
}
83-
}
84-
8567
CopyNCCLIDToVar(nccl_ids, func, scope);
8668
}
8769
};

paddle/fluid/operators/collective/c_wait_comm_op.cc

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ class Scope;
2222
#include "paddle/phi/core/distributed/comm_context_manager.h"
2323
#include "paddle/phi/core/distributed/nccl_comm_context.h"
2424
#include "paddle/phi/core/platform/collective_helper.h"
25-
COMMON_DECLARE_bool(dynamic_static_unified_comm);
2625
#endif
2726

2827
namespace paddle::operators {
@@ -56,31 +55,20 @@ class CWaitCommOp : public framework::OperatorBase {
5655

5756
const auto& comm_context_manager =
5857
phi::distributed::CommContextManager::GetInstance();
59-
if (FLAGS_dynamic_static_unified_comm) {
60-
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)),
61-
true,
62-
common::errors::InvalidArgument(
63-
"You choose to use new communication library by "
64-
"setting environment "
65-
"variable FLAGS_dynamic_static_unified_comm True. "
66-
"But ring_id(%d) is "
67-
"not found in comm_context_manager.",
68-
std::to_string(ring_id)));
69-
phi::distributed::NCCLCommContext* comm_ctx =
70-
static_cast<phi::distributed::NCCLCommContext*>(
71-
comm_context_manager.Get(std::to_string(ring_id)));
72-
comm_stream = comm_ctx->GetStream();
73-
event = comm_ctx->GetComputeEvent();
74-
VLOG(3) << "new comm_context_manager has rid " << ring_id;
75-
} else {
76-
comm_stream =
77-
platform::NCCLCommContext::Instance().Get(ring_id, place)->stream();
78-
79-
event = platform::NCCLCommContext::Instance()
80-
.Get(ring_id, place)
81-
->comm_event();
82-
VLOG(3) << "old NCCLCommContext has rid " << ring_id;
83-
}
58+
59+
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)),
60+
true,
61+
common::errors::InvalidArgument(
62+
"You choose to use new communication library. "
63+
"But ring_id(%d) is "
64+
"not found in comm_context_manager.",
65+
std::to_string(ring_id)));
66+
phi::distributed::NCCLCommContext* comm_ctx =
67+
static_cast<phi::distributed::NCCLCommContext*>(
68+
comm_context_manager.Get(std::to_string(ring_id)));
69+
comm_stream = comm_ctx->GetStream();
70+
event = comm_ctx->GetComputeEvent();
71+
VLOG(3) << "new comm_context_manager has rid " << ring_id;
8472

8573
// comm_stream-->event-->compute_stream
8674
#ifdef PADDLE_WITH_HIP

paddle/fluid/operators/collective/c_wait_compute_op.cc

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ class Scope;
2222
#include "paddle/phi/core/distributed/comm_context_manager.h"
2323
#include "paddle/phi/core/distributed/nccl_comm_context.h"
2424
#include "paddle/phi/core/platform/collective_helper.h"
25-
COMMON_DECLARE_bool(dynamic_static_unified_comm);
2625
#endif
2726

2827
namespace paddle::operators {
@@ -56,31 +55,20 @@ class CWaitComputeOp : public framework::OperatorBase {
5655

5756
const auto& comm_context_manager =
5857
phi::distributed::CommContextManager::GetInstance();
59-
if (FLAGS_dynamic_static_unified_comm) {
60-
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)),
61-
true,
62-
common::errors::InvalidArgument(
63-
"You choose to use new communication library by "
64-
"setting environment "
65-
"variable FLAGS_dynamic_static_unified_comm True. "
66-
"But ring_id(%d) is "
67-
"not found in comm_context_manager.",
68-
std::to_string(ring_id)));
69-
phi::distributed::NCCLCommContext* comm_ctx =
70-
static_cast<phi::distributed::NCCLCommContext*>(
71-
comm_context_manager.Get(std::to_string(ring_id)));
72-
comm_stream = comm_ctx->GetStream();
73-
event = comm_ctx->GetComputeEvent();
74-
VLOG(3) << "new comm_context_manager has rid " << ring_id;
75-
} else {
76-
comm_stream =
77-
platform::NCCLCommContext::Instance().Get(ring_id, place)->stream();
78-
79-
event = platform::NCCLCommContext::Instance()
80-
.Get(ring_id, place)
81-
->compute_event();
82-
VLOG(3) << "old NCCLCommContext has rid " << ring_id;
83-
}
58+
59+
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)),
60+
true,
61+
common::errors::InvalidArgument(
62+
"You choose to use new communication library. "
63+
"But ring_id(%d) is "
64+
"not found in comm_context_manager.",
65+
std::to_string(ring_id)));
66+
phi::distributed::NCCLCommContext* comm_ctx =
67+
static_cast<phi::distributed::NCCLCommContext*>(
68+
comm_context_manager.Get(std::to_string(ring_id)));
69+
comm_stream = comm_ctx->GetStream();
70+
event = comm_ctx->GetComputeEvent();
71+
VLOG(3) << "new comm_context_manager has rid " << ring_id;
8472

8573
// compute_stream-->event-->comm_stream
8674
#ifdef PADDLE_WITH_HIP

0 commit comments

Comments
 (0)