Skip to content

Commit af253e2

Browse files
committed
for customdevice xccl stream is compatible with gpustream
1 parent 0c4a400 commit af253e2

27 files changed

+251
-217
lines changed

paddle/fluid/distributed/collective/process_group_custom.cc

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
215215
return RunFnInXCCLEnv(
216216
[&](const phi::stream::Stream& stream) {
217217
auto comm_context = this->GetCommContext();
218-
comm_context->AllGather(out_tensor, in_tensor_maybe_partial, stream);
218+
comm_context->AllGather(
219+
out_tensor, in_tensor_maybe_partial, stream.raw_stream());
219220
},
220221
in_tensor_maybe_partial,
221222
CommType::ALLGATHER,
@@ -239,7 +240,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
239240
out_tensor,
240241
in_tensor,
241242
paddle::distributed::ToXCCLRedType(opts.reduce_op),
242-
stream);
243+
stream.raw_stream());
243244
},
244245
in_tensor,
245246
CommType::ALLREDUCE,
@@ -315,7 +316,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllToAll(
315316
rank_,
316317
size_,
317318
comm_context->GetXcclComm(),
318-
stream);
319+
stream.raw_stream());
319320
},
320321
in_tensor,
321322
CommType::ALLTOALL,
@@ -358,7 +359,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
358359
[&](const phi::stream::Stream& stream) {
359360
int root = opts.source_rank + opts.source_root;
360361
auto comm_context = this->GetCommContext();
361-
comm_context->Broadcast(out_tensor, in_tensor, root, stream);
362+
comm_context->Broadcast(
363+
out_tensor, in_tensor, root, stream.raw_stream());
362364
},
363365
in_tensor,
364366
CommType::BROADCAST,
@@ -382,7 +384,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Reduce(
382384
in_tensor,
383385
paddle::distributed::ToXCCLRedType(opts.reduce_op),
384386
opts.root_rank,
385-
stream);
387+
stream.raw_stream());
386388
},
387389
in_tensor,
388390
CommType::REDUCE,
@@ -406,7 +408,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::ReduceScatter(
406408
out_tensor,
407409
in_tensor,
408410
paddle::distributed::ToXCCLRedType(opts.reduce_op),
409-
stream);
411+
stream.raw_stream());
410412
},
411413
in_tensor,
412414
CommType::REDUCE_SCATTER,
@@ -441,7 +443,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Scatter(
441443
for (auto i = 0; i < size_; i++) {
442444
partial_tensor = GetPartialTensor(in_tensor, offset, numel);
443445
if (i != rank_) {
444-
comm_context->Send(partial_tensor, numel, i, stream);
446+
comm_context->Send(partial_tensor, numel, i, stream.raw_stream());
445447
} else {
446448
phi::DeviceManager::GetDeviceWithPlace(stream.GetPlace())
447449
->MemoryCopyD2D(out_tensor->data(),
@@ -452,7 +454,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Scatter(
452454
offset += numel;
453455
}
454456
} else {
455-
comm_context->Recv(out_tensor, numel, opts.root_rank, stream);
457+
comm_context->Recv(
458+
out_tensor, numel, opts.root_rank, stream.raw_stream());
456459
}
457460
},
458461
in_tensor,
@@ -506,7 +509,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Gather(
506509
for (auto i = 0; i < size_; i++) {
507510
auto& gather_tensor = gather_tensors[i];
508511
if (i != rank_) {
509-
comm_context->Recv(&gather_tensor, gather_tensor.numel(), i, stream);
512+
comm_context->Recv(
513+
&gather_tensor, gather_tensor.numel(), i, stream.raw_stream());
510514
} else {
511515
phi::DeviceManager::GetDeviceWithPlace(stream.GetPlace())
512516
->MemoryCopyD2D(
@@ -518,7 +522,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Gather(
518522
}
519523
} else {
520524
// send to root
521-
comm_context->Send(in_tensor, in_tensor.numel(), opts.root_rank, stream);
525+
comm_context->Send(
526+
in_tensor, in_tensor.numel(), opts.root_rank, stream.raw_stream());
522527
}
523528
};
524529
return RunFnInXCCLEnv(
@@ -542,7 +547,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Recv(
542547
return RunFnInXCCLEnv(
543548
[&](const phi::stream::Stream& stream) {
544549
auto comm_context = this->GetCommContext();
545-
comm_context->Recv(tensor, tensor->numel(), src_rank, stream);
550+
comm_context->Recv(
551+
tensor, tensor->numel(), src_rank, stream.raw_stream());
546552
},
547553
*tensor,
548554
CommType::RECV,
@@ -569,7 +575,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Send(
569575
comm_context->Send(tensor_maybe_partial,
570576
tensor_maybe_partial.numel(),
571577
dst_rank,
572-
stream);
578+
stream.raw_stream());
573579
},
574580
tensor_maybe_partial,
575581
CommType::SEND,
@@ -915,7 +921,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
915921
&output,
916922
input,
917923
paddle::distributed::ToXCCLRedType(opts.reduce_op),
918-
stream);
924+
stream.raw_stream());
919925
},
920926
CommType::ALLREDUCE);
921927
}
@@ -942,7 +948,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
942948
const auto root =
943949
opts.source_rank * in_tensors.size() + opts.source_root;
944950
auto comm_context = this->GetCommContext();
945-
comm_context->Broadcast(&output, input, root, stream);
951+
comm_context->Broadcast(&output, input, root, stream.raw_stream());
946952
},
947953
CommType::BROADCAST);
948954
}
@@ -988,7 +994,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Send(
988994
const phi::stream::Stream& stream,
989995
int dst_rank) {
990996
auto comm_context = this->GetCommContext();
991-
comm_context->Send(input, input.numel(), dst_rank, stream);
997+
comm_context->Send(input, input.numel(), dst_rank, stream.raw_stream());
992998
},
993999
dst_rank,
9941000
CommType::SEND);
@@ -1008,7 +1014,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Recv(
10081014
const phi::stream::Stream& stream,
10091015
int src_rank) {
10101016
auto comm_context = this->GetCommContext();
1011-
comm_context->Recv(&output, output.numel(), src_rank, stream);
1017+
comm_context->Recv(
1018+
&output, output.numel(), src_rank, stream.raw_stream());
10121019
},
10131020
src_rank,
10141021
CommType::RECV);
@@ -1037,7 +1044,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
10371044
const phi::ccl::CCLComm& comm,
10381045
const phi::stream::Stream& stream) {
10391046
auto comm_context = this->GetCommContext();
1040-
comm_context->AllGather(&output, input, stream);
1047+
comm_context->AllGather(&output, input, stream.raw_stream());
10411048
},
10421049
CommType::ALLGATHER);
10431050
}
@@ -1089,7 +1096,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllToAll(
10891096
rank_,
10901097
size_,
10911098
comm_context->GetXcclComm(),
1092-
stream);
1099+
stream.raw_stream());
10931100
},
10941101
CommType::ALLTOALL);
10951102
}
@@ -1166,7 +1173,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllToAll(
11661173
rank_,
11671174
size_,
11681175
comm_context->GetXcclComm(),
1169-
stream);
1176+
stream.raw_stream());
11701177
},
11711178
in_tensors,
11721179
CommType::ALLTOALL,
@@ -1197,7 +1204,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Reduce(
11971204
input,
11981205
paddle::distributed::ToXCCLRedType(opts.reduce_op),
11991206
opts.root_rank,
1200-
stream);
1207+
stream.raw_stream());
12011208
},
12021209
CommType::REDUCE);
12031210
}
@@ -1232,13 +1239,15 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Scatter(
12321239
for (auto i = 0; i < size_; i++) {
12331240
auto input_data = reinterpret_cast<phi::DenseTensor*>(
12341241
GetPointerByOffset(input.data(), offset, input.dtype()));
1235-
comm_context->Send(*input_data, count, i, stream);
1242+
comm_context->Send(*input_data, count, i, stream.raw_stream());
12361243
offset += count;
12371244
}
1238-
comm_context->Recv(&output, count, opts.root_rank, stream);
1245+
comm_context->Recv(
1246+
&output, count, opts.root_rank, stream.raw_stream());
12391247
comm_context->GroupEnd();
12401248
} else {
1241-
comm_context->Recv(&output, count, opts.root_rank, stream);
1249+
comm_context->Recv(
1250+
&output, count, opts.root_rank, stream.raw_stream());
12421251
}
12431252
},
12441253
CommType::SCATTER);

