Skip to content

Commit 588f0ec

Browse files
committed
add index_elementwise_get_grad kernel
1 parent 108db2c commit 588f0ec

File tree

10 files changed

+292
-62
lines changed

10 files changed

+292
-62
lines changed

paddle/fluid/pybind/slice_utils.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,15 @@ static inline common::DDim infer_size_symdimvector(common::DDim a,
5454
auto sizeA = (dimA >= 0) ? a[dimA] : 1;
5555
auto sizeB = (dimB >= 0) ? b[dimB] : 1;
5656

57-
PADDLE_ENFORCE(sizeA == sizeB || sizeA == 1 || sizeB == 1,
58-
common::errors::Fatal("The size of tensor a (",
59-
sizeA,
60-
") must match the size of tensor b (",
61-
sizeB,
62-
") at non-singleton dimension ",
63-
i));
57+
PADDLE_ENFORCE_EQ(
58+
sizeA == sizeB || sizeA == 1 || sizeB == 1,
59+
true,
60+
common::errors::Fatal("The size of tensor a (",
61+
sizeA,
62+
") must match the size of tensor b (",
63+
sizeB,
64+
") at non-singleton dimension ",
65+
i));
6466

6567
// 1s map to the other size (even 0).
6668
expandedSizes[i] = sizeA == 1 ? sizeB : sizeA;

paddle/phi/infermeta/backward.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1887,4 +1887,18 @@ void SetValueGradInferMeta(const MetaTensor& out_grad,
18871887
value_grad->share_lod(values);
18881888
}
18891889
}
1890+
1891+
void IndexElementwiseGetGradInferMeta(
1892+
const MetaTensor& x,
1893+
const std::vector<const MetaTensor*>& index,
1894+
const MetaTensor& out_grad,
1895+
const std::vector<int64_t>& input_dims,
1896+
const std::vector<int64_t>& input_strides,
1897+
const std::vector<int64_t>& index_dims,
1898+
const std::vector<int64_t>& index_strides,
1899+
MetaTensor* x_grad) {
1900+
if (x_grad) {
1901+
x_grad->share_meta(x);
1902+
}
1903+
}
18901904
} // namespace phi

paddle/phi/infermeta/backward.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,4 +680,14 @@ void SetValueGradInferMeta(const MetaTensor& out_grad,
680680
MetaTensor* x_grad,
681681
MetaTensor* value_grad);
682682

683+
void IndexElementwiseGetGradInferMeta(
684+
const MetaTensor& x,
685+
const std::vector<const MetaTensor*>& index,
686+
const MetaTensor& out_grad,
687+
const std::vector<int64_t>& input_dims,
688+
const std::vector<int64_t>& input_strides,
689+
const std::vector<int64_t>& index_dims,
690+
const std::vector<int64_t>& index_strides,
691+
MetaTensor* x_grad);
692+
683693
} // namespace phi

paddle/phi/kernels/funcs/index_elementwise.cu.h

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ constexpr int MAX_DIMS = 16;
3636
#else
3737
constexpr int MAX_DIMS = 25;
3838
#endif
39+
constexpr int MAX_DIMS = 9;
3940

