Skip to content

Commit ade7108

Browse files
committed
update ComputeStrides
1 parent 557fff7 commit ade7108

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

paddle/phi/kernels/funcs/index_elementwise.cu.h

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -133,16 +133,18 @@ struct OffsetCalculator {
133133

134134
template <typename T>
135135
std::array<int64_t, DDim::kMaxRank> ComputeStrides(
136-
const std::vector<int64_t>& index_dims) {
137-
const int rank = index_dims.size();
138-
std::array<int64_t, DDim::kMaxRank> strides{};
136+
const phi::DenseTensor& input, const size_t index_dims_size) {
137+
const auto& input_strides = input.strides();
138+
const size_t element_size_bytes = sizeof(T);
139139

140-
if (rank >= 1) {
141-
strides[rank - 1] = index_dims[rank - 1] * sizeof(T);
142-
}
140+
std::array<int64_t, DDim::kMaxRank> strides{};
143141

144-
for (int i = rank - 2; i >= 0; --i) {
145-
strides[i] = index_dims[i] * strides[i + 1];
142+
for (int i = 0; i < index_dims_size; ++i) {
143+
if (i < input_strides.size()) {
144+
strides[i] = input_strides[i] * element_size_bytes;
145+
} else {
146+
strides[i] = 0;
147+
}
146148
}
147149

148150
return strides;
@@ -235,7 +237,7 @@ void IndexElementwiseKernel(const phi::GPUContext& ctx,
235237
DenseTensor* output) {
236238
auto num_indices = index_dims.size();
237239

238-
auto index_stride = ComputeStrides<T>(index_dims);
240+
auto index_stride = ComputeStrides<T>(input, num_indices);
239241
auto index_ptrs = GetIndexDataPtrs<IndexT>(index);
240242

241243
auto sizes = std::array<int64_t, DDim::kMaxRank>{};

0 commit comments

Comments
 (0)