|
13 | 13 | // limitations under the License.
|
14 | 14 |
|
15 | 15 | #include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h"
|
| 16 | +#include "paddle/pir/core/builtin_attribute.h" |
| 17 | +#include "paddle/pir/core/builtin_type.h" |
16 | 18 | #include "paddle/pir/dialect/shape/ir/shape_op.h"
|
17 | 19 |
|
18 | 20 | namespace paddle::dialect {
|
19 | 21 |
|
20 | 22 | bool InferSymbolicShapeInterface::InferSymbolicShape(
|
21 |
| - pir::Builder &builder, |
22 |
| - const std::vector<pir::OpOperand> &operands, |
23 |
| - std::vector<pir::Value> &reified_return_shapes) { |
24 |
| - return impl_->infer_symbolic_shapes( |
25 |
| - operation(), builder, operands, reified_return_shapes); |
| 23 | + pir::ShapeConstraintIRAnalysis *shape_analysis) { |
| 24 | + return impl_->infer_symbolic_shapes(operation(), shape_analysis); |
26 | 25 | }
|
27 | 26 | } // namespace paddle::dialect
|
28 | 27 |
|
29 | 28 | namespace paddle::dialect {
|
30 | 29 |
|
31 | 30 | namespace {
|
32 | 31 |
|
33 |
| -bool DeriveShapeFromOperand(pir::Builder *builder, |
34 |
| - pir::Value operand, |
35 |
| - std::vector<pir::Value> *reified_return_shapes) { |
36 |
| - auto shaped_type = operand.type().dyn_cast<pir::ShapedTypeInterface>(); |
37 |
| - if (!shaped_type) return false; |
38 |
| - reified_return_shapes->assign( |
39 |
| - {builder->Build<pir::shape::ShapeOfOp>(operand).result(0)}); |
| 32 | +bool InferSymbolicShapeAllEqualUnary( |
| 33 | + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { |
| 34 | + pir::Value operand_source = op->operand_source(0); |
| 35 | + std::string operand_source_id = pir::GetValueId(&operand_source); |
| 36 | + pir::OpResult res = op->result(0); |
| 37 | + std::string res_id = pir::GetValueId(&res); |
| 38 | + shape_analysis->value_id_to_shapeordata_[res_id] = |
| 39 | + shape_analysis->value_id_to_shapeordata_[operand_source_id]; |
40 | 40 | return true;
|
41 | 41 | }
|
42 | 42 |
|
43 |
| -// Returns a new scalar integer value having type `type`. |
44 |
| -// Here `type` must be an integer or index type. |
45 |
| -pir::Value MaybeCastTo(pir::Builder &builder, // NOLINT |
46 |
| - pir::Value value, |
47 |
| - pir::Type type) { |
48 |
| - if (type == value.type()) return value; |
49 |
| - // if (!type.IsIndex() && !value.type().IsIndex()) { |
50 |
| - // Value casted = |
51 |
| - // builder.Build<shape::IndexCastOp>(builder.index_type(), value) |
52 |
| - // .result(0); |
53 |
| - // return builder.Build<shape::IndexCastOp>(type, casted).result(0); |
54 |
| - // } |
55 |
| - // return builder.Build<shape::IndexCastOp>(type, value).result(0); |
| 43 | +bool InferSymbolicShapeAllEqualBinary( |
| 44 | + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { |
| 45 | + pir::Value operand_source = op->operand_source(0); |
| 46 | + std::string operand_source_id = pir::GetValueId(&operand_source); |
| 47 | + pir::OpResult res = op->result(0); |
| 48 | + std::string res_id = pir::GetValueId(&res); |
| 49 | + shape_analysis->value_id_to_shapeordata_[res_id] = |
| 50 | + shape_analysis->value_id_to_shapeordata_[operand_source_id]; |
| 51 | + return true; |
56 | 52 | }
|
| 53 | + |
57 | 54 | } // namespace
|
58 | 55 |
|
59 |
| -bool AbsOpInferSymbolicShape( |
60 |
| - pir::Builder &builder, // NOLINT |
61 |
| - const std::vector<pir::OpOperand> &operands, |
62 |
| - std::vector<pir::Value> &reified_return_shapes) { // NOLINT |
63 |
| - return DeriveShapeFromOperand( |
64 |
| - &builder, operands.front().source(), &reified_return_shapes); |
65 |
| -} |
66 |
| - |
67 |
| -bool Abs_OpInferSymbolicShape( |
68 |
| - pir::Builder &builder, // NOLINT |
69 |
| - const std::vector<pir::OpOperand> &operands, |
70 |
| - std::vector<pir::Value> &reified_return_shapes) { // NOLINT |
71 |
| - return DeriveShapeFromOperand( |
72 |
| - &builder, operands.front().source(), &reified_return_shapes); |
73 |
| -} |
74 |
| - |
75 |
| -bool TransposeOpInferSymbolicShape( |
76 |
| - pir::Builder &builder, // NOLINT |
77 |
| - const std::vector<pir::OpOperand> &operands, |
78 |
| - std::vector<pir::Value> &reified_return_shapes) { // NOLINT |
79 |
| - // auto operand_type = operands[0].type().dyn_cast<DenseTensorType>(); |
80 |
| - // // Currently not support unranked type. |
81 |
| - // if (!operand_type) return false; |
82 |
| - // std::vector<int64_t> permutation = this->permutation(); |
83 |
| - // std::vector<Value> shape_values(permutation.size()); |
84 |
| - // Type shape_scalar_type = builder.index_type(); |
85 |
| - // auto to_shape_scalar_type = [&](Value v) { |
86 |
| - // return MaybeCastTo(builder, v, shape_scalar_type); |
87 |
| - // }; |
88 |
| - // auto shaped_type = operand_type.dyn_cast<ShapedTypeInterface>(); |
89 |
| - // auto shape_vector = shaped_type.GetDyShape(); |
90 |
| - // for (auto [idx, element] = std::tuple{0, shape_vector.begin()}; |
91 |
| - // element != shape_vector.end(); |
92 |
| - // ++idx, ++element) { |
93 |
| - // auto it = std::find(permutation.begin(), permutation.end(), idx); |
94 |
| - // // TODO(zhangbopd): Need BuildOrFold |
95 |
| - // Value value_dim = to_shape_scalar_type( |
96 |
| - // builder.Build<shape::TensorDimOp>(operands[0].source(), |
97 |
| - // idx).result(0)); |
98 |
| - // shape_values[std::distance(permutation.begin(), it)] = value_dim; |
99 |
| - // } |
100 |
| - // Value output_shape = |
101 |
| - // builder.Build<shape::FromElementsOp>(shape_values).result(0); |
102 |
| - // reified_return_shapes.push_back(output_shape); |
| 56 | +bool AbsOpInferSymbolicShape(pir::Operation *op, |
| 57 | + pir::ShapeConstraintIRAnalysis *shape_analysis) { |
| 58 | + return InferSymbolicShapeAllEqualUnary(op, shape_analysis); |
| 59 | +} |
| 60 | + |
| 61 | +bool Abs_OpInferSymbolicShape(pir::Operation *op, |
| 62 | + pir::ShapeConstraintIRAnalysis *shape_analysis) { |
| 63 | + return InferSymbolicShapeAllEqualUnary(op, shape_analysis); |
| 64 | +} |
| 65 | + |
| 66 | +bool CastOpInferSymbolicShape(pir::Operation *op, |
| 67 | + pir::ShapeConstraintIRAnalysis *shape_analysis) { |
| 68 | + return InferSymbolicShapeAllEqualUnary(op, shape_analysis); |
| 69 | +} |
| 70 | + |
| 71 | +bool Cast_OpInferSymbolicShape(pir::Operation *op, |
| 72 | + pir::ShapeConstraintIRAnalysis *shape_analysis) { |
| 73 | + return InferSymbolicShapeAllEqualUnary(op, shape_analysis); |
| 74 | +} |
| 75 | + |
| 76 | +bool ExpOpInferSymbolicShape(pir::Operation *op, |
| 77 | + pir::ShapeConstraintIRAnalysis *shape_analysis) { |
| 78 | + return InferSymbolicShapeAllEqualUnary(op, shape_analysis); |
| 79 | +} |
103 | 80 |
|
| 81 | +bool Exp_OpInferSymbolicShape(pir::Operation *op, |
| 82 | + pir::ShapeConstraintIRAnalysis *shape_analysis) { |
| 83 | + return InferSymbolicShapeAllEqualUnary(op, shape_analysis); |
| 84 | +} |
| 85 | + |
| 86 | +bool SubtractOpInferSymbolicShape( |
| 87 | + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { |
| 88 | + return InferSymbolicShapeAllEqualBinary(op, shape_analysis); |
| 89 | +} |
| 90 | + |
| 91 | +bool Subtract_OpInferSymbolicShape( |
| 92 | + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { |
| 93 | + return InferSymbolicShapeAllEqualBinary(op, shape_analysis); |
| 94 | +} |
| 95 | + |
| 96 | +bool ShapeOpInferSymbolicShape(pir::Operation *op, |
| 97 | + pir::ShapeConstraintIRAnalysis *shape_analysis) { |
| 98 | + pir::Value operand_source = op->operand_source(0); |
| 99 | + std::string operand_source_id = pir::GetValueId(&operand_source); |
| 100 | + pir::OpResult res = op->result(0); |
| 101 | + std::string res_id = pir::GetValueId(&res); |
| 102 | + |
| 103 | + std::vector<int64_t> dims = |
| 104 | + common::vectorize(res.type().dyn_cast<pir::DenseTensorType>().dims()); |
| 105 | + |
| 106 | + std::vector<symbol::DimExpr> shapes; |
| 107 | + for (int64_t dim : dims) { |
| 108 | + symbol::DimExpr dim_expr; |
| 109 | + if (dim == -1) { |
| 110 | + symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName()); |
| 111 | + dim_expr = res_dim_expr; |
| 112 | + } else { |
| 113 | + symbol::DimExpr res_dim_expr(dim); |
| 114 | + dim_expr = res_dim_expr; |
| 115 | + } |
| 116 | + shapes.push_back(dim_expr); |
| 117 | + } |
| 118 | + |
| 119 | + symbol::ShapeOrDataDimExprs shape_data{shapes}; |
| 120 | + shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; |
104 | 121 | return true;
|
105 | 122 | }
|
106 | 123 |
|
107 |
| -bool ConcatOpInferSymbolicShape( |
108 |
| - pir::Builder &builder, // NOLINT |
109 |
| - const std::vector<pir::OpOperand> &operands, |
110 |
| - std::vector<pir::Value> &reified_return_shapes) { // NOLINT |
111 |
| - // std::vector<Value> inputs = {x()}; |
112 |
| - // auto operand_type = inputs[0].type().dyn_cast<DenseTensorType>(); |
113 |
| - // // Currently not support unranked type. |
114 |
| - // if (!operand_type) return false; |
115 |
| - // Type shapeScalarType = builder.index_type(); |
116 |
| - // auto to_shape_scalar_type = [&](Value v) { |
117 |
| - // return MaybeCastTo(builder, v, shapeScalarType); |
118 |
| - // }; |
119 |
| - // std::vector<std::vector<Value>> all_shape_values; |
120 |
| - // for (size_t inputId = 0; inputId < inputs.size(); ++inputId) { |
121 |
| - // Value operand = inputs[inputId]; |
122 |
| - // auto operand_type = operand.type().dyn_cast<DenseTensorType>(); |
123 |
| - // if (!operand_type) return false; |
124 |
| - // std::vector<Value> shape_values; |
125 |
| - // auto shaped_type = operand_type.dyn_cast<ShapedTypeInterface>(); |
126 |
| - // auto shape_vector = shaped_type.GetDyShape(); |
127 |
| - // for (auto [idx, element] = std::tuple{0, shape_vector.begin()}; |
128 |
| - // element != shape_vector.end(); |
129 |
| - // ++idx, ++element) { |
130 |
| - // Value value_dim = to_shape_scalar_type( |
131 |
| - // builder.Build<shape::TensorDimOp>(operand, idx).result(0)); |
132 |
| - // shape_values.push_back(value_dim); |
133 |
| - // } |
134 |
| - // all_shape_values.emplace_back(std::move(shape_values)); |
135 |
| - // } |
136 |
| - // [[maybe_unused]] int axis = this->dimension(); |
137 |
| - // auto &shape_values = all_shape_values[0]; |
138 |
| - // for (size_t vecId = 1; vecId < all_shape_values.size(); ++vecId) { |
139 |
| - // auto &otherShapeValues = all_shape_values[vecId]; |
140 |
| - // if (otherShapeValues.size() != shape_values.size()) return false; |
141 |
| - // TODO(zhangbopd): AddIOp |
142 |
| - // shape_values[axis] = |
143 |
| - // builder.Build<arith::AddIOp>(shape_values[axis], |
144 |
| - // otherShapeValues[axis]); |
145 |
| - // } |
146 |
| - // Value output_shape = |
147 |
| - // builder.Build<shape::FromElementsOp>(shape_values).result(0); |
148 |
| - // reified_return_shapes.push_back(output_shape); |
| 124 | +bool ShapeSrOpInferSymbolicShape( |
| 125 | + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { |
| 126 | + return ShapeOpInferSymbolicShape(op, shape_analysis); |
| 127 | +} |
| 128 | + |
| 129 | +bool StackOpInferSymbolicShape(pir::Operation *op, |
| 130 | + pir::ShapeConstraintIRAnalysis *shape_analysis) { |
| 131 | + pir::Value operand_source = op->operand_source(0); |
| 132 | + std::string operand_source_id = pir::GetValueId(&operand_source); |
| 133 | + pir::OpResult res = op->result(0); |
| 134 | + std::string res_id = pir::GetValueId(&res); |
| 135 | + |
| 136 | + symbol::ShapeOrDataDimExprs shape_data; |
| 137 | + shape_data = shape_analysis->value_id_to_shapeordata_[operand_source_id]; |
| 138 | + shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; |
| 139 | + return true; |
| 140 | +} |
| 141 | + |
| 142 | +bool ReshapeOpInferSymbolicShape( |
| 143 | + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { |
| 144 | + pir::Value operand_source_1 = op->operand_source(1); |
| 145 | + std::string operand_source_1_id = pir::GetValueId(&operand_source_1); |
| 146 | + pir::OpResult res = op->result(0); |
| 147 | + std::string res_id = pir::GetValueId(&res); |
| 148 | + |
| 149 | + symbol::ShapeOrDataDimExprs shape_data; |
| 150 | + |
| 151 | + shape_data = shape_analysis->value_id_to_shapeordata_[operand_source_1_id]; |
| 152 | + shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; |
149 | 153 | return true;
|
150 | 154 | }
|
151 | 155 |
|
| 156 | +bool Reshape_OpInferSymbolicShape( |
| 157 | + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { |
| 158 | + return ReshapeOpInferSymbolicShape(op, shape_analysis); |
| 159 | +} |
| 160 | + |
152 | 161 | } // namespace paddle::dialect
|
| 162 | +namespace cinn::dialect { |
| 163 | + |
| 164 | +bool SliceOpInferSymbolicShape(pir::Operation *op, |
| 165 | + pir::ShapeConstraintIRAnalysis *shape_analysis) { |
| 166 | + pir::Value operand_source = op->operand_source(0); |
| 167 | + std::string operand_source_id = pir::GetValueId(&operand_source); |
| 168 | + pir::OpResult res = op->result(0); |
| 169 | + std::string res_id = pir::GetValueId(&res); |
| 170 | + |
| 171 | + std::vector<int64_t> dims = |
| 172 | + common::vectorize(res.type().dyn_cast<pir::DenseTensorType>().dims()); |
| 173 | + |
| 174 | + std::vector<symbol::DimExpr> shapes; |
| 175 | + for (int64_t dim : dims) { |
| 176 | + symbol::DimExpr dim_expr; |
| 177 | + if (dim == -1) { |
| 178 | + symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName()); |
| 179 | + dim_expr = res_dim_expr; |
| 180 | + } else { |
| 181 | + symbol::DimExpr res_dim_expr(dim); |
| 182 | + dim_expr = res_dim_expr; |
| 183 | + } |
| 184 | + shapes.push_back(dim_expr); |
| 185 | + } |
| 186 | + |
| 187 | + // pir::AttributeMap attributes = op->attributes(); |
| 188 | + |
| 189 | + // auto attr_starts = |
| 190 | + // attributes["starts"].dyn_cast<pir::ArrayAttribute>().AsVector(); |
| 191 | + // auto start = attr_starts[0].dyn_cast<pir::Int64Attribute>().data(); |
| 192 | + |
| 193 | + // auto attr_ends = |
| 194 | + // attributes["ends"].dyn_cast<pir::ArrayAttribute>().AsVector(); |
| 195 | + // auto end = attr_ends[0].dyn_cast<pir::Int64Attribute>().data(); |
| 196 | + |
| 197 | + symbol::ShapeOrDataDimExprs shape_data{shapes}; |
| 198 | + shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; |
| 199 | + return true; |
| 200 | +} |
| 201 | + |
| 202 | +} // namespace cinn::dialect |
| 203 | + |
153 | 204 | IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::InferSymbolicShapeInterface)
|
0 commit comments