Skip to content

[CINN] Refine infer_symbol_shape of data, slice op #72418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions paddle/cinn/common/dim_expr_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,43 @@ struct DimExprConverterWithSymbolBindings::
if (std::holds_alternative<ShapeSymbolBinding>(symbol_binding)) {
return inputs_[input_idx]->sym_shape[input_dim_idx]->GetDimExpr();
}

auto LinearToMultiDim = [](int64_t index,
const std::vector<int64_t>& dimensions) {
std::vector<int64_t> result(dimensions.size(), 0);
std::vector<int64_t> strides(dimensions.size(), 1);
for (int64_t i = dimensions.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * dimensions[i + 1];
}
int64_t cur_index = index;
for (int64_t i = 0; i < dimensions.size(); ++i) {
result[i] = cur_index / strides[i];
cur_index %= strides[i];
}
return result;
};
// for data binding [S0, a, b], inputs[a] is Tensor A, return A(b)
return ir::Cast::Make(cinn::common::I64(),
inputs_[input_idx](cinn::ir::Expr(input_dim_idx)));
PADDLE_ENFORCE_LE(inputs_[input_idx].ndims(),
9,
::common::errors::InvalidArgument(
"The rank of the input tensor must be less than or "
"equal to 9, but got %d",
inputs_[input_idx].ndims()));
const std::vector<ir::Expr> indices = [&]() -> std::vector<ir::Expr> {
const auto& dimensions = inputs_[input_idx]->shape;
std::vector<ir::Expr> result(dimensions.size(), 0);
std::vector<int64_t> strides(dimensions.size(), 1);
for (int64_t i = dimensions.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * dimensions[i + 1].as_int64();
}
int64_t cur_index = input_dim_idx;
for (int64_t i = 0; i < dimensions.size(); ++i) {
result[i] = ir::Expr(cur_index / strides[i]);
cur_index %= strides[i];
}
return result;
}();
return ir::Cast::Make(cinn::common::I64(), inputs_[input_idx](indices));
}

