Skip to content

Commit 108db2c

Browse files
committed
Add stride mechanism and unit tests
1 parent e12543a commit 108db2c

File tree

10 files changed

+1046
-226
lines changed

10 files changed

+1046
-226
lines changed

paddle/fluid/pybind/slice_utils.h

Lines changed: 181 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,152 @@ namespace py = pybind11;
3939

4040
namespace paddle {
4141
namespace pybind {
42+
static inline common::DDim infer_size_symdimvector(common::DDim a,
43+
common::DDim b) {
44+
// Use ptrdiff_t to ensure signed comparison.
45+
auto dimsA = a.size();
46+
auto dimsB = b.size();
47+
auto ndim = dimsA > dimsB ? dimsA : dimsB;
48+
common::DDim expandedSizes = common::make_ddim(std::vector<int64_t>(ndim, 0));
49+
50+
for (int64_t i = ndim - 1; i >= 0; --i) {
51+
int64_t offset = ndim - 1 - i;
52+
int64_t dimA = dimsA - 1 - offset;
53+
int64_t dimB = dimsB - 1 - offset;
54+
auto sizeA = (dimA >= 0) ? a[dimA] : 1;
55+
auto sizeB = (dimB >= 0) ? b[dimB] : 1;
56+
57+
PADDLE_ENFORCE(sizeA == sizeB || sizeA == 1 || sizeB == 1,
58+
common::errors::Fatal("The size of tensor a (",
59+
sizeA,
60+
") must match the size of tensor b (",
61+
sizeB,
62+
") at non-singleton dimension ",
63+
i));
64+
65+
// 1s map to the other size (even 0).
66+
expandedSizes[i] = sizeA == 1 ? sizeB : sizeA;
67+
}
68+
69+
return expandedSizes;
70+
}
71+
72+
static inline std::vector<paddle::Tensor> expand_outplace(
73+
std::vector<paddle::Tensor> to_expand) {
74+
// expands a list of Tensors; ignores undefined (null) tensors
75+
bool first = true;
76+
common::DDim sizes;
77+
for (size_t i = 0; i < to_expand.size(); i++) {
78+
if (!to_expand[i].initialized()) {
79+
continue;
80+
} else if (first) {
81+
sizes = to_expand[i].dims();
82+
first = false;
83+
} else {
84+
sizes = infer_size_symdimvector(sizes, to_expand[i].dims());
85+
}
86+
}
87+
88+
std::vector<paddle::Tensor> result(to_expand.size());
89+
for (size_t i = 0; i < to_expand.size(); i++) {
90+
if (!to_expand[i].initialized()) {
91+
continue;
92+
} else if (to_expand[i].dims() == sizes) {
93+
result[i] = to_expand[i];
94+
} else {
95+
result[i] =
96+
expand_ad_func(to_expand[i], common::vectorize<int64_t>(sizes));
97+
}
98+
}
99+
return result;
100+
}
101+
102+
struct AdvancedIndex {
103+
AdvancedIndex(paddle::Tensor src, std::vector<paddle::Tensor> indices);
104+
105+
paddle::Tensor src;
106+
std::vector<paddle::Tensor> indices;
107+
std::vector<int64_t> indexed_sizes;
108+
std::vector<int64_t> indexed_strides;
109+
std::vector<int64_t> src_sizes;
110+
std::vector<int64_t> src_strides;
111+
int64_t dims_before;
112+
int64_t dims_after;
113+
};
114+
115+
inline static void restride_src(std::vector<int64_t>* shape,
116+
std::vector<int64_t>* strides,
117+
int64_t dims_before,
118+
int64_t dims_indexed,
119+
std::vector<int64_t> replacement_shape) {
120+
int64_t end = dims_before + dims_indexed;
121+
shape->erase(shape->begin() + dims_before, shape->begin() + end);
122+
strides->erase(strides->begin() + dims_before, strides->begin() + end);
123+
shape->insert(shape->begin() + dims_before,
124+
replacement_shape.begin(),
125+
replacement_shape.end());
126+
strides->insert(strides->begin() + dims_before, replacement_shape.size(), 0);
127+
}
128+
129+
// move to cuda kernel
130+
inline static paddle::Tensor reshape_indexer(paddle::Tensor* index,
131+
int64_t dims_before,
132+
int64_t dims_after) {
133+
auto orig_shape = common::vectorize<int64_t>(index->dims());
134+
auto shape = std::vector<int64_t>{};
135+
shape.insert(shape.end(), dims_before, 1);
136+
shape.insert(shape.end(), orig_shape.begin(), orig_shape.end());
137+
shape.insert(shape.end(), dims_after, 1);
138+
*index = reshape_ad_func(*index, shape);
139+
return *index;
140+
}
141+
142+
inline AdvancedIndex::AdvancedIndex(paddle::Tensor src,
143+
std::vector<paddle::Tensor> indices_list) {
144+
uint32_t element_size_bytes = phi::SizeOf(src.dtype());
145+
int64_t dims_before = 0, dims_after = 0, dims_indexed = 0;
146+
std::vector<int64_t> shape_vec = common::vectorize<int64_t>(src.dims());
147+
std::vector<int64_t> stride_vec = common::vectorize<int64_t>(src.strides());
148+
std::vector<int64_t> replacement_shape;
149+
std::vector<int64_t> idx_shape_vec = {};
150+
std::vector<int64_t> idx_stride_vec = {};
151+
152+
for (size_t dim = 0; dim < indices_list.size(); dim++) {
153+
if (!indices_list[dim].defined() || indices_list[dim].dims().size() == 0) {
154+
if (dims_indexed == 0) {
155+
dims_before++;
156+
} else {
157+
dims_after++;
158+
}
159+
} else {
160+
dims_indexed++;
161+
replacement_shape = common::vectorize<int64_t>(indices_list[dim].dims());
162+
if (!replacement_shape.empty() && replacement_shape.back() == 1) {
163+
replacement_shape.pop_back();
164+
}
165+
166+
idx_shape_vec.push_back(shape_vec[dim]);
167+
idx_stride_vec.push_back(stride_vec[dim] * element_size_bytes);
168+
}
169+
}
170+
171+
this->dims_before = dims_before;
172+
this->dims_after = dims_after;
173+
restride_src(
174+
&shape_vec, &stride_vec, dims_before, dims_indexed, replacement_shape);
175+
this->src_sizes = shape_vec;
176+
this->src_strides = stride_vec;
177+
178+
this->indexed_sizes = idx_shape_vec;
179+
this->indexed_strides = idx_stride_vec;
180+
181+
// use dims_before and dims_after / move to cuda kernel
182+
for (auto& index : indices_list) {
183+
if (index.defined() && index.dims().size() > 0) {
184+
this->indices.push_back(reshape_indexer(&index, dims_before, dims_after));
185+
}
186+
}
187+
}
42188

43189
template <typename T>
44190
inline T GetDenseTensorValue(const phi::DenseTensor* x) {
@@ -493,18 +639,33 @@ static paddle::Tensor dealWithAdvancedIndex(
493639
return transed_tensor;
494640
}
495641

496-
inline std::vector<int64_t> ComputeIndexStrides(const paddle::Tensor& input,
497-
const size_t index_dims_size) {
498-
const auto& input_strides = input.strides();
499-
size_t element_size_bytes = phi::SizeOf(input.dtype());
500-
std::vector<int64_t> strides(index_dims_size, 0);
501-
const size_t min_size =
502-
std::min(static_cast<size_t>(input_strides.size()), index_dims_size);
503-
for (size_t i = 0; i < min_size; ++i) {
504-
strides[i] = input_strides[i] * element_size_bytes;
505-
}
642+
static std::vector<paddle::Tensor> PrepareIndices(
643+
const paddle::Tensor& tensor,
644+
const paddle::Tensor& bool_2_idx,
645+
const paddle::Tensor& bool_index) {
646+
std::vector<paddle::Tensor> indices;
647+
for (int j = 0; j < bool_2_idx.shape()[1]; ++j) {
648+
paddle::Tensor sliced_tensor =
649+
slice_ad_func(bool_2_idx, {1}, {j}, {j + 1}, {1}, {});
506650

507-
return strides;
651+
// Calculate the required dimensionality
652+
int64_t original_ndim =
653+
tensor.shape().size() - bool_index.shape().size() + 1;
654+
int64_t sliced_ndim = sliced_tensor.shape().size();
655+
int64_t num_ones_to_add = original_ndim - sliced_ndim;
656+
657+
// Reshape the tensor by adding 1s if needed
658+
if (num_ones_to_add > 0) {
659+
std::vector<int64_t> new_shape = sliced_tensor.shape();
660+
for (int64_t k = 0; k < num_ones_to_add; ++k) {
661+
new_shape.push_back(1);
662+
}
663+
sliced_tensor = reshape_ad_func(sliced_tensor, new_shape);
664+
}
665+
666+
indices.emplace_back(sliced_tensor);
667+
}
668+
return indices;
508669
}
509670

510671
static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor,
@@ -547,17 +708,15 @@ static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor,
547708

548709
auto bool_2_idx = nonzero_ad_func(bool_index);
549710
#ifdef PADDLE_WITH_CUDA
550-
std::vector<paddle::Tensor> indices;
551-
for (int j = 0; j < bool_2_idx.shape()[1]; ++j) {
552-
paddle::Tensor sliced_tensor =
553-
slice_ad_func(bool_2_idx, {1}, {j}, {j + 1}, {1}, {});
554-
indices.emplace_back(sliced_tensor);
555-
}
556-
auto index_dims_vec = common::vectorize<int64_t>(bool_index.dims());
557-
auto index_stride = ComputeIndexStrides(tensor, index_dims_vec.size());
558-
559-
return index_elementwise_ad_func(
560-
tensor, indices, index_dims_vec, index_stride);
711+
auto indices = PrepareIndices(tensor, bool_2_idx, bool_index);
712+
AdvancedIndex ad = AdvancedIndex(tensor, indices);
713+
714+
return index_elementwise_get_ad_func(tensor,
715+
ad.indices,
716+
ad.src_sizes,
717+
ad.src_strides,
718+
ad.indexed_sizes,
719+
ad.indexed_strides);
561720
#else
562721

563722
return gather_nd_ad_func(tensor, bool_2_idx);

paddle/phi/infermeta/binary.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2146,11 +2146,13 @@ void GatherNdInferMeta(const MetaTensor& x,
21462146
out->set_dtype(x.dtype());
21472147
}
21482148

2149-
void IndexElementwiseInferMeta(const MetaTensor& x,
2150-
const std::vector<const MetaTensor*>& index,
2151-
const std::vector<int64_t>& index_dims,
2152-
const std::vector<int64_t>& index_stride,
2153-
MetaTensor* out) {
2149+
void IndexElementwiseGetInferMeta(const MetaTensor& x,
2150+
const std::vector<const MetaTensor*>& index,
2151+
const std::vector<int64_t>& input_dims,
2152+
const std::vector<int64_t>& input_strides,
2153+
const std::vector<int64_t>& index_dims,
2154+
const std::vector<int64_t>& index_stride,
2155+
MetaTensor* out) {
21542156
const auto& x_dims = x.dims();
21552157

21562158
PADDLE_ENFORCE_LE(

paddle/phi/infermeta/binary.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -401,11 +401,13 @@ void GatherNdInferMeta(const MetaTensor& x,
401401
const MetaTensor& index,
402402
MetaTensor* out);
403403

404-
void IndexElementwiseInferMeta(const MetaTensor& x,
405-
const std::vector<const MetaTensor*>& index,
406-
const std::vector<int64_t>& index_dims,
407-
const std::vector<int64_t>& index_stride,
408-
MetaTensor* out);
404+
void IndexElementwiseGetInferMeta(const MetaTensor& x,
405+
const std::vector<const MetaTensor*>& index,
406+
const std::vector<int64_t>& input_dims,
407+
const std::vector<int64_t>& input_strides,
408+
const std::vector<int64_t>& index_dims,
409+
const std::vector<int64_t>& index_stride,
410+
MetaTensor* out);
409411

410412
void GatherTreeMeta(const MetaTensor& ids,
411413
const MetaTensor& parents,

paddle/phi/kernels/funcs/index_elementwise.cu.h

Lines changed: 4 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -157,119 +157,16 @@ std::array<char*, DDim::kMaxRank> GetIndexDataPtrs(
157157

158158
template <int N, bool signed_strides = false>
159159
static OffsetCalculator<N, uint32_t, signed_strides> make_offset_calculator(
160-
const DenseTensor& output,
161-
const DenseTensor& input,
162-
const std::vector<const DenseTensor*> index) {
163-
int ndim = output.dims().size();
164-
const int64_t* shape = output.dims().Get();
165-
std::vector<int64_t> shape_vec(shape, shape + ndim);
166-
std::reverse(shape_vec.begin(), shape_vec.end());
167-
const int64_t* desired_shape = shape_vec.data();
168-
169-
std::vector<std::vector<int64_t>> strides;
170-
std::vector<const DenseTensor*> tensors = {&output, &input};
171-
172-
for (const auto& idx_tensor : index) {
173-
tensors.push_back(idx_tensor);
174-
}
175-
176-
for (const auto& tensor : tensors) {
177-
std::vector<int64_t> stride_bytes(ndim, 0);
178-
const auto& original_shape = tensor->dims();
179-
const auto& original_strides = tensor->strides();
180-
int64_t element_size_in_bytes = phi::SizeOf(tensor->dtype());
181-
int offset = ndim - original_shape.size();
182-
183-
if (tensor == &input) {
184-
stride_bytes[ndim - 1] = element_size_in_bytes;
185-
} else {
186-
if (offset > 0) {
187-
stride_bytes.resize(ndim, 0);
188-
} else {
189-
stride_bytes.resize(ndim);
190-
}
191-
192-
for (int i = 0; i < original_shape.size(); ++i) {
193-
if (original_shape[i] == 1 && shape[offset + i] != 1) {
194-
stride_bytes[offset + i] = 0;
195-
} else {
196-
stride_bytes[offset + i] =
197-
original_strides[i] * element_size_in_bytes;
198-
}
199-
}
200-
}
201-
std::reverse(stride_bytes.begin(), stride_bytes.end());
202-
strides.push_back(stride_bytes);
203-
}
204-
160+
int ndim,
161+
const int64_t* shape,
162+
const std::vector<std::vector<int64_t>>& strides) {
205163
std::array<const int64_t*, N> strides_array;
206164
for (int i = 0; i < N; ++i) {
207165
strides_array[i] = strides[i].data();
208166
}
209167

210168
return OffsetCalculator<N, uint32_t, signed_strides>(
211-
ndim, desired_shape, strides_array.data());
212-
}
213-
214-
template <typename T, typename IndexT = int>
215-
void IndexElementwiseKernel(const phi::GPUContext& ctx,
216-
const DenseTensor& input,
217-
const std::vector<const DenseTensor*> index,
218-
const std::vector<int64_t>& index_dims,
219-
const std::vector<int64_t>& index_stride,
220-
DenseTensor* output) {
221-
auto num_indices = index_dims.size();
222-
223-
auto index_ptrs = GetIndexDataPtrs<IndexT>(index);
224-
225-
auto sizes = std::array<int64_t, DDim::kMaxRank>{};
226-
auto strides = std::array<int64_t, DDim::kMaxRank>{};
227-
228-
for (unsigned i = 0; i < num_indices; i++) {
229-
sizes[i] = index_dims[i];
230-
strides[i] = index_stride[i];
231-
}
232-
233-
auto offset_calc = make_offset_calculator<3>(*output, input, index);
234-
235-
const int64_t N = output->numel();
236-
PADDLE_ENFORCE(N >= 0 && N <= std::numeric_limits<int32_t>::max(),
237-
238-
"Output numel be in the range [0, "
239-
"std::numeric_limits<int32_t>::max()]");
240-
constexpr int nt = launch_size_nd;
241-
constexpr int vt = launch_bound2;
242-
const dim3 block(nt);
243-
const dim3 grid((N + block.x * vt - 1) / (block.x * vt));
244-
auto stream = ctx.stream();
245-
246-
using dtype = OpaqueType<sizeof(T)>;
247-
248-
const char* in_ptr = reinterpret_cast<const char*>(input.data<T>());
249-
char* out_ptr = reinterpret_cast<char*>(output->data<T>());
250-
251-
index_elementwise_kernel<nt, vt>
252-
<<<grid, block, 0, stream>>>(N, [=] __device__(int idx) {
253-
const auto offsets = offset_calc.get(idx);
254-
char* const out_data = out_ptr + offsets[0];
255-
const char* const in_data = in_ptr + offsets[1];
256-
257-
int64_t offset = 0;
258-
#pragma unroll
259-
for (int i = 0; i < num_indices; i++) {
260-
int64_t index =
261-
*reinterpret_cast<int64_t*>(index_ptrs[i] + offsets[2]);
262-
PADDLE_ENFORCE(-sizes[i] <= index && index < sizes[i],
263-
"index out of bounds");
264-
if (index < 0) {
265-
index += sizes[i];
266-
}
267-
offset += index * strides[i];
268-
}
269-
270-
*reinterpret_cast<dtype*>(out_data) =
271-
*reinterpret_cast<const dtype*>(in_data + offset);
272-
});
169+
ndim, shape, strides_array.data());
273170
}
274171

275172
} // namespace funcs

0 commit comments

Comments
 (0)