4041
static constexpr int launch_bound2 = 4;
4142
static constexpr int launch_size_nd = 128;
@@ -91,9 +92,11 @@ struct OffsetCalculator {
9192
const int64_t* const* strides,
9293
const int64_t* element_sizes = nullptr)
9394
: dims(dims) {
94-
PADDLE_ENFORCE(dims <= MAX_DIMS,
95-
"The number of dimensions (%d) exceeds MAX_DIMS.",
96-
dims);
95+
PADDLE_ENFORCE_LE(
96+
dims,
97+
MAX_DIMS,
98+
common::errors::InvalidArgument(
99+
"Tensor has too many dims. Maximum dim is d%.", MAX_DIMS));
97100
for (int i = 0; i < dims; i++) {
98101
sizes_[i] = IntDivider<index_t>(sizes[i]);
99102
for (int arg = 0; arg < NARGS; arg++) {
@@ -144,10 +147,12 @@ std::array<char*, DDim::kMaxRank> GetIndexDataPtrs(
144147
for (size_t i = 0; i < index.size(); ++i) {
145148
const IndexT* p_index = index[i]->data<IndexT>();
146149

147-
PADDLE_ENFORCE(p_index != nullptr,
148-
"The pointer p_index must not be nullptr. "
149-
"Please ensure the index tensor is valid and its data "
150-
"is correctly initialized.");
150+
PADDLE_ENFORCE_NOT_NULL(
151+
p_index,
152+
::common::errors::InvalidArgument(
153+
"The pointer p_index is nullptr, "
154+
"please check whether the index tensor is valid and "
155+
"its data is correctly initialized."));
151156

152157
index_ptrs[i] = reinterpret_cast<char*>(const_cast<IndexT*>(p_index));
153158
}

paddle/phi/kernels/funcs/stride_utils.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,10 @@ static inline void reorder_dimensions(const std::vector<int64_t> stride_size,
197197
permute_dimensions<N>(stride_size, perm_, strides_array, shape_);
198198
}
199199

200-
std::vector<int64_t> compatible_stride(const std::vector<int64_t>* shape_,
201-
const int64_t ndim,
202-
const int64_t element_size) {
200+
static inline std::vector<int64_t> compatible_stride(
201+
const std::vector<int64_t>* shape_,
202+
const int64_t ndim,
203+
const int64_t element_size) {
203204
std::vector<int64_t> stride;
204205
int64_t next_stride = element_size;
205206

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/index_elementwise_get_grad_kernel.h"
16+
17+
#include "paddle/phi/backends/gpu/gpu_context.h"
18+
#include "paddle/phi/common/bfloat16.h"
19+
#include "paddle/phi/core/kernel_registry.h"
20+
#include "paddle/phi/kernels/funcs/eigen/common.h"
21+
#include "paddle/phi/kernels/funcs/index_elementwise.cu.h"
22+
#include "paddle/phi/kernels/funcs/stride_utils.h"
23+
24+
namespace phi {
25+
26+
template <typename T, typename IndexT = int>
27+
void GPUIndexElementwisePutKernel(const phi::GPUContext& ctx,
28+
const DenseTensor& input,
29+
const DenseTensor& value,
30+
const std::vector<const DenseTensor*>& index,
31+
const std::vector<int64_t>& input_dims,
32+
const std::vector<int64_t>& input_strides,
33+
const std::vector<int64_t>& index_dims,
34+
const std::vector<int64_t>& index_strides,
35+
DenseTensor* output) {
36+
int64_t numel = 0;
37+
38+
auto num_indices = index_dims.size();
39+
40+
auto sizes = std::array<int64_t, 25>{};
41+
auto strides = std::array<int64_t, 25>{};
42+
for (unsigned i = 0; i < num_indices; i++) {
43+
sizes[i] = index_dims[i];
44+
strides[i] = index_strides[i];
45+
}
46+
auto index_ptrs = funcs::GetIndexDataPtrs<IndexT>(index);
47+
48+
std::array<int64_t*, 3> strides_array;
49+
std::vector<int64_t> desired_shape;
50+
51+
funcs::IndexPutStride<3>(input_dims,
52+
input_strides,
53+
phi::SizeOf(input.dtype()),
54+
std::vector<int64_t>(),
55+
std::vector<int64_t>(),
56+
phi::SizeOf(value.dtype()),
57+
common::vectorize<int64_t>(index[0]->dims()),
58+
common::vectorize<int64_t>(index[0]->strides()),
59+
phi::SizeOf(index[0]->dtype()),
60+
&desired_shape,
61+
&strides_array,
62+
&numel);
63+
64+
const int64_t* template_stride = strides_array[2];
65+
PADDLE_ENFORCE_NOT_NULL(template_stride,
66+
::common::errors::InvalidArgument(
67+
"strides_array[2] should not be nullptr in "
68+
"GPUIndexElementwiseGetKernel"));
69+
70+
size_t stride_size = desired_shape.size();
71+
std::vector<std::vector<int64_t>> strides_vector;
72+
strides_vector.reserve(num_indices + 2);
73+
74+
for (int i = 0; i < 2; ++i) {
75+
if (i < strides_array.size() && strides_array[i] != nullptr) {
76+
strides_vector.emplace_back(strides_array[i],
77+
strides_array[i] + stride_size);
78+
} else {
79+
strides_vector.emplace_back(stride_size, 0);
80+
}
81+
}
82+
83+
std::vector<int64_t> template_vec(template_stride,
84+
template_stride + stride_size);
85+
for (size_t i = 0; i < num_indices; ++i) {
86+
strides_vector.push_back(template_vec);
87+
}
88+
89+
auto offset_calc = funcs::make_offset_calculator<3>(
90+
desired_shape.size(), desired_shape.data(), strides_vector);
91+
92+
const int64_t N = numel;
93+
PADDLE_ENFORCE(N >= 0 && N <= std::numeric_limits<int32_t>::max(),
94+
"Output numel be in the range [0, "
95+
"std::numeric_limits<int32_t>::max()]");
96+
97+
constexpr int nt = 128;
98+
constexpr int vt = 4;
99+
const dim3 block(nt);
100+
const dim3 grid((N + block.x * vt - 1) / (block.x * vt));
101+
auto stream = ctx.stream();
102+
103+
using dtype = funcs::OpaqueType<sizeof(T)>;
104+
105+
const char* in_ptr = reinterpret_cast<const char*>(value.data<T>());
106+
char* out_ptr = reinterpret_cast<char*>(output->data<T>());
107+
108+
funcs::index_elementwise_kernel<nt, vt>
109+
<<<grid, block, 0, stream>>>(N, [=] __device__(int idx) {
110+
const auto offsets = offset_calc.get(idx);
111+
char* const out_data = out_ptr + offsets[0];
112+
const char* const in_data = in_ptr + offsets[1];
113+
114+
int64_t offset = 0;
115+
#pragma unroll
116+
for (int i = 0; i < num_indices; i++) {
117+
int64_t index =
118+
*reinterpret_cast<int64_t*>(index_ptrs[i] + offsets[2]);
119+
PADDLE_ENFORCE(-sizes[i] <= index && index < sizes[i],
120+
"index out of bounds");
121+
if (index < 0) {
122+
index += sizes[i];
123+
}
124+
offset += index * strides[i];
125+
}
126+
*reinterpret_cast<dtype*>(out_data + offset) =
127+
*reinterpret_cast<const dtype*>(in_data);
128+
});
129+
}
130+
131+
template <typename T, typename Context>
132+
void IndexElementwiseGetGradKernel(const Context& ctx,
133+
const DenseTensor& x,
134+
const std::vector<const DenseTensor*>& index,
135+
const DenseTensor& out_grad,
136+
const std::vector<int64_t>& input_dims,
137+
const std::vector<int64_t>& input_strides,
138+
const std::vector<int64_t>& index_dims,
139+
const std::vector<int64_t>& index_strides,
140+
DenseTensor* x_grad) {
141+
ctx.template Alloc<T>(x_grad);
142+
auto dxt = phi::EigenVector<T>::Flatten(*x_grad);
143+
auto& place = *ctx.eigen_device();
144+
dxt.device(place) = dxt.constant(static_cast<T>(0));
145+
if (out_grad.numel() == 0) return;
146+
147+
const auto& index_type = index[0]->dtype();
148+
PADDLE_ENFORCE_EQ(
149+
index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64,
150+
true,
151+
common::errors::InvalidArgument(
152+
"Index holds the wrong type, it holds [%s], but "
153+
"desires to be [%s] or [%s].",
154+
index_type,
155+
phi::DataType::INT32,
156+
phi::DataType::INT64));
157+
158+
if (index_type == phi::DataType::INT32) {
159+
GPUIndexElementwisePutKernel<T, int>(ctx,
160+
x,
161+
out_grad,
162+
index,
163+
input_dims,
164+
input_strides,
165+
index_dims,
166+
index_strides,
167+
x_grad);
168+
} else if (index_type == phi::DataType::INT64) {
169+
GPUIndexElementwisePutKernel<T, int64_t>(ctx,
170+
x,
171+
out_grad,
172+
index,
173+
input_dims,
174+
input_strides,
175+
index_dims,
176+
index_strides,
177+
x_grad);
178+
}
179+
}
180+
181+
} // namespace phi
182+
PD_REGISTER_KERNEL(index_elementwise_get_grad,
183+
GPU,
184+
ALL_LAYOUT,
185+
phi::IndexElementwiseGetGradKernel,
186+
bool,
187+
float,
188+
double,
189+
int,
190+
int8_t,
191+
int64_t,
192+
int16_t,
193+
uint8_t,
194+
phi::dtype::float16,
195+
phi::dtype::bfloat16,
196+
phi::dtype::complex<float>,
197+
phi::dtype::complex<double>) {}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/phi/core/dense_tensor.h"
18+
#include "paddle/phi/core/tensor_array.h"
19+
20+
namespace phi {
21+
22+
template <typename T, typename Context>
23+
void IndexElementwiseGetGradKernel(const Context& ctx,
24+
const DenseTensor& x,
25+
const std::vector<const DenseTensor*>& index,
26+
const DenseTensor& out_grad,
27+
const std::vector<int64_t>& input_dims,
28+
const std::vector<int64_t>& input_strides,
29+
const std::vector<int64_t>& index_dims,
30+
const std::vector<int64_t>& index_strides,
31+
DenseTensor* x_grad);
32+
33+
} // namespace phi

paddle/phi/ops/yaml/backward.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1672,6 +1672,18 @@
16721672
inplace : (out_grad -> x_grad)
16731673
backward : index_add_double_grad
16741674

1675+
- backward_op : index_elementwise_get_grad
1676+
forward : index_elementwise_get (Tensor x, Tensor[] index, int64_t[] input_dims, int64_t[] input_strides, int64_t[] index_dims, int64_t[] index_stride) -> Tensor(out)
1677+
args : (Tensor x, Tensor[] index, Tensor out_grad, int64_t[] input_dims, int64_t[] input_strides, int64_t[] index_dims, int64_t[] index_stride)
1678+
output : Tensor(x_grad)
1679+
infer_meta :
1680+
func : IndexElementwiseGetGradInferMeta
1681+
kernel :
1682+
func : index_elementwise_get_grad
1683+
data_type : out_grad
1684+
data_transform :
1685+
skip_transform : index
1686+
16751687
- backward_op : index_put_double_grad
16761688
forward : index_put_grad (Tensor x, Tensor[] indices, Tensor value, Tensor grad_out, bool accumulate=false) -> Tensor(grad_x), Tensor(grad_value)
16771689
args : (Tensor x, Tensor[] indices, Tensor value, Tensor grad_x_grad, Tensor grad_value_grad, bool accumulate=false)

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2761,9 +2761,7 @@
27612761
kernel :
27622762
func : index_elementwise_get
27632763
data_type : x
2764-
# backward : index_elementwise_grad
2765-
# interfaces : paddle::dialect::InferSymbolicShapeInterface
2766-
traits : paddle::dialect::ForwardOnlyTrait
2764+
backward : index_elementwise_get_grad
27672765

27682766
- op : index_put
27692767
args : (Tensor x, Tensor[] indices, Tensor value, bool accumulate=false)

0 commit comments

Comments
 (0)