DimExprToIrExprVisitorWithSymbolBinding(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,6 @@ bool InferSymbolicShapeElementWiseBinary(

if (x_shape.data() && y_shape.data() &&
x_shape.data()->size() == y_shape.data()->size() && DataComputeFunc) {
PADDLE_ENFORCE_LE(
x_shape.shape().size(),
1,
common::errors::InvalidArgument("When compute data, the rank of x "
"should be 0 or 1, but now received %d",
x_shape.shape().size()));
PADDLE_ENFORCE_LE(
y_shape.shape().size(),
1,
common::errors::InvalidArgument("When compute data, the rank of y "
"should be 0 or 1, but now received %d",
y_shape.shape().size()));
std::vector<symbol::DimExpr> out_data;
for (size_t i = 0; i < x_shape.data()->size(); ++i) {
out_data.emplace_back(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_slice_utils.h"

namespace paddle::dialect::slice_utils {

void SliceDfsImpl(const ExprVec &datas,
const std::vector<int64_t> &shape,
int64_t axis,
int64_t start,
int64_t end,
int64_t cur_visit_axis,
int offset,
ExprVec *result) {
int64_t begin = 0;
int64_t stop = shape.at(cur_visit_axis);
if (cur_visit_axis == axis) {
begin = start;
stop = end;
}
const int64_t cur_stride = std::accumulate(shape.begin() + cur_visit_axis + 1,
shape.end(),
1,
std::multiplies<int64_t>());
for (int64_t i = begin; i < stop; ++i) {
const int64_t cur_offset = offset + i * cur_stride;
// last dim
if (cur_visit_axis == static_cast<int64_t>(shape.size() - 1)) {
result->push_back(datas[cur_offset]);
} else {
SliceDfsImpl(datas,
shape,
axis,
start,
end,
cur_visit_axis + 1,
cur_offset,
result);
}
}
}

ExprVec SimpleSlice(const ExprVec &datas,
const std::vector<int64_t> &shape,
int64_t axis,
int64_t start,
int64_t end) {
ExprVec result;
SliceDfsImpl(datas, shape, axis, start, end, 0, 0, &result);
return result;
}

} // namespace paddle::dialect::slice_utils
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,23 @@ inline std::vector<int64_t> FormatSliceAxes(
return axes_vec;
}

/**
* @brief Simple slice function like paddle.slice for a given the data vector.
*
* @param datas Input dataset of type ExprVec.
* @param shape The shape of datas
* @param axis Axis along which to perform the slicing action.
* @param start Starting index for the slice.
* @param end Ending index for the slice.
*
* @return Returns the result after slicing the input data.
*/
ExprVec SimpleSlice(const ExprVec &datas,
const std::vector<int64_t> &shape,
int64_t axis,
int64_t start,
int64_t end);

inline ShapeOrData SliceRawInferSymbolicShape(
const pir::Value x,
const pir::Value out,
Expand All @@ -205,20 +222,19 @@ inline ShapeOrData SliceRawInferSymbolicShape(
const std::vector<int64_t> &decrease_axis,
pir::InferSymbolicShapeContext *infer_context) {
const auto &in_shapeordata = infer_context->GetShapeOrDataForValue(x);
const ExprVec &in_dims = in_shapeordata.shape();
ExprVec starts = starts_expr;
ExprVec ends = ends_expr;
std::vector<int64_t> infer_flags = [&infer_flags_raw, &axes_raw] {
return infer_flags_raw.empty() ? std::vector<int64_t>(axes_raw.size(), 1)
: infer_flags_raw;
}();
const std::vector<int64_t> axes = FormatSliceAxes(axes_raw, in_dims.size());
const ExprVec slice_dims =
GetSliceDims(in_dims, axes, starts, ends, &infer_flags);
const ExprVec out_dims = GetDecreasedDims(slice_dims, decrease_axis);

const auto &GetShapeDimExprs = [&]() -> symbol::ShapeOrDataDimExprs {
const ExprVec &in_dims = in_shapeordata.shape();
std::vector<int64_t> axes = FormatSliceAxes(axes_raw, in_dims.size());
ExprVec slice_dims =
GetSliceDims(in_dims, axes, starts, ends, &infer_flags);
ExprVec out_dims = GetDecreasedDims(slice_dims, decrease_axis);

auto IsOne = [](const symbol::DimExpr &expr) {
return expr.isa<int64_t>() && expr.dyn_cast<int64_t>() == 1;
};
Expand All @@ -240,8 +256,6 @@ inline ShapeOrData SliceRawInferSymbolicShape(
// When `pd.slice` is operating on a tensor which is produced by a `pd.shape`
// op, the result should be written into data.
const auto &GetDataDimExprs = [&]() -> symbol::ShapeOrDataDimExprs {
std::vector<symbol::DimExpr> out_data;

// Currently, we DO NOT support the case that any element in `axes` `starts`
// or `ends` is a Symbol.
auto vec_int64 = details::VecExpr2Int64(starts);
Expand All @@ -263,15 +277,13 @@ inline ShapeOrData SliceRawInferSymbolicShape(
}
return ends_int[0];
}();
const std::vector<int64_t> in_shape =
details::VecExpr2Int64(in_dims).value();
std::vector<symbol::DimExpr> out_data = SimpleSlice(
in_shapeordata.data().value(), in_shape, axes.at(0), start, end);

for (int64_t i = start; i < end; i++) {
out_data.push_back(in_shapeordata.data().value().at(i));
}

const ExprVec shape = GetDecreasedDims(
ExprVec{static_cast<int64_t>(out_data.size())}, decrease_axis);
return symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(shape, out_data)};
symbol::TensorShapeOrDataDimExprs(out_dims, out_data)};
};
bool starts_ends_all_int =
std::all_of(starts_expr.begin(),
Expand All @@ -281,10 +293,10 @@ inline ShapeOrData SliceRawInferSymbolicShape(
ends_expr.end(),
[](const symbol::DimExpr &e) { return e.isa<int64_t>(); });

const auto &out_shape =
in_shapeordata.data().has_value() && starts_ends_all_int
? GetDataDimExprs()
: GetShapeDimExprs();
const auto &out_shape = in_shapeordata.data().has_value() &&
starts_ends_all_int && axes_raw.size() == 1
? GetDataDimExprs()
: GetShapeDimExprs();
if (out_shape.data().has_value() && out_shape.shape().empty()) { // 0D tensor
const paddle::dialect::DenseTensorType &tensor_type =
out.type().dyn_cast<paddle::dialect::DenseTensorType>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,23 @@ bool DataOpInferSymbolicShape(pir::Operation *op,
value.type().dyn_cast<pir::DenseTensorType>();
const auto &dims = tensor_type.dims();
if (dims.size() == 0) return true;
if (dims.size() != 1) return false;
if (dims[0] >= 1 && dims[0] <= ::common::DDim::kMaxRank) {
return true;
if (dims.size() == 1) {
if (dims[0] >= 1 && dims[0] <= ::common::DDim::kMaxRank) {
return true;
}
return false;
}
return false;
if (common::contain_unknown_dim(dims)) return false;
if (common::product(dims) > ::common::DDim::kMaxRank) return false;

// only one dim is greater than one, and the other dims are 1
int gt_one_dim_count = 0;
for (int i = 0; i < dims.size(); ++i) {
if (dims[i] > 1) {
gt_one_dim_count++;
}
}
return gt_one_dim_count <= 1;
};

auto IsIntType = [&](pir::Value value) {
Expand Down
Loading