Skip to content

Commit 3fac9fe

Browse files
committed
add slice
1 parent 75fe3ae commit 3fac9fe

File tree

7 files changed

+163
-122
lines changed

7 files changed

+163
-122
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc

+130-116
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h"
16+
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
1617
#include "paddle/pir/core/builtin_attribute.h"
1718
#include "paddle/pir/core/builtin_type.h"
1819
#include "paddle/pir/dialect/shape/ir/shape_attribute.h"
@@ -25,113 +26,112 @@ bool InferSymbolicShapeInterface::InferSymbolicShape(
2526
}
2627
} // namespace paddle::dialect
2728

28-
namespace paddle::dialect {
29-
3029
namespace {
3130

32-
bool InferSymbolicShapeAllEqualUnary(
31+
bool SameOperandsAndResultShape(
3332
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
3433
pir::Value operand_source = op->operand_source(0);
35-
std::string operand_source_id = pir::GetValueId(&operand_source);
3634
pir::OpResult res = op->result(0);
37-
std::string res_id = pir::GetValueId(&res);
38-
shape_analysis->value_id_to_shapeordata_[res_id] =
39-
shape_analysis->value_id_to_shapeordata_[operand_source_id];
40-
return true;
41-
}
35+
symbol::ShapeOrDataDimExprs operand_shape_or_data =
36+
shape_analysis->value_to_shape_or_data_[operand_source];
37+
shape_analysis->value_to_shape_or_data_[res] = operand_shape_or_data;
4238

43-
bool InferSymbolicShapeAllEqualBinary(
44-
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
45-
pir::Value operand_source = op->operand_source(0);
46-
std::string operand_source_id = pir::GetValueId(&operand_source);
47-
pir::OpResult res = op->result(0);
48-
std::string res_id = pir::GetValueId(&res);
49-
shape_analysis->value_id_to_shapeordata_[res_id] =
50-
shape_analysis->value_id_to_shapeordata_[operand_source_id];
39+
op->set_attribute("symbolic_shape",
40+
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(),
41+
operand_shape_or_data));
5142
return true;
5243
}
5344

5445
} // namespace
5546

47+
namespace paddle::dialect {
5648
bool AbsOpInferSymbolicShape(pir::Operation *op,
5749
pir::ShapeConstraintIRAnalysis *shape_analysis) {
58-
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
50+
return SameOperandsAndResultShape(op, shape_analysis);
5951
}
6052

6153
bool Abs_OpInferSymbolicShape(pir::Operation *op,
6254
pir::ShapeConstraintIRAnalysis *shape_analysis) {
63-
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
55+
return SameOperandsAndResultShape(op, shape_analysis);
6456
}
6557

6658
bool DataOpInferSymbolicShape(pir::Operation *op,
6759
pir::ShapeConstraintIRAnalysis *shape_analysis) {
68-
symbol::ShapeOrDataDimExprs sss;
60+
auto attributes = op->attributes();
61+
pir::Attribute attr = attributes["shape"];
62+
std::vector<int64_t> dims =
63+
attr.dyn_cast<paddle::dialect::IntArrayAttribute>().data().GetData();
6964

65+
std::vector<symbol::DimExpr> sym_dims;
66+
for (auto dim : dims) {
67+
symbol::DimExpr dim_expr;
68+
if (dim == -1) {
69+
symbol::DimExpr symbolic_dim_expr(shape_analysis->GetNextSymName());
70+
dim_expr = symbolic_dim_expr;
71+
} else {
72+
symbol::DimExpr numeric_dim_expr(dim);
73+
dim_expr = numeric_dim_expr;
74+
}
75+
sym_dims.push_back(dim_expr);
76+
}
77+
78+
symbol::ShapeOrDataDimExprs shape_data{sym_dims};
7079
op->set_attribute(
71-
"sym_shape",
72-
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), sss));
80+
"symbolic_shape",
81+
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
82+
83+
pir::OpResult res = op->result(0);
84+
shape_analysis->value_to_shape_or_data_[res] = shape_data;
7385

74-
// auto attributes = op->attributes();
75-
// pir::Attribute attr = attributes["shape"];
76-
// const auto &vec = attr.dyn_cast<pir::ArrayAttribute>().AsVector();
7786
return true;
7887
}
7988

