Skip to content

Commit 634161b

Browse files
committed
add basic op for reshape
1 parent 40ac433 commit 634161b

File tree

7 files changed

+197
-168
lines changed

7 files changed

+197
-168
lines changed

paddle/cinn/hlir/dialect/operator/ir/ops.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
func : SliceRawInferMeta
7575
kernel :
7676
func : slice
77+
interfaces : paddle::dialect::InferSymbolicShapeInterface
7778

7879
- op : uniform_random
7980
args : (int64_t[] shape, float min, float max, int seed, DataType dtype, int diag_num = 0, int diag_step=0, float diag_val=1.0)

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc

+122-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h"
16+
#include "paddle/pir/core/builtin_attribute.h"
17+
#include "paddle/pir/core/builtin_type.h"
1618
#include "paddle/pir/dialect/shape/ir/shape_op.h"
1719

1820
namespace paddle::dialect {
@@ -37,6 +39,18 @@ bool InferSymbolicShapeAllEqualUnary(
3739
shape_analysis->value_id_to_shapeordata_[operand_source_id];
3840
return true;
3941
}
42+
43+
bool InferSymbolicShapeAllEqualBinary(
44+
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
45+
pir::Value operand_source = op->operand_source(0);
46+
std::string operand_source_id = pir::GetValueId(&operand_source);
47+
pir::OpResult res = op->result(0);
48+
std::string res_id = pir::GetValueId(&res);
49+
shape_analysis->value_id_to_shapeordata_[res_id] =
50+
shape_analysis->value_id_to_shapeordata_[operand_source_id];
51+
return true;
52+
}
53+
4054
} // namespace
4155

