Skip to content

Commit 4b43823

Browse files
author
chengduozh
committed
remove ncclBroadcast
test=develop
1 parent 5d641ba commit 4b43823

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc

+10-9
Original file line numberDiff line numberDiff line change
@@ -43,23 +43,24 @@ class NCCLBroadcastOpKernel : public framework::OpKernel<T> {
4343

4444
auto in = ctx.Input<framework::Tensor>("X");
4545
auto out = ctx.Output<framework::Tensor>("Out");
46-
out->Resize(in->dims());
47-
void* recv_buffer = out->mutable_data<T>(ctx.GetPlace());
48-
const void* send_buffer = in->data<void>();
49-
50-
int in_dev_id = boost::get<platform::CUDAPlace>(in->place()).device;
51-
PADDLE_ENFORCE_EQ(dev_id, in_dev_id);
46+
PADDLE_ENFORCE(out->IsInitialized(),
47+
"Currently, the output of broadcast op must be initialized, "
48+
"because this op can only be an In-Place operation.");
49+
void* send_recv_buffer = out->mutable_data<T>(ctx.GetPlace());
50+
PADDLE_ENFORCE_EQ(
51+
send_recv_buffer, in->data<void>(),
52+
"Currently, the broadcast op can only be an In-Place operation.");
5253

5354
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
5455
auto comm = dev_ctx.nccl_comm();
5556
auto stream = dev_ctx.stream();
5657

57-
PADDLE_ENFORCE(platform::dynload::ncclBroadcast(
58-
send_buffer, recv_buffer, static_cast<size_t>(in->numel()),
58+
PADDLE_ENFORCE(platform::dynload::ncclBcast(
59+
send_recv_buffer, static_cast<size_t>(in->numel()),
5960
platform::ToNCCLDataType(in->type()), root_dev_id, comm, stream));
6061

6162
VLOG(3) << "Bcast " << ctx.Inputs("X")[0] << ", (" << in->numel() << ")"
62-
<< " From " << root_dev_id << " to " << in_dev_id;
63+
<< " From " << root_dev_id << " to " << dev_id;
6364

6465
if (ctx.Attr<bool>("sync_mode")) {
6566
PADDLE_ENFORCE(cudaStreamSynchronize(stream));

paddle/fluid/platform/dynload/nccl.h

-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ extern void* nccl_dso_handle;
6262
__macro(ncclCommUserRank); \
6363
__macro(ncclAllReduce); \
6464
__macro(ncclBcast); \
65-
__macro(ncclBroadcast); \
6665
__macro(ncclAllGather); \
6766
__macro(ncclGroupStart); \
6867
__macro(ncclGroupEnd); \

0 commit comments

Comments
 (0)