8089
bool CastOpInferSymbolicShape(pir::Operation *op,
8190
pir::ShapeConstraintIRAnalysis *shape_analysis) {
82-
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
91+
return SameOperandsAndResultShape(op, shape_analysis);
8392
}
8493

8594
bool Cast_OpInferSymbolicShape(pir::Operation *op,
8695
pir::ShapeConstraintIRAnalysis *shape_analysis) {
87-
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
96+
return SameOperandsAndResultShape(op, shape_analysis);
8897
}
8998

9099
bool ExpOpInferSymbolicShape(pir::Operation *op,
91100
pir::ShapeConstraintIRAnalysis *shape_analysis) {
92-
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
101+
return SameOperandsAndResultShape(op, shape_analysis);
93102
}
94103

95104
bool Exp_OpInferSymbolicShape(pir::Operation *op,
96105
pir::ShapeConstraintIRAnalysis *shape_analysis) {
97-
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
106+
return SameOperandsAndResultShape(op, shape_analysis);
98107
}
99108

100109
bool SubtractOpInferSymbolicShape(
101110
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
102-
return InferSymbolicShapeAllEqualBinary(op, shape_analysis);
111+
return SameOperandsAndResultShape(op, shape_analysis);
103112
}
104113

105114
bool Subtract_OpInferSymbolicShape(
106115
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
107-
return InferSymbolicShapeAllEqualBinary(op, shape_analysis);
116+
return SameOperandsAndResultShape(op, shape_analysis);
108117
}
109118

110119
bool ShapeOpInferSymbolicShape(pir::Operation *op,
111120
pir::ShapeConstraintIRAnalysis *shape_analysis) {
112121
pir::Value operand_source = op->operand_source(0);
113-
std::string operand_source_id = pir::GetValueId(&operand_source);
114122
pir::OpResult res = op->result(0);
115-
std::string res_id = pir::GetValueId(&res);
116123

117-
std::vector<int64_t> dims =
118-
common::vectorize(res.type().dyn_cast<pir::DenseTensorType>().dims());
124+
symbol::ShapeOrDataDimExprs operand_shape_or_data =
125+
shape_analysis->value_to_shape_or_data_[operand_source];
119126

120-
std::vector<symbol::DimExpr> shapes;
121-
for (int64_t dim : dims) {
122-
symbol::DimExpr dim_expr;
123-
if (dim == -1) {
124-
symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName());
125-
dim_expr = res_dim_expr;
126-
} else {
127-
symbol::DimExpr res_dim_expr(dim);
128-
dim_expr = res_dim_expr;
129-
}
130-
shapes.push_back(dim_expr);
131-
}
127+
symbol::ShapeOrDataDimExprs extend_shape_or_data =
128+
symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(
129+
operand_shape_or_data);
132130

133-
symbol::ShapeOrDataDimExprs shape_data{shapes};
134-
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
131+
shape_analysis->value_to_shape_or_data_[res] = extend_shape_or_data;
132+
op->set_attribute("symbolic_shape",
133+
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(),
134+
extend_shape_or_data));
135135
return true;
136136
}
137137

