Skip to content

Commit 5ad51a7

Browse files
committed
[PHI] Optimize Gather kernel with vectorization
1 parent 0a85d41 commit 5ad51a7

File tree

2 files changed

+104
-126
lines changed

2 files changed

+104
-126
lines changed

paddle/phi/kernels/funcs/gather.cu.h

+60-80
Original file line numberDiff line numberDiff line change
@@ -22,29 +22,13 @@ limitations under the License. */
2222
#include "paddle/phi/backends/gpu/gpu_primitives.h"
2323
#include "paddle/phi/common/place.h"
2424
#include "paddle/phi/core/dense_tensor.h"
25+
#include "paddle/phi/kernels/funcs/aligned_vector.h"
2526
#include "paddle/phi/kernels/funcs/math_function.h"
27+
#include "paddle/phi/kernels/primitive/kernel_primitives.h"
2628

2729
namespace phi {
2830
namespace funcs {
2931

30-
template <typename T, typename IndexT = int>
31-
__global__ void GatherCUDAKernel(const T* params,
32-
const IndexT* indices,
33-
T* output,
34-
size_t index_size,
35-
size_t slice_size,
36-
int64_t index_dim_size) {
37-
CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
38-
int64_t indices_i = i / slice_size;
39-
int64_t slice_i = i - indices_i * slice_size; // offset inside the slice
40-
IndexT gather_i =
41-
(indices[indices_i] < 0 ? (indices[indices_i] + index_dim_size)
42-
: indices[indices_i]);
43-
int64_t params_i = gather_i * slice_size + slice_i;
44-
*(output + i) = *(params + params_i);
45-
}
46-
}
47-
4832
template <typename T, typename IndexT = int>
4933
__global__ void GatherNdCUDAKernel(const T* input,
5034
const Dim<DDim::kMaxRank> input_dims,
@@ -81,48 +65,6 @@ __global__ void GatherNdCUDAKernel(const T* input,
8165
}
8266
}
8367

84-
/**
85-
* A thin wrapper on gpu tensor
86-
* Return a new tensor from source tensor, gathered according to index
87-
* input[src]: type-T source Tensor
88-
* input[index]: type-IndexT index Tensor (1-D)
89-
* return: output tensor
90-
*/
91-
template <typename T, typename IndexT = int>
92-
void GPUGather(const phi::GPUContext& ctx,
93-
const DenseTensor& src,
94-
const DenseTensor& index,
95-
DenseTensor* output) {
96-
if (index.dims().size() == 2) {
97-
PADDLE_ENFORCE_EQ(
98-
index.dims()[1],
99-
1,
100-
common::errors::InvalidArgument("If the index's rank of gather_op is 2,"
101-
" the second dimension should be 1."));
102-
}
103-
104-
// index size
105-
int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];
106-
107-
auto src_dims = src.dims();
108-
109-
// slice size
110-
int64_t slice_size = 1;
111-
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
112-
113-
const T* p_src = src.data<T>();
114-
const IndexT* p_index = index.data<IndexT>();
115-
T* p_output = output->data<T>();
116-
117-
int block = 512;
118-
int64_t n = slice_size * index_size;
119-
dim3 grid = dim3((n + block - 1) / block);
120-
phi::backends::gpu::LimitGridDim(ctx, &grid);
121-
122-
GatherCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
123-
p_src, p_index, p_output, index_size, slice_size, src_dims[0]);
124-
}
125-
12668
template <typename T, typename IndexT = int>
12769
void GPUGatherNd(const phi::GPUContext& ctx,
12870
const DenseTensor& input,
@@ -170,20 +112,20 @@ void GPUGatherNd(const phi::GPUContext& ctx,
170112
end_size);
171113
}
172114

