Skip to content

Commit 8d4fb21

Browse files
authored
Use DimExpr and change InferSymbolicShapeInterface (#60371)
* Use DimExpr and change InferSymbolicShapeInterface * static infer lib
1 parent b989f8a commit 8d4fb21

File tree

11 files changed

+335
-420
lines changed

11 files changed

+335
-420
lines changed

paddle/cinn/hlir/dialect/operator/ir/ops.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
func : SliceRawInferMeta
7575
kernel :
7676
func : slice
77+
interfaces : paddle::dialect::InferSymbolicShapeInterface
7778

7879
- op : uniform_random
7980
args : (int64_t[] shape, float min, float max, int seed, DataType dtype, int diag_num = 0, int diag_step=0, float diag_val=1.0)

paddle/fluid/inference/CMakeLists.txt

+3-2
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,9 @@ set(KERNEL_LIST
6464

6565
# shared inference library deps
6666
list(REMOVE_DUPLICATES fluid_modules)
67-
#windows GPU static library over the limit, so not create_static_lib, and cc_library is dummy
68-
if(WIN32 AND WITH_GPU)
67+
# windows static library(both CPU and GPU)over the limit, so no longer create_static_lib,
68+
# and cc_library is dummy
69+
if(WIN32)
6970
cc_library(paddle_inference DEPS ${fluid_modules} ${STATIC_INFERENCE_API}
7071
${utils_modules})
7172
else()

paddle/fluid/pir/dialect/op_generator/infer_symbolic_shape_gen.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,9 @@
1313
# limitations under the License.
1414

1515
OP_GET_KERNEL_TYPE_FOR_VAR_TEMPLATE = """
16-
bool {op_name}::InferSymbolicShape(pir::Builder &builder,
17-
const std::vector<pir::OpOperand> &operands,
18-
std::vector<pir::Value> &reified_return_shapes) {{
16+
bool {op_name}::InferSymbolicShape(pir::ShapeConstraintIRAnalysis* shape_analysis) {{
1917
VLOG(4) << "Infer symbolic shape for op: {op_name}";
20-
return {op_name}InferSymbolicShape(builder, operands, reified_return_shapes);
18+
return {op_name}InferSymbolicShape(this->operation(), shape_analysis);
2119
}}
2220
"""
2321

paddle/fluid/pir/dialect/op_generator/op_gen.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,7 @@ class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{
133133
"""
134134

135135
infer_symbolic_shape_template = """
136-
static bool InferSymbolicShape(pir::Builder &builder,
137-
const std::vector<pir::OpOperand> &operands,
138-
std::vector<pir::Value> &reified_return_shapes);
136+
bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis* shape_analysis);
139137
"""
140138

141139
# =====================================

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc

+162-111
Original file line numberDiff line numberDiff line change
@@ -13,141 +13,192 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h"
16+
#include "paddle/pir/core/builtin_attribute.h"
17+
#include "paddle/pir/core/builtin_type.h"
1618
#include "paddle/pir/dialect/shape/ir/shape_op.h"
1719

1820
namespace paddle::dialect {
1921

2022
bool InferSymbolicShapeInterface::InferSymbolicShape(
21-
pir::Builder &builder,
22-
const std::vector<pir::OpOperand> &operands,
23-
std::vector<pir::Value> &reified_return_shapes) {
24-
return impl_->infer_symbolic_shapes(
25-
operation(), builder, operands, reified_return_shapes);
23+
pir::ShapeConstraintIRAnalysis *shape_analysis) {
24+
return impl_->infer_symbolic_shapes(operation(), shape_analysis);
2625
}
2726
} // namespace paddle::dialect
2827

2928
namespace paddle::dialect {
3029

3130
namespace {
3231

33-
bool DeriveShapeFromOperand(pir::Builder *builder,
34-
pir::Value operand,
35-
std::vector<pir::Value> *reified_return_shapes) {
36-
auto shaped_type = operand.type().dyn_cast<pir::ShapedTypeInterface>();
37-
if (!shaped_type) return false;
38-
reified_return_shapes->assign(
39-
{builder->Build<pir::shape::ShapeOfOp>(operand).result(0)});
32+
bool InferSymbolicShapeAllEqualUnary(
33+
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
34+
pir::Value operand_source = op->operand_source(0);
35+
std::string operand_source_id = pir::GetValueId(&operand_source);
36+
pir::OpResult res = op->result(0);
37+
std::string res_id = pir::GetValueId(&res);
38+
shape_analysis->value_id_to_shapeordata_[res_id] =
39+
shape_analysis->value_id_to_shapeordata_[operand_source_id];
4040
return true;
4141
}
4242

43-
// Returns a new scalar integer value having type `type`.
44-
// Here `type` must be an integer or index type.
45-
pir::Value MaybeCastTo(pir::Builder &builder, // NOLINT
46-
pir::Value value,
47-
pir::Type type) {
48-
if (type == value.type()) return value;
49-
// if (!type.IsIndex() && !value.type().IsIndex()) {
50-
// Value casted =
51-
// builder.Build<shape::IndexCastOp>(builder.index_type(), value)
52-
// .result(0);
53-
// return builder.Build<shape::IndexCastOp>(type, casted).result(0);
54-
// }
55-
// return builder.Build<shape::IndexCastOp>(type, value).result(0);
43+
bool InferSymbolicShapeAllEqualBinary(
44+
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
45+
pir::Value operand_source = op->operand_source(0);
46+
std::string operand_source_id = pir::GetValueId(&operand_source);
47+
pir::OpResult res = op->result(0);
48+
std::string res_id = pir::GetValueId(&res);
49+
shape_analysis->value_id_to_shapeordata_[res_id] =
50+
shape_analysis->value_id_to_shapeordata_[operand_source_id];
51+
return true;
5652
}
53+
5754
} // namespace
5855

59-
bool AbsOpInferSymbolicShape(
60-
pir::Builder &builder, // NOLINT
61-
const std::vector<pir::OpOperand> &operands,
62-
std::vector<pir::Value> &reified_return_shapes) { // NOLINT
63-
return DeriveShapeFromOperand(
64-
&builder, operands.front().source(), &reified_return_shapes);
65-
}
66-
67-
bool Abs_OpInferSymbolicShape(
68-
pir::Builder &builder, // NOLINT
69-
const std::vector<pir::OpOperand> &operands,
70-
std::vector<pir::Value> &reified_return_shapes) { // NOLINT
71-
return DeriveShapeFromOperand(
72-
&builder, operands.front().source(), &reified_return_shapes);
73-
}
74-
75-
bool TransposeOpInferSymbolicShape(
76-
pir::Builder &builder, // NOLINT
77-
const std::vector<pir::OpOperand> &operands,
78-
std::vector<pir::Value> &reified_return_shapes) { // NOLINT
79-
// auto operand_type = operands[0].type().dyn_cast<DenseTensorType>();
80-
// // Currently not support unranked type.
81-
// if (!operand_type) return false;
82-
// std::vector<int64_t> permutation = this->permutation();
83-
// std::vector<Value> shape_values(permutation.size());
84-
// Type shape_scalar_type = builder.index_type();
85-
// auto to_shape_scalar_type = [&](Value v) {
86-
// return MaybeCastTo(builder, v, shape_scalar_type);
87-
// };
88-
// auto shaped_type = operand_type.dyn_cast<ShapedTypeInterface>();
89-
// auto shape_vector = shaped_type.GetDyShape();
90-
// for (auto [idx, element] = std::tuple{0, shape_vector.begin()};
91-
// element != shape_vector.end();
92-
// ++idx, ++element) {
93-
// auto it = std::find(permutation.begin(), permutation.end(), idx);
94-
// // TODO(zhangbopd): Need BuildOrFold
95-
// Value value_dim = to_shape_scalar_type(
96-
// builder.Build<shape::TensorDimOp>(operands[0].source(),
97-
// idx).result(0));
98-
// shape_values[std::distance(permutation.begin(), it)] = value_dim;
99-
// }
100-
// Value output_shape =
101-
// builder.Build<shape::FromElementsOp>(shape_values).result(0);
102-
// reified_return_shapes.push_back(output_shape);
56+
bool AbsOpInferSymbolicShape(pir::Operation *op,
57+
pir::ShapeConstraintIRAnalysis *shape_analysis) {
58+
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
59+
}
60+
61+
bool Abs_OpInferSymbolicShape(pir::Operation *op,
62+
pir::ShapeConstraintIRAnalysis *shape_analysis) {
63+
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
64+
}
65+
66+
bool CastOpInferSymbolicShape(pir::Operation *op,
67+
pir::ShapeConstraintIRAnalysis *shape_analysis) {
68+
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
69+
}
70+
71+
bool Cast_OpInferSymbolicShape(pir::Operation *op,
72+
pir::ShapeConstraintIRAnalysis *shape_analysis) {
73+
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
74+
}
75+
76+
bool ExpOpInferSymbolicShape(pir::Operation *op,
77+
pir::ShapeConstraintIRAnalysis *shape_analysis) {
78+
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
79+
}
10380

81+
bool Exp_OpInferSymbolicShape(pir::Operation *op,
82+
pir::ShapeConstraintIRAnalysis *shape_analysis) {
83+
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
84+
}
85+
86+
bool SubtractOpInferSymbolicShape(
87+
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
88+
return InferSymbolicShapeAllEqualBinary(op, shape_analysis);
89+
}
90+
91+
bool Subtract_OpInferSymbolicShape(
92+
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
93+
return InferSymbolicShapeAllEqualBinary(op, shape_analysis);
94+
}
95+
96+
bool ShapeOpInferSymbolicShape(pir::Operation *op,
97+
pir::ShapeConstraintIRAnalysis *shape_analysis) {
98+
pir::Value operand_source = op->operand_source(0);
99+
std::string operand_source_id = pir::GetValueId(&operand_source);
100+
pir::OpResult res = op->result(0);
101+
std::string res_id = pir::GetValueId(&res);
102+
103+
std::vector<int64_t> dims =
104+
common::vectorize(res.type().dyn_cast<pir::DenseTensorType>().dims());
105+
106+
std::vector<symbol::DimExpr> shapes;
107+
for (int64_t dim : dims) {
108+
symbol::DimExpr dim_expr;
109+
if (dim == -1) {
110+
symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName());
111+
dim_expr = res_dim_expr;
112+
} else {
113+
symbol::DimExpr res_dim_expr(dim);
114+
dim_expr = res_dim_expr;
115+
}
116+
shapes.push_back(dim_expr);
117+
}
118+
119+
symbol::ShapeOrDataDimExprs shape_data{shapes};
120+
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
104121
return true;
105122
}
106123

107-
bool ConcatOpInferSymbolicShape(
108-
pir::Builder &builder, // NOLINT
109-
const std::vector<pir::OpOperand> &operands,
110-
std::vector<pir::Value> &reified_return_shapes) { // NOLINT
111-
// std::vector<Value> inputs = {x()};
112-
// auto operand_type = inputs[0].type().dyn_cast<DenseTensorType>();
113-
// // Currently not support unranked type.
114-
// if (!operand_type) return false;
115-
// Type shapeScalarType = builder.index_type();
116-
// auto to_shape_scalar_type = [&](Value v) {
117-
// return MaybeCastTo(builder, v, shapeScalarType);
118-
// };
119-
// std::vector<std::vector<Value>> all_shape_values;
120-
// for (size_t inputId = 0; inputId < inputs.size(); ++inputId) {
121-
// Value operand = inputs[inputId];
122-
// auto operand_type = operand.type().dyn_cast<DenseTensorType>();
123-
// if (!operand_type) return false;
124-
// std::vector<Value> shape_values;
125-
// auto shaped_type = operand_type.dyn_cast<ShapedTypeInterface>();
126-
// auto shape_vector = shaped_type.GetDyShape();
127-
// for (auto [idx, element] = std::tuple{0, shape_vector.begin()};
128-
// element != shape_vector.end();
129-
// ++idx, ++element) {
130-
// Value value_dim = to_shape_scalar_type(
131-
// builder.Build<shape::TensorDimOp>(operand, idx).result(0));
132-
// shape_values.push_back(value_dim);
133-
// }
134-
// all_shape_values.emplace_back(std::move(shape_values));
135-
// }
136-
// [[maybe_unused]] int axis = this->dimension();
137-
// auto &shape_values = all_shape_values[0];
138-
// for (size_t vecId = 1; vecId < all_shape_values.size(); ++vecId) {
139-
// auto &otherShapeValues = all_shape_values[vecId];
140-
// if (otherShapeValues.size() != shape_values.size()) return false;
141-
// TODO(zhangbopd): AddIOp
142-
// shape_values[axis] =
143-
// builder.Build<arith::AddIOp>(shape_values[axis],
144-
// otherShapeValues[axis]);
145-
// }
146-
// Value output_shape =
147-
// builder.Build<shape::FromElementsOp>(shape_values).result(0);
148-
// reified_return_shapes.push_back(output_shape);
124+
bool ShapeSrOpInferSymbolicShape(
125+
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
126+
return ShapeOpInferSymbolicShape(op, shape_analysis);
127+
}
128+
129+
bool StackOpInferSymbolicShape(pir::Operation *op,
130+
pir::ShapeConstraintIRAnalysis *shape_analysis) {
131+
pir::Value operand_source = op->operand_source(0);
132+
std::string operand_source_id = pir::GetValueId(&operand_source);
133+
pir::OpResult res = op->result(0);
134+
std::string res_id = pir::GetValueId(&res);
135+
136+
symbol::ShapeOrDataDimExprs shape_data;
137+
shape_data = shape_analysis->value_id_to_shapeordata_[operand_source_id];
138+
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
139+
return true;
140+
}
141+
142+
bool ReshapeOpInferSymbolicShape(
143+
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
144+
pir::Value operand_source_1 = op->operand_source(1);
145+
std::string operand_source_1_id = pir::GetValueId(&operand_source_1);
146+
pir::OpResult res = op->result(0);
147+
std::string res_id = pir::GetValueId(&res);
148+
149+
symbol::ShapeOrDataDimExprs shape_data;
150+
151+
shape_data = shape_analysis->value_id_to_shapeordata_[operand_source_1_id];
152+
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
149153
return true;
150154
}
151155

156+
bool Reshape_OpInferSymbolicShape(
157+
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
158+
return ReshapeOpInferSymbolicShape(op, shape_analysis);
159+
}
160+
152161
} // namespace paddle::dialect
162+
namespace cinn::dialect {
163+
164+
bool SliceOpInferSymbolicShape(pir::Operation *op,
165+
pir::ShapeConstraintIRAnalysis *shape_analysis) {
166+
pir::Value operand_source = op->operand_source(0);
167+
std::string operand_source_id = pir::GetValueId(&operand_source);
168+
pir::OpResult res = op->result(0);
169+
std::string res_id = pir::GetValueId(&res);
170+
171+
std::vector<int64_t> dims =
172+
common::vectorize(res.type().dyn_cast<pir::DenseTensorType>().dims());
173+
174+
std::vector<symbol::DimExpr> shapes;
175+
for (int64_t dim : dims) {
176+
symbol::DimExpr dim_expr;
177+
if (dim == -1) {
178+
symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName());
179+
dim_expr = res_dim_expr;
180+
} else {
181+
symbol::DimExpr res_dim_expr(dim);
182+
dim_expr = res_dim_expr;
183+
}
184+
shapes.push_back(dim_expr);
185+
}
186+
187+
// pir::AttributeMap attributes = op->attributes();
188+
189+
// auto attr_starts =
190+
// attributes["starts"].dyn_cast<pir::ArrayAttribute>().AsVector();
191+
// auto start = attr_starts[0].dyn_cast<pir::Int64Attribute>().data();
192+
193+
// auto attr_ends =
194+
// attributes["ends"].dyn_cast<pir::ArrayAttribute>().AsVector();
195+
// auto end = attr_ends[0].dyn_cast<pir::Int64Attribute>().data();
196+
197+
symbol::ShapeOrDataDimExprs shape_data{shapes};
198+
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
199+
return true;
200+
}
201+
202+
} // namespace cinn::dialect
203+
153204
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::InferSymbolicShapeInterface)

0 commit comments

Comments
 (0)