@@ -22,29 +22,13 @@ limitations under the License. */
22
22
#include " paddle/phi/backends/gpu/gpu_primitives.h"
23
23
#include " paddle/phi/common/place.h"
24
24
#include " paddle/phi/core/dense_tensor.h"
25
+ #include " paddle/phi/kernels/funcs/aligned_vector.h"
25
26
#include " paddle/phi/kernels/funcs/math_function.h"
27
+ #include " paddle/phi/kernels/primitive/kernel_primitives.h"
26
28
27
29
namespace phi {
28
30
namespace funcs {
29
31
30
- template <typename T, typename IndexT = int >
31
- __global__ void GatherCUDAKernel (const T* params,
32
- const IndexT* indices,
33
- T* output,
34
- size_t index_size,
35
- size_t slice_size,
36
- int64_t index_dim_size) {
37
- CUDA_KERNEL_LOOP_TYPE (i, index_size * slice_size, int64_t ) {
38
- int64_t indices_i = i / slice_size;
39
- int64_t slice_i = i - indices_i * slice_size; // offset inside the slice
40
- IndexT gather_i =
41
- (indices[indices_i] < 0 ? (indices[indices_i] + index_dim_size)
42
- : indices[indices_i]);
43
- int64_t params_i = gather_i * slice_size + slice_i;
44
- *(output + i) = *(params + params_i);
45
- }
46
- }
47
-
48
32
template <typename T, typename IndexT = int >
49
33
__global__ void GatherNdCUDAKernel (const T* input,
50
34
const Dim<DDim::kMaxRank > input_dims,
@@ -81,48 +65,6 @@ __global__ void GatherNdCUDAKernel(const T* input,
81
65
}
82
66
}
83
67
84
- /* *
85
- * A thin wrapper on gpu tensor
86
- * Return a new tensor from source tensor, gathered according to index
87
- * input[src]: type-T source Tensor
88
- * input[index]: type-IndexT index Tensor (1-D)
89
- * return: output tensor
90
- */
91
- template <typename T, typename IndexT = int >
92
- void GPUGather (const phi::GPUContext& ctx,
93
- const DenseTensor& src,
94
- const DenseTensor& index,
95
- DenseTensor* output) {
96
- if (index .dims ().size () == 2 ) {
97
- PADDLE_ENFORCE_EQ (
98
- index .dims ()[1 ],
99
- 1 ,
100
- common::errors::InvalidArgument (" If the index's rank of gather_op is 2,"
101
- " the second dimension should be 1." ));
102
- }
103
-
104
- // index size
105
- int64_t index_size = index .dims ().size () == 0 ? 1 : index .dims ()[0 ];
106
-
107
- auto src_dims = src.dims ();
108
-
109
- // slice size
110
- int64_t slice_size = 1 ;
111
- for (int i = 1 ; i < src_dims.size (); ++i) slice_size *= src_dims[i];
112
-
113
- const T* p_src = src.data <T>();
114
- const IndexT* p_index = index .data <IndexT>();
115
- T* p_output = output->data <T>();
116
-
117
- int block = 512 ;
118
- int64_t n = slice_size * index_size;
119
- dim3 grid = dim3 ((n + block - 1 ) / block);
120
- phi::backends::gpu::LimitGridDim (ctx, &grid);
121
-
122
- GatherCUDAKernel<T, IndexT><<<grid, block, 0 , ctx.stream ()>>>(
123
- p_src, p_index, p_output, index_size, slice_size, src_dims[0 ]);
124
- }
125
-
126
68
template <typename T, typename IndexT = int >
127
69
void GPUGatherNd (const phi::GPUContext& ctx,
128
70
const DenseTensor& input,
@@ -170,20 +112,20 @@ void GPUGatherNd(const phi::GPUContext& ctx,
170
112
end_size);
171
113
}
172
114
173
- template <typename T, typename U>
115
+ template <typename T, typename U, int VecSize >
174
116
__global__ void GatherGPUKernel (const T* input,
175
117
const U* index,
176
118
T* out,
177
119
int64_t outer_dim_size,
178
- int64_t inner_dim_size,
179
120
int64_t out_index_dim_size,
180
121
int64_t input_index_dim_size,
181
122
int64_t size) {
182
- int64_t idx = blockDim.x * blockIdx.x + threadIdx.x ;
123
+ int64_t block_size = blockDim.x ;
124
+ int64_t idx = (blockIdx.x * block_size + threadIdx.x ) * VecSize;
183
125
int64_t outer_size = outer_dim_size * out_index_dim_size;
184
- for (; idx < size; idx += blockDim .x * gridDim. x ) {
126
+ for (; idx < size; idx += gridDim .x * block_size * VecSize ) {
185
127
int64_t inner_dim_index = idx / outer_size;
186
- int64_t next_idx = idx - outer_size * inner_dim_index ;
128
+ int64_t next_idx = idx % outer_size;
187
129
int64_t index_dim_index = next_idx / outer_dim_size;
188
130
U index_val = index [index_dim_index];
189
131
@@ -201,11 +143,15 @@ __global__ void GatherGPUKernel(const T* input,
201
143
index_val += input_index_dim_size;
202
144
}
203
145
204
- int64_t out_dim_index = next_idx - outer_dim_size * index_dim_index ;
146
+ int64_t out_dim_index = next_idx % outer_dim_size;
205
147
int64_t input_index =
206
148
inner_dim_index * (outer_dim_size * input_index_dim_size) +
207
149
index_val * outer_dim_size + out_dim_index;
208
- out[idx] = input[input_index];
150
+
151
+ using VecType = kps::details::VectorType<T, VecSize>;
152
+ const VecType* src = reinterpret_cast <const VecType*>(&input[input_index]);
153
+ VecType* dst = reinterpret_cast <VecType*>(&out[idx]);
154
+ *dst = *src;
209
155
}
210
156
}
211
157
@@ -248,12 +194,10 @@ void GatherV2CUDAFunction(const DenseTensor* input,
248
194
int axis_index = axis;
249
195
int64_t index_dim_size = input_dim[axis_index];
250
196
251
- int64_t inner_dim_size = 1 ;
252
197
int64_t outer_dim_size = 1 ;
253
198
std::vector<int64_t > out_dim_vec;
254
199
255
200
for (int i = 0 ; i < axis_index; i++) {
256
- inner_dim_size *= input_dim[i];
257
201
out_dim_vec.push_back (input_dim[i]);
258
202
}
259
203
if (index ->dims ().size () != 0 ) {
@@ -270,18 +214,54 @@ void GatherV2CUDAFunction(const DenseTensor* input,
270
214
int64_t out_size = out->numel ();
271
215
if (out_size == 0 ) return ;
272
216
273
- auto config = phi::backends::gpu::GetGpuLaunchConfig1D (ctx, out_size);
217
+ int vec_size = 4 ;
218
+ vec_size = std::min (phi::GetVectorizedSize (input), vec_size);
219
+ vec_size = std::min (phi::GetVectorizedSize (out), vec_size);
220
+ while (vec_size > 1 && outer_dim_size % vec_size != 0 ) {
221
+ vec_size /= 2 ;
222
+ }
223
+
224
+ constexpr int loop_count = 4 ;
225
+ auto config = phi::backends::gpu::GetGpuLaunchConfig1D (
226
+ ctx, out_size, vec_size * loop_count);
274
227
auto stream = ctx.stream ();
275
- GatherGPUKernel<T, U>
276
- <<<config.block_per_grid , config.thread_per_block , 0 , stream>>>(
277
- input_data,
278
- index_data,
279
- out_data,
280
- outer_dim_size,
281
- inner_dim_size,
282
- index_size,
283
- index_dim_size,
284
- out_size);
228
+
229
+ switch (vec_size) {
230
+ #define CASE_VEC_SIZE (__Sz ) \
231
+ case __Sz: \
232
+ GatherGPUKernel<T, U, __Sz> \
233
+ <<<config.block_per_grid , config.thread_per_block , 0 , stream>>>( \
234
+ input_data, \
235
+ index_data, \
236
+ out_data, \
237
+ outer_dim_size, \
238
+ index_size, \
239
+ index_dim_size, \
240
+ out_size); \
241
+ break
242
+ CASE_VEC_SIZE (4 );
243
+ CASE_VEC_SIZE (2 );
244
+ CASE_VEC_SIZE (1 );
245
+ #undef CASE_VEC_SIZE
246
+ default :
247
+ PADDLE_THROW (common::errors::Unimplemented (
248
+ " Unsupported vectorized size: %d" , vec_size));
249
+ }
250
+ }
251
+
252
+ /* *
253
+ * A thin wrapper on gpu tensor
254
+ * Return a new tensor from source tensor, gathered according to index
255
+ * input[src]: type-T source Tensor
256
+ * input[index]: type-IndexT index Tensor (1-D)
257
+ * return: output tensor
258
+ */
259
+ template <typename T, typename IndexT = int >
260
+ void GPUGather (const phi::GPUContext& ctx,
261
+ const DenseTensor& src,
262
+ const DenseTensor& index,
263
+ DenseTensor* output) {
264
+ GatherV2CUDAFunction<T, IndexT>(&src, &index , /* axis= */ 0 , output, ctx);
285
265
}
286
266
287
267
template <typename T, typename U>
0 commit comments