@@ -34,7 +34,8 @@ template <typename T>
34
34
class NCCLBroadcastOpKernel : public framework ::OpKernel<T> {
35
35
public:
36
36
void Compute (const framework::ExecutionContext& ctx) const override {
37
- PADDLE_ENFORCE (platform::is_gpu_place (ctx.GetPlace ()));
37
+ PADDLE_ENFORCE (platform::is_gpu_place (ctx.GetPlace ()),
38
+ " The place of ExecutionContext should be CUDAPlace." );
38
39
39
40
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
40
41
int dev_id = boost::get<platform::CUDAPlace>(ctx.GetPlace ()).device ;
@@ -43,30 +44,28 @@ class NCCLBroadcastOpKernel : public framework::OpKernel<T> {
43
44
auto in = ctx.Input <framework::Tensor>(" X" );
44
45
auto out = ctx.Output <framework::Tensor>(" Out" );
45
46
out->Resize (in->dims ());
47
+ void * recv_buffer = out->mutable_data <T>(ctx.GetPlace ());
48
+ const void * send_buffer = in->data <void >();
46
49
47
- const int in_dev_id = boost::get<platform::CUDAPlace>(in->place ()).device ;
50
+ int in_dev_id = boost::get<platform::CUDAPlace>(in->place ()).device ;
48
51
PADDLE_ENFORCE_EQ (dev_id, in_dev_id);
49
52
50
53
auto & dev_ctx = ctx.template device_context <platform::CUDADeviceContext>();
51
54
auto comm = dev_ctx.nccl_comm ();
52
55
auto stream = dev_ctx.stream ();
53
56
54
- void * send_recv_buffer = const_cast <void *>(in->data <void >());
55
- if (root_dev_id != in_dev_id) {
56
- send_recv_buffer = out->mutable_data <T>(ctx.GetPlace ());
57
- }
57
+ PADDLE_ENFORCE (platform::dynload::ncclBroadcast (
58
+ send_buffer, recv_buffer, static_cast <size_t >(in->numel ()),
59
+ platform::ToNCCLDataType (in->type ()), root_dev_id, comm, stream));
58
60
59
- VLOG (3 ) << " Bcast " << ctx.Inputs (" X" )[0 ] << " , ("
60
- << static_cast <size_t >(in->numel ()) << " )"
61
+ VLOG (3 ) << " Bcast " << ctx.Inputs (" X" )[0 ] << " , (" << in->numel () << " )"
61
62
<< " From " << root_dev_id << " to " << in_dev_id;
62
63
63
- PADDLE_ENFORCE (platform::dynload::ncclBcast (
64
- send_recv_buffer, static_cast <size_t >(in->numel ()),
65
- platform::ToNCCLDataType (in->type ()), root_dev_id, comm, stream));
66
-
67
64
if (ctx.Attr <bool >(" sync_mode" )) {
68
65
PADDLE_ENFORCE (cudaStreamSynchronize (stream));
69
66
}
67
+ #else
68
+ PADDLE_THROW (" PaddlePaddle should compile with GPU." );
70
69
#endif
71
70
}
72
71
};
0 commit comments