16
16
17
17
#include " paddle/fluid/distributed/collective/common.h"
18
18
#include " paddle/fluid/distributed/collective/custom_ccl_tools.h"
19
+ #include " paddle/fluid/distributed/collective/utils.h"
19
20
#include " paddle/fluid/memory/malloc.h"
20
21
#include " paddle/fluid/platform/device_context.h"
21
22
#include " paddle/fluid/platform/place.h"
22
23
#include " paddle/phi/api/lib/utils/allocator.h"
23
24
#include " paddle/phi/common/place.h"
25
+ #include " paddle/phi/core/distributed/check/static_check.h"
24
26
25
27
DECLARE_bool (xccl_blocking_wait);
26
28
@@ -234,10 +236,21 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
234
236
const phi::DenseTensor& in_tensor,
235
237
int64_t offset,
236
238
int64_t numel,
237
- bool sync_op // for compatibility, no use now
238
- ) {
239
- std::vector<phi::DenseTensor> in_wrapper{in_tensor};
239
+ bool sync_op, // for compatibility, no use now
240
+ bool use_calc_stream) {
241
+ // numel > 0 indicates the tensor need to be sliced
242
+ const phi::DenseTensor& in_tensor_maybe_partial =
243
+ numel > 0
244
+ ? paddle::distributed::GetPartialTensor (in_tensor, offset, numel)
245
+ : in_tensor;
246
+ phi::distributed::CommStaticCheck::GatherLikeShape (*out_tensor,
247
+ in_tensor_maybe_partial,
248
+ /* dst_rank*/ rank_,
249
+ /* cur_rank*/ rank_,
250
+ size_);
251
+ std::vector<phi::DenseTensor> in_wrapper{in_tensor_maybe_partial};
240
252
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
253
+
241
254
return Collective (
242
255
in_wrapper,
243
256
out_wrapper,
@@ -247,80 +260,23 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
247
260
const phi::stream::Stream& stream) {
248
261
return phi::DeviceManager::CCLAllGather (
249
262
device_type_,
250
- XcclGetPointerByOffset ( input.data (), offset, input. dtype () ),
263
+ input.data (),
251
264
output.data (),
252
- numel,
265
+ input. numel () ,
253
266
phi::ccl::ToCCLDataType (input.dtype ()),
254
267
comm,
255
268
stream);
256
269
},
257
270
CommType::ALLGATHER);
258
271
}
259
272
260
- std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce (
261
- phi::DenseTensor* out_tensor,
262
- const phi::DenseTensor& in_tensor,
263
- const AllreduceOptions& opts,
264
- bool sync_op // for compatibility, no use now
265
- ) {
266
- std::vector<phi::DenseTensor> in_wrapper{in_tensor};
267
- std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
268
- return AllReduce (in_wrapper, out_wrapper, opts);
269
- }
270
-
271
- std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast (
273
+ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather (
272
274
phi::DenseTensor* out_tensor,
273
275
const phi::DenseTensor& in_tensor,
274
- const BroadcastOptions& opts,
275
- bool sync_op // for compatibility, no use now
276
- ) {
277
- std::vector<phi::DenseTensor> in_wrapper{in_tensor};
278
- std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
279
- return Broadcast (in_wrapper, out_wrapper, opts);
280
- }
281
-
282
- std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier (
283
- const BarrierOptions& opts) {
284
- // Only support single card single process
285
- PADDLE_ENFORCE_GE (opts.device_id ,
286
- 0 ,
287
- platform::errors::PreconditionNotMet (
288
- " The barrier device id must greater or equal than 0." ));
289
- platform::CustomPlace place (device_type_, opts.device_id );
290
- auto allocator = std::unique_ptr<phi::Allocator>(
291
- new paddle::experimental::DefaultAllocator (place));
292
- phi::DenseTensorMeta meta (phi::DataType::FLOAT32, phi::DDim{1 });
293
- phi::DenseTensor barrier_tensor{allocator.get (), meta};
294
-
295
- auto task = ProcessGroupCustom::AllReduce (&barrier_tensor,
296
- barrier_tensor,
297
- {},
298
- /* sync_op*/ true );
299
- auto xccl_task = dynamic_cast <ProcessGroupCustom::CustomTask*>(task.get ());
300
- xccl_task->barrierTensors_ = {barrier_tensor};
301
- return task;
302
- }
303
-
304
- phi::DeviceContext* ProcessGroupCustom::GetDeviceContext (
305
- const Place& place) const {
306
- const std::string key = GetKeyFromPlace (place);
307
- const auto & iter = places_to_ctx_.find (key);
308
- PADDLE_ENFORCE_NE (
309
- iter,
310
- places_to_ctx_.end (),
311
- platform::errors::NotFound (
312
- " Cannot find the device context in this process group." ));
313
- return iter->second [0 ].get ();
314
- }
315
-
316
- phi::ccl::CCLComm ProcessGroupCustom::CustomCCLComm (const Place& place) const {
317
- std::vector<Place> places = {place};
318
- const auto & iter = places_to_customcomm_.find (GetKeyFromPlaces (places));
319
- PADDLE_ENFORCE_NE (iter,
320
- places_to_customcomm_.end (),
321
- platform::errors::InvalidArgument (
322
- " Cannot find nccl comm in process group." ));
323
- return iter->second [0 ]->GetCustomCCLComm ();
276
+ int64_t offset,
277
+ int64_t numel,
278
+ bool sync_op) {
279
+ return AllGather (out_tensor, in_tensor, offset, numel, sync_op);
324
280
}
325
281
326
282
// TODO(sunyilun): methods below will be removed later
@@ -356,6 +312,28 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
356
312
CommType::ALLGATHER);
357
313
}
358
314
315
+ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce (
316
+ phi::DenseTensor* out_tensor,
317
+ const phi::DenseTensor& in_tensor,
318
+ const AllreduceOptions& opts,
319
+ bool sync_op, // for compatibility, no use now
320
+ bool use_calc_stream) {
321
+ std::vector<phi::DenseTensor> in_wrapper{in_tensor};
322
+ std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
323
+ return AllReduce (in_wrapper, out_wrapper, opts);
324
+ }
325
+
326
+ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce (
327
+ phi::DenseTensor* out_tensor,
328
+ const phi::DenseTensor& in_tensor,
329
+ const AllreduceOptions& opts,
330
+ bool sync_op // for compatibility, no use now
331
+ ) {
332
+ std::vector<phi::DenseTensor> in_wrapper{in_tensor};
333
+ std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
334
+ return AllReduce (in_wrapper, out_wrapper, opts);
335
+ }
336
+
359
337
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce (
360
338
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
361
339
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
@@ -390,6 +368,72 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
390
368
CommType::ALLREDUCE);
391
369
}
392
370
371
+ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast (
372
+ phi::DenseTensor* out_tensor,
373
+ const phi::DenseTensor& in_tensor,
374
+ const BroadcastOptions& opts,
375
+ bool sync_op, // for compatibility, no use now
376
+ bool use_calc_stream) {
377
+ std::vector<phi::DenseTensor> in_wrapper{in_tensor};
378
+ std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
379
+ return Broadcast (in_wrapper, out_wrapper, opts);
380
+ }
381
+
382
+ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast (
383
+ phi::DenseTensor* out_tensor,
384
+ const phi::DenseTensor& in_tensor,
385
+ const BroadcastOptions& opts,
386
+ bool sync_op) {
387
+ std::vector<phi::DenseTensor> in_wrapper{in_tensor};
388
+ std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
389
+ return Broadcast (in_wrapper, out_wrapper, opts);
390
+ }
391
+
392
+ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier (
393
+ const BarrierOptions& opts) {
394
+ // Only support single card single process
395
+ PADDLE_ENFORCE_GE (opts.device_id ,
396
+ 0 ,
397
+ platform::errors::PreconditionNotMet (
398
+ " The barrier device id must greater or equal than 0." ));
399
+ platform::CustomPlace place (device_type_, opts.device_id );
400
+ auto allocator = std::unique_ptr<phi::Allocator>(
401
+ new paddle::experimental::DefaultAllocator (place));
402
+ phi::DenseTensorMeta meta (phi::DataType::FLOAT32, phi::DDim{1 });
403
+ phi::DenseTensor barrier_tensor{allocator.get (), meta};
404
+
405
+ auto task = ProcessGroupCustom::AllReduce (&barrier_tensor,
406
+ barrier_tensor,
407
+ {},
408
+ /* sync_op*/ true ,
409
+ false );
410
+ auto xccl_task = dynamic_cast <ProcessGroupCustom::CustomTask*>(task.get ());
411
+ xccl_task->barrierTensors_ = {barrier_tensor};
412
+ return task;
413
+ }
414
+
415
+ phi::DeviceContext* ProcessGroupCustom::GetDeviceContext (
416
+ const Place& place) const {
417
+ const std::string key = GetKeyFromPlace (place);
418
+ const auto & iter = places_to_ctx_.find (key);
419
+ PADDLE_ENFORCE_NE (
420
+ iter,
421
+ places_to_ctx_.end (),
422
+ platform::errors::NotFound (
423
+ " Cannot find the device context in this process group." ));
424
+ return iter->second [0 ].get ();
425
+ }
426
+
427
+ phi::ccl::CCLComm ProcessGroupCustom::CustomCCLComm (const Place& place) const {
428
+ std::vector<Place> places = {place};
429
+ const auto & iter = places_to_customcomm_.find (GetKeyFromPlaces (places));
430
+ PADDLE_ENFORCE_NE (iter,
431
+ places_to_customcomm_.end (),
432
+ platform::errors::InvalidArgument (
433
+ " Cannot find nccl comm in process group." ));
434
+ return iter->second [0 ]->GetCustomCCLComm ();
435
+ }
436
+
393
437
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast (
394
438
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
395
439
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
0 commit comments