173-
template <typename T, typename U>
115+
template <typename T, typename U, int VecSize>
174116
__global__ void GatherGPUKernel(const T* input,
175117
const U* index,
176118
T* out,
177119
int64_t outer_dim_size,
178-
int64_t inner_dim_size,
179120
int64_t out_index_dim_size,
180121
int64_t input_index_dim_size,
181122
int64_t size) {
182-
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
123+
int64_t block_size = blockDim.x;
124+
int64_t idx = (blockIdx.x * block_size + threadIdx.x) * VecSize;
183125
int64_t outer_size = outer_dim_size * out_index_dim_size;
184-
for (; idx < size; idx += blockDim.x * gridDim.x) {
126+
for (; idx < size; idx += gridDim.x * block_size * VecSize) {
185127
int64_t inner_dim_index = idx / outer_size;
186-
int64_t next_idx = idx - outer_size * inner_dim_index;
128+
int64_t next_idx = idx % outer_size;
187129
int64_t index_dim_index = next_idx / outer_dim_size;
188130
U index_val = index[index_dim_index];
189131

@@ -201,11 +143,15 @@ __global__ void GatherGPUKernel(const T* input,
201143
index_val += input_index_dim_size;
202144
}
203145

204-
int64_t out_dim_index = next_idx - outer_dim_size * index_dim_index;
146+
int64_t out_dim_index = next_idx % outer_dim_size;
205147
int64_t input_index =
206148
inner_dim_index * (outer_dim_size * input_index_dim_size) +
207149
index_val * outer_dim_size + out_dim_index;
208-
out[idx] = input[input_index];
150+
151+
using VecType = kps::details::VectorType<T, VecSize>;
152+
const VecType* src = reinterpret_cast<const VecType*>(&input[input_index]);
153+
VecType* dst = reinterpret_cast<VecType*>(&out[idx]);
154+
*dst = *src;
209155
}
210156
}
211157

@@ -248,12 +194,10 @@ void GatherV2CUDAFunction(const DenseTensor* input,
248194
int axis_index = axis;
249195
int64_t index_dim_size = input_dim[axis_index];
250196

251-
int64_t inner_dim_size = 1;
252197
int64_t outer_dim_size = 1;
253198
std::vector<int64_t> out_dim_vec;
254199

255200
for (int i = 0; i < axis_index; i++) {
256-
inner_dim_size *= input_dim[i];
257201
out_dim_vec.push_back(input_dim[i]);
258202
}
259203
if (index->dims().size() != 0) {
@@ -270,18 +214,54 @@ void GatherV2CUDAFunction(const DenseTensor* input,
270214
int64_t out_size = out->numel();
271215
if (out_size == 0) return;
272216

273-
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, out_size);
217+
int vec_size = 4;
218+
vec_size = std::min(phi::GetVectorizedSize(input), vec_size);
219+
vec_size = std::min(phi::GetVectorizedSize(out), vec_size);
220+
while (vec_size > 1 && outer_dim_size % vec_size != 0) {
221+
vec_size /= 2;
222+
}
223+
224+
constexpr int loop_count = 4;
225+
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
226+
ctx, out_size, vec_size * loop_count);
274227
auto stream = ctx.stream();
275-
GatherGPUKernel<T, U>
276-
<<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
277-
input_data,
278-
index_data,
279-
out_data,
280-
outer_dim_size,
281-
inner_dim_size,
282-
index_size,
283-
index_dim_size,
284-
out_size);
228+
229+
switch (vec_size) {
230+
#define CASE_VEC_SIZE(__Sz) \
231+
case __Sz: \
232+
GatherGPUKernel<T, U, __Sz> \
233+
<<<config.block_per_grid, config.thread_per_block, 0, stream>>>( \
234+
input_data, \
235+
index_data, \
236+
out_data, \
237+
outer_dim_size, \
238+
index_size, \
239+
index_dim_size, \
240+
out_size); \
241+
break
242+
CASE_VEC_SIZE(4);
243+
CASE_VEC_SIZE(2);
244+
CASE_VEC_SIZE(1);
245+
#undef CASE_VEC_SIZE
246+
default:
247+
PADDLE_THROW(common::errors::Unimplemented(
248+
"Unsupported vectorized size: %d", vec_size));
249+
}
250+
}
251+
252+
/**
253+
* A thin wrapper on gpu tensor
254+
* Return a new tensor from source tensor, gathered according to index
255+
* input[src]: type-T source Tensor
256+
* input[index]: type-IndexT index Tensor (1-D)
257+
* return: output tensor
258+
*/
259+
template <typename T, typename IndexT = int>
260+
void GPUGather(const phi::GPUContext& ctx,
261+
const DenseTensor& src,
262+
const DenseTensor& index,
263+
DenseTensor* output) {
264+
GatherV2CUDAFunction<T, IndexT>(&src, &index, /* axis= */ 0, output, ctx);
285265
}
286266

