Skip to content

Commit ed3f636

Browse files
authored
[CINN] Refine infer_symbol_shape of data, slice op (#72418)
* refine infer_symbol_shape of data, slice op * add file * delete check * fix bug * fix bug
1 parent 8750974 commit ed3f636

File tree

5 files changed

+148
-37
lines changed

5 files changed

+148
-37
lines changed

paddle/cinn/common/dim_expr_converter.cc

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,43 @@ struct DimExprConverterWithSymbolBindings::
143143
if (std::holds_alternative<ShapeSymbolBinding>(symbol_binding)) {
144144
return inputs_[input_idx]->sym_shape[input_dim_idx]->GetDimExpr();
145145
}
146+
147+
auto LinearToMultiDim = [](int64_t index,
148+
const std::vector<int64_t>& dimensions) {
149+
std::vector<int64_t> result(dimensions.size(), 0);
150+
std::vector<int64_t> strides(dimensions.size(), 1);
151+
for (int64_t i = dimensions.size() - 2; i >= 0; --i) {
152+
strides[i] = strides[i + 1] * dimensions[i + 1];
153+
}
154+
int64_t cur_index = index;
155+
for (int64_t i = 0; i < dimensions.size(); ++i) {
156+
result[i] = cur_index / strides[i];
157+
cur_index %= strides[i];
158+
}
159+
return result;
160+
};
146161
// for data binding [S0, a, b], inputs[a] is Tensor A, return A(b)
147-
return ir::Cast::Make(cinn::common::I64(),
148-
inputs_[input_idx](cinn::ir::Expr(input_dim_idx)));
162+
PADDLE_ENFORCE_LE(inputs_[input_idx].ndims(),
163+
9,
164+
::common::errors::InvalidArgument(
165+
"The rank of the input tensor must be less than or "
166+
"equal to 9, but got %d",
167+
inputs_[input_idx].ndims()));
168+
const std::vector<ir::Expr> indices = [&]() -> std::vector<ir::Expr> {
169+
const auto& dimensions = inputs_[input_idx]->shape;
170+
std::vector<ir::Expr> result(dimensions.size(), 0);
171+
std::vector<int64_t> strides(dimensions.size(), 1);
172+
for (int64_t i = dimensions.size() - 2; i >= 0; --i) {
173+
strides[i] = strides[i + 1] * dimensions[i + 1].as_int64();
174+
}
175+
int64_t cur_index = input_dim_idx;
176+
for (int64_t i = 0; i < dimensions.size(); ++i) {
177+
result[i] = ir::Expr(cur_index / strides[i]);
178+
cur_index %= strides[i];
179+
}
180+
return result;
181+
}();
182+
return ir::Cast::Make(cinn::common::I64(), inputs_[input_idx](indices));
149183
}
150184

