Skip to content

Commit 5d641ba

Browse files
author
chengduozh
committed
use ncclBroadcast
test=develop
1 parent 1aedf4b commit 5d641ba

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

paddle/fluid/operators/distributed_ops/broadcast_op.cu.cc

+11-12
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ template <typename T>
3434
class NCCLBroadcastOpKernel : public framework::OpKernel<T> {
3535
public:
3636
void Compute(const framework::ExecutionContext& ctx) const override {
37-
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()));
37+
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
38+
"The place of ExecutionContext should be CUDAPlace.");
3839

3940
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
4041
int dev_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).device;
@@ -43,30 +44,28 @@ class NCCLBroadcastOpKernel : public framework::OpKernel<T> {
4344
auto in = ctx.Input<framework::Tensor>("X");
4445
auto out = ctx.Output<framework::Tensor>("Out");
4546
out->Resize(in->dims());
47+
void* recv_buffer = out->mutable_data<T>(ctx.GetPlace());
48+
const void* send_buffer = in->data<void>();
4649

47-
const int in_dev_id = boost::get<platform::CUDAPlace>(in->place()).device;
50+
int in_dev_id = boost::get<platform::CUDAPlace>(in->place()).device;
4851
PADDLE_ENFORCE_EQ(dev_id, in_dev_id);
4952

5053
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
5154
auto comm = dev_ctx.nccl_comm();
5255
auto stream = dev_ctx.stream();
5356

54-
void* send_recv_buffer = const_cast<void*>(in->data<void>());
55-
if (root_dev_id != in_dev_id) {
56-
send_recv_buffer = out->mutable_data<T>(ctx.GetPlace());
57-
}
57+
PADDLE_ENFORCE(platform::dynload::ncclBroadcast(
58+
send_buffer, recv_buffer, static_cast<size_t>(in->numel()),
59+
platform::ToNCCLDataType(in->type()), root_dev_id, comm, stream));
5860

59-
VLOG(3) << "Bcast " << ctx.Inputs("X")[0] << ", ("
60-
<< static_cast<size_t>(in->numel()) << ")"
61+
VLOG(3) << "Bcast " << ctx.Inputs("X")[0] << ", (" << in->numel() << ")"
6162
<< " From " << root_dev_id << " to " << in_dev_id;
6263

63-
PADDLE_ENFORCE(platform::dynload::ncclBcast(
64-
send_recv_buffer, static_cast<size_t>(in->numel()),
65-
platform::ToNCCLDataType(in->type()), root_dev_id, comm, stream));
66-
6764
if (ctx.Attr<bool>("sync_mode")) {
6865
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
6966
}
67+
#else
68+
PADDLE_THROW("PaddlePaddle should compile with GPU.");
7069
#endif
7170
}
7271
};

paddle/fluid/platform/dynload/nccl.h

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ extern void* nccl_dso_handle;
6262
__macro(ncclCommUserRank); \
6363
__macro(ncclAllReduce); \
6464
__macro(ncclBcast); \
65+
__macro(ncclBroadcast); \
6566
__macro(ncclAllGather); \
6667
__macro(ncclGroupStart); \
6768
__macro(ncclGroupEnd); \

0 commit comments

Comments
 (0)