Skip to content

Commit a8a30e6

Browse files
committed
add index_elementwise_get_grad kernel
1 parent 108db2c commit a8a30e6

File tree

8 files changed

+268
-48
lines changed

8 files changed

+268
-48
lines changed

paddle/phi/infermeta/backward.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1887,4 +1887,18 @@ void SetValueGradInferMeta(const MetaTensor& out_grad,
18871887
value_grad->share_lod(values);
18881888
}
18891889
}
1890+
1891+
void IndexElementwiseGetGradInferMeta(
1892+
const MetaTensor& x,
1893+
const std::vector<const MetaTensor*>& index,
1894+
const MetaTensor& out_grad,
1895+
const std::vector<int64_t>& input_dims,
1896+
const std::vector<int64_t>& input_strides,
1897+
const std::vector<int64_t>& index_dims,
1898+
const std::vector<int64_t>& index_strides,
1899+
MetaTensor* x_grad) {
1900+
if (x_grad) {
1901+
x_grad->share_meta(x);
1902+
}
1903+
}
18901904
} // namespace phi

paddle/phi/infermeta/backward.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,4 +680,14 @@ void SetValueGradInferMeta(const MetaTensor& out_grad,
680680
MetaTensor* x_grad,
681681
MetaTensor* value_grad);
682682

683+
void IndexElementwiseGetGradInferMeta(
684+
const MetaTensor& x,
685+
const std::vector<const MetaTensor*>& index,
686+
const MetaTensor& out_grad,
687+
const std::vector<int64_t>& input_dims,
688+
const std::vector<int64_t>& input_strides,
689+
const std::vector<int64_t>& index_dims,
690+
const std::vector<int64_t>& index_strides,
691+
MetaTensor* x_grad);
692+
683693
} // namespace phi

paddle/phi/kernels/funcs/stride_utils.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,10 @@ static inline void reorder_dimensions(const std::vector<int64_t> stride_size,
197197
permute_dimensions<N>(stride_size, perm_, strides_array, shape_);
198198
}
199199