4256
bool AbsOpInferSymbolicShape(pir::Operation *op,
@@ -69,15 +83,121 @@ bool Exp_OpInferSymbolicShape(pir::Operation *op,
6983
return InferSymbolicShapeAllEqualUnary(op, shape_analysis);
7084
}
7185

72-
bool TransposeOpInferSymbolicShape(
86+
bool SubtractOpInferSymbolicShape(
7387
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
88+
return InferSymbolicShapeAllEqualBinary(op, shape_analysis);
89+
}
90+
91+
bool Subtract_OpInferSymbolicShape(
92+
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
93+
return InferSymbolicShapeAllEqualBinary(op, shape_analysis);
94+
}
95+
96+
bool ShapeOpInferSymbolicShape(pir::Operation *op,
97+
pir::ShapeConstraintIRAnalysis *shape_analysis) {
98+
pir::Value operand_source = op->operand_source(0);
99+
std::string operand_source_id = pir::GetValueId(&operand_source);
100+
pir::OpResult res = op->result(0);
101+
std::string res_id = pir::GetValueId(&res);
102+
103+
std::vector<int64_t> dims =
104+
common::vectorize(res.type().dyn_cast<pir::DenseTensorType>().dims());
105+
106+
std::vector<symbol::DimExpr> shapes;
107+
for (int64_t dim : dims) {
108+
symbol::DimExpr dim_expr;
109+
if (dim == -1) {
110+
symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName());
111+
dim_expr = res_dim_expr;
112+
} else {
113+
symbol::DimExpr res_dim_expr(dim);
114+
dim_expr = res_dim_expr;
115+
}
116+
shapes.push_back(dim_expr);
117+
}
118+
119+
symbol::ShapeOrDataDimExprs shape_data{shapes};
120+
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
121+
}
122+
123+
bool ShapeSrOpInferSymbolicShape(
124+
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
125+
return ShapeOpInferSymbolicShape(op, shape_analysis);
126+
}
127+
128+
bool StackOpInferSymbolicShape(pir::Operation *op,
129+
pir::ShapeConstraintIRAnalysis *shape_analysis) {
130+
pir::Value operand_source = op->operand_source(0);
131+
std::string operand_source_id = pir::GetValueId(&operand_source);
132+
pir::OpResult res = op->result(0);
133+
std::string res_id = pir::GetValueId(&res);
134+
135+
symbol::ShapeOrDataDimExprs shape_data;
136+
shape_data = shape_analysis->value_id_to_shapeordata_[operand_source_id];
137+
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
74138
return true;
75139
}
76140

77-
bool ConcatOpInferSymbolicShape(
141+
bool ReshapeOpInferSymbolicShape(
78142
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
143+
pir::Value operand_source_1 = op->operand_source(1);
144+
std::string operand_source_1_id = pir::GetValueId(&operand_source_1);
145+
pir::OpResult res = op->result(0);
146+
std::string res_id = pir::GetValueId(&res);
147+
148+
symbol::ShapeOrDataDimExprs shape_data;
149+
150+
shape_data = shape_analysis->value_id_to_shapeordata_[operand_source_1_id];
151+
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
79152
return true;
80153
}
81154

155+
bool Reshape_OpInferSymbolicShape(
156+
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
157+
return ReshapeOpInferSymbolicShape(op, shape_analysis);
158+
}
159+
82160
} // namespace paddle::dialect
161+
namespace cinn::dialect {
162+
163+
bool SliceOpInferSymbolicShape(pir::Operation *op,
164+
pir::ShapeConstraintIRAnalysis *shape_analysis) {
165+
pir::Value operand_source = op->operand_source(0);
166+
std::string operand_source_id = pir::GetValueId(&operand_source);
167+
pir::OpResult res = op->result(0);
168+
std::string res_id = pir::GetValueId(&res);
169+
170+
std::vector<int64_t> dims =
171+
common::vectorize(res.type().dyn_cast<pir::DenseTensorType>().dims());
172+
173+
std::vector<symbol::DimExpr> shapes;
174+
for (int64_t dim : dims) {
175+
symbol::DimExpr dim_expr;
176+
if (dim == -1) {
177+
symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName());
178+
dim_expr = res_dim_expr;
179+
} else {
180+
symbol::DimExpr res_dim_expr(dim);
181+
dim_expr = res_dim_expr;
182+
}
183+
shapes.push_back(dim_expr);
184+
}
185+
186+
pir::AttributeMap attributes = op->attributes();
187+
188+
auto attr_starts =
189+
attributes["starts"].dyn_cast<pir::ArrayAttribute>().AsVector();
190+
auto start = attr_starts[0].dyn_cast<pir::Int64Attribute>().data();
191+
192+
auto attr_ends =
193+
attributes["ends"].dyn_cast<pir::ArrayAttribute>().AsVector();
194+
auto end = attr_ends[0].dyn_cast<pir::Int64Attribute>().data();
195+
196+
symbol::ShapeOrDataDimExprs shape_data{shapes};
197+
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
198+
return true;
199+
}
200+
201+
} // namespace cinn::dialect
202+
83203
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::InferSymbolicShapeInterface)

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h

+32
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ class InferSymbolicShapeInterface
5858
Concept *impl_;
5959
};
6060

61+
} // namespace paddle::dialect
62+
63+
namespace paddle::dialect {
64+
6165
bool AbsOpInferSymbolicShape(pir::Operation *op,
6266
pir::ShapeConstraintIRAnalysis *shape_analysis);
6367

@@ -76,6 +80,34 @@ bool ExpOpInferSymbolicShape(pir::Operation *op,
7680
bool Exp_OpInferSymbolicShape(pir::Operation *op,
7781
pir::ShapeConstraintIRAnalysis *shape_analysis);
7882

83+
bool SubtractOpInferSymbolicShape(
84+
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);
85+
86+
bool Subtract_OpInferSymbolicShape(
87+
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);
88+
89+
bool ShapeOpInferSymbolicShape(pir::Operation *op,
90+
pir::ShapeConstraintIRAnalysis *shape_analysis);
91+
92+
bool ShapeSrOpInferSymbolicShape(
93+
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);
94+
95+
bool StackOpInferSymbolicShape(pir::Operation *op,
96+
pir::ShapeConstraintIRAnalysis *shape_analysis);
97+
98+
bool ReshapeOpInferSymbolicShape(
99+
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);
100+
101+
bool Reshape_OpInferSymbolicShape(
102+
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);
103+
79104
} // namespace paddle::dialect
80105

106+
namespace cinn::dialect {
107+
108+
bool SliceOpInferSymbolicShape(pir::Operation *op,
109+
pir::ShapeConstraintIRAnalysis *shape_analysis);
110+
111+
}
112+
81113
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::InferSymbolicShapeInterface)

