@@ -215,7 +215,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
215
215
return RunFnInXCCLEnv (
216
216
[&](const phi::stream::Stream& stream) {
217
217
auto comm_context = this ->GetCommContext ();
218
- comm_context->AllGather (out_tensor, in_tensor_maybe_partial, stream);
218
+ comm_context->AllGather (
219
+ out_tensor, in_tensor_maybe_partial, stream.raw_stream ());
219
220
},
220
221
in_tensor_maybe_partial,
221
222
CommType::ALLGATHER,
@@ -239,7 +240,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
239
240
out_tensor,
240
241
in_tensor,
241
242
paddle::distributed::ToXCCLRedType (opts.reduce_op ),
242
- stream);
243
+ stream. raw_stream () );
243
244
},
244
245
in_tensor,
245
246
CommType::ALLREDUCE,
@@ -315,7 +316,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllToAll(
315
316
rank_,
316
317
size_,
317
318
comm_context->GetXcclComm (),
318
- stream);
319
+ stream. raw_stream () );
319
320
},
320
321
in_tensor,
321
322
CommType::ALLTOALL,
@@ -358,7 +359,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
358
359
[&](const phi::stream::Stream& stream) {
359
360
int root = opts.source_rank + opts.source_root ;
360
361
auto comm_context = this ->GetCommContext ();
361
- comm_context->Broadcast (out_tensor, in_tensor, root, stream);
362
+ comm_context->Broadcast (
363
+ out_tensor, in_tensor, root, stream.raw_stream ());
362
364
},
363
365
in_tensor,
364
366
CommType::BROADCAST,
@@ -382,7 +384,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Reduce(
382
384
in_tensor,
383
385
paddle::distributed::ToXCCLRedType (opts.reduce_op ),
384
386
opts.root_rank ,
385
- stream);
387
+ stream. raw_stream () );
386
388
},
387
389
in_tensor,
388
390
CommType::REDUCE,
@@ -406,7 +408,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::ReduceScatter(
406
408
out_tensor,
407
409
in_tensor,
408
410
paddle::distributed::ToXCCLRedType (opts.reduce_op ),
409
- stream);
411
+ stream. raw_stream () );
410
412
},
411
413
in_tensor,
412
414
CommType::REDUCE_SCATTER,
@@ -441,7 +443,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Scatter(
441
443
for (auto i = 0 ; i < size_; i++) {
442
444
partial_tensor = GetPartialTensor (in_tensor, offset, numel);
443
445
if (i != rank_) {
444
- comm_context->Send (partial_tensor, numel, i, stream);
446
+ comm_context->Send (partial_tensor, numel, i, stream. raw_stream () );
445
447
} else {
446
448
phi::DeviceManager::GetDeviceWithPlace (stream.GetPlace ())
447
449
->MemoryCopyD2D (out_tensor->data (),
@@ -452,7 +454,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Scatter(
452
454
offset += numel;
453
455
}
454
456
} else {
455
- comm_context->Recv (out_tensor, numel, opts.root_rank , stream);
457
+ comm_context->Recv (
458
+ out_tensor, numel, opts.root_rank , stream.raw_stream ());
456
459
}
457
460
},
458
461
in_tensor,
@@ -506,7 +509,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Gather(
506
509
for (auto i = 0 ; i < size_; i++) {
507
510
auto & gather_tensor = gather_tensors[i];
508
511
if (i != rank_) {
509
- comm_context->Recv (&gather_tensor, gather_tensor.numel (), i, stream);
512
+ comm_context->Recv (
513
+ &gather_tensor, gather_tensor.numel (), i, stream.raw_stream ());
510
514
} else {
511
515
phi::DeviceManager::GetDeviceWithPlace (stream.GetPlace ())
512
516
->MemoryCopyD2D (
@@ -518,7 +522,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Gather(
518
522
}
519
523
} else {
520
524
// send to root
521
- comm_context->Send (in_tensor, in_tensor.numel (), opts.root_rank , stream);
525
+ comm_context->Send (
526
+ in_tensor, in_tensor.numel (), opts.root_rank , stream.raw_stream ());
522
527
}
523
528
};
524
529
return RunFnInXCCLEnv (
@@ -542,7 +547,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Recv(
542
547
return RunFnInXCCLEnv (
543
548
[&](const phi::stream::Stream& stream) {
544
549
auto comm_context = this ->GetCommContext ();
545
- comm_context->Recv (tensor, tensor->numel (), src_rank, stream);
550
+ comm_context->Recv (
551
+ tensor, tensor->numel (), src_rank, stream.raw_stream ());
546
552
},
547
553
*tensor,
548
554
CommType::RECV,
@@ -569,7 +575,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Send(
569
575
comm_context->Send (tensor_maybe_partial,
570
576
tensor_maybe_partial.numel (),
571
577
dst_rank,
572
- stream);
578
+ stream. raw_stream () );
573
579
},
574
580
tensor_maybe_partial,
575
581
CommType::SEND,
@@ -915,7 +921,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
915
921
&output,
916
922
input,
917
923
paddle::distributed::ToXCCLRedType (opts.reduce_op ),
918
- stream);
924
+ stream. raw_stream () );
919
925
},
920
926
CommType::ALLREDUCE);
921
927
}
@@ -942,7 +948,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
942
948
const auto root =
943
949
opts.source_rank * in_tensors.size () + opts.source_root ;
944
950
auto comm_context = this ->GetCommContext ();
945
- comm_context->Broadcast (&output, input, root, stream);
951
+ comm_context->Broadcast (&output, input, root, stream. raw_stream () );
946
952
},
947
953
CommType::BROADCAST);
948
954
}
@@ -988,7 +994,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Send(
988
994
const phi::stream::Stream& stream,
989
995
int dst_rank) {
990
996
auto comm_context = this ->GetCommContext ();
991
- comm_context->Send (input, input.numel (), dst_rank, stream);
997
+ comm_context->Send (input, input.numel (), dst_rank, stream. raw_stream () );
992
998
},
993
999
dst_rank,
994
1000
CommType::SEND);
@@ -1008,7 +1014,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Recv(
1008
1014
const phi::stream::Stream& stream,
1009
1015
int src_rank) {
1010
1016
auto comm_context = this ->GetCommContext ();
1011
- comm_context->Recv (&output, output.numel (), src_rank, stream);
1017
+ comm_context->Recv (
1018
+ &output, output.numel (), src_rank, stream.raw_stream ());
1012
1019
},
1013
1020
src_rank,
1014
1021
CommType::RECV);
@@ -1037,7 +1044,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
1037
1044
const phi::ccl::CCLComm& comm,
1038
1045
const phi::stream::Stream& stream) {
1039
1046
auto comm_context = this ->GetCommContext ();
1040
- comm_context->AllGather (&output, input, stream);
1047
+ comm_context->AllGather (&output, input, stream. raw_stream () );
1041
1048
},
1042
1049
CommType::ALLGATHER);
1043
1050
}
@@ -1089,7 +1096,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllToAll(
1089
1096
rank_,
1090
1097
size_,
1091
1098
comm_context->GetXcclComm (),
1092
- stream);
1099
+ stream. raw_stream () );
1093
1100
},
1094
1101
CommType::ALLTOALL);
1095
1102
}
@@ -1166,7 +1173,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllToAll(
1166
1173
rank_,
1167
1174
size_,
1168
1175
comm_context->GetXcclComm (),
1169
- stream);
1176
+ stream. raw_stream () );
1170
1177
},
1171
1178
in_tensors,
1172
1179
CommType::ALLTOALL,
@@ -1197,7 +1204,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Reduce(
1197
1204
input,
1198
1205
paddle::distributed::ToXCCLRedType (opts.reduce_op ),
1199
1206
opts.root_rank ,
1200
- stream);
1207
+ stream. raw_stream () );
1201
1208
},
1202
1209
CommType::REDUCE);
1203
1210
}
@@ -1232,13 +1239,15 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Scatter(
1232
1239
for (auto i = 0 ; i < size_; i++) {
1233
1240
auto input_data = reinterpret_cast <phi::DenseTensor*>(
1234
1241
GetPointerByOffset (input.data (), offset, input.dtype ()));
1235
- comm_context->Send (*input_data, count, i, stream);
1242
+ comm_context->Send (*input_data, count, i, stream. raw_stream () );
1236
1243
offset += count;
1237
1244
}
1238
- comm_context->Recv (&output, count, opts.root_rank , stream);
1245
+ comm_context->Recv (
1246
+ &output, count, opts.root_rank , stream.raw_stream ());
1239
1247
comm_context->GroupEnd ();
1240
1248
} else {
1241
- comm_context->Recv (&output, count, opts.root_rank , stream);
1249
+ comm_context->Recv (
1250
+ &output, count, opts.root_rank , stream.raw_stream ());
1242
1251
}
1243
1252
},
1244
1253
CommType::SCATTER);
0 commit comments