Skip to content

[CustomDevice] Achieve compatibility of Xccl stream with gpu stream #72983

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 32 additions & 23 deletions paddle/fluid/distributed/collective/process_group_custom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
return RunFnInXCCLEnv(
[&](const phi::stream::Stream& stream) {
auto comm_context = this->GetCommContext();
comm_context->AllGather(out_tensor, in_tensor_maybe_partial, stream);
comm_context->AllGather(
out_tensor, in_tensor_maybe_partial, stream.raw_stream());
},
in_tensor_maybe_partial,
CommType::ALLGATHER,
Expand All @@ -239,7 +240,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
out_tensor,
in_tensor,
paddle::distributed::ToXCCLRedType(opts.reduce_op),
stream);
stream.raw_stream());
},
in_tensor,
CommType::ALLREDUCE,
Expand Down Expand Up @@ -315,7 +316,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllToAll(
rank_,
size_,
comm_context->GetXcclComm(),
stream);
stream.raw_stream());
},
in_tensor,
CommType::ALLTOALL,
Expand Down Expand Up @@ -358,7 +359,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
[&](const phi::stream::Stream& stream) {
int root = opts.source_rank + opts.source_root;
auto comm_context = this->GetCommContext();
comm_context->Broadcast(out_tensor, in_tensor, root, stream);
comm_context->Broadcast(
out_tensor, in_tensor, root, stream.raw_stream());
},
in_tensor,
CommType::BROADCAST,
Expand All @@ -382,7 +384,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Reduce(
in_tensor,
paddle::distributed::ToXCCLRedType(opts.reduce_op),
opts.root_rank,
stream);
stream.raw_stream());
},
in_tensor,
CommType::REDUCE,
Expand All @@ -406,7 +408,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::ReduceScatter(
out_tensor,
in_tensor,
paddle::distributed::ToXCCLRedType(opts.reduce_op),
stream);
stream.raw_stream());
},
in_tensor,
CommType::REDUCE_SCATTER,
Expand Down Expand Up @@ -441,7 +443,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Scatter(
for (auto i = 0; i < size_; i++) {
partial_tensor = GetPartialTensor(in_tensor, offset, numel);
if (i != rank_) {
comm_context->Send(partial_tensor, numel, i, stream);
comm_context->Send(partial_tensor, numel, i, stream.raw_stream());
} else {
phi::DeviceManager::GetDeviceWithPlace(stream.GetPlace())
->MemoryCopyD2D(out_tensor->data(),
Expand All @@ -452,7 +454,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Scatter(
offset += numel;
}
} else {
comm_context->Recv(out_tensor, numel, opts.root_rank, stream);
comm_context->Recv(
out_tensor, numel, opts.root_rank, stream.raw_stream());
}
},
in_tensor,
Expand Down Expand Up @@ -506,7 +509,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Gather(
for (auto i = 0; i < size_; i++) {
auto& gather_tensor = gather_tensors[i];
if (i != rank_) {
comm_context->Recv(&gather_tensor, gather_tensor.numel(), i, stream);
comm_context->Recv(
&gather_tensor, gather_tensor.numel(), i, stream.raw_stream());
} else {
phi::DeviceManager::GetDeviceWithPlace(stream.GetPlace())
->MemoryCopyD2D(
Expand All @@ -518,7 +522,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Gather(
}
} else {
// send to root
comm_context->Send(in_tensor, in_tensor.numel(), opts.root_rank, stream);
comm_context->Send(
in_tensor, in_tensor.numel(), opts.root_rank, stream.raw_stream());
}
};
return RunFnInXCCLEnv(
Expand All @@ -542,7 +547,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Recv(
return RunFnInXCCLEnv(
[&](const phi::stream::Stream& stream) {
auto comm_context = this->GetCommContext();
comm_context->Recv(tensor, tensor->numel(), src_rank, stream);
comm_context->Recv(
tensor, tensor->numel(), src_rank, stream.raw_stream());
},
*tensor,
CommType::RECV,
Expand All @@ -569,7 +575,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Send(
comm_context->Send(tensor_maybe_partial,
tensor_maybe_partial.numel(),
dst_rank,
stream);
stream.raw_stream());
},
tensor_maybe_partial,
CommType::SEND,
Expand Down Expand Up @@ -915,7 +921,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
&output,
input,
paddle::distributed::ToXCCLRedType(opts.reduce_op),
stream);
stream.raw_stream());
},
CommType::ALLREDUCE);
}
Expand All @@ -942,7 +948,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
const auto root =
opts.source_rank * in_tensors.size() + opts.source_root;
auto comm_context = this->GetCommContext();
comm_context->Broadcast(&output, input, root, stream);
comm_context->Broadcast(&output, input, root, stream.raw_stream());
},
CommType::BROADCAST);
}
Expand Down Expand Up @@ -988,7 +994,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Send(
const phi::stream::Stream& stream,
int dst_rank) {
auto comm_context = this->GetCommContext();
comm_context->Send(input, input.numel(), dst_rank, stream);
comm_context->Send(input, input.numel(), dst_rank, stream.raw_stream());
},
dst_rank,
CommType::SEND);
Expand All @@ -1008,7 +1014,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Recv(
const phi::stream::Stream& stream,
int src_rank) {
auto comm_context = this->GetCommContext();
comm_context->Recv(&output, output.numel(), src_rank, stream);
comm_context->Recv(
&output, output.numel(), src_rank, stream.raw_stream());
},
src_rank,
CommType::RECV);
Expand Down Expand Up @@ -1037,7 +1044,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
const phi::ccl::CCLComm& comm,
const phi::stream::Stream& stream) {
auto comm_context = this->GetCommContext();
comm_context->AllGather(&output, input, stream);
comm_context->AllGather(&output, input, stream.raw_stream());
},
CommType::ALLGATHER);
}
Expand Down Expand Up @@ -1089,7 +1096,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllToAll(
rank_,
size_,
comm_context->GetXcclComm(),
stream);
stream.raw_stream());
},
CommType::ALLTOALL);
}
Expand Down Expand Up @@ -1166,7 +1173,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllToAll(
rank_,
size_,
comm_context->GetXcclComm(),
stream);
stream.raw_stream());
},
in_tensors,
CommType::ALLTOALL,
Expand Down Expand Up @@ -1197,7 +1204,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Reduce(
input,
paddle::distributed::ToXCCLRedType(opts.reduce_op),
opts.root_rank,
stream);
stream.raw_stream());
},
CommType::REDUCE);
}
Expand Down Expand Up @@ -1232,13 +1239,15 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Scatter(
for (auto i = 0; i < size_; i++) {
auto input_data = reinterpret_cast<phi::DenseTensor*>(
GetPointerByOffset(input.data(), offset, input.dtype()));
comm_context->Send(*input_data, count, i, stream);
comm_context->Send(*input_data, count, i, stream.raw_stream());
offset += count;
}
comm_context->Recv(&output, count, opts.root_rank, stream);
comm_context->Recv(
&output, count, opts.root_rank, stream.raw_stream());
comm_context->GroupEnd();
} else {
comm_context->Recv(&output, count, opts.root_rank, stream);
comm_context->Recv(
&output, count, opts.root_rank, stream.raw_stream());
}
},
CommType::SCATTER);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/imperative/gradient_accumulator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ void TensorAdd(const VarType& src, VarType* dst) {
phi::DeviceContextPool::Instance().Get(place)); \
phi::stream::Stream stream(place, ctx->stream()); \
auto device = phi::DeviceManager::GetDeviceWithPlace(place); \
device->BlasAXPBY<T>(stream, \
device->BlasAXPBY<T>(stream.raw_stream(), \
static_cast<size_t>(numel), \
1., \
src_tensor.data<T>(), \
Expand Down
12 changes: 6 additions & 6 deletions paddle/fluid/imperative/xccl_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace imperative {

static void XcclAllReduce(const phi::DenseTensor &src,
phi::DenseTensor *dst,
const phi::stream::Stream &stream,
const phi::stream::stream_t &stream,
const phi::ccl::CCLComm &comm) {
const auto &place = src.place();
PADDLE_ENFORCE_EQ(
Expand Down Expand Up @@ -171,15 +171,15 @@ void XCCLParallelContext::AllReduceByStream(const framework::Variable &src,
platform::XCCLComm *comm =
platform::XCCLCommContext::Instance(place.GetDeviceType())
.Get(ring_id, place);
auto stream = use_calc_stream ? dev_ctx->GetStream() : comm->stream();
auto stream = use_calc_stream ? dev_ctx->stream() : comm->stream();

if (src.IsType<phi::DenseTensor>()) {
if (!dst->IsType<phi::DenseTensor>()) {
dst->Clear();
}
XcclAllReduce(src.Get<phi::DenseTensor>(),
dst->GetMutable<phi::DenseTensor>(),
*stream,
stream,
comm->comm());
} else {
PADDLE_THROW(common::errors::InvalidArgument(
Expand Down Expand Up @@ -207,7 +207,7 @@ void XCCLParallelContext::Broadcast(framework::Variable *src, int ring_id) {
src_tensor->dtype(),
0,
comm->comm(),
*stream);
stream);
}

phi::DeviceContext *XCCLParallelContext::GetDeviceContext(int ring_id) {
Expand Down Expand Up @@ -235,7 +235,7 @@ void XCCLParallelContext::WaitCompute(int ring_id) {
->GetStream();
auto comm_stream = platform::XCCLCommContext::Instance(place_.GetDeviceType())
.Get(ring_id, place_)
->stream();
->GetStream();
auto event = compute_events_[ring_id].get();

// compute_stream-->event-->comm_stream
Expand All @@ -261,7 +261,7 @@ void XCCLParallelContext::WaitComm(int ring_id) {
->GetStream();
auto comm_stream = platform::XCCLCommContext::Instance(place_.GetDeviceType())
.Get(ring_id, place_)
->stream();
->GetStream();
auto event = comm_events_[ring_id].get();

// comm_stream-->event-->compute_stream
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/backends/callback_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ void CallbackManager::AddCallback(std::function<void()> callback) const {
void CallbackManager::Wait() const {
phi::DeviceGuard guard(stream_->GetPlace());
phi::DeviceManager::GetDeviceWithPlace(stream_->GetPlace())
->SynchronizeStream(stream_);
->SynchronizeStream(stream_->raw_stream());

{
std::lock_guard<std::mutex> lock(mtx_);
Expand Down
Loading
Loading