Skip to content

Commit 17b4dd7

Browse files
authored
Fix global gather and global scatter operators (#36517)
* fix global gather and global scatter operators
1 parent 6a572a1 commit 17b4dd7

File tree

2 files changed

+11
-17
lines changed

2 files changed

+11
-17
lines changed

paddle/fluid/operators/collective/global_scatter_op.cu.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> {
4747
if (platform::is_cpu_place(local_count->place())) {
4848
cpu_local_count_data = local_count->data<int64_t>();
4949
} else {
50-
framework::TensorCopy(*local_count, platform::CPUPlace(),
51-
&cpu_local_count);
50+
framework::TensorCopySync(*local_count, platform::CPUPlace(),
51+
&cpu_local_count);
5252
cpu_local_count_data = cpu_local_count.data<int64_t>();
5353
}
5454
auto global_count_len = 0;
@@ -57,8 +57,8 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> {
5757
cpu_global_count_data = global_count->data<int64_t>();
5858
global_count_len = global_count->numel();
5959
} else {
60-
framework::TensorCopy(*global_count, platform::CPUPlace(),
61-
&cpu_global_count);
60+
framework::TensorCopySync(*global_count, platform::CPUPlace(),
61+
&cpu_global_count);
6262
cpu_global_count_data = cpu_global_count.data<int64_t>();
6363
global_count_len = cpu_global_count.numel();
6464
}

python/paddle/distributed/utils.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,11 @@ def global_scatter(x,
6565
to global_count.
6666
6767
Args:
68-
x (Tensor): Tensor. Every element in the list must be a Tensor whose data type
69-
should be float16, float32, float64, int32 or int64.
68+
x (Tensor): Tensor. The tensor data type should be float16, float32, float64, int32 or int64.
7069
local_count (Tensor): Tensor which have n_expert * world_size elements that indicates
71-
how many data needed to be sent. Every element in the list must be a Tensor whose
72-
data type should be int64.
70+
how many data needed to be sent. The tensor data type should be int64.
7371
global_count (Tensor): Tensor which have n_expert * world_size elements that indicates
74-
how many data needed to be received. Every element in the list must be a Tensor whose
75-
data type should be int64.
72+
how many data needed to be received. The tensor data type should be int64.
7673
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
7774
use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True.
7875
@@ -161,19 +158,16 @@ def global_gather(x,
161158
to global_count.
162159
163160
Args:
164-
x (Tensor): Tensor. Every element in the list must be a Tensor whose data type
165-
should be float16, float32, float64, int32 or int64.
161+
x (Tensor): Tensor. Tensor whose data type should be float16, float32, float64, int32 or int64.
166162
local_count (Tensor): Tensor which have n_expert * world_size elements that indicates
167-
how many data needed to be received. Every element in the list must be a Tensor whose
168-
data type should be int64.
163+
how many data needed to be received. Tensor data type should be int64.
169164
global_count (Tensor): Tensor which have n_expert * world_size elements that indicates
170-
how many data needed to be sent. Every element in the list must be a Tensor whose
171-
data type should be int64.
165+
how many data needed to be sent. Tensor data type should be int64.
172166
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
173167
use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True.
174168
175169
Returns:
176-
None.
170+
out (Tensor): The data received from all experts.
177171
178172
Examples:
179173
.. code-block:: python

0 commit comments

Comments
 (0)