paddle/fluid/imperative/gradient_accumulator.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ void TensorAdd(const VarType& src, VarType* dst) {
221221
phi::DeviceContextPool::Instance().Get(place)); \
222222
phi::stream::Stream stream(place, ctx->stream()); \
223223
auto device = phi::DeviceManager::GetDeviceWithPlace(place); \
224-
device->BlasAXPBY<T>(stream, \
224+
device->BlasAXPBY<T>(stream.raw_stream(), \
225225
static_cast<size_t>(numel), \
226226
1., \
227227
src_tensor.data<T>(), \

paddle/fluid/imperative/xccl_context.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ namespace imperative {
3636

3737
static void XcclAllReduce(const phi::DenseTensor &src,
3838
phi::DenseTensor *dst,
39-
const phi::stream::Stream &stream,
39+
const phi::stream::stream_t &stream,
4040
const phi::ccl::CCLComm &comm) {
4141
const auto &place = src.place();
4242
PADDLE_ENFORCE_EQ(
@@ -171,15 +171,15 @@ void XCCLParallelContext::AllReduceByStream(const framework::Variable &src,
171171
platform::XCCLComm *comm =
172172
platform::XCCLCommContext::Instance(place.GetDeviceType())
173173
.Get(ring_id, place);
174-
auto stream = use_calc_stream ? dev_ctx->GetStream() : comm->stream();
174+
auto stream = use_calc_stream ? dev_ctx->stream() : comm->stream();
175175

176176
if (src.IsType<phi::DenseTensor>()) {
177177
if (!dst->IsType<phi::DenseTensor>()) {
178178
dst->Clear();
179179
}
180180
XcclAllReduce(src.Get<phi::DenseTensor>(),
181181
dst->GetMutable<phi::DenseTensor>(),
182-
*stream,
182+
stream,
183183
comm->comm());
184184
} else {
185185
PADDLE_THROW(common::errors::InvalidArgument(
@@ -207,7 +207,7 @@ void XCCLParallelContext::Broadcast(framework::Variable *src, int ring_id) {
207207
src_tensor->dtype(),
208208
0,
209209
comm->comm(),
210-
*stream);
210+
stream);
211211
}
212212

213213
phi::DeviceContext *XCCLParallelContext::GetDeviceContext(int ring_id) {
@@ -235,7 +235,7 @@ void XCCLParallelContext::WaitCompute(int ring_id) {
235235
->GetStream();
236236
auto comm_stream = platform::XCCLCommContext::Instance(place_.GetDeviceType())
237237
.Get(ring_id, place_)
238-
->stream();
238+
->GetStream();
239239
auto event = compute_events_[ring_id].get();
240240

241241
// compute_stream-->event-->comm_stream
@@ -261,7 +261,7 @@ void XCCLParallelContext::WaitComm(int ring_id) {
261261
->GetStream();
262262
auto comm_stream = platform::XCCLCommContext::Instance(place_.GetDeviceType())
263263
.Get(ring_id, place_)
264-
->stream();
264+
->GetStream();
265265
auto event = comm_events_[ring_id].get();
266266

267267
// comm_stream-->event-->compute_stream

paddle/phi/backends/callback_manager.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ void CallbackManager::AddCallback(std::function<void()> callback) const {
4141
void CallbackManager::Wait() const {
4242
phi::DeviceGuard guard(stream_->GetPlace());
4343
phi::DeviceManager::GetDeviceWithPlace(stream_->GetPlace())
44-
->SynchronizeStream(stream_);
44+
->SynchronizeStream(stream_->raw_stream());
4545

4646
{
4747
std::lock_guard<std::mutex> lock(mtx_);

0 commit comments

Comments
 (0)