13
13
// limitations under the License.
14
14
15
15
#include " paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h"
16
+ #include " paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
16
17
#include " paddle/pir/core/builtin_attribute.h"
17
18
#include " paddle/pir/core/builtin_type.h"
18
19
#include " paddle/pir/dialect/shape/ir/shape_attribute.h"
@@ -25,113 +26,112 @@ bool InferSymbolicShapeInterface::InferSymbolicShape(
25
26
}
26
27
} // namespace paddle::dialect
27
28
28
- namespace paddle ::dialect {
29
-
30
29
namespace {
31
30
32
- bool InferSymbolicShapeAllEqualUnary (
31
+ bool SameOperandsAndResultShape (
33
32
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
34
33
pir::Value operand_source = op->operand_source (0 );
35
- std::string operand_source_id = pir::GetValueId (&operand_source);
36
34
pir::OpResult res = op->result (0 );
37
- std::string res_id = pir::GetValueId (&res);
38
- shape_analysis->value_id_to_shapeordata_ [res_id] =
39
- shape_analysis->value_id_to_shapeordata_ [operand_source_id];
40
- return true ;
41
- }
35
+ symbol::ShapeOrDataDimExprs operand_shape_or_data =
36
+ shape_analysis->value_to_shape_or_data_ [operand_source];
37
+ shape_analysis->value_to_shape_or_data_ [res] = operand_shape_or_data;
42
38
43
- bool InferSymbolicShapeAllEqualBinary (
44
- pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
45
- pir::Value operand_source = op->operand_source (0 );
46
- std::string operand_source_id = pir::GetValueId (&operand_source);
47
- pir::OpResult res = op->result (0 );
48
- std::string res_id = pir::GetValueId (&res);
49
- shape_analysis->value_id_to_shapeordata_ [res_id] =
50
- shape_analysis->value_id_to_shapeordata_ [operand_source_id];
39
+ op->set_attribute (" symbolic_shape" ,
40
+ pir::shape::SymbolAttribute::get (pir::IrContext::Instance (),
41
+ operand_shape_or_data));
51
42
return true ;
52
43
}
53
44
54
45
} // namespace
55
46
47
+ namespace paddle ::dialect {
56
48
bool AbsOpInferSymbolicShape (pir::Operation *op,
57
49
pir::ShapeConstraintIRAnalysis *shape_analysis) {
58
- return InferSymbolicShapeAllEqualUnary (op, shape_analysis);
50
+ return SameOperandsAndResultShape (op, shape_analysis);
59
51
}
60
52
61
53
bool Abs_OpInferSymbolicShape (pir::Operation *op,
62
54
pir::ShapeConstraintIRAnalysis *shape_analysis) {
63
- return InferSymbolicShapeAllEqualUnary (op, shape_analysis);
55
+ return SameOperandsAndResultShape (op, shape_analysis);
64
56
}
65
57
66
58
bool DataOpInferSymbolicShape (pir::Operation *op,
67
59
pir::ShapeConstraintIRAnalysis *shape_analysis) {
68
- symbol::ShapeOrDataDimExprs sss;
60
+ auto attributes = op->attributes ();
61
+ pir::Attribute attr = attributes[" shape" ];
62
+ std::vector<int64_t > dims =
63
+ attr.dyn_cast <paddle::dialect::IntArrayAttribute>().data ().GetData ();
69
64
65
+ std::vector<symbol::DimExpr> sym_dims;
66
+ for (auto dim : dims) {
67
+ symbol::DimExpr dim_expr;
68
+ if (dim == -1 ) {
69
+ symbol::DimExpr symbolic_dim_expr (shape_analysis->GetNextSymName ());
70
+ dim_expr = symbolic_dim_expr;
71
+ } else {
72
+ symbol::DimExpr numeric_dim_expr (dim);
73
+ dim_expr = numeric_dim_expr;
74
+ }
75
+ sym_dims.push_back (dim_expr);
76
+ }
77
+
78
+ symbol::ShapeOrDataDimExprs shape_data{sym_dims};
70
79
op->set_attribute (
71
- " sym_shape" ,
72
- pir::shape::SymbolAttribute::get (pir::IrContext::Instance (), sss));
80
+ " symbolic_shape" ,
81
+ pir::shape::SymbolAttribute::get (pir::IrContext::Instance (), shape_data));
82
+
83
+ pir::OpResult res = op->result (0 );
84
+ shape_analysis->value_to_shape_or_data_ [res] = shape_data;
73
85
74
- // auto attributes = op->attributes();
75
- // pir::Attribute attr = attributes["shape"];
76
- // const auto &vec = attr.dyn_cast<pir::ArrayAttribute>().AsVector();
77
86
return true ;
78
87
}
79
88
80
89
bool CastOpInferSymbolicShape (pir::Operation *op,
81
90
pir::ShapeConstraintIRAnalysis *shape_analysis) {
82
- return InferSymbolicShapeAllEqualUnary (op, shape_analysis);
91
+ return SameOperandsAndResultShape (op, shape_analysis);
83
92
}
84
93
85
94
bool Cast_OpInferSymbolicShape (pir::Operation *op,
86
95
pir::ShapeConstraintIRAnalysis *shape_analysis) {
87
- return InferSymbolicShapeAllEqualUnary (op, shape_analysis);
96
+ return SameOperandsAndResultShape (op, shape_analysis);
88
97
}
89
98
90
99
bool ExpOpInferSymbolicShape (pir::Operation *op,
91
100
pir::ShapeConstraintIRAnalysis *shape_analysis) {
92
- return InferSymbolicShapeAllEqualUnary (op, shape_analysis);
101
+ return SameOperandsAndResultShape (op, shape_analysis);
93
102
}
94
103
95
104
bool Exp_OpInferSymbolicShape (pir::Operation *op,
96
105
pir::ShapeConstraintIRAnalysis *shape_analysis) {
97
- return InferSymbolicShapeAllEqualUnary (op, shape_analysis);
106
+ return SameOperandsAndResultShape (op, shape_analysis);
98
107
}
99
108
100
109
bool SubtractOpInferSymbolicShape (
101
110
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
102
- return InferSymbolicShapeAllEqualBinary (op, shape_analysis);
111
+ return SameOperandsAndResultShape (op, shape_analysis);
103
112
}
104
113
105
114
bool Subtract_OpInferSymbolicShape (
106
115
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
107
- return InferSymbolicShapeAllEqualBinary (op, shape_analysis);
116
+ return SameOperandsAndResultShape (op, shape_analysis);
108
117
}
109
118
110
119
bool ShapeOpInferSymbolicShape (pir::Operation *op,
111
120
pir::ShapeConstraintIRAnalysis *shape_analysis) {
112
121
pir::Value operand_source = op->operand_source (0 );
113
- std::string operand_source_id = pir::GetValueId (&operand_source);
114
122
pir::OpResult res = op->result (0 );
115
- std::string res_id = pir::GetValueId (&res);
116
123
117
- std::vector< int64_t > dims =
118
- common::vectorize (res. type (). dyn_cast <pir::DenseTensorType>(). dims ()) ;
124
+ symbol::ShapeOrDataDimExprs operand_shape_or_data =
125
+ shape_analysis-> value_to_shape_or_data_ [operand_source] ;
119
126
120
- std::vector<symbol::DimExpr> shapes;
121
- for (int64_t dim : dims) {
122
- symbol::DimExpr dim_expr;
123
- if (dim == -1 ) {
124
- symbol::DimExpr res_dim_expr (shape_analysis->GetNextSymName ());
125
- dim_expr = res_dim_expr;
126
- } else {
127
- symbol::DimExpr res_dim_expr (dim);
128
- dim_expr = res_dim_expr;
129
- }
130
- shapes.push_back (dim_expr);
131
- }
127
+ symbol::ShapeOrDataDimExprs extend_shape_or_data =
128
+ symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData (
129
+ operand_shape_or_data);
132
130
133
- symbol::ShapeOrDataDimExprs shape_data{shapes};
134
- shape_analysis->value_id_to_shapeordata_ [res_id] = shape_data;
131
+ shape_analysis->value_to_shape_or_data_ [res] = extend_shape_or_data;
132
+ op->set_attribute (" symbolic_shape" ,
133
+ pir::shape::SymbolAttribute::get (pir::IrContext::Instance (),
134
+ extend_shape_or_data));
135
135
return true ;
136
136
}
137
137
@@ -147,8 +147,8 @@ bool StackOpInferSymbolicShape(pir::Operation *op,
147
147
pir::OpResult res = op->result (0 );
148
148
std::string res_id = pir::GetValueId (&res);
149
149
150
- symbol::ShapeOrDataDimExprs shape_data;
151
- shape_data = shape_analysis->value_id_to_shapeordata_ [operand_source_id];
150
+ symbol::ShapeOrDataDimExprs shape_data =
151
+ shape_analysis->value_id_to_shapeordata_ [operand_source_id];
152
152
shape_analysis->value_id_to_shapeordata_ [res_id] = shape_data;
153
153
return true ;
154
154
}
@@ -160,9 +160,8 @@ bool ReshapeOpInferSymbolicShape(
160
160
pir::OpResult res = op->result (0 );
161
161
std::string res_id = pir::GetValueId (&res);
162
162
163
- symbol::ShapeOrDataDimExprs shape_data;
164
-
165
- shape_data = shape_analysis->value_id_to_shapeordata_ [operand_source_1_id];
163
+ symbol::ShapeOrDataDimExprs shape_data =
164
+ shape_analysis->value_id_to_shapeordata_ [operand_source_1_id];
166
165
shape_analysis->value_id_to_shapeordata_ [res_id] = shape_data;
167
166
return true ;
168
167
}
@@ -174,82 +173,97 @@ bool Reshape_OpInferSymbolicShape(
174
173
175
174
bool FullIntArrayOpInferSymbolicShape (
176
175
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
177
- for (auto &res : op->results ()) {
178
- std::string value_id = pir::GetValueId (&res);
179
- std::vector<int64_t > dims =
180
- common::vectorize (res.type ().dyn_cast <pir::DenseTensorType>().dims ());
181
-
182
- std::vector<symbol::DimExpr> shapes;
183
- for (int64_t dim : dims) {
184
- symbol::DimExpr dim_expr;
185
- if (dim == -1 ) {
186
- symbol::DimExpr res_dim_expr (shape_analysis->GetNextSymName ());
187
- dim_expr = res_dim_expr;
188
- } else {
189
- symbol::DimExpr res_dim_expr (dim);
190
- dim_expr = res_dim_expr;
191
- }
192
- shapes.push_back (dim_expr);
193
- }
176
+ auto attributes = op->attributes ();
177
+ pir::Attribute attr = attributes[" value" ];
178
+ const auto &vec = attr.dyn_cast <pir::ArrayAttribute>().AsVector ();
194
179
195
- auto attributes = op-> attributes () ;
196
- pir::Attribute attr = attributes[ " value " ] ;
197
- const auto &vec = attr. dyn_cast <pir::ArrayAttribute>(). AsVector () ;
180
+ std::vector<symbol::DimExpr> sym_dims ;
181
+ sym_dims. push_back ( symbol::DimExpr ( std::int64_t (vec. size ()))) ;
182
+ symbol::ShapeOrDataDimExprs shape_data{sym_dims} ;
198
183
199
- for (auto item : vec) {
200
- int64_t i = item.dyn_cast <pir::Int64Attribute>().data ();
201
- shapes.push_back (symbol::DimExpr (i));
202
- }
203
-
204
- // for (auto &item : shapes) {
205
- // VLOG(0) << symbol::ToString(item);
206
- // }
184
+ op->set_attribute (
185
+ " symbolic_shape" ,
186
+ pir::shape::SymbolAttribute::get (pir::IrContext::Instance (), shape_data));
207
187
208
- symbol::ShapeOrDataDimExprs shape_data{shapes};
209
- shape_analysis->value_id_to_shapeordata_ [value_id] = shape_data;
210
- return true ;
211
- }
188
+ pir::OpResult res = op->result (0 );
189
+ shape_analysis->value_to_shape_or_data_ [res] = shape_data;
190
+ return true ;
212
191
}
213
192
214
- } // namespace paddle::dialect
215
- namespace cinn ::dialect {
216
-
217
193
bool SliceOpInferSymbolicShape (pir::Operation *op,
218
194
pir::ShapeConstraintIRAnalysis *shape_analysis) {
219
195
pir::Value operand_source = op->operand_source (0 );
220
- std::string operand_source_id = pir::GetValueId (&operand_source);
221
- pir::OpResult res = op->result (0 );
222
- std::string res_id = pir::GetValueId (&res);
223
196
224
- std::vector<int64_t > dims =
225
- common::vectorize (res.type ().dyn_cast <pir::DenseTensorType>().dims ());
226
-
227
- std::vector<symbol::DimExpr> shapes;
228
- for (int64_t dim : dims) {
229
- symbol::DimExpr dim_expr;
230
- if (dim == -1 ) {
231
- symbol::DimExpr res_dim_expr (shape_analysis->GetNextSymName ());
232
- dim_expr = res_dim_expr;
233
- } else {
234
- symbol::DimExpr res_dim_expr (dim);
235
- dim_expr = res_dim_expr;
236
- }
237
- shapes.push_back (dim_expr);
197
+ symbol::ShapeOrDataDimExprs operand_shape_or_data =
198
+ shape_analysis->value_to_shape_or_data_ [operand_source];
199
+
200
+ pir::AttributeMap attributes = op->attributes ();
201
+ // auto attr_axes =
202
+ // attributes["axes"].dyn_cast<pir::ArrayAttribute>().AsVector();
203
+ // auto attr_infer_flags =
204
+ // attributes["infer_flags"].dyn_cast<pir::ArrayAttribute>().AsVector();
205
+ // auto attr_decrease_axis =
206
+ // attributes["decrease_axis"].dyn_cast<pir::ArrayAttribute>().AsVector();
207
+
208
+ // std::vector<int64_t> new_axes;
209
+ // for (size_t i = 0; i < attr_axes.size(); ++i) {
210
+ // if (attr_axes[i].dyn_cast<pir::Int64Attribute>().data() < 0) {
211
+ // new_axes.push_back(
212
+ // std::max(int64_t(0),
213
+ // attr_axes[i].dyn_cast<pir::Int64Attribute>().data() +
214
+ // int64_t(operand_shape_or_data.size())));
215
+ // } else {
216
+ // new_axes.push_back(attr_axes[i].dyn_cast<pir::Int64Attribute>().data());
217
+ // }
218
+ // }
219
+
220
+ // Special case.
221
+ std::vector<int64_t > starts =
222
+ attributes.at (" starts" )
223
+ .dyn_cast <paddle::dialect::IntArrayAttribute>()
224
+ .data ()
225
+ .GetData ();
226
+ int64_t start = starts[0 ];
227
+ std::vector<symbol::DimExpr> out_dims;
228
+ if (operand_shape_or_data.data ().has_value ()) {
229
+ out_dims.push_back (operand_shape_or_data.data ().value ()[start]);
230
+ } else {
231
+ out_dims.push_back (operand_shape_or_data.shape ()[start]);
238
232
}
239
233
240
- // pir::AttributeMap attributes = op->attributes();
234
+ // Note(zhangbopd): Currently we do not consider the case that the
235
+ // new_axes/attr_starts/attr_ends etc. are symoblic.
236
+ // CheckAndUpdateSliceAttrs(operand_shape_or_data,
237
+ // new_axes,
238
+ // &attr_starts,
239
+ // &attr_ends,
240
+ // &attr_infer_flags);
241
+ // auto slice_dims = GetSliceDims(operand_shape_or_data,
242
+ // new_axes,
243
+ // attr_starts,
244
+ // attr_ends,
245
+ // &attr_infer_flags);
246
+ // std::vector<symbol::DimExpr> out_dims =
247
+ // GetDecreasedDims(slice_dims, attr_decrease_axis);
248
+
249
+ symbol::ShapeOrDataDimExprs shape_data{out_dims};
250
+ // unknown Attribute
251
+ // op->set_attribute(
252
+ // "symbolic_shape",
253
+ // pir::shape::SymbolAttribute::get(pir::IrContext::Instance(),
254
+ // shape_data));
241
255
242
- // auto attr_starts =
243
- // attributes["starts"].dyn_cast<pir::ArrayAttribute>().AsVector();
244
- // auto start = attr_starts[0].dyn_cast<pir::Int64Attribute>().data();
256
+ pir::OpResult res = op->result (0 );
257
+ shape_analysis->value_to_shape_or_data_ [res] = shape_data;
258
+ return true ;
259
+ }
245
260
246
- // auto attr_ends =
247
- // attributes["ends"].dyn_cast<pir::ArrayAttribute>().AsVector();
248
- // auto end = attr_ends[0].dyn_cast<pir::Int64Attribute>().data();
261
+ } // namespace paddle::dialect
262
+ namespace cinn ::dialect {
249
263
250
- symbol::ShapeOrDataDimExprs shape_data{shapes};
251
- shape_analysis-> value_id_to_shapeordata_ [res_id] = shape_data;
252
- return true ;
264
+ bool SliceOpInferSymbolicShape (pir::Operation *op,
265
+ pir::ShapeConstraintIRAnalysis *shape_analysis) {
266
+ return paddle::dialect::SliceOpInferSymbolicShape (op, shape_analysis) ;
253
267
}
254
268
255
269
} // namespace cinn::dialect
0 commit comments