@@ -147,8 +147,8 @@ bool StackOpInferSymbolicShape(pir::Operation *op,
147147
pir::OpResult res = op->result(0);
148148
std::string res_id = pir::GetValueId(&res);
149149

150-
symbol::ShapeOrDataDimExprs shape_data;
151-
shape_data = shape_analysis->value_id_to_shapeordata_[operand_source_id];
150+
symbol::ShapeOrDataDimExprs shape_data =
151+
shape_analysis->value_id_to_shapeordata_[operand_source_id];
152152
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
153153
return true;
154154
}
@@ -160,9 +160,8 @@ bool ReshapeOpInferSymbolicShape(
160160
pir::OpResult res = op->result(0);
161161
std::string res_id = pir::GetValueId(&res);
162162

163-
symbol::ShapeOrDataDimExprs shape_data;
164-
165-
shape_data = shape_analysis->value_id_to_shapeordata_[operand_source_1_id];
163+
symbol::ShapeOrDataDimExprs shape_data =
164+
shape_analysis->value_id_to_shapeordata_[operand_source_1_id];
166165
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
167166
return true;
168167
}
@@ -174,82 +173,97 @@ bool Reshape_OpInferSymbolicShape(
174173

175174
bool FullIntArrayOpInferSymbolicShape(
176175
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
177-
for (auto &res : op->results()) {
178-
std::string value_id = pir::GetValueId(&res);
179-
std::vector<int64_t> dims =
180-
common::vectorize(res.type().dyn_cast<pir::DenseTensorType>().dims());
181-
182-
std::vector<symbol::DimExpr> shapes;
183-
for (int64_t dim : dims) {
184-
symbol::DimExpr dim_expr;
185-
if (dim == -1) {
186-
symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName());
187-
dim_expr = res_dim_expr;
188-
} else {
189-
symbol::DimExpr res_dim_expr(dim);
190-
dim_expr = res_dim_expr;
191-
}
192-
shapes.push_back(dim_expr);
193-
}
176+
auto attributes = op->attributes();
177+
pir::Attribute attr = attributes["value"];
178+
const auto &vec = attr.dyn_cast<pir::ArrayAttribute>().AsVector();
194179

195-
auto attributes = op->attributes();
196-
pir::Attribute attr = attributes["value"];
197-
const auto &vec = attr.dyn_cast<pir::ArrayAttribute>().AsVector();
180+
std::vector<symbol::DimExpr> sym_dims;
181+
sym_dims.push_back(symbol::DimExpr(std::int64_t(vec.size())));
182+
symbol::ShapeOrDataDimExprs shape_data{sym_dims};
198183

199-
for (auto item : vec) {
200-
int64_t i = item.dyn_cast<pir::Int64Attribute>().data();
201-
shapes.push_back(symbol::DimExpr(i));
202-
}
203-
204-
// for (auto &item : shapes) {
205-
// VLOG(0) << symbol::ToString(item);
206-
// }
184+
op->set_attribute(
185+
"symbolic_shape",
186+
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
207187

208-
symbol::ShapeOrDataDimExprs shape_data{shapes};
209-
shape_analysis->value_id_to_shapeordata_[value_id] = shape_data;
210-
return true;
211-
}
188+
pir::OpResult res = op->result(0);
189+
shape_analysis->value_to_shape_or_data_[res] = shape_data;
190+
return true;
212191
}
213192

214-
} // namespace paddle::dialect
215-
namespace cinn::dialect {
216-
217193
bool SliceOpInferSymbolicShape(pir::Operation *op,
218194
pir::ShapeConstraintIRAnalysis *shape_analysis) {
219195
pir::Value operand_source = op->operand_source(0);
220-
std::string operand_source_id = pir::GetValueId(&operand_source);
221-
pir::OpResult res = op->result(0);
222-
std::string res_id = pir::GetValueId(&res);
223196

224-
std::vector<int64_t> dims =
225-
common::vectorize(res.type().dyn_cast<pir::DenseTensorType>().dims());
226-
227-
std::vector<symbol::DimExpr> shapes;
228-
for (int64_t dim : dims) {
229-
symbol::DimExpr dim_expr;
230-
if (dim == -1) {
231-
symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName());
232-
dim_expr = res_dim_expr;
233-
} else {
234-
symbol::DimExpr res_dim_expr(dim);
235-
dim_expr = res_dim_expr;
236-
}
237-
shapes.push_back(dim_expr);
197+
symbol::ShapeOrDataDimExprs operand_shape_or_data =
198+
shape_analysis->value_to_shape_or_data_[operand_source];
199+
200+
pir::AttributeMap attributes = op->attributes();
201+
// auto attr_axes =
202+
// attributes["axes"].dyn_cast<pir::ArrayAttribute>().AsVector();
203+
// auto attr_infer_flags =
204+
// attributes["infer_flags"].dyn_cast<pir::ArrayAttribute>().AsVector();
205+
// auto attr_decrease_axis =
206+
// attributes["decrease_axis"].dyn_cast<pir::ArrayAttribute>().AsVector();
207+
208+
// std::vector<int64_t> new_axes;
209+
// for (size_t i = 0; i < attr_axes.size(); ++i) {
210+
// if (attr_axes[i].dyn_cast<pir::Int64Attribute>().data() < 0) {
211+
// new_axes.push_back(
212+
// std::max(int64_t(0),
213+
// attr_axes[i].dyn_cast<pir::Int64Attribute>().data() +
214+
// int64_t(operand_shape_or_data.size())));
215+
// } else {
216+
// new_axes.push_back(attr_axes[i].dyn_cast<pir::Int64Attribute>().data());
217+
// }
218+
// }
219+
220+
// Special case.
221+
std::vector<int64_t> starts =
222+
attributes.at("starts")
223+
.dyn_cast<paddle::dialect::IntArrayAttribute>()
224+
.data()
225+
.GetData();
226+
int64_t start = starts[0];
227+
std::vector<symbol::DimExpr> out_dims;
228+
if (operand_shape_or_data.data().has_value()) {
229+
out_dims.push_back(operand_shape_or_data.data().value()[start]);
230+
} else {
231+
out_dims.push_back(operand_shape_or_data.shape()[start]);
238232
}
239233

240-
// pir::AttributeMap attributes = op->attributes();
234+
// Note(zhangbopd): Currently we do not consider the case that the
235+
// new_axes/attr_starts/attr_ends etc. are symoblic.
236+
// CheckAndUpdateSliceAttrs(operand_shape_or_data,
237+
// new_axes,
238+
// &attr_starts,
239+
// &attr_ends,
240+
// &attr_infer_flags);
241+
// auto slice_dims = GetSliceDims(operand_shape_or_data,
242+
// new_axes,
243+
// attr_starts,
244+
// attr_ends,
245+
// &attr_infer_flags);
246+
// std::vector<symbol::DimExpr> out_dims =
247+
// GetDecreasedDims(slice_dims, attr_decrease_axis);
248+
249+
symbol::ShapeOrDataDimExprs shape_data{out_dims};
250+
// unknown Attribute
251+
// op->set_attribute(
252+
// "symbolic_shape",
253+
// pir::shape::SymbolAttribute::get(pir::IrContext::Instance(),
254+
// shape_data));
241255

242-
// auto attr_starts =
243-
// attributes["starts"].dyn_cast<pir::ArrayAttribute>().AsVector();
244-
// auto start = attr_starts[0].dyn_cast<pir::Int64Attribute>().data();
256+
pir::OpResult res = op->result(0);
257+
shape_analysis->value_to_shape_or_data_[res] = shape_data;
258+
return true;
259+
}
245260

246-
// auto attr_ends =
247-
// attributes["ends"].dyn_cast<pir::ArrayAttribute>().AsVector();
248-
// auto end = attr_ends[0].dyn_cast<pir::Int64Attribute>().data();
261+
} // namespace paddle::dialect
262+
namespace cinn::dialect {
249263

250-
symbol::ShapeOrDataDimExprs shape_data{shapes};
251-
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
252-
return true;
264+
bool SliceOpInferSymbolicShape(pir::Operation *op,
265+
pir::ShapeConstraintIRAnalysis *shape_analysis) {
266+
return paddle::dialect::SliceOpInferSymbolicShape(op, shape_analysis);
253267
}
254268

255269
} // namespace cinn::dialect

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h

