Skip to content

Commit e12543a

Browse files
committed
update
1 parent ade7108 commit e12543a

File tree

7 files changed

+40
-36
lines changed

7 files changed

+40
-36
lines changed

paddle/fluid/pybind/slice_utils.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,20 @@ static paddle::Tensor dealWithAdvancedIndex(
493493
return transed_tensor;
494494
}
495495

496+
inline std::vector<int64_t> ComputeIndexStrides(const paddle::Tensor& input,
497+
const size_t index_dims_size) {
498+
const auto& input_strides = input.strides();
499+
size_t element_size_bytes = phi::SizeOf(input.dtype());
500+
std::vector<int64_t> strides(index_dims_size, 0);
501+
const size_t min_size =
502+
std::min(static_cast<size_t>(input_strides.size()), index_dims_size);
503+
for (size_t i = 0; i < min_size; ++i) {
504+
strides[i] = input_strides[i] * element_size_bytes;
505+
}
506+
507+
return strides;
508+
}
509+
496510
static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor,
497511
const paddle::Tensor& bool_index) {
498512
PADDLE_ENFORCE(bool_index.shape().size() <= tensor.shape().size(),
@@ -540,8 +554,10 @@ static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor,
540554
indices.emplace_back(sliced_tensor);
541555
}
542556
auto index_dims_vec = common::vectorize<int64_t>(bool_index.dims());
557+
auto index_stride = ComputeIndexStrides(tensor, index_dims_vec.size());
543558

544-
return index_elementwise_ad_func(tensor, indices, index_dims_vec);
559+
return index_elementwise_ad_func(
560+
tensor, indices, index_dims_vec, index_stride);
545561
#else
546562

547563
return gather_nd_ad_func(tensor, bool_2_idx);

paddle/phi/infermeta/binary.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2149,9 +2149,9 @@ void GatherNdInferMeta(const MetaTensor& x,
21492149
void IndexElementwiseInferMeta(const MetaTensor& x,
21502150
const std::vector<const MetaTensor*>& index,
21512151
const std::vector<int64_t>& index_dims,
2152+
const std::vector<int64_t>& index_stride,
21522153
MetaTensor* out) {
21532154
const auto& x_dims = x.dims();
2154-
// auto index_dims = index.dims();
21552155

21562156
PADDLE_ENFORCE_LE(
21572157
index_dims.size(),

paddle/phi/infermeta/binary.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ void GatherNdInferMeta(const MetaTensor& x,
404404
void IndexElementwiseInferMeta(const MetaTensor& x,
405405
const std::vector<const MetaTensor*>& index,
406406
const std::vector<int64_t>& index_dims,
407+
const std::vector<int64_t>& index_stride,
407408
MetaTensor* out);
408409

409410
void GatherTreeMeta(const MetaTensor& ids,

paddle/phi/kernels/funcs/index_elementwise.cu.h

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -38,7 +38,6 @@ constexpr int MAX_DIMS = 25;
3838
#endif
3939

4040
static constexpr int launch_bound2 = 4;
41-
4241
static constexpr int launch_size_nd = 128;
4342

4443
template <int nt, int vt, typename func_t>
@@ -92,8 +91,9 @@ struct OffsetCalculator {
9291
const int64_t* const* strides,
9392
const int64_t* element_sizes = nullptr)
9493
: dims(dims) {
95-
PADDLE_ENFORCE(
96-
dims <= MAX_DIMS, "tensor has too many (>", MAX_DIMS, ") dims");
94+
PADDLE_ENFORCE(dims <= MAX_DIMS,
95+
"The number of dimensions (%d) exceeds MAX_DIMS.",
96+
dims);
9797
for (int i = 0; i < dims; i++) {
9898
sizes_[i] = IntDivider<index_t>(sizes[i]);
9999
for (int arg = 0; arg < NARGS; arg++) {
@@ -131,41 +131,23 @@ struct OffsetCalculator {
131131
stride_t strides_[MAX_DIMS][std::max<int>(NARGS, 1)];
132132
};
133133

134-
template <typename T>
135-
std::array<int64_t, DDim::kMaxRank> ComputeStrides(
136-
const phi::DenseTensor& input, const size_t index_dims_size) {
137-
const auto& input_strides = input.strides();
138-
const size_t element_size_bytes = sizeof(T);
139-
140-
std::array<int64_t, DDim::kMaxRank> strides{};
141-
142-
for (int i = 0; i < index_dims_size; ++i) {
143-
if (i < input_strides.size()) {
144-
strides[i] = input_strides[i] * element_size_bytes;
145-
} else {
146-
strides[i] = 0;
147-
}
148-
}
149-
150-
return strides;
151-
}
152-
153134
template <typename IndexT>
154135
std::array<char*, DDim::kMaxRank> GetIndexDataPtrs(
155136
const std::vector<const DenseTensor*> index) {
156137
std::array<char*, DDim::kMaxRank> index_ptrs{};
157138

158139
PADDLE_ENFORCE_LE(index.size(),
159140
DDim::kMaxRank,
160-
"The number of index tensors exceeds the maximum rank.");
141+
"The rank of the index tensor must be less than or "
142+
"equal to DDim::kMaxRank.");
161143

162144
for (size_t i = 0; i < index.size(); ++i) {
163145
const IndexT* p_index = index[i]->data<IndexT>();
164146

165147
PADDLE_ENFORCE(p_index != nullptr,
166-
"The pointer p_index is nullptr, "
167-
"please check whether the index tensor is valid and "
168-
"its data is correctly initialized.");
148+
"The pointer p_index must not be nullptr. "
149+
"Please ensure the index tensor is valid and its data "
150+
"is correctly initialized.");
169151

170152
index_ptrs[i] = reinterpret_cast<char*>(const_cast<IndexT*>(p_index));
171153
}
@@ -234,10 +216,10 @@ void IndexElementwiseKernel(const phi::GPUContext& ctx,
234216
const DenseTensor& input,
235217
const std::vector<const DenseTensor*> index,
236218
const std::vector<int64_t>& index_dims,
219+
const std::vector<int64_t>& index_stride,
237220
DenseTensor* output) {
238221
auto num_indices = index_dims.size();
239222

240-
auto index_stride = ComputeStrides<T>(input, num_indices);
241223
auto index_ptrs = GetIndexDataPtrs<IndexT>(index);
242224

243225
auto sizes = std::array<int64_t, DDim::kMaxRank>{};
@@ -252,7 +234,9 @@ void IndexElementwiseKernel(const phi::GPUContext& ctx,
252234

253235
const int64_t N = output->numel();
254236
PADDLE_ENFORCE(N >= 0 && N <= std::numeric_limits<int32_t>::max(),
255-
"N >= 0 && N <= std::numeric_limits<int32_t>::max()");
237+
238+
"Output numel be in the range [0, "
239+
"std::numeric_limits<int32_t>::max()]");
256240
constexpr int nt = launch_size_nd;
257241
constexpr int vt = launch_bound2;
258242
const dim3 block(nt);

paddle/phi/kernels/gpu/index_elementwise_kernel.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -25,6 +25,7 @@ void IndexElementwiseKernel(const Context& ctx,
2525
const DenseTensor& x,
2626
const std::vector<const DenseTensor*>& index,
2727
const std::vector<int64_t>& index_dims,
28+
const std::vector<int64_t>& index_stride,
2829
DenseTensor* out) {
2930
const auto& index_type = index[0]->dtype();
3031
PADDLE_ENFORCE_EQ(
@@ -48,10 +49,11 @@ void IndexElementwiseKernel(const Context& ctx,
4849
ctx.template Alloc<T>(out);
4950

5051
if (index_type == phi::DataType::INT32) {
51-
phi::funcs::IndexElementwiseKernel<T, int>(ctx, x, index, index_dims, out);
52+
phi::funcs::IndexElementwiseKernel<T, int>(
53+
ctx, x, index, index_dims, index_stride, out);
5254
} else if (index_type == phi::DataType::INT64) {
5355
phi::funcs::IndexElementwiseKernel<T, int64_t>(
54-
ctx, x, index, index_dims, out);
56+
ctx, x, index, index_dims, index_stride, out);
5557
}
5658
}
5759

paddle/phi/kernels/index_elementwise_kernel.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -24,6 +24,7 @@ void IndexElementwiseKernel(const Context &ctx,
2424
const DenseTensor &x,
2525
const std::vector<const DenseTensor *> &index,
2626
const std::vector<int64_t> &index_dims,
27+
const std::vector<int64_t> &index_stride,
2728
DenseTensor *out);
2829

2930
} // namespace phi

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2754,7 +2754,7 @@
27542754
interfaces : paddle::dialect::InferSymbolicShapeInterface, paddle::dialect::LayoutTransformationInterface
27552755

27562756
- op : index_elementwise
2757-
args : (Tensor x, Tensor[] index, int64_t[] index_dims)
2757+
args : (Tensor x, Tensor[] index, int64_t[] index_dims, int64_t[] index_stride)
27582758
output : Tensor (out)
27592759
infer_meta :
27602760
func : IndexElementwiseInferMeta

0 commit comments

Comments
 (0)