151185
DimExprToIrExprVisitorWithSymbolBinding(

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/element_wise_binary.cc

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,6 @@ bool InferSymbolicShapeElementWiseBinary(
6060

6161
if (x_shape.data() && y_shape.data() &&
6262
x_shape.data()->size() == y_shape.data()->size() && DataComputeFunc) {
63-
PADDLE_ENFORCE_LE(
64-
x_shape.shape().size(),
65-
1,
66-
common::errors::InvalidArgument("When compute data, the rank of x "
67-
"should be 0 or 1, but now received %d",
68-
x_shape.shape().size()));
69-
PADDLE_ENFORCE_LE(
70-
y_shape.shape().size(),
71-
1,
72-
common::errors::InvalidArgument("When compute data, the rank of y "
73-
"should be 0 or 1, but now received %d",
74-
y_shape.shape().size()));
7563
std::vector<symbol::DimExpr> out_data;
7664
for (size_t i = 0; i < x_shape.data()->size(); ++i) {
7765
out_data.emplace_back(
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_slice_utils.h"
16+
17+
namespace paddle::dialect::slice_utils {
18+
19+
void SliceDfsImpl(const ExprVec &datas,
20+
const std::vector<int64_t> &shape,
21+
int64_t axis,
22+
int64_t start,
23+
int64_t end,
24+
int64_t cur_visit_axis,
25+
int offset,
26+
ExprVec *result) {
27+
int64_t begin = 0;
28+
int64_t stop = shape.at(cur_visit_axis);
29+
if (cur_visit_axis == axis) {
30+
begin = start;
31+
stop = end;
32+
}
33+
const int64_t cur_stride = std::accumulate(shape.begin() + cur_visit_axis + 1,
34+
shape.end(),
35+
1,
36+
std::multiplies<int64_t>());
37+
for (int64_t i = begin; i < stop; ++i) {
38+
const int64_t cur_offset = offset + i * cur_stride;
39+
// last dim
40+
if (cur_visit_axis == static_cast<int64_t>(shape.size() - 1)) {
41+
result->push_back(datas[cur_offset]);
42+
} else {
43+
SliceDfsImpl(datas,
44+
shape,
45+
axis,
46+
start,
47+
end,
48+
cur_visit_axis + 1,
49+
cur_offset,
50+
result);
51+
}
52+
}
53+
}
54+
55+
ExprVec SimpleSlice(const ExprVec &datas,
56+
const std::vector<int64_t> &shape,
57+
int64_t axis,
58+
int64_t start,
59+
int64_t end) {
60+
ExprVec result;
61+
SliceDfsImpl(datas, shape, axis, start, end, 0, 0, &result);
62+
return result;
63+
}
64+
65+
} // namespace paddle::dialect::slice_utils

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_slice_utils.h

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,23 @@ inline std::vector<int64_t> FormatSliceAxes(
195195
return axes_vec;
196196
}
197197

198+
/**
199+
* @brief Simple slice function like paddle.slice for a given the data vector.
200+
*
201+
* @param datas Input dataset of type ExprVec.
202+
* @param shape The shape of datas
203+
* @param axis Axis along which to perform the slicing action.
204+
* @param start Starting index for the slice.
205+
* @param end Ending index for the slice.
206+
*
207+
* @return Returns the result after slicing the input data.
208+
*/
209+
ExprVec SimpleSlice(const ExprVec &datas,
210+
const std::vector<int64_t> &shape,
211+
int64_t axis,
212+
int64_t start,
213+
int64_t end);
214+
198215
inline ShapeOrData SliceRawInferSymbolicShape(
199216
const pir::Value x,
200217
const pir::Value out,
@@ -205,20 +222,19 @@ inline ShapeOrData SliceRawInferSymbolicShape(
205222
const std::vector<int64_t> &decrease_axis,
206223
pir::InferSymbolicShapeContext *infer_context) {
207224
const auto &in_shapeordata = infer_context->GetShapeOrDataForValue(x);
225+
const ExprVec &in_dims = in_shapeordata.shape();
208226
ExprVec starts = starts_expr;
209227
ExprVec ends = ends_expr;
210228
std::vector<int64_t> infer_flags = [&infer_flags_raw, &axes_raw] {
211229
return infer_flags_raw.empty() ? std::vector<int64_t>(axes_raw.size(), 1)
212230
: infer_flags_raw;
213231
}();
232+
const std::vector<int64_t> axes = FormatSliceAxes(axes_raw, in_dims.size());
233+
const ExprVec slice_dims =
234+
GetSliceDims(in_dims, axes, starts, ends, &infer_flags);
235+
const ExprVec out_dims = GetDecreasedDims(slice_dims, decrease_axis);
214236

215237
const auto &GetShapeDimExprs = [&]() -> symbol::ShapeOrDataDimExprs {
216-
const ExprVec &in_dims = in_shapeordata.shape();
217-
std::vector<int64_t> axes = FormatSliceAxes(axes_raw, in_dims.size());
218-
ExprVec slice_dims =
219-
GetSliceDims(in_dims, axes, starts, ends, &infer_flags);
220-
ExprVec out_dims = GetDecreasedDims(slice_dims, decrease_axis);
221-
222238
auto IsOne = [](const symbol::DimExpr &expr) {
223239
return expr.isa<int64_t>() && expr.dyn_cast<int64_t>() == 1;
224240
};
@@ -240,8 +256,6 @@ inline ShapeOrData SliceRawInferSymbolicShape(
240256
// When `pd.slice` is operating on a tensor which is produced by a `pd.shape`
241257
// op, the result should be written into data.
242258
const auto &GetDataDimExprs = [&]() -> symbol::ShapeOrDataDimExprs {
243-
std::vector<symbol::DimExpr> out_data;
244-
245259
// Currently, we DO NOT support the case that any element in `axes` `starts`
246260
// or `ends` is a Symbol.
247261
auto vec_int64 = details::VecExpr2Int64(starts);
@@ -263,15 +277,13 @@ inline ShapeOrData SliceRawInferSymbolicShape(
263277
}
264278
return ends_int[0];
265279
}();
280+
const std::vector<int64_t> in_shape =
281+
details::VecExpr2Int64(in_dims).value();
282+
std::vector<symbol::DimExpr> out_data = SimpleSlice(
283+
in_shapeordata.data().value(), in_shape, axes.at(0), start, end);
266284

267-
for (int64_t i = start; i < end; i++) {
268-
out_data.push_back(in_shapeordata.data().value().at(i));
269-
}
270-
271-
const ExprVec shape = GetDecreasedDims(
272-
ExprVec{static_cast<int64_t>(out_data.size())}, decrease_axis);
273285
return symbol::ShapeOrDataDimExprs{
274-
symbol::TensorShapeOrDataDimExprs(shape, out_data)};
286+
symbol::TensorShapeOrDataDimExprs(out_dims, out_data)};
275287
};
276288
bool starts_ends_all_int =
277289
std::all_of(starts_expr.begin(),
@@ -281,10 +293,10 @@ inline ShapeOrData SliceRawInferSymbolicShape(
281293
ends_expr.end(),
282294
[](const symbol::DimExpr &e) { return e.isa<int64_t>(); });
283295

284-
const auto &out_shape =
285-
in_shapeordata.data().has_value() && starts_ends_all_int
286-
? GetDataDimExprs()
287-
: GetShapeDimExprs();
296+
const auto &out_shape = in_shapeordata.data().has_value() &&
297+
starts_ends_all_int && axes_raw.size() == 1
298+
? GetDataDimExprs()
299+
: GetShapeDimExprs();
288300
if (out_shape.data().has_value() && out_shape.shape().empty()) { // 0D tensor
289301
const paddle::dialect::DenseTensorType &tensor_type =
290302
out.type().dyn_cast<paddle::dialect::DenseTensorType>();

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,11 +193,23 @@ bool DataOpInferSymbolicShape(pir::Operation *op,
193193
value.type().dyn_cast<pir::DenseTensorType>();
194194
const auto &dims = tensor_type.dims();
195195
if (dims.size() == 0) return true;
196-
if (dims.size() != 1) return false;
197-
if (dims[0] >= 1 && dims[0] <= ::common::DDim::kMaxRank) {
198-
return true;
196+
if (dims.size() == 1) {
197+
if (dims[0] >= 1 && dims[0] <= ::common::DDim::kMaxRank) {
198+
return true;
199+
}
200+
return false;
199201
}
200-
return false;
202+
if (common::contain_unknown_dim(dims)) return false;
203+
if (common::product(dims) > ::common::DDim::kMaxRank) return false;
204+
205+
// only one dim is greater than one, and the other dims are 1
206+
int gt_one_dim_count = 0;
207+
for (int i = 0; i < dims.size(); ++i) {
208+
if (dims[i] > 1) {
209+
gt_one_dim_count++;
210+
}
211+
}
212+
return gt_one_dim_count <= 1;
201213
};
202214

203215
auto IsIntType = [&](pir::Value value) {

0 commit comments

Comments
 (0)