@@ -65,14 +65,11 @@ def global_scatter(x,
65
65
to global_count.
66
66
67
67
Args:
68
- x (Tensor): Tensor. Every element in the list must be a Tensor whose data type
69
- should be float16, float32, float64, int32 or int64.
68
+ x (Tensor): Tensor. The tensor data type should be float16, float32, float64, int32 or int64.
70
69
local_count (Tensor): Tensor which have n_expert * world_size elements that indicates
71
- how many data needed to be sent. Every element in the list must be a Tensor whose
72
- data type should be int64.
70
+ how many data needed to be sent. The tensor data type should be int64.
73
71
global_count (Tensor): Tensor which have n_expert * world_size elements that indicates
74
- how many data needed to be received. Every element in the list must be a Tensor whose
75
- data type should be int64.
72
+ how many data needed to be received. The tensor data type should be int64.
76
73
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
77
74
use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True.
78
75
@@ -161,19 +158,16 @@ def global_gather(x,
161
158
to global_count.
162
159
163
160
Args:
164
- x (Tensor): Tensor. Every element in the list must be a Tensor whose data type
165
- should be float16, float32, float64, int32 or int64.
161
+ x (Tensor): Tensor. Tensor whose data type should be float16, float32, float64, int32 or int64.
166
162
local_count (Tensor): Tensor which have n_expert * world_size elements that indicates
167
- how many data needed to be received. Every element in the list must be a Tensor whose
168
- data type should be int64.
163
+ how many data needed to be received. Tensor data type should be int64.
169
164
global_count (Tensor): Tensor which have n_expert * world_size elements that indicates
170
- how many data needed to be sent. Every element in the list must be a Tensor whose
171
- data type should be int64.
165
+ how many data needed to be sent. Tensor data type should be int64.
172
166
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
173
167
use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True.
174
168
175
169
Returns:
176
- None.
170
+ out (Tensor): The data received from all experts.
177
171
178
172
Examples:
179
173
.. code-block:: python
0 commit comments