@@ -43,23 +43,24 @@ class NCCLBroadcastOpKernel : public framework::OpKernel<T> {
43
43
44
44
auto in = ctx.Input <framework::Tensor>(" X" );
45
45
auto out = ctx.Output <framework::Tensor>(" Out" );
46
- out->Resize (in->dims ());
47
- void * recv_buffer = out->mutable_data <T>(ctx.GetPlace ());
48
- const void * send_buffer = in->data <void >();
49
-
50
- int in_dev_id = boost::get<platform::CUDAPlace>(in->place ()).device ;
51
- PADDLE_ENFORCE_EQ (dev_id, in_dev_id);
46
+ PADDLE_ENFORCE (out->IsInitialized (),
47
+ " Currently, the output of broadcast op must be initialized, "
48
+ " because this op can only be an In-Place operation." );
49
+ void * send_recv_buffer = out->mutable_data <T>(ctx.GetPlace ());
50
+ PADDLE_ENFORCE_EQ (
51
+ send_recv_buffer, in->data <void >(),
52
+ " Currently, the broadcast op can only be an In-Place operation." );
52
53
53
54
auto & dev_ctx = ctx.template device_context <platform::CUDADeviceContext>();
54
55
auto comm = dev_ctx.nccl_comm ();
55
56
auto stream = dev_ctx.stream ();
56
57
57
- PADDLE_ENFORCE (platform::dynload::ncclBroadcast (
58
- send_buffer, recv_buffer , static_cast <size_t >(in->numel ()),
58
+ PADDLE_ENFORCE (platform::dynload::ncclBcast (
59
+ send_recv_buffer , static_cast <size_t >(in->numel ()),
59
60
platform::ToNCCLDataType (in->type ()), root_dev_id, comm, stream));
60
61
61
62
VLOG (3 ) << " Bcast " << ctx.Inputs (" X" )[0 ] << " , (" << in->numel () << " )"
62
- << " From " << root_dev_id << " to " << in_dev_id ;
63
+ << " From " << root_dev_id << " to " << dev_id ;
63
64
64
65
if (ctx.Attr <bool >(" sync_mode" )) {
65
66
PADDLE_ENFORCE (cudaStreamSynchronize (stream));
0 commit comments