diff --git a/paddle/fluid/distributed/collective/process_group_custom.cc b/paddle/fluid/distributed/collective/process_group_custom.cc index fea5c6a1883f63..058a5d660fa803 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.cc +++ b/paddle/fluid/distributed/collective/process_group_custom.cc @@ -215,7 +215,8 @@ std::shared_ptr ProcessGroupCustom::AllGather( return RunFnInXCCLEnv( [&](const phi::stream::Stream& stream) { auto comm_context = this->GetCommContext(); - comm_context->AllGather(out_tensor, in_tensor_maybe_partial, stream); + comm_context->AllGather( + out_tensor, in_tensor_maybe_partial, stream.raw_stream()); }, in_tensor_maybe_partial, CommType::ALLGATHER, @@ -239,7 +240,7 @@ std::shared_ptr ProcessGroupCustom::AllReduce( out_tensor, in_tensor, paddle::distributed::ToXCCLRedType(opts.reduce_op), - stream); + stream.raw_stream()); }, in_tensor, CommType::ALLREDUCE, @@ -315,7 +316,7 @@ std::shared_ptr ProcessGroupCustom::AllToAll( rank_, size_, comm_context->GetXcclComm(), - stream); + stream.raw_stream()); }, in_tensor, CommType::ALLTOALL, @@ -358,7 +359,8 @@ std::shared_ptr ProcessGroupCustom::Broadcast( [&](const phi::stream::Stream& stream) { int root = opts.source_rank + opts.source_root; auto comm_context = this->GetCommContext(); - comm_context->Broadcast(out_tensor, in_tensor, root, stream); + comm_context->Broadcast( + out_tensor, in_tensor, root, stream.raw_stream()); }, in_tensor, CommType::BROADCAST, @@ -382,7 +384,7 @@ std::shared_ptr ProcessGroupCustom::Reduce( in_tensor, paddle::distributed::ToXCCLRedType(opts.reduce_op), opts.root_rank, - stream); + stream.raw_stream()); }, in_tensor, CommType::REDUCE, @@ -406,7 +408,7 @@ std::shared_ptr ProcessGroupCustom::ReduceScatter( out_tensor, in_tensor, paddle::distributed::ToXCCLRedType(opts.reduce_op), - stream); + stream.raw_stream()); }, in_tensor, CommType::REDUCE_SCATTER, @@ -441,7 +443,7 @@ std::shared_ptr ProcessGroupCustom::Scatter( for (auto i = 0; i < size_; i++) { partial_tensor = GetPartialTensor(in_tensor, offset, numel); if (i != rank_) { - comm_context->Send(partial_tensor, numel, i, stream); + comm_context->Send(partial_tensor, numel, i, stream.raw_stream()); } else { phi::DeviceManager::GetDeviceWithPlace(stream.GetPlace()) ->MemoryCopyD2D(out_tensor->data(), @@ -452,7 +454,8 @@ std::shared_ptr ProcessGroupCustom::Scatter( offset += numel; } } else { - comm_context->Recv(out_tensor, numel, opts.root_rank, stream); + comm_context->Recv( + out_tensor, numel, opts.root_rank, stream.raw_stream()); } }, in_tensor, @@ -506,7 +509,8 @@ std::shared_ptr ProcessGroupCustom::Gather( for (auto i = 0; i < size_; i++) { auto& gather_tensor = gather_tensors[i]; if (i != rank_) { - comm_context->Recv(&gather_tensor, gather_tensor.numel(), i, stream); + comm_context->Recv( + &gather_tensor, gather_tensor.numel(), i, stream.raw_stream()); } else { phi::DeviceManager::GetDeviceWithPlace(stream.GetPlace()) ->MemoryCopyD2D( @@ -518,7 +522,8 @@ std::shared_ptr ProcessGroupCustom::Gather( } } else { // send to root - comm_context->Send(in_tensor, in_tensor.numel(), opts.root_rank, stream); + comm_context->Send( + in_tensor, in_tensor.numel(), opts.root_rank, stream.raw_stream()); } }; return RunFnInXCCLEnv( @@ -542,7 +547,8 @@ std::shared_ptr ProcessGroupCustom::Recv( return RunFnInXCCLEnv( [&](const phi::stream::Stream& stream) { auto comm_context = this->GetCommContext(); - comm_context->Recv(tensor, tensor->numel(), src_rank, stream); + comm_context->Recv( + tensor, tensor->numel(), src_rank, stream.raw_stream()); }, *tensor, CommType::RECV, @@ -569,7 +575,7 @@ std::shared_ptr ProcessGroupCustom::Send( comm_context->Send(tensor_maybe_partial, tensor_maybe_partial.numel(), dst_rank, - stream); + stream.raw_stream()); }, tensor_maybe_partial, CommType::SEND, @@ -915,7 +921,7 @@ std::shared_ptr ProcessGroupCustom::AllReduce( &output, input, paddle::distributed::ToXCCLRedType(opts.reduce_op), - stream); + stream.raw_stream()); }, CommType::ALLREDUCE); } @@ -942,7 +948,7 @@ std::shared_ptr ProcessGroupCustom::Broadcast( const auto root = opts.source_rank * in_tensors.size() + opts.source_root; auto comm_context = this->GetCommContext(); - comm_context->Broadcast(&output, input, root, stream); + comm_context->Broadcast(&output, input, root, stream.raw_stream()); }, CommType::BROADCAST); } @@ -988,7 +994,7 @@ std::shared_ptr ProcessGroupCustom::Send( const phi::stream::Stream& stream, int dst_rank) { auto comm_context = this->GetCommContext(); - comm_context->Send(input, input.numel(), dst_rank, stream); + comm_context->Send(input, input.numel(), dst_rank, stream.raw_stream()); }, dst_rank, CommType::SEND); @@ -1008,7 +1014,8 @@ std::shared_ptr ProcessGroupCustom::Recv( const phi::stream::Stream& stream, int src_rank) { auto comm_context = this->GetCommContext(); - comm_context->Recv(&output, output.numel(), src_rank, stream); + comm_context->Recv( + &output, output.numel(), src_rank, stream.raw_stream()); }, src_rank, CommType::RECV); @@ -1037,7 +1044,7 @@ std::shared_ptr ProcessGroupCustom::AllGather( const phi::ccl::CCLComm& comm, const phi::stream::Stream& stream) { auto comm_context = this->GetCommContext(); - comm_context->AllGather(&output, input, stream); + comm_context->AllGather(&output, input, stream.raw_stream()); }, CommType::ALLGATHER); } @@ -1089,7 +1096,7 @@ std::shared_ptr ProcessGroupCustom::AllToAll( rank_, size_, comm_context->GetXcclComm(), - stream); + stream.raw_stream()); }, CommType::ALLTOALL); } @@ -1166,7 +1173,7 @@ std::shared_ptr ProcessGroupCustom::AllToAll( rank_, size_, comm_context->GetXcclComm(), - stream); + stream.raw_stream()); }, in_tensors, CommType::ALLTOALL, @@ -1197,7 +1204,7 @@ std::shared_ptr ProcessGroupCustom::Reduce( input, paddle::distributed::ToXCCLRedType(opts.reduce_op), opts.root_rank, - stream); + stream.raw_stream()); }, CommType::REDUCE); } @@ -1232,13 +1239,15 @@ std::shared_ptr ProcessGroupCustom::Scatter( for (auto i = 0; i < size_; i++) { auto input_data = reinterpret_cast( GetPointerByOffset(input.data(), offset, input.dtype())); - comm_context->Send(*input_data, count, i, stream); + comm_context->Send(*input_data, count, i, stream.raw_stream()); offset += count; } - comm_context->Recv(&output, count, opts.root_rank, stream); + comm_context->Recv( + &output, count, opts.root_rank, stream.raw_stream()); comm_context->GroupEnd(); } else { - comm_context->Recv(&output, count, opts.root_rank, stream); + comm_context->Recv( + &output, count, opts.root_rank, stream.raw_stream()); } }, CommType::SCATTER); diff --git a/paddle/fluid/imperative/gradient_accumulator.cc b/paddle/fluid/imperative/gradient_accumulator.cc index 3c3a3fe6043074..c4393517b446e1 100644 --- a/paddle/fluid/imperative/gradient_accumulator.cc +++ b/paddle/fluid/imperative/gradient_accumulator.cc @@ -221,7 +221,7 @@ void TensorAdd(const VarType& src, VarType* dst) { phi::DeviceContextPool::Instance().Get(place)); \ phi::stream::Stream stream(place, ctx->stream()); \ auto device = phi::DeviceManager::GetDeviceWithPlace(place); \ - device->BlasAXPBY(stream, \ + device->BlasAXPBY(stream.raw_stream(), \ static_cast(numel), \ 1., \ src_tensor.data(), \ diff --git a/paddle/fluid/imperative/xccl_context.cc b/paddle/fluid/imperative/xccl_context.cc index 47a558b3ec26fe..dde863b43d45f9 100644 --- a/paddle/fluid/imperative/xccl_context.cc +++ b/paddle/fluid/imperative/xccl_context.cc @@ -36,7 +36,7 @@ namespace imperative { static void XcclAllReduce(const phi::DenseTensor &src, phi::DenseTensor *dst, - const phi::stream::Stream &stream, + const phi::stream::stream_t &stream, const phi::ccl::CCLComm &comm) { const auto &place = src.place(); PADDLE_ENFORCE_EQ( @@ -171,7 +171,7 @@ void XCCLParallelContext::AllReduceByStream(const framework::Variable &src, platform::XCCLComm *comm = platform::XCCLCommContext::Instance(place.GetDeviceType()) .Get(ring_id, place); - auto stream = use_calc_stream ? dev_ctx->GetStream() : comm->stream(); + auto stream = use_calc_stream ? dev_ctx->stream() : comm->stream(); if (src.IsType()) { if (!dst->IsType()) { @@ -179,7 +179,7 @@ void XCCLParallelContext::AllReduceByStream(const framework::Variable &src, } XcclAllReduce(src.Get(), dst->GetMutable(), - *stream, + stream, comm->comm()); } else { PADDLE_THROW(common::errors::InvalidArgument( @@ -207,7 +207,7 @@ void XCCLParallelContext::Broadcast(framework::Variable *src, int ring_id) { src_tensor->dtype(), 0, comm->comm(), - *stream); + stream); } phi::DeviceContext *XCCLParallelContext::GetDeviceContext(int ring_id) { @@ -235,7 +235,7 @@ void XCCLParallelContext::WaitCompute(int ring_id) { ->GetStream(); auto comm_stream = platform::XCCLCommContext::Instance(place_.GetDeviceType()) .Get(ring_id, place_) - ->stream(); + ->GetStream(); auto event = compute_events_[ring_id].get(); // compute_stream-->event-->comm_stream @@ -261,7 +261,7 @@ void XCCLParallelContext::WaitComm(int ring_id) { ->GetStream(); auto comm_stream = platform::XCCLCommContext::Instance(place_.GetDeviceType()) .Get(ring_id, place_) - ->stream(); + ->GetStream(); auto event = comm_events_[ring_id].get(); // comm_stream-->event-->compute_stream diff --git a/paddle/phi/backends/callback_manager.cc b/paddle/phi/backends/callback_manager.cc index 0d658258fa4c05..6fde1a9e6d3564 100644 --- a/paddle/phi/backends/callback_manager.cc +++ b/paddle/phi/backends/callback_manager.cc @@ -41,7 +41,7 @@ void CallbackManager::AddCallback(std::function callback) const { void CallbackManager::Wait() const { phi::DeviceGuard guard(stream_->GetPlace()); phi::DeviceManager::GetDeviceWithPlace(stream_->GetPlace()) - ->SynchronizeStream(stream_); + ->SynchronizeStream(stream_->raw_stream()); { std::lock_guard lock(mtx_); diff --git a/paddle/phi/backends/custom/custom_device.cc b/paddle/phi/backends/custom/custom_device.cc index 851e4b5f8293f5..a3c360caeab522 100644 --- a/paddle/phi/backends/custom/custom_device.cc +++ b/paddle/phi/backends/custom/custom_device.cc @@ -151,30 +151,29 @@ class CustomDevice : public DeviceInterface { stream->set_stream(c_stream); } - void DestroyStream(size_t dev_id, stream::Stream* stream) override { + void DestroyStream(size_t dev_id, stream::stream_t stream) override { if (pimpl_->destroy_stream) { const auto device = &devices_pool[dev_id]; - PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->destroy_stream( - device, reinterpret_cast(stream->raw_stream()))); + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( + pimpl_->destroy_stream(device, reinterpret_cast(stream))); } } - void SynchronizeStream(size_t dev_id, const stream::Stream* stream) override { + void SynchronizeStream(size_t dev_id, stream::stream_t stream) override { if (pimpl_->synchronize_stream) { const auto device = &devices_pool[dev_id]; PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->synchronize_stream( - device, reinterpret_cast(stream->raw_stream()))); + device, reinterpret_cast(stream))); } } - bool QueryStream(size_t dev_id, const stream::Stream* stream) override { + bool QueryStream(size_t dev_id, stream::stream_t stream) override { if (!pimpl_->query_stream) { SynchronizeStream(dev_id, stream); return true; } else { const auto device = &devices_pool[dev_id]; - return pimpl_->query_stream( - device, reinterpret_cast(stream->raw_stream())) == + return pimpl_->query_stream(device, reinterpret_cast(stream)) == C_SUCCESS; } } @@ -732,16 +731,16 @@ class CustomDevice : public DeviceInterface { phi::DataType data_type, ccl::CCLReduceOp op, const ccl::CCLComm& comm, - const stream::Stream& stream) override { + const stream::stream_t& stream) override { CHECK_PTR(pimpl_->xccl_all_reduce); - PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_all_reduce( - send_buf, - recv_buf, - count, - ToCDataType(data_type), - ToXCCLReduceOp(op), - reinterpret_cast(comm), - reinterpret_cast(stream.raw_stream()))); + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( + pimpl_->xccl_all_reduce(send_buf, + recv_buf, + count, + ToCDataType(data_type), + ToXCCLReduceOp(op), + reinterpret_cast(comm), + reinterpret_cast(stream))); } void CCLBroadcast(void* buf, @@ -749,15 +748,15 @@ class CustomDevice : public DeviceInterface { phi::DataType data_type, size_t root, const ccl::CCLComm& comm, - const stream::Stream& stream) override { + const stream::stream_t& stream) override { CHECK_PTR(pimpl_->xccl_broadcast); - PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_broadcast( - buf, - count, - ToCDataType(data_type), - root, - reinterpret_cast(comm), - reinterpret_cast(stream.raw_stream()))); + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( + pimpl_->xccl_broadcast(buf, + count, + ToCDataType(data_type), + root, + reinterpret_cast(comm), + reinterpret_cast(stream))); } void CCLReduce(void* in_data, @@ -767,7 +766,7 @@ class CustomDevice : public DeviceInterface { ccl::CCLReduceOp reduce_op, size_t root_id, const ccl::CCLComm& comm, - const stream::Stream& stream) override { + const stream::stream_t& stream) override { CHECK_PTR(pimpl_->xccl_reduce); PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( pimpl_->xccl_reduce(in_data, @@ -777,7 +776,7 @@ class CustomDevice : public DeviceInterface { ToXCCLReduceOp(reduce_op), root_id, reinterpret_cast(comm), - reinterpret_cast(stream.raw_stream()))); + reinterpret_cast(stream))); } void CCLAllGather(void* send_buf, @@ -785,15 +784,15 @@ class CustomDevice : public DeviceInterface { size_t count, phi::DataType data_type, const ccl::CCLComm& comm, - const stream::Stream& stream) override { + const stream::stream_t& stream) override { CHECK_PTR(pimpl_->xccl_all_gather); - PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_all_gather( - send_buf, - recv_buf, - count, - ToCDataType(data_type), - reinterpret_cast(comm), - reinterpret_cast(stream.raw_stream()))); + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( + pimpl_->xccl_all_gather(send_buf, + recv_buf, + count, + ToCDataType(data_type), + reinterpret_cast(comm), + reinterpret_cast(stream))); } void CCLReduceScatter(void* send_buf, @@ -802,16 +801,16 @@ class CustomDevice : public DeviceInterface { phi::DataType data_type, ccl::CCLReduceOp reduce_op, const ccl::CCLComm& comm, - const stream::Stream& stream) override { + const stream::stream_t& stream) override { CHECK_PTR(pimpl_->xccl_reduce_scatter); - PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_reduce_scatter( - send_buf, - recv_buf, - count, - ToCDataType(data_type), - ToXCCLReduceOp(reduce_op), - reinterpret_cast(comm), - reinterpret_cast(stream.raw_stream()))); + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( + pimpl_->xccl_reduce_scatter(send_buf, + recv_buf, + count, + ToCDataType(data_type), + ToXCCLReduceOp(reduce_op), + reinterpret_cast(comm), + reinterpret_cast(stream))); } void CCLGroupStart() override { @@ -831,7 +830,7 @@ class CustomDevice : public DeviceInterface { phi::DataType data_type, size_t dest_rank, const ccl::CCLComm& comm, - const stream::Stream& stream) override { + const stream::stream_t& stream) override { CHECK_PTR(pimpl_->xccl_send); PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( pimpl_->xccl_send(send_buf, @@ -839,7 +838,7 @@ class CustomDevice : public DeviceInterface { ToCDataType(data_type), dest_rank, reinterpret_cast(comm), - reinterpret_cast(stream.raw_stream()))); + reinterpret_cast(stream))); } void CCLRecv(void* recv_buf, @@ -847,7 +846,7 @@ class CustomDevice : public DeviceInterface { phi::DataType data_type, size_t src_rank, const ccl::CCLComm& comm, - const stream::Stream& stream) override { + const stream::stream_t& stream) override { CHECK_PTR(pimpl_->xccl_recv); PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( pimpl_->xccl_recv(recv_buf, @@ -855,7 +854,7 @@ class CustomDevice : public DeviceInterface { ToCDataType(data_type), src_rank, reinterpret_cast(comm), - reinterpret_cast(stream.raw_stream()))); + reinterpret_cast(stream))); } void CCLAllToAll(const void** send_buf, @@ -867,24 +866,24 @@ class CustomDevice : public DeviceInterface { size_t rank, size_t nranks, const ccl::CCLComm& comm, - const stream::Stream& stream) override { + const stream::stream_t& stream) override { if (pimpl_->xccl_all_to_all) { std::vector c_send_dtype, c_recv_dtype; for (size_t i = 0; i < nranks; ++i) { c_send_dtype.push_back(ToCDataType(send_dtype[i])); c_recv_dtype.push_back(ToCDataType(recv_dtype[i])); } - PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_all_to_all( - send_buf, - send_count, - c_send_dtype.data(), - recv_buf, - recv_count, - c_recv_dtype.data(), - rank, - nranks, - reinterpret_cast(comm), - reinterpret_cast(stream.raw_stream()))); + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( + pimpl_->xccl_all_to_all(send_buf, + send_count, + c_send_dtype.data(), + recv_buf, + recv_count, + c_recv_dtype.data(), + rank, + nranks, + reinterpret_cast(comm), + reinterpret_cast(stream))); } else if (pimpl_->xccl_send && pimpl_->xccl_recv) { // NOTE(wangran16): fallback to send and recv, while avoiding some devices // not supporting asynchronous send and recv. @@ -895,24 +894,26 @@ class CustomDevice : public DeviceInterface { ToCDataType(recv_dtype[i]), i, reinterpret_cast(comm), - reinterpret_cast(stream.raw_stream()))); + reinterpret_cast(stream))); } for (size_t i = 0; i < nranks; ++i) { if (i != rank) { - PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_send( - const_cast(send_buf[i]), - send_count[i], - ToCDataType(send_dtype[i]), - i, - reinterpret_cast(comm), - reinterpret_cast(stream.raw_stream()))); + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( + pimpl_->xccl_send(const_cast(send_buf[i]), + send_count[i], + ToCDataType(send_dtype[i]), + i, + reinterpret_cast(comm), + reinterpret_cast(stream))); } } + const phi::stream::Stream stream_wrapper( + Place(AllocationType::CUSTOM, Type()), stream); MemoryCopyD2D(rank, recv_buf[rank], send_buf[rank], send_count[rank] * phi::SizeOf(send_dtype[rank]), - &stream); + &stream_wrapper); for (size_t i = rank + 1; i < nranks; ++i) { PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( pimpl_->xccl_recv(recv_buf[i], @@ -920,7 +921,7 @@ class CustomDevice : public DeviceInterface { ToCDataType(recv_dtype[i]), i, reinterpret_cast(comm), - reinterpret_cast(stream.raw_stream()))); + reinterpret_cast(stream))); } } else { PADDLE_THROW(common::errors::Unavailable( @@ -929,7 +930,7 @@ class CustomDevice : public DeviceInterface { } void BlasAXPBY(size_t dev_id, - const stream::Stream& stream, + const stream::stream_t& stream, phi::DataType dtype, size_t numel, float alpha, @@ -940,7 +941,7 @@ class CustomDevice : public DeviceInterface { const auto device = &devices_pool[dev_id]; PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( pimpl_->blas_axpby(device, - reinterpret_cast(stream.raw_stream()), + reinterpret_cast(stream), ToCDataType(dtype), numel, alpha, diff --git a/paddle/phi/backends/device_base.cc b/paddle/phi/backends/device_base.cc index d55c2e26493789..4174edb0318c72 100644 --- a/paddle/phi/backends/device_base.cc +++ b/paddle/phi/backends/device_base.cc @@ -103,16 +103,16 @@ void DeviceInterface::CreateStream(size_t dev_id, INTERFACE_UNIMPLEMENT; } -void DeviceInterface::DestroyStream(size_t dev_id, stream::Stream* stream) { +void DeviceInterface::DestroyStream(size_t dev_id, stream::stream_t stream) { INTERFACE_UNIMPLEMENT; } void DeviceInterface::SynchronizeStream(size_t dev_id, - const stream::Stream* stream) { + stream::stream_t stream) { INTERFACE_UNIMPLEMENT; } -bool DeviceInterface::QueryStream(size_t dev_id, const stream::Stream* stream) { +bool DeviceInterface::QueryStream(size_t dev_id, stream::stream_t stream) { INTERFACE_UNIMPLEMENT; return true; } @@ -323,7 +323,7 @@ void DeviceInterface::CCLBroadcast(void* data, phi::DataType data_type, size_t root, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream) { + const stream::stream_t& stream) { INTERFACE_UNIMPLEMENT; } @@ -333,7 +333,7 @@ void DeviceInterface::CCLAllReduce(void* in_data, phi::DataType data_type, ccl::CCLReduceOp reduce_op, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream) { + const stream::stream_t& stream) { INTERFACE_UNIMPLEMENT; } @@ -344,7 +344,7 @@ void DeviceInterface::CCLReduce(void* in_data, ccl::CCLReduceOp reduce_op, size_t root_id, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream) { + const stream::stream_t& stream) { INTERFACE_UNIMPLEMENT; } @@ -353,7 +353,7 @@ void DeviceInterface::CCLAllGather(void* in_data, size_t num, phi::DataType data_type, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream) { + const stream::stream_t& stream) { INTERFACE_UNIMPLEMENT; } @@ -363,7 +363,7 @@ void DeviceInterface::CCLReduceScatter(void* in_data, phi::DataType data_type, ccl::CCLReduceOp op, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream) { + const stream::stream_t& stream) { INTERFACE_UNIMPLEMENT; } @@ -376,7 +376,7 @@ void DeviceInterface::CCLSend(void* sendbuf, phi::DataType data_type, size_t dst_rank, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream) { + const stream::stream_t& stream) { INTERFACE_UNIMPLEMENT; } @@ -385,7 +385,7 @@ void DeviceInterface::CCLRecv(void* recvbuf, phi::DataType data_type, size_t src_rank, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream) { + const stream::stream_t& stream) { INTERFACE_UNIMPLEMENT; } @@ -398,13 +398,13 @@ void DeviceInterface::CCLAllToAll(const void** send_buf, size_t rank, size_t nranks, const ccl::CCLComm& comm, - const stream::Stream& stream) { + const stream::stream_t& stream) { INTERFACE_UNIMPLEMENT; } // blas void DeviceInterface::BlasAXPBY(size_t dev_id, - const stream::Stream& stream, + const stream::stream_t& stream, phi::DataType dtype, size_t numel, float alpha, diff --git a/paddle/phi/backends/device_base.h b/paddle/phi/backends/device_base.h index dad931ff27f3ce..8626ab708e98c6 100644 --- a/paddle/phi/backends/device_base.h +++ b/paddle/phi/backends/device_base.h @@ -92,13 +92,13 @@ class DeviceInterface { // Driver / Runtime const stream::Stream::Flag& flag = stream::Stream::Flag::kDefaultFlag); // ! Destroys an asynchronous stream. - virtual void DestroyStream(size_t dev_id, stream::Stream* stream); + virtual void DestroyStream(size_t dev_id, stream::stream_t stream); // ! Waits for stream tasks to complete. - virtual void SynchronizeStream(size_t dev_id, const stream::Stream* stream); + virtual void SynchronizeStream(size_t dev_id, stream::stream_t stream); // ! Queries an asynchronous stream for completion status. - virtual bool QueryStream(size_t dev_id, const stream::Stream* stream); + virtual bool QueryStream(size_t dev_id, stream::stream_t stream); // ! Add a callback to a compute stream. virtual void AddCallback(size_t dev_id, @@ -201,7 +201,7 @@ class DeviceInterface { // Driver / Runtime phi::DataType data_type, size_t root, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream); + const stream::stream_t& stream); virtual void CCLAllReduce(void* in_data, void* out_data, @@ -209,7 +209,7 @@ class DeviceInterface { // Driver / Runtime phi::DataType data_type, ccl::CCLReduceOp reduce_op, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream); + const stream::stream_t& stream); virtual void CCLReduce(void* in_data, void* out_data, size_t num, @@ -217,20 +217,20 @@ class DeviceInterface { // Driver / Runtime ccl::CCLReduceOp reduce_op, size_t root_id, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream); + const stream::stream_t& stream); virtual void CCLAllGather(void* in_data, void* out_data, size_t num, phi::DataType data_type, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream); + const stream::stream_t& stream); virtual void CCLReduceScatter(void* in_data, void* out_data, size_t num, phi::DataType data_type, ccl::CCLReduceOp op, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream); + const stream::stream_t& stream); virtual void CCLGroupStart(); virtual void CCLGroupEnd(); virtual void CCLSend(void* sendbuf, @@ -238,13 +238,13 @@ class DeviceInterface { // Driver / Runtime phi::DataType data_type, size_t dst_rank, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream); + const stream::stream_t& stream); virtual void CCLRecv(void* recvbuf, size_t num, phi::DataType data_type, size_t src_rank, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream); + const stream::stream_t& stream); virtual void CCLAllToAll(const void** send_buf, const size_t* send_count, @@ -255,10 +255,10 @@ class DeviceInterface { // Driver / Runtime size_t rank, size_t nranks, const ccl::CCLComm& comm, - const stream::Stream& stream); + const stream::stream_t& stream); // blas virtual void BlasAXPBY(size_t dev_id, - const stream::Stream& stream, + const stream::stream_t& stream, phi::DataType dtype, size_t numel, float alpha, diff --git a/paddle/phi/backends/device_manager.cc b/paddle/phi/backends/device_manager.cc index daebb5d1aab584..f3cfa269537533 100644 --- a/paddle/phi/backends/device_manager.cc +++ b/paddle/phi/backends/device_manager.cc @@ -50,17 +50,17 @@ void Device::CreateStream(stream::Stream* stream, impl_->CreateStream(dev_id_, stream, priority, flag); } -void Device::DestroyStream(stream::Stream* stream) { +void Device::DestroyStream(stream::stream_t stream) { CheckInitialized(); impl_->DestroyStream(dev_id_, stream); } -void Device::SynchronizeStream(const stream::Stream* stream) { +void Device::SynchronizeStream(stream::stream_t stream) { CheckInitialized(); impl_->SynchronizeStream(dev_id_, stream); } -bool Device::QueryStream(const stream::Stream* stream) { +bool Device::QueryStream(stream::stream_t stream) { CheckInitialized(); return impl_->QueryStream(dev_id_, stream); } @@ -172,7 +172,7 @@ void Device::MemorySet(void* ptr, uint8_t value, size_t size) { } template -void Device::BlasAXPBY(const stream::Stream& stream, +void Device::BlasAXPBY(const stream::stream_t& stream, size_t numel, float alpha, const T* x, @@ -189,57 +189,57 @@ void Device::BlasAXPBY(const stream::Stream& stream, reinterpret_cast(y)); } -template void Device::BlasAXPBY(const stream::Stream& stream, +template void Device::BlasAXPBY(const stream::stream_t& stream, size_t numel, float alpha, const paddle::float16* x, float beta, paddle::float16* y); -template void Device::BlasAXPBY(const stream::Stream& stream, +template void Device::BlasAXPBY(const stream::stream_t& stream, size_t numel, float alpha, const float* x, float beta, float* y); -template void Device::BlasAXPBY(const stream::Stream& stream, +template void Device::BlasAXPBY(const stream::stream_t& stream, size_t numel, float alpha, const double* x, float beta, double* y); -template void Device::BlasAXPBY(const stream::Stream& stream, +template void Device::BlasAXPBY(const stream::stream_t& stream, size_t numel, float alpha, const int8_t* x, float beta, int8_t* y); -template void Device::BlasAXPBY(const stream::Stream& stream, +template void Device::BlasAXPBY(const stream::stream_t& stream, size_t numel, float alpha, const int16_t* x, float beta, int16_t* y); -template void Device::BlasAXPBY(const stream::Stream& stream, +template void Device::BlasAXPBY(const stream::stream_t& stream, size_t numel, float alpha, const int32_t* x, float beta, int32_t* y); -template void Device::BlasAXPBY(const stream::Stream& stream, +template void Device::BlasAXPBY(const stream::stream_t& stream, size_t numel, float alpha, const int64_t* x, float beta, int64_t* y); template void Device::BlasAXPBY>( - const stream::Stream& stream, + const stream::stream_t& stream, size_t numel, float alpha, const phi::dtype::complex* x, float beta, phi::dtype::complex* y); template void Device::BlasAXPBY>( - const stream::Stream& stream, + const stream::stream_t& stream, size_t numel, float alpha, const phi::dtype::complex* x, @@ -608,7 +608,7 @@ void DeviceManager::CCLBroadcast(const std::string& device_type, phi::DataType data_type, size_t root_id, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream) { + const stream::stream_t& stream) { auto dev_impl = GetDeviceInterfaceWithType(device_type); dev_impl->CCLBroadcast(data, num, data_type, root_id, ccl_comm, stream); } @@ -620,7 +620,7 @@ void DeviceManager::CCLAllReduce(const std::string& device_type, phi::DataType data_type, ccl::CCLReduceOp reduce_op, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream) { + const stream::stream_t& stream) { auto dev_impl = GetDeviceInterfaceWithType(device_type); dev_impl->CCLAllReduce( in_data, out_data, num, data_type, reduce_op, ccl_comm, stream); @@ -634,7 +634,7 @@ void DeviceManager::CCLReduce(const std::string& device_type, ccl::CCLReduceOp reduce_op, size_t root_id, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream) { + const stream::stream_t& stream) { auto dev_impl = GetDeviceInterfaceWithType(device_type); dev_impl->CCLReduce( in_data, out_data, num, data_type, reduce_op, root_id, ccl_comm, stream); @@ -646,7 +646,7 @@ void DeviceManager::CCLAllGather(const std::string& device_type, size_t num, phi::DataType data_type, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream) { + const stream::stream_t& stream) { auto dev_impl = GetDeviceInterfaceWithType(device_type); dev_impl->CCLAllGather(in_data, out_data, num, data_type, ccl_comm, stream); } @@ -658,7 +658,7 @@ void DeviceManager::CCLReduceScatter(const std::string& device_type, phi::DataType data_type, ccl::CCLReduceOp op, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream) { + const stream::stream_t& stream) { auto dev_impl = GetDeviceInterfaceWithType(device_type); dev_impl->CCLReduceScatter( in_data, out_data, num, data_type, op, ccl_comm, stream); @@ -680,7 +680,7 @@ void DeviceManager::CCLSend(const std::string& device_type, phi::DataType data_type, size_t dst_rank, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream) { + const stream::stream_t& stream) { auto dev_impl = GetDeviceInterfaceWithType(device_type); dev_impl->CCLSend(sendbuf, num, data_type, dst_rank, ccl_comm, stream); } @@ -691,7 +691,7 @@ void DeviceManager::CCLRecv(const std::string& device_type, phi::DataType data_type, size_t src_rank, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream) { + const stream::stream_t& stream) { auto dev_impl = GetDeviceInterfaceWithType(device_type); dev_impl->CCLRecv(recvbuf, num, data_type, src_rank, ccl_comm, stream); } @@ -706,7 +706,7 @@ void DeviceManager::CCLAllToAll(const std::string& device_type, size_t rank, size_t nranks, const ccl::CCLComm& comm, - const stream::Stream& stream) { + const stream::stream_t& stream) { auto dev_impl = GetDeviceInterfaceWithType(device_type); dev_impl->CCLAllToAll(send_buf, send_count, diff --git a/paddle/phi/backends/device_manager.h b/paddle/phi/backends/device_manager.h index d89356b77e7fab..3b90daf4339087 100644 --- a/paddle/phi/backends/device_manager.h +++ b/paddle/phi/backends/device_manager.h @@ -46,13 +46,13 @@ class Device final { const stream::Stream::Flag& flag = stream::Stream::Flag::kDefaultFlag); // ! Destroys an asynchronous stream. - void DestroyStream(stream::Stream* stream); + void DestroyStream(stream::stream_t stream); // ! Waits for stream tasks to complete. - void SynchronizeStream(const stream::Stream* stream); + void SynchronizeStream(stream::stream_t stream); // ! Queries an asynchronous stream for completion status. - bool QueryStream(const stream::Stream* stream); + bool QueryStream(stream::stream_t stream); // ! Add a callback to a compute stream. void AddCallback(stream::Stream* stream, stream::Stream::Callback* callback); @@ -116,7 +116,7 @@ class Device final { // Blas // ! y = alpha * x + beta * y template - void BlasAXPBY(const stream::Stream& stream, + void BlasAXPBY(const stream::stream_t& stream, size_t numel, float alpha, const T* x, @@ -217,7 +217,7 @@ class DeviceManager { phi::DataType data_type, size_t root, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream); + const stream::stream_t& stream); static void CCLAllReduce(const std::string& device_type, void* in_data, void* out_data, @@ -225,7 +225,7 @@ class DeviceManager { phi::DataType data_type, ccl::CCLReduceOp reduce_op, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream); + const stream::stream_t& stream); static void CCLReduce(const std::string& device_type, void* in_data, void* out_data, @@ -234,14 +234,14 @@ class DeviceManager { ccl::CCLReduceOp reduce_op, size_t root_id, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream); + const stream::stream_t& stream); static void CCLAllGather(const std::string& device_type, void* in_data, void* out_data, size_t num, phi::DataType data_type, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream); + const stream::stream_t& stream); static void CCLReduceScatter(const std::string& device_type, void* in_data, void* out_data, @@ -249,7 +249,7 @@ class DeviceManager { phi::DataType data_type, ccl::CCLReduceOp op, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream); + const stream::stream_t& stream); static void CCLGroupStart(const std::string& device_type); static void CCLGroupEnd(const std::string& device_type); static void CCLSend(const std::string& device_type, @@ -258,14 +258,14 @@ class DeviceManager { phi::DataType data_type, size_t dst_rank, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream); + const stream::stream_t& stream); static void CCLRecv(const std::string& device_type, void* recvbuf, size_t num, phi::DataType data_type, size_t src_rank, const ccl::CCLComm& ccl_comm, - const stream::Stream& stream); + const stream::stream_t& stream); static void CCLAllToAll(const std::string& device_type, const void** send_buf, @@ -277,7 +277,7 @@ class DeviceManager { size_t rank, size_t nranks, const ccl::CCLComm& comm, - const stream::Stream& stream); + const stream::stream_t& stream); // profiler static void ProfilerInitialize(const std::string& dev_type, phi::TraceEventCollector* collector, diff --git a/paddle/phi/backends/stream.cc b/paddle/phi/backends/stream.cc index c3efabdc0954bd..a146d904e47e01 100644 --- a/paddle/phi/backends/stream.cc +++ b/paddle/phi/backends/stream.cc @@ -87,10 +87,10 @@ void Stream::WaitEvent(event::Event* event) const { void Stream::Wait() const { #if !defined(_WIN32) - device_->SynchronizeStream(this); + device_->SynchronizeStream(this->raw_stream()); #else while (1) { - if (device_->QueryStream(this)) { + if (device_->QueryStream(this->raw_stream())) { break; } } @@ -104,7 +104,7 @@ void Stream::Destroy() { if (own_data_ && phi::DeviceManager::HasDeviceType(place_.GetDeviceType())) { phi::DeviceManager::SetDevice(place_); - device_->DestroyStream(this); + device_->DestroyStream(this->raw_stream()); } own_data_ = false; stream_ = nullptr; @@ -112,9 +112,11 @@ void Stream::Destroy() { } } -bool Stream::Query() const { return device_->QueryStream(this); } +bool Stream::Query() const { return device_->QueryStream(this->raw_stream()); } -void Stream::Synchronize() const { device_->SynchronizeStream(this); } +void Stream::Synchronize() const { + device_->SynchronizeStream(this->raw_stream()); +} const Place& Stream::GetPlace() const { return place_; } diff --git a/paddle/phi/core/distributed/xccl_comm_context.cc b/paddle/phi/core/distributed/xccl_comm_context.cc index 4dd2bcc48857c3..d64500621847e6 100644 --- a/paddle/phi/core/distributed/xccl_comm_context.cc +++ b/paddle/phi/core/distributed/xccl_comm_context.cc @@ -70,7 +70,7 @@ XCCLCommContext::XCCLCommContext(const phi::Place& place, void XCCLCommContext::Broadcast(phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, int root, - const phi::stream::Stream& stream) const { + const phi::stream::stream_t& stream) const { CommStaticCheck::SameShape(*out_tensor, in_tensor, /*dst_rank*/ rank_, @@ -98,7 +98,7 @@ void XCCLCommContext::Broadcast(phi::DenseTensor* out_tensor, void XCCLCommContext::AllGather(phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, - const phi::stream::Stream& stream) const { + const phi::stream::stream_t& stream) const { phi::distributed::CommStaticCheck::GatherLikeShape( *out_tensor, in_tensor, @@ -117,7 +117,7 @@ void XCCLCommContext::AllGather(phi::DenseTensor* out_tensor, void XCCLCommContext::ReduceScatter(phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, phi::ccl::CCLReduceOp reduce_type, - const phi::stream::Stream& stream) const { + const phi::stream::stream_t& stream) const { phi::distributed::CommStaticCheck::ScatterLikeShape( *out_tensor, in_tensor, @@ -138,7 +138,7 @@ void XCCLCommContext::ReduceScatter(phi::DenseTensor* out_tensor, void XCCLCommContext::Send(const phi::DenseTensor& in_tensor, const int64_t& count, const int& peer, - const phi::stream::Stream& stream) const { + const phi::stream::stream_t& stream) const { phi::distributed::CommStaticCheck::CheckShape( in_tensor, rank_, size_, phi::AllocationType::CUSTOM); phi::DeviceManager::CCLSend(place_.GetDeviceType(), @@ -155,7 +155,7 @@ void XCCLCommContext::Send(const phi::DenseTensor& in_tensor, void XCCLCommContext::Recv(phi::DenseTensor* out_tensor, const int64_t& count, const int& peer, - const phi::stream::Stream& stream) const { + const phi::stream::stream_t& stream) const { phi::distributed::CommStaticCheck::CheckShape( *out_tensor, rank_, size_, phi::AllocationType::CUSTOM); phi::DeviceManager::CCLRecv(place_.GetDeviceType(), @@ -172,7 +172,7 @@ void XCCLCommContext::Recv(phi::DenseTensor* out_tensor, void XCCLCommContext::AllReduce(phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, phi::ccl::CCLReduceOp reduce_type, - const phi::stream::Stream& stream) const { + const phi::stream::stream_t stream) const { phi::distributed::CommStaticCheck::SameShape(*out_tensor, in_tensor, /*dst_rank*/ rank_, @@ -193,7 +193,7 @@ void XCCLCommContext::Reduce(phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, phi::ccl::CCLReduceOp reduce_type, int root, - const phi::stream::Stream& stream) const { + const phi::stream::stream_t& stream) const { phi::distributed::CommStaticCheck::SameShape(*out_tensor, in_tensor, /*dst_rank*/ root, diff --git a/paddle/phi/core/distributed/xccl_comm_context.h b/paddle/phi/core/distributed/xccl_comm_context.h index ba739c26b464d0..cc4e4d87b7d5c2 100644 --- a/paddle/phi/core/distributed/xccl_comm_context.h +++ b/paddle/phi/core/distributed/xccl_comm_context.h @@ -36,6 +36,7 @@ class XCCLCommContext final : public CommContext { ccl::CCLComm GetXcclComm() const { return xccl_comm_; } std::shared_ptr GetStream() const { return stream_; } + phi::stream::stream_t stream() const { return stream_->raw_stream(); } std::string GetDeviceType() const { return place_.GetDeviceType(); } @@ -48,37 +49,37 @@ class XCCLCommContext final : public CommContext { void Broadcast(phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, int root, - const phi::stream::Stream& stream) const; + const phi::stream::stream_t& stream) const; void Send(const phi::DenseTensor& in_tensor, const int64_t& count, const int& peer, - const phi::stream::Stream& stream) const; + const phi::stream::stream_t& stream) const; void Recv(phi::DenseTensor* out_tensor, const int64_t& count, const int& peer, - const phi::stream::Stream& stream) const; + const phi::stream::stream_t& stream) const; void ReduceScatter(phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, phi::ccl::CCLReduceOp reduce_type, - const phi::stream::Stream& stream) const; + const phi::stream::stream_t& stream) const; void AllGather(phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, - const phi::stream::Stream& stream) const; + const phi::stream::stream_t& stream) const; void AllReduce(phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, phi::ccl::CCLReduceOp reduce_type, - const phi::stream::Stream& stream) const; + const phi::stream::stream_t stream) const; void Reduce(phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, phi::ccl::CCLReduceOp reduce_type, int root, - const phi::stream::Stream& stream) const; + const phi::stream::stream_t& stream) const; void GroupStart() const; diff --git a/paddle/phi/core/memory/memcpy.cc b/paddle/phi/core/memory/memcpy.cc index 59aa30af119adb..85874ff1f6bb15 100644 --- a/paddle/phi/core/memory/memcpy.cc +++ b/paddle/phi/core/memory/memcpy.cc @@ -514,7 +514,8 @@ void Copy(phi::Place dst_place, #endif -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && \ + !defined(PADDLE_WITH_CUSTOM_DEVICE) static constexpr size_t kMaxGpuAsyncCopyBytes = 64 * 1024; // 64K #ifdef PADDLE_WITH_HIP diff --git a/paddle/phi/core/platform/collective_helper.cc b/paddle/phi/core/platform/collective_helper.cc index 8ee93ed5f62fe8..810750d63f6c83 100644 --- a/paddle/phi/core/platform/collective_helper.cc +++ b/paddle/phi/core/platform/collective_helper.cc @@ -397,10 +397,12 @@ class XCCLCommImpl : public XCCLComm { void set_comm(phi::ccl::CCLComm comm) { comm_ = comm; } phi::ccl::CCLComm comm() const override { return comm_; } - std::shared_ptr stream() const override { + phi::stream::stream_t stream() const override { + return dev_ctx_->GetStream()->raw_stream(); + } + std::shared_ptr GetStream() const override { return dev_ctx_->GetStream(); } - void set_dev_ctx(std::unique_ptr&& dev_ctx) { dev_ctx_ = std::move(dev_ctx); } diff --git a/paddle/phi/core/platform/collective_helper.h b/paddle/phi/core/platform/collective_helper.h index 558ff0e9446af1..5765b136384629 100644 --- a/paddle/phi/core/platform/collective_helper.h +++ b/paddle/phi/core/platform/collective_helper.h @@ -251,7 +251,8 @@ class XCCLComm { virtual int rank() const = 0; virtual int device_id() const = 0; virtual phi::ccl::CCLComm comm() const = 0; - virtual std::shared_ptr stream() const = 0; + virtual phi::stream::stream_t stream() const = 0; + virtual std::shared_ptr GetStream() const = 0; virtual std::shared_ptr compute_event() const = 0; virtual std::shared_ptr comm_event() const = 0; virtual phi::CustomContext* dev_context() const = 0; diff --git a/paddle/phi/kernels/cpu/all_gather_kernel.cc b/paddle/phi/kernels/cpu/all_gather_kernel.cc index f16dbe06e9c18a..d27eb7ac5dcf7c 100644 --- a/paddle/phi/kernels/cpu/all_gather_kernel.cc +++ b/paddle/phi/kernels/cpu/all_gather_kernel.cc @@ -71,7 +71,7 @@ void AllGatherKernel(const phi::CustomContext& dev_ctx, errors::InvalidArgument( "nranks: %s should equal to %s", nranks, comm_ctx->GetSize())); - comm_ctx->AllGather(out, x, *dev_ctx.GetStream()); + comm_ctx->AllGather(out, x, dev_ctx.stream()); } #endif } // namespace phi diff --git a/paddle/phi/kernels/cpu/all_reduce_kernel.cc b/paddle/phi/kernels/cpu/all_reduce_kernel.cc index f3b247d2fc0a46..9773a637d1a406 100644 --- a/paddle/phi/kernels/cpu/all_reduce_kernel.cc +++ b/paddle/phi/kernels/cpu/all_reduce_kernel.cc @@ -67,7 +67,7 @@ void AllReduceKernel(const phi::CustomContext& dev_ctx, errors::Unavailable("XCCLCommContext is nullptr, collective op should " "has ring_id attr.")); comm_ctx->AllReduce( - out, x, phi::ccl::ToXCCLReduceOp(reduce_type), *dev_ctx.GetStream()); + out, x, phi::ccl::ToXCCLReduceOp(reduce_type), dev_ctx.stream()); } #endif diff --git a/paddle/phi/kernels/cpu/all_to_all_kernel.cc b/paddle/phi/kernels/cpu/all_to_all_kernel.cc index 5df84c5360de77..7b777474dc1fc0 100644 --- a/paddle/phi/kernels/cpu/all_to_all_kernel.cc +++ b/paddle/phi/kernels/cpu/all_to_all_kernel.cc @@ -60,7 +60,7 @@ void AllToAllKernel(const phi::CustomContext& dev_ctx, rank, nranks, comm_ctx->GetXcclComm(), - *dev_ctx.GetStream()); + dev_ctx.stream()); } #endif diff --git a/paddle/phi/kernels/cpu/reduce_kernel.cc b/paddle/phi/kernels/cpu/reduce_kernel.cc index 4fd012a2ccbfaf..87e218d3047a38 100644 --- a/paddle/phi/kernels/cpu/reduce_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_kernel.cc @@ -75,11 +75,8 @@ void ReduceKernel(const phi::CustomContext& dev_ctx, nullptr, errors::Unavailable("XCCLCommContext is nullptr, collective op should " "has ring_id attr.")); - comm_ctx->Reduce(out, - x, - phi::ccl::ToXCCLReduceOp(reduce_type), - root, - *dev_ctx.GetStream()); + comm_ctx->Reduce( + out, x, phi::ccl::ToXCCLReduceOp(reduce_type), root, dev_ctx.stream()); } #endif diff --git a/paddle/phi/kernels/custom/barrier_kernel.cc b/paddle/phi/kernels/custom/barrier_kernel.cc index f11fd9e1451871..25a053150fea8f 100644 --- a/paddle/phi/kernels/custom/barrier_kernel.cc +++ b/paddle/phi/kernels/custom/barrier_kernel.cc @@ -47,7 +47,7 @@ void BarrierKernel(const Context& dev_ctx, in->dtype(), phi::ccl::CCLReduceOp::SUM, comm->GetXcclComm(), - *stream); + stream->raw_stream()); } } // namespace phi diff --git a/paddle/phi/kernels/custom/c_allreduce_kernel_impl.h b/paddle/phi/kernels/custom/c_allreduce_kernel_impl.h index 436342d8eca327..1039e0901032c4 100644 --- a/paddle/phi/kernels/custom/c_allreduce_kernel_impl.h +++ b/paddle/phi/kernels/custom/c_allreduce_kernel_impl.h @@ -95,7 +95,7 @@ void CAllReduceKernel(const Context& dev_ctx, dtype, red_type, comm->GetXcclComm(), - *stream); + stream->raw_stream()); } template @@ -123,7 +123,7 @@ void AllReduceKernel(const Context& dev_ctx, dtype, red_type, comm->GetXcclComm(), - *stream); + stream->raw_stream()); } } // namespace phi diff --git a/paddle/phi/kernels/custom/c_broadcast_kernel.cc b/paddle/phi/kernels/custom/c_broadcast_kernel.cc index e79f3f3fc918df..d0ae73573d926d 100644 --- a/paddle/phi/kernels/custom/c_broadcast_kernel.cc +++ b/paddle/phi/kernels/custom/c_broadcast_kernel.cc @@ -48,7 +48,7 @@ void CBroadcastKernel(const Context& dev_ctx, dtype, root, comm->GetXcclComm(), - *stream); + stream->raw_stream()); VLOG(3) << "rank " << comm->GetRank() << " invoke Bcast. sent " << x->numel(); if (out != x) { @@ -65,7 +65,7 @@ void CBroadcastKernel(const Context& dev_ctx, dtype, root, comm->GetXcclComm(), - *stream); + stream->raw_stream()); VLOG(3) << "rank " << comm->GetRank() << " invoke Bcast. received " << common::product(out->dims()); } diff --git a/paddle/phi/kernels/custom/c_concat_kernel.cc b/paddle/phi/kernels/custom/c_concat_kernel.cc index 4312acc8a6d988..81af6be4c79667 100644 --- a/paddle/phi/kernels/custom/c_concat_kernel.cc +++ b/paddle/phi/kernels/custom/c_concat_kernel.cc @@ -95,7 +95,7 @@ void CConcatKernel(const Context& dev_ctx, send_numel, x->dtype(), comm->GetXcclComm(), - stream); + stream.raw_stream()); } std::vector inputs; int axis = x->dims().size() - 1; diff --git a/paddle/phi/kernels/custom/c_softmax_with_entropy_kernel.cc b/paddle/phi/kernels/custom/c_softmax_with_entropy_kernel.cc index c8bea36826f6c6..2786cc633f8e46 100644 --- a/paddle/phi/kernels/custom/c_softmax_with_entropy_kernel.cc +++ b/paddle/phi/kernels/custom/c_softmax_with_entropy_kernel.cc @@ -69,7 +69,7 @@ void CSoftmaxWithEntropyKernel(const Context& dev_ctx, logits_2d_max->dtype(), phi::ccl::CCLReduceOp::MAX, comm->GetXcclComm(), - stream); + stream.raw_stream()); // step 2, obtain logit - logit_max auto logits_2d_sub_max = paddle::experimental::clip( @@ -112,7 +112,7 @@ void CSoftmaxWithEntropyKernel(const Context& dev_ctx, predicted_logits->dtype(), phi::ccl::CCLReduceOp::SUM, comm->GetXcclComm(), - stream); + stream.raw_stream()); // step 4, obtain exp(logit) auto softmax_2d_tensor = logits_2d_sub_max.exp(); @@ -130,7 +130,7 @@ void CSoftmaxWithEntropyKernel(const Context& dev_ctx, sum_exp_logits->dtype(), phi::ccl::CCLReduceOp::SUM, comm->GetXcclComm(), - stream); + stream.raw_stream()); auto softmax_out = softmax_2d_tensor.divide( paddle::experimental::reshape(sum_exp_logits_tensor, {N, 1})); diff --git a/paddle/phi/kernels/custom/global_gather_kernel.cc b/paddle/phi/kernels/custom/global_gather_kernel.cc index a94b9a15365751..ad67db01fb55b9 100644 --- a/paddle/phi/kernels/custom/global_gather_kernel.cc +++ b/paddle/phi/kernels/custom/global_gather_kernel.cc @@ -104,7 +104,7 @@ void GlobalGatherKernel(const Context& dev_ctx, x->dtype(), j, comm->GetXcclComm(), - *stream); + stream->raw_stream()); } } for (auto j = 0; j < nranks; ++j) { @@ -119,7 +119,7 @@ void GlobalGatherKernel(const Context& dev_ctx, x->dtype(), j, comm->GetXcclComm(), - *stream); + stream->raw_stream()); } else { phi::DeviceManager::GetDeviceWithPlace(place)->MemoryCopyD2D( reinterpret_cast(recv_buf + expert_ptr[idx] * in_feat), @@ -139,7 +139,7 @@ void GlobalGatherKernel(const Context& dev_ctx, x->dtype(), j, comm->GetXcclComm(), - *stream); + stream->raw_stream()); } } } diff --git a/paddle/phi/kernels/custom/global_scatter_kernel.cc b/paddle/phi/kernels/custom/global_scatter_kernel.cc index aad97eedc25b22..96b4fafa7fbff4 100644 --- a/paddle/phi/kernels/custom/global_scatter_kernel.cc +++ b/paddle/phi/kernels/custom/global_scatter_kernel.cc @@ -104,7 +104,7 @@ void GlobalScatterKernel(const Context& dev_ctx, x->dtype(), j, comm->GetXcclComm(), - *stream); + stream->raw_stream()); recv_ptr += cpu_global_count_data[idx]; } } @@ -120,7 +120,7 @@ void GlobalScatterKernel(const Context& dev_ctx, x->dtype(), j, comm->GetXcclComm(), - *stream); + stream->raw_stream()); } } } @@ -142,7 +142,7 @@ void GlobalScatterKernel(const Context& dev_ctx, x->dtype(), j, comm->GetXcclComm(), - *stream); + stream->raw_stream()); recv_ptr += cpu_global_count_data[idx]; } } diff --git a/test/cpp/fluid/platform/device/custom/custom_device_test.cc b/test/cpp/fluid/platform/device/custom/custom_device_test.cc index 7ad8c84f4de406..72153ff085d122 100644 --- a/test/cpp/fluid/platform/device/custom/custom_device_test.cc +++ b/test/cpp/fluid/platform/device/custom/custom_device_test.cc @@ -186,8 +186,13 @@ void TestCustomCCL(const phi::Place& place) { phi::DeviceManager::CCLDestroyComm(dev_type, nullptr); phi::DeviceManager::CCLGetUniqueId(dev_type, &root_id); phi::DeviceManager::CCLCommInitRank(dev_type, 0, &root_id, 0, nullptr); - phi::DeviceManager::CCLBroadcast( - dev_type, nullptr, 0, phi::DataType::FLOAT32, 0, comm, stream); + phi::DeviceManager::CCLBroadcast(dev_type, + nullptr, + 0, + phi::DataType::FLOAT32, + 0, + comm, + stream.raw_stream()); phi::DeviceManager::CCLAllReduce(dev_type, nullptr, nullptr, @@ -195,7 +200,7 @@ void TestCustomCCL(const phi::Place& place) { phi::DataType::FLOAT32, phi::ccl::CCLReduceOp::SUM, comm, - stream); + stream.raw_stream()); phi::DeviceManager::CCLReduce(dev_type, nullptr, nullptr, @@ -204,9 +209,14 @@ void TestCustomCCL(const phi::Place& place) { phi::ccl::CCLReduceOp::SUM, 0, comm, - stream); - phi::DeviceManager::CCLAllGather( - dev_type, nullptr, nullptr, 0, phi::DataType::FLOAT32, comm, stream); + stream.raw_stream()); + phi::DeviceManager::CCLAllGather(dev_type, + nullptr, + nullptr, + 0, + phi::DataType::FLOAT32, + comm, + stream.raw_stream()); phi::DeviceManager::CCLReduceScatter(dev_type, nullptr, nullptr, @@ -214,13 +224,23 @@ void TestCustomCCL(const phi::Place& place) { phi::DataType::FLOAT32, phi::ccl::CCLReduceOp::SUM, comm, - stream); + stream.raw_stream()); phi::DeviceManager::CCLGroupStart(dev_type); phi::DeviceManager::CCLGroupEnd(dev_type); - phi::DeviceManager::CCLSend( - dev_type, nullptr, 0, phi::DataType::FLOAT32, 0, comm, stream); - phi::DeviceManager::CCLRecv( - dev_type, nullptr, 0, phi::DataType::FLOAT32, 0, comm, stream); + phi::DeviceManager::CCLSend(dev_type, + nullptr, + 0, + phi::DataType::FLOAT32, + 0, + comm, + stream.raw_stream()); + phi::DeviceManager::CCLRecv(dev_type, + nullptr, + 0, + phi::DataType::FLOAT32, + 0, + comm, + stream.raw_stream()); } TEST(CustomDevice, Tensor) {