13
13
// limitations under the License.
14
14
15
15
#include " paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h"
16
+ #include " paddle/pir/core/builtin_attribute.h"
17
+ #include " paddle/pir/core/builtin_type.h"
16
18
#include " paddle/pir/dialect/shape/ir/shape_op.h"
17
19
18
20
namespace paddle ::dialect {
@@ -37,6 +39,18 @@ bool InferSymbolicShapeAllEqualUnary(
37
39
shape_analysis->value_id_to_shapeordata_ [operand_source_id];
38
40
return true ;
39
41
}
42
+
43
+ bool InferSymbolicShapeAllEqualBinary (
44
+ pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
45
+ pir::Value operand_source = op->operand_source (0 );
46
+ std::string operand_source_id = pir::GetValueId (&operand_source);
47
+ pir::OpResult res = op->result (0 );
48
+ std::string res_id = pir::GetValueId (&res);
49
+ shape_analysis->value_id_to_shapeordata_ [res_id] =
50
+ shape_analysis->value_id_to_shapeordata_ [operand_source_id];
51
+ return true ;
52
+ }
53
+
40
54
} // namespace
41
55
42
56
bool AbsOpInferSymbolicShape (pir::Operation *op,
@@ -69,15 +83,121 @@ bool Exp_OpInferSymbolicShape(pir::Operation *op,
69
83
return InferSymbolicShapeAllEqualUnary (op, shape_analysis);
70
84
}
71
85
72
- bool TransposeOpInferSymbolicShape (
86
+ bool SubtractOpInferSymbolicShape (
73
87
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
88
+ return InferSymbolicShapeAllEqualBinary (op, shape_analysis);
89
+ }
90
+
91
+ bool Subtract_OpInferSymbolicShape (
92
+ pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
93
+ return InferSymbolicShapeAllEqualBinary (op, shape_analysis);
94
+ }
95
+
96
+ bool ShapeOpInferSymbolicShape (pir::Operation *op,
97
+ pir::ShapeConstraintIRAnalysis *shape_analysis) {
98
+ pir::Value operand_source = op->operand_source (0 );
99
+ std::string operand_source_id = pir::GetValueId (&operand_source);
100
+ pir::OpResult res = op->result (0 );
101
+ std::string res_id = pir::GetValueId (&res);
102
+
103
+ std::vector<int64_t > dims =
104
+ common::vectorize (res.type ().dyn_cast <pir::DenseTensorType>().dims ());
105
+
106
+ std::vector<symbol::DimExpr> shapes;
107
+ for (int64_t dim : dims) {
108
+ symbol::DimExpr dim_expr;
109
+ if (dim == -1 ) {
110
+ symbol::DimExpr res_dim_expr (shape_analysis->GetNextSymName ());
111
+ dim_expr = res_dim_expr;
112
+ } else {
113
+ symbol::DimExpr res_dim_expr (dim);
114
+ dim_expr = res_dim_expr;
115
+ }
116
+ shapes.push_back (dim_expr);
117
+ }
118
+
119
+ symbol::ShapeOrDataDimExprs shape_data{shapes};
120
+ shape_analysis->value_id_to_shapeordata_ [res_id] = shape_data;
121
+ }
122
+
123
+ bool ShapeSrOpInferSymbolicShape (
124
+ pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
125
+ return ShapeOpInferSymbolicShape (op, shape_analysis);
126
+ }
127
+
128
+ bool StackOpInferSymbolicShape (pir::Operation *op,
129
+ pir::ShapeConstraintIRAnalysis *shape_analysis) {
130
+ pir::Value operand_source = op->operand_source (0 );
131
+ std::string operand_source_id = pir::GetValueId (&operand_source);
132
+ pir::OpResult res = op->result (0 );
133
+ std::string res_id = pir::GetValueId (&res);
134
+
135
+ symbol::ShapeOrDataDimExprs shape_data;
136
+ shape_data = shape_analysis->value_id_to_shapeordata_ [operand_source_id];
137
+ shape_analysis->value_id_to_shapeordata_ [res_id] = shape_data;
74
138
return true ;
75
139
}
76
140
77
- bool ConcatOpInferSymbolicShape (
141
+ bool ReshapeOpInferSymbolicShape (
78
142
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
143
+ pir::Value operand_source_1 = op->operand_source (1 );
144
+ std::string operand_source_1_id = pir::GetValueId (&operand_source_1);
145
+ pir::OpResult res = op->result (0 );
146
+ std::string res_id = pir::GetValueId (&res);
147
+
148
+ symbol::ShapeOrDataDimExprs shape_data;
149
+
150
+ shape_data = shape_analysis->value_id_to_shapeordata_ [operand_source_1_id];
151
+ shape_analysis->value_id_to_shapeordata_ [res_id] = shape_data;
79
152
return true ;
80
153
}
81
154
155
+ bool Reshape_OpInferSymbolicShape (
156
+ pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
157
+ return ReshapeOpInferSymbolicShape (op, shape_analysis);
158
+ }
159
+
82
160
} // namespace paddle::dialect
161
+ namespace cinn ::dialect {
162
+
163
+ bool SliceOpInferSymbolicShape (pir::Operation *op,
164
+ pir::ShapeConstraintIRAnalysis *shape_analysis) {
165
+ pir::Value operand_source = op->operand_source (0 );
166
+ std::string operand_source_id = pir::GetValueId (&operand_source);
167
+ pir::OpResult res = op->result (0 );
168
+ std::string res_id = pir::GetValueId (&res);
169
+
170
+ std::vector<int64_t > dims =
171
+ common::vectorize (res.type ().dyn_cast <pir::DenseTensorType>().dims ());
172
+
173
+ std::vector<symbol::DimExpr> shapes;
174
+ for (int64_t dim : dims) {
175
+ symbol::DimExpr dim_expr;
176
+ if (dim == -1 ) {
177
+ symbol::DimExpr res_dim_expr (shape_analysis->GetNextSymName ());
178
+ dim_expr = res_dim_expr;
179
+ } else {
180
+ symbol::DimExpr res_dim_expr (dim);
181
+ dim_expr = res_dim_expr;
182
+ }
183
+ shapes.push_back (dim_expr);
184
+ }
185
+
186
+ pir::AttributeMap attributes = op->attributes ();
187
+
188
+ auto attr_starts =
189
+ attributes[" starts" ].dyn_cast <pir::ArrayAttribute>().AsVector ();
190
+ auto start = attr_starts[0 ].dyn_cast <pir::Int64Attribute>().data ();
191
+
192
+ auto attr_ends =
193
+ attributes[" ends" ].dyn_cast <pir::ArrayAttribute>().AsVector ();
194
+ auto end = attr_ends[0 ].dyn_cast <pir::Int64Attribute>().data ();
195
+
196
+ symbol::ShapeOrDataDimExprs shape_data{shapes};
197
+ shape_analysis->value_id_to_shapeordata_ [res_id] = shape_data;
198
+ return true ;
199
+ }
200
+
201
+ } // namespace cinn::dialect
202
+
83
203
IR_DEFINE_EXPLICIT_TYPE_ID (paddle::dialect::InferSymbolicShapeInterface)
0 commit comments