287267
template <typename T, typename U>

test/ir/inference/test_trt_convert_gather.py

+44-46
Original file line numberDiff line numberDiff line change
@@ -49,59 +49,57 @@ def generate_input3(axis):
4949
return np.array([axis]).astype(np.int32)
5050

5151
for shape in [[32], [16, 64], [32, 16, 16], [32, 64, 16, 32]]:
52-
for index in [[1, 4], [4, 8]]:
52+
for index in [[0, 1]]:
5353
for axis in [0, 1, 2, 3]:
54-
for overwrite in [True, False]:
55-
for input in [
56-
{"X": ["input_data"], "Index": ["index_data"]},
57-
]:
58-
for index_type_int32 in [True, False]:
59-
self.shape = shape
60-
self.axis = axis
61-
self.input_num = len(input)
62-
self.index_type_int32 = index_type_int32
63-
dics = [{"overwrite": overwrite, "axis": axis}]
64-
ops_config = [
54+
for input in [
55+
{"X": ["input_data"], "Index": ["index_data"]},
56+
]:
57+
for index_type_int32 in [True, False]:
58+
self.shape = shape
59+
self.axis = axis
60+
self.input_num = len(input)
61+
self.index_type_int32 = index_type_int32
62+
ops_config = [
63+
{
64+
"op_type": "gather",
65+
"op_inputs": input,
66+
"op_outputs": {"Out": ["output_data"]},
67+
"op_attrs": {"axis": axis},
68+
}
69+
]
70+
ops = self.generate_op_config(ops_config)
71+
72+
program_config = ProgramConfig(
73+
ops=ops,
74+
weights={},
75+
inputs=(
6576
{
66-
"op_type": "gather",
67-
"op_inputs": input,
68-
"op_outputs": {"Out": ["output_data"]},
69-
"op_attrs": dics[0],
77+
"index_data": TensorConfig(
78+
data_gen=partial(
79+
(
80+
generate_input2
81+
if index_type_int32
82+
else generate_input4
83+
),
84+
index,
85+
)
86+
),
87+
"input_data": TensorConfig(
88+
data_gen=partial(
89+
generate_input1, shape
90+
)
91+
),
7092
}
71-
]
72-
ops = self.generate_op_config(ops_config)
73-
74-
program_config = ProgramConfig(
75-
ops=ops,
76-
weights={},
77-
inputs=(
78-
{
79-
"index_data": TensorConfig(
80-
data_gen=partial(
81-
(
82-
generate_input2
83-
if index_type_int32
84-
else generate_input4
85-
),
86-
index,
87-
)
88-
),
89-
"input_data": TensorConfig(
90-
data_gen=partial(
91-
generate_input1, shape
92-
)
93-
),
94-
}
95-
),
96-
outputs=["output_data"],
97-
)
98-
99-
yield program_config
93+
),
94+
outputs=["output_data"],
95+
)
96+
97+
yield program_config
10098

10199
def generate_dynamic_shape(self):
102100
if len(self.shape) == 1:
103101
self.dynamic_shape.min_input_shape = {
104-
"input_data": [1],
102+
"input_data": [2],
105103
"index_data": [2],
106104
}
107105
self.dynamic_shape.max_input_shape = {

0 commit comments

Comments
 (0)