paddle/fluid/pir/dialect/operator/ir/op_dialect.cc

+31
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,31 @@
2929
namespace paddle {
3030
namespace dialect {
3131

32+
struct CombineOpInferSymbolicShapeInterfaceModel
33+
: public InferSymbolicShapeInterface::Concept {
34+
static inline bool InferSymbolicShape(
35+
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
36+
symbol::ShapeOrDataDimExprs value_shape;
37+
38+
// for (auto operand_source : op->operands_source()) {
39+
// std::string operand_source_id = pir::GetValueId(&operand_source);
40+
// auto source_shape_vec =
41+
// shape_analysis->value_id_to_shapeordata_[operand_source_id];
42+
// for (int i = 0; i < source_shape_vec.size(); i++) {
43+
// value_shape.second.emplace_back(source_shape_vec[i]);
44+
// }
45+
// }
46+
47+
auto res = op->result(0);
48+
auto res_id = pir::GetValueId(&res);
49+
50+
shape_analysis->value_id_to_shapeordata_[res_id] = value_shape;
51+
}
52+
53+
CombineOpInferSymbolicShapeInterfaceModel()
54+
: InferSymbolicShapeInterface::Concept(InferSymbolicShape) {}
55+
};
56+
3257
OperatorDialect::OperatorDialect(pir::IrContext *ctx)
3358
: pir::Dialect(name(), ctx, pir::TypeId::get<OperatorDialect>()) {
3459
initialize();
@@ -37,6 +62,12 @@ OperatorDialect::OperatorDialect(pir::IrContext *ctx)
3762
info.AttachInterface(std::move(
3863
pir::InterfaceValue::
3964
Get<pir::TuplePushOp, VjpInterface, TuplePushOpVjpInterfaceModel>()));
65+
66+
info = ctx->GetRegisteredOpInfo(pir::CombineOp::name());
67+
info.AttachInterface(std::move(
68+
pir::InterfaceValue::Get<pir::CombineOp,
69+
InferSymbolicShapeInterface,
70+
CombineOpInferSymbolicShapeInterfaceModel>()));
4071
}
4172

4273
void OperatorDialect::initialize() {

paddle/fluid/pir/dialect/operator/ir/ops.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -1044,6 +1044,7 @@
10441044
view: (x -> out)
10451045
intermediate : xshape
10461046
backward: reshape_grad
1047+
interfaces : paddle::dialect::InferSymbolicShapeInterface
10471048

10481049
- op : rnn
10491050
args: (Tensor x, Tensor[] pre_state, Tensor[] weight_list, Tensor sequence_length, Tensor dropout_state_in, float dropout_prob=0.0, bool is_bidirec=false, int input_size=10, int hidden_size=100, int num_layers=1, str mode="RNN_TANH", int seed=0, bool is_test=false)
@@ -1214,6 +1215,7 @@
12141215
func : subtract
12151216
inplace : (x -> out)
12161217
backward : subtract_grad
1218+
interfaces : paddle::dialect::InferSymbolicShapeInterface
12171219

12181220
- op : sum
12191221
args : (Tensor x, IntArray axis={}, DataType dtype=DataType::UNDEFINED, bool keepdim=false)

0 commit comments

Comments
 (0)