+3
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ bool Reshape_OpInferSymbolicShape(
107107
bool FullIntArrayOpInferSymbolicShape(
108108
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);
109109

110+
bool SliceOpInferSymbolicShape(pir::Operation *op,
111+
pir::ShapeConstraintIRAnalysis *shape_analysis);
112+
110113
} // namespace paddle::dialect
111114

112115
namespace cinn::dialect {

paddle/fluid/pir/dialect/operator/ir/ops.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -1156,6 +1156,7 @@
11561156
kernel :
11571157
func : slice
11581158
backward : slice_grad
1159+
interfaces : paddle::dialect::InferSymbolicShapeInterface
11591160

11601161
- op : soft_relu
11611162
args : (Tensor x, float threshold = 20.0f)

paddle/fluid/pir/transforms/shape_optimization_pass.cc

+1-2
Original file line numberDiff line numberDiff line change
@@ -442,14 +442,13 @@ void DebugPrintOpInfo(
442442
pir::Operation* op,
443443
pir::ShapeConstraintIRAnalysis* shape_analysis = nullptr) {
444444
for (auto& res : op->results()) {
445-
auto value_id = pir::GetValueId(&res);
446445
std::ostringstream print_stream;
447446

448447
print_stream << "result(" << res.index() << ") "
449448
<< "ShapeOrData: ";
450449

451450
if (shape_analysis != nullptr) {
452-
auto shape_data = shape_analysis->value_id_to_shapeordata_[value_id];
451+
auto shape_data = shape_analysis->value_to_shape_or_data_[res];
453452
print_stream << "shape: [";
454453

455454
for (auto str : shape_data.shape()) {

0 commit comments

Comments
 (0)