200-
std::vector<int64_t> compatible_stride(const std::vector<int64_t>* shape_,
201-
const int64_t ndim,
202-
const int64_t element_size) {
200+
static inline std::vector<int64_t> compatible_stride(
201+
const std::vector<int64_t>* shape_,
202+
const int64_t ndim,
203+
const int64_t element_size) {
203204
std::vector<int64_t> stride;
204205
int64_t next_stride = element_size;
205206

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/index_elementwise_get_grad_kernel.h"
16+
17+
#include "paddle/phi/backends/gpu/gpu_context.h"
18+
#include "paddle/phi/common/bfloat16.h"
19+
#include "paddle/phi/core/kernel_registry.h"
20+
#include "paddle/phi/kernels/funcs/eigen/common.h"
21+
#include "paddle/phi/kernels/funcs/index_elementwise.cu.h"
22+
#include "paddle/phi/kernels/funcs/stride_utils.h"
23+
24+
namespace phi {
25+
26+
template <typename T, typename IndexT = int>
27+
void GPUIndexElementwisePutKernel(const phi::GPUContext& ctx,
28+
const DenseTensor& input,
29+
const DenseTensor& value,
30+
const std::vector<const DenseTensor*>& index,
31+
const std::vector<int64_t>& input_dims,
32+
const std::vector<int64_t>& input_strides,
33+
const std::vector<int64_t>& index_dims,
34+
const std::vector<int64_t>& index_strides,
35+
DenseTensor* output) {
36+
int64_t numel = 0;
37+
38+
auto num_indices = index_dims.size();
39+
40+
auto sizes = std::array<int64_t, 25>{};
41+
auto strides = std::array<int64_t, 25>{};
42+
for (unsigned i = 0; i < num_indices; i++) {
43+
sizes[i] = index_dims[i];
44+
strides[i] = index_strides[i];
45+
}
46+
auto index_ptrs = funcs::GetIndexDataPtrs<IndexT>(index);
47+
48+
std::array<int64_t*, 3> strides_array;
49+
std::vector<int64_t> desired_shape;
50+
51+
funcs::IndexPutStride<3>(input_dims,
52+
input_strides,
53+
phi::SizeOf(input.dtype()),
54+
std::vector<int64_t>(),
55+
std::vector<int64_t>(),
56+
phi::SizeOf(value.dtype()),
57+
common::vectorize<int64_t>(index[0]->dims()),
58+
common::vectorize<int64_t>(index[0]->strides()),
59+
phi::SizeOf(index[0]->dtype()),
60+
&desired_shape,
61+
&strides_array,
62+
&numel);
63+
64+
const int64_t* template_stride = strides_array[2];
65+
PADDLE_ENFORCE(
66+
template_stride != nullptr,
67+
"strides_array[2] should not be nullptr in GPUIndexElementwiseGetKernel");
68+
size_t stride_size = desired_shape.size();
69+
std::vector<std::vector<int64_t>> strides_vector;
70+
strides_vector.reserve(num_indices + 2);
71+
72+
for (int i = 0; i < 2; ++i) {
73+
if (i < strides_array.size() && strides_array[i] != nullptr) {
74+
strides_vector.emplace_back(strides_array[i],
75+
strides_array[i] + stride_size);
76+
} else {
77+
strides_vector.emplace_back(stride_size, 0);
78+
}
79+
}
80+
81+
std::vector<int64_t> template_vec(template_stride,
82+
template_stride + stride_size);
83+
for (size_t i = 0; i < num_indices; ++i) {
84+
strides_vector.push_back(template_vec);
85+
}
86+
87+
auto offset_calc = funcs::make_offset_calculator<3>(
88+
desired_shape.size(), desired_shape.data(), strides_vector);
89+
90+
const int64_t N = numel;
91+
PADDLE_ENFORCE(N >= 0 && N <= std::numeric_limits<int32_t>::max(),
92+
"N >= 0 && N <= std::numeric_limits<int32_t>::max()");
93+
94+
constexpr int nt = 128;
95+
constexpr int vt = 4;
96+
const dim3 block(nt);
97+
const dim3 grid((N + block.x * vt - 1) / (block.x * vt));
98+
auto stream = ctx.stream();
99+
100+
using dtype = funcs::OpaqueType<sizeof(T)>;
101+
102+
const char* in_ptr = reinterpret_cast<const char*>(value.data<T>());
103+
char* out_ptr = reinterpret_cast<char*>(output->data<T>());
104+
105+
funcs::index_elementwise_kernel<nt, vt>
106+
<<<grid, block, 0, stream>>>(N, [=] __device__(int idx) {
107+
const auto offsets = offset_calc.get(idx);
108+
char* const out_data = out_ptr + offsets[0];
109+
const char* const in_data = in_ptr + offsets[1];
110+
111+
int64_t offset = 0;
112+
#pragma unroll
113+
for (int i = 0; i < num_indices; i++) {
114+
int64_t index =
115+
*reinterpret_cast<int64_t*>(index_ptrs[i] + offsets[2]);
116+
PADDLE_ENFORCE(-sizes[i] <= index && index < sizes[i],
117+
"index out of bounds");
118+
if (index < 0) {
119+
index += sizes[i];
120+
}
121+
offset += index * strides[i];
122+
}
123+
*reinterpret_cast<dtype*>(out_data + offset) =
124+
*reinterpret_cast<const dtype*>(in_data);
125+
});
126+
}
127+
128+
template <typename T, typename Context>
129+
void IndexElementwiseGetGradKernel(const Context& ctx,
130+
const DenseTensor& x,
131+
const std::vector<const DenseTensor*>& index,
132+
const DenseTensor& out_grad,
133+
const std::vector<int64_t>& input_dims,
134+
const std::vector<int64_t>& input_strides,
135+
const std::vector<int64_t>& index_dims,
136+
const std::vector<int64_t>& index_strides,
137+
DenseTensor* x_grad) {
138+
ctx.template Alloc<T>(x_grad);
139+
auto dxt = phi::EigenVector<T>::Flatten(*x_grad);
140+
auto& place = *ctx.eigen_device();
141+
dxt.device(place) = dxt.constant(static_cast<T>(0));
142+
if (out_grad.numel() == 0) return;
143+
144+
const auto& index_type = index[0]->dtype();
145+
PADDLE_ENFORCE_EQ(
146+
index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64,
147+
true,
148+
common::errors::InvalidArgument(
149+
"Index holds the wrong type, it holds [%s], but "
150+
"desires to be [%s] or [%s].",
151+
index_type,
152+
phi::DataType::INT32,
153+
phi::DataType::INT64));
154+
155+
if (index_type == phi::DataType::INT32) {
156+
GPUIndexElementwisePutKernel<T, int>(ctx,
157+
x,
158+
out_grad,
159+
index,
160+
input_dims,
161+
input_strides,
162+
index_dims,
163+
index_strides,
164+
x_grad);
165+
} else if (index_type == phi::DataType::INT64) {
166+
GPUIndexElementwisePutKernel<T, int64_t>(ctx,
167+
x,
168+
out_grad,
169+
index,
170+
input_dims,
171+
input_strides,
172+
index_dims,
173+
index_strides,
174+
x_grad);
175+
}
176+
}
177+
178+
} // namespace phi
179+
PD_REGISTER_KERNEL(index_elementwise_get_grad,
180+
GPU,
181+
ALL_LAYOUT,
182+
phi::IndexElementwiseGetGradKernel,
183+
bool,
184+
float,
185+
double,
186+
int,
187+
int8_t,
188+
int64_t,
189+
int16_t,
190+
uint8_t,
191+
phi::dtype::float16,
192+
phi::dtype::bfloat16,
193+
phi::dtype::complex<float>,
194+
phi::dtype::complex<double>) {}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/phi/core/dense_tensor.h"
18+
#include "paddle/phi/core/tensor_array.h"
19+
20+
namespace phi {
21+
22+
template <typename T, typename Context>
23+
void IndexElementwiseGetGradKernel(const Context& ctx,
24+
const DenseTensor& x,
25+
const std::vector<const DenseTensor*>& index,
26+
const DenseTensor& out_grad,
27+
const std::vector<int64_t>& input_dims,
28+
const std::vector<int64_t>& input_strides,
29+
const std::vector<int64_t>& index_dims,
30+
const std::vector<int64_t>& index_strides,
31+
DenseTensor* x_grad);
32+
33+
} // namespace phi

paddle/phi/ops/yaml/backward.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1672,6 +1672,18 @@
16721672
inplace : (out_grad -> x_grad)
16731673
backward : index_add_double_grad
16741674

1675+
- backward_op : index_elementwise_get_grad
1676+
forward : index_elementwise_get (Tensor x, Tensor[] index, int64_t[] input_dims, int64_t[] input_strides, int64_t[] index_dims, int64_t[] index_stride) -> Tensor(out)
1677+
args : (Tensor x, Tensor[] index, Tensor out_grad, int64_t[] input_dims, int64_t[] input_strides, int64_t[] index_dims, int64_t[] index_stride)
1678+
output : Tensor(x_grad)
1679+
infer_meta :
1680+
func : IndexElementwiseGetGradInferMeta
1681+
kernel :
1682+
func : index_elementwise_get_grad
1683+
data_type : out_grad
1684+
data_transform :
1685+
skip_transform : index
1686+
16751687
- backward_op : index_put_double_grad
16761688
forward : index_put_grad (Tensor x, Tensor[] indices, Tensor value, Tensor grad_out, bool accumulate=false) -> Tensor(grad_x), Tensor(grad_value)
16771689
args : (Tensor x, Tensor[] indices, Tensor value, Tensor grad_x_grad, Tensor grad_value_grad, bool accumulate=false)

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2761,9 +2761,7 @@
27612761
kernel :
27622762
func : index_elementwise_get
27632763
data_type : x
2764-
# backward : index_elementwise_grad
2765-
# interfaces : paddle::dialect::InferSymbolicShapeInterface
2766-
traits : paddle::dialect::ForwardOnlyTrait
2764+
backward : index_elementwise_get_grad
27672765

27682766
- op : index_put
27692767
args : (Tensor x, Tensor[] indices, Tensor value, bool accumulate=false)

test/legacy_test/test_index_elementwise.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
import numpy as np
1717

1818
import paddle
19-
from paddle import base
20-
from paddle.base import core
2119

2220

2321
def np_index_elementwise(x, index):
@@ -64,45 +62,6 @@ def setUp(self):
6462

6563
self.out_np = np_index_elementwise(self.x_np, self.index_np)
6664

67-
def test_static_graph(self):
68-
paddle.enable_static()
69-
startup_program = base.Program()
70-
train_program = base.Program()
71-
72-
with base.program_guard(startup_program, train_program):
73-
x = paddle.static.data(
74-
name='x', dtype=self.dtype, shape=self.x_shape
75-
)
76-
index = paddle.static.data(
77-
name='index', dtype='bool', shape=self.index_shape
78-
)
79-
out = x[index]
80-
81-
place = (
82-
base.CUDAPlace(0)
83-
if core.is_compiled_with_cuda()
84-
else base.CPUPlace()
85-
)
86-
exe = base.Executor(place)
87-
88-
result = exe.run(
89-
base.default_main_program(),
90-
feed={
91-
'x': self.x_np,
92-
'index': self.index_np,
93-
},
94-
fetch_list=[out],
95-
)[0]
96-
97-
atol = 1e-05 if self.dtype in ["float32", "float64"] else 0
98-
rtol = 1e-05 if self.dtype in ["float32", "float64"] else 0
99-
100-
np.testing.assert_allclose(
101-
result, self.out_np, atol=atol, rtol=rtol
102-
)
103-
104-
paddle.disable_static()
105-
10665
def test_dygraph(self):
10766
paddle.disable_static()
10867

@@ -227,7 +186,6 @@ def setUp(self):
227186
)
228187
self.out_np = np_index_elementwise(self.x_np, self.index_np)
229188

230-
self.test_static_graph()
231189
self.test_dygraph()
232190

233191

0 commit comments

Comments
 (0)