@@ -133,16 +133,18 @@ struct OffsetCalculator {
133
133
134
134
template <typename T>
135
135
std::array<int64_t , DDim::kMaxRank > ComputeStrides (
136
- const std::vector< int64_t >& index_dims ) {
137
- const int rank = index_dims. size ();
138
- std::array< int64_t , DDim:: kMaxRank > strides{} ;
136
+ const phi::DenseTensor& input, const size_t index_dims_size ) {
137
+ const auto & input_strides = input. strides ();
138
+ const size_t element_size_bytes = sizeof (T) ;
139
139
140
- if (rank >= 1 ) {
141
- strides[rank - 1 ] = index_dims[rank - 1 ] * sizeof (T);
142
- }
140
+ std::array<int64_t , DDim::kMaxRank > strides{};
143
141
144
- for (int i = rank - 2 ; i >= 0 ; --i) {
145
- strides[i] = index_dims[i] * strides[i + 1 ];
142
+ for (int i = 0 ; i < index_dims_size; ++i) {
143
+ if (i < input_strides.size ()) {
144
+ strides[i] = input_strides[i] * element_size_bytes;
145
+ } else {
146
+ strides[i] = 0 ;
147
+ }
146
148
}
147
149
148
150
return strides;
@@ -235,7 +237,7 @@ void IndexElementwiseKernel(const phi::GPUContext& ctx,
235
237
DenseTensor* output) {
236
238
auto num_indices = index_dims.size ();
237
239
238
- auto index_stride = ComputeStrides<T>(index_dims );
240
+ auto index_stride = ComputeStrides<T>(input, num_indices );
239
241
auto index_ptrs = GetIndexDataPtrs<IndexT>(index);
240
242
241
243
auto sizes = std::array<int64_t , DDim::kMaxRank >{};
0 commit comments