Skip to content

Commit f32fec4

Browse files
authored
[PIR][DynamicShape] Add InferSymbolicShape for matmul, max (#61587)
Add InferSymbolicShape for matmul, max, and refine some codes & err msg
1 parent a4c6d3d commit f32fec4

File tree

2 files changed

+263
-75
lines changed

2 files changed

+263
-75
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc

Lines changed: 117 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@
2020
#include "paddle/pir/core/builtin_type_interfaces.h"
2121
#include "paddle/pir/dialect/shape/ir/shape_attribute.h"
2222

23+
// to make codes shorter
24+
using ShapeOrData = symbol::ShapeOrDataDimExprs;
25+
using TensorExprs = symbol::TensorShapeOrDataDimExprs;
26+
using TensorListExprs = symbol::TensorListShapeOrDataDimExprs;
27+
2328
template <typename T>
2429
struct AttributeTrait;
2530

@@ -78,9 +83,6 @@ bool SameOperandsAndResultShape(
7883
symbol::ShapeOrDataDimExprs operand_shape_or_data =
7984
shape_analysis->GetShapeOrDataForValue(operand_source);
8085

81-
op->set_attribute("symbolic_shape",
82-
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(),
83-
operand_shape_or_data));
8486
pir::Value res = op->result(0);
8587
shape_analysis->SetShapeOrDataForValue(res, operand_shape_or_data);
8688
return true;
@@ -143,9 +145,7 @@ bool InferSymbolicShapeElementWiseBinary(
143145
symbol::ShapeOrDataDimExprs shape_data{
144146
symbol::TensorShapeOrDataDimExprs(shapes)};
145147
shape_analysis->SetShapeOrDataForValue(res, shape_data);
146-
op->set_attribute(
147-
"symbolic_shape",
148-
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
148+
149149
return true;
150150
}
151151

@@ -184,9 +184,6 @@ bool DataOpInferSymbolicShape(pir::Operation *op,
184184

185185
symbol::ShapeOrDataDimExprs shape_data{
186186
symbol::TensorShapeOrDataDimExprs(sym_dims)};
187-
op->set_attribute(
188-
"symbolic_shape",
189-
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
190187

191188
pir::Value res = op->result(0);
192189
shape_analysis->SetShapeOrDataForValue(res, shape_data);
@@ -263,9 +260,7 @@ bool ShapeOpInferSymbolicShape(pir::Operation *op,
263260
sym_shape, operand_shape_or_data.shape())};
264261

265262
shape_analysis->SetShapeOrDataForValue(res, shape_or_data);
266-
op->set_attribute("symbolic_shape",
267-
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(),
268-
shape_or_data));
263+
269264
return true;
270265
}
271266

@@ -305,9 +300,6 @@ bool StackOpInferSymbolicShape(pir::Operation *op,
305300
symbol::ShapeOrDataDimExprs shape_data(
306301
symbol::TensorShapeOrDataDimExprs(out_dims, out_dims_data));
307302

308-
op->set_attribute(
309-
"symbolic_shape",
310-
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
311303
pir::Value res = op->result(0);
312304
shape_analysis->SetShapeOrDataForValue(res, shape_data);
313305
return true;
@@ -368,9 +360,7 @@ bool ReduceInferDim(pir::Operation *op,
368360
pir::Value res = op->result(0);
369361
symbol::ShapeOrDataDimExprs shape_data{
370362
symbol::TensorShapeOrDataDimExprs(shapes)};
371-
op->set_attribute(
372-
"symbolic_shape",
373-
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
363+
374364
shape_analysis->SetShapeOrDataForValue(res, shape_data);
375365
return true;
376366
}
@@ -441,9 +431,6 @@ bool ReshapeOpInferSymbolicShape(
441431

442432
symbol::ShapeOrDataDimExprs shape_data{
443433
symbol::TensorShapeOrDataDimExprs(out_dims)};
444-
op->set_attribute(
445-
"symbolic_shape",
446-
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
447434

448435
pir::Value res0 = op->result(0);
449436
pir::Value res1 = op->result(1);
@@ -476,10 +463,6 @@ bool FullIntArrayOpInferSymbolicShape(
476463
symbol::ShapeOrDataDimExprs shape_data{
477464
symbol::TensorShapeOrDataDimExprs(shape, data)};
478465

479-
op->set_attribute(
480-
"symbolic_shape",
481-
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
482-
483466
pir::Value res = op->result(0);
484467
shape_analysis->SetShapeOrDataForValue(res, shape_data);
485468
return true;
@@ -537,10 +520,6 @@ bool SliceOpInferSymbolicShape(pir::Operation *op,
537520
symbol::ShapeOrDataDimExprs shape_data{
538521
symbol::TensorShapeOrDataDimExprs(sym_shape, out_data)};
539522

540-
op->set_attribute(
541-
"symbolic_shape",
542-
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
543-
544523
shape_analysis->SetShapeOrDataForValue(res, shape_data);
545524
return true;
546525
}
@@ -580,10 +559,6 @@ bool FullOpInferSymbolicShape(pir::Operation *op,
580559
symbol::TensorShapeOrDataDimExprs(sym_shape)};
581560
shape_data.SetData(sym_data);
582561

583-
op->set_attribute(
584-
"symbolic_shape",
585-
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
586-
587562
pir::Value res = op->result(0);
588563
shape_analysis->SetShapeOrDataForValue(res, shape_data);
589564
return true;
@@ -629,10 +604,6 @@ bool ConcatOpInferSymbolicShape(
629604
symbol::ShapeOrDataDimExprs shape_data{
630605
symbol::TensorShapeOrDataDimExprs(out_dims)};
631606

632-
op->set_attribute(
633-
"symbolic_shape",
634-
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
635-
636607
pir::Value res = op->result(0);
637608
shape_analysis->SetShapeOrDataForValue(res, shape_data);
638609

@@ -685,10 +656,6 @@ bool GatherNdOpInferSymbolicShape(
685656
symbol::ShapeOrDataDimExprs shape_data{
686657
symbol::TensorShapeOrDataDimExprs(result_sym_dims)};
687658

688-
op->set_attribute(
689-
"symbolic_shape",
690-
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
691-
692659
pir::Value res = op->result(0);
693660
shape_analysis->SetShapeOrDataForValue(res, shape_data);
694661

@@ -702,7 +669,7 @@ bool PowOpInferSymbolicShape(pir::Operation *op,
702669
bool Pow_OpInferSymbolicShape(pir::Operation *op,
703670
pir::ShapeConstraintIRAnalysis *shape_analysis) {
704671
PADDLE_THROW(phi::errors::Unimplemented(
705-
op->name() + " DOES NOT have InferSymbolicShapeInterface!"));
672+
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
706673
return PowOpInferSymbolicShape(op, shape_analysis);
707674
}
708675

@@ -812,10 +779,6 @@ bool SqueezeOpInferSymbolicShape(
812779
symbol::ShapeOrDataDimExprs shape_data{
813780
symbol::TensorShapeOrDataDimExprs(output_shape_sym)};
814781

815-
op->set_attribute(
816-
"symbolic_shape",
817-
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
818-
819782
pir::Value res = op->result(0);
820783
shape_analysis->SetShapeOrDataForValue(res, shape_data);
821784

@@ -891,10 +854,6 @@ bool UnsqueezeOpInferSymbolicShape(
891854
symbol::ShapeOrDataDimExprs shape_data{
892855
symbol::TensorShapeOrDataDimExprs(result_sym_dims)};
893856

894-
op->set_attribute(
895-
"symbolic_shape",
896-
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
897-
898857
pir::Value res = op->result(0);
899858
shape_analysis->SetShapeOrDataForValue(res, shape_data);
900859

@@ -949,10 +908,6 @@ bool TileOpInferSymbolicShape(pir::Operation *op,
949908
symbol::ShapeOrDataDimExprs shape_data{
950909
symbol::TensorShapeOrDataDimExprs(out_shape)};
951910

952-
op->set_attribute(
953-
"symbolic_shape",
954-
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
955-
956911
pir::Value res = op->result(0);
957912
shape_analysis->SetShapeOrDataForValue(res, shape_data);
958913

@@ -962,7 +917,7 @@ bool TileOpInferSymbolicShape(pir::Operation *op,
962917
bool TransposeOpInferSymbolicShape(
963918
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
964919
PADDLE_THROW(phi::errors::Unimplemented(
965-
op->name() + " DOES NOT have InferSymbolicShapeInterface!"));
920+
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
966921
return true;
967922
}
968923
bool Transpose_OpInferSymbolicShape(
@@ -1095,35 +1050,132 @@ bool EmbeddingOpInferSymbolicShape(
10951050
bool SparseWeightEmbeddingOpInferSymbolicShape(
10961051
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
10971052
PADDLE_THROW(phi::errors::Unimplemented(
1098-
op->name() + " DOES NOT have InferSymbolicShapeInterface!"));
1053+
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
10991054
return true;
11001055
}
11011056

11021057
bool ExpandOpInferSymbolicShape(
11031058
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
11041059
PADDLE_THROW(phi::errors::Unimplemented(
1105-
op->name() + " DOES NOT have InferSymbolicShapeInterface!"));
1060+
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
11061061
return true;
11071062
}
11081063

11091064
bool MatmulOpInferSymbolicShape(
11101065
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
1111-
PADDLE_THROW(phi::errors::Unimplemented(
1112-
op->name() + " DOES NOT have InferSymbolicShapeInterface!"));
1066+
// x_dims can't be const or ref here, in case to be broadcasted
1067+
std::vector<symbol::DimExpr> x_dims = [&] {
1068+
std::vector<symbol::DimExpr> dims;
1069+
const auto &x_shape_or_data =
1070+
shape_analysis->GetShapeOrDataForValue(op->operand_source(0));
1071+
if (x_shape_or_data.data().has_value()) {
1072+
dims = x_shape_or_data.data().value();
1073+
} else {
1074+
dims = x_shape_or_data.shape();
1075+
}
1076+
return dims;
1077+
}();
1078+
1079+
// y_dims can't be const or ref here, in case to be broadcasted
1080+
std::vector<symbol::DimExpr> y_dims = [&] {
1081+
std::vector<symbol::DimExpr> dims;
1082+
const auto y_shape_or_data =
1083+
shape_analysis->GetShapeOrDataForValue(op->operand_source(1));
1084+
if (y_shape_or_data.data().has_value()) {
1085+
dims = y_shape_or_data.data().value();
1086+
} else {
1087+
dims = y_shape_or_data.shape();
1088+
}
1089+
return dims;
1090+
}();
1091+
1092+
size_t ndims_x = x_dims.size();
1093+
size_t ndims_y = y_dims.size();
1094+
1095+
const bool x_broadcasted = [&] {
1096+
bool broadcasted = false;
1097+
if (ndims_x == 1) {
1098+
x_dims.insert(x_dims.begin(), 1);
1099+
ndims_x = 2;
1100+
broadcasted = true;
1101+
}
1102+
return broadcasted;
1103+
}();
1104+
1105+
const bool y_broadcasted = [&] {
1106+
bool broadcasted = false;
1107+
if (ndims_y == 1) {
1108+
y_dims.emplace_back(1);
1109+
ndims_y = 2;
1110+
broadcasted = true;
1111+
}
1112+
return broadcasted;
1113+
}();
1114+
1115+
std::vector<symbol::DimExpr> out_dims;
1116+
if (ndims_x > ndims_y) {
1117+
out_dims.assign(x_dims.begin(), x_dims.end() - 2);
1118+
} else if (ndims_x < ndims_y) {
1119+
out_dims.assign(y_dims.begin(), y_dims.end() - 2);
1120+
} else {
1121+
symbol::DimExprBuilder builder{nullptr};
1122+
for (size_t i = 0; i < ndims_x - 2; ++i) {
1123+
out_dims.emplace_back(builder.Broadcast(x_dims[i], y_dims[i]));
1124+
}
1125+
}
1126+
1127+
symbol::DimExpr out_M =
1128+
op->attributes().at("transpose_x").dyn_cast<pir::BoolAttribute>().data()
1129+
? x_dims[ndims_x - 1]
1130+
: x_dims[ndims_x - 2];
1131+
symbol::DimExpr out_N =
1132+
op->attributes().at("transpose_y").dyn_cast<pir::BoolAttribute>().data()
1133+
? y_dims[ndims_y - 2]
1134+
: y_dims[ndims_y - 1];
1135+
if (!x_broadcasted) {
1136+
out_dims.emplace_back(out_M);
1137+
}
1138+
if (!y_broadcasted) {
1139+
out_dims.emplace_back(out_N);
1140+
}
1141+
1142+
shape_analysis->SetShapeOrDataForValue(op->result(0),
1143+
ShapeOrData{TensorExprs(out_dims)});
1144+
11131145
return true;
11141146
}
11151147

11161148
bool MaxOpInferSymbolicShape(pir::Operation *op,
11171149
pir::ShapeConstraintIRAnalysis *shape_analysis) {
1118-
PADDLE_THROW(phi::errors::Unimplemented(
1119-
op->name() + " DOES NOT have InferSymbolicShapeInterface!"));
1120-
return true;
1150+
bool keepdim =
1151+
op->attributes().at("keepdim").dyn_cast<pir::BoolAttribute>().data();
1152+
1153+
const std::vector<int64_t> axis = [&] {
1154+
pir::Operation *axis_gen_op = op->operand_source(1).defining_op();
1155+
std::vector<int64_t> axis_vec;
1156+
if (axis_gen_op->isa<paddle::dialect::FullIntArrayOp>()) {
1157+
axis_vec = GetVectorAttr(
1158+
axis_gen_op->dyn_cast<paddle::dialect::FullIntArrayOp>(), "value");
1159+
} else {
1160+
// TODO(lanxianghit): there's other source: pir::VectorType,
1161+
// paddle::dialect::DenseTensorType, but after PRIM, maybe always
1162+
// FullIntArrayOp, to be confirmed
1163+
PADDLE_THROW(
1164+
phi::errors::Unimplemented("MaxOpInferSymbolicShape: 'axis' only "
1165+
"support FullIntArrayOp's result now."));
1166+
}
1167+
return axis_vec;
1168+
}();
1169+
1170+
bool reduce_all = axis.size() == 0 ? true : false;
1171+
1172+
return ReduceInferDim(op, shape_analysis, axis, keepdim, reduce_all);
11211173
}
11221174

11231175
bool TrilOpInferSymbolicShape(pir::Operation *op,
11241176
pir::ShapeConstraintIRAnalysis *shape_analysis) {
11251177
PADDLE_THROW(phi::errors::Unimplemented(
1126-
op->name() + " DOES NOT have InferSymbolicShapeInterface!"));
1178+
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
11271179
return true;
11281180
}
11291181

@@ -1135,7 +1187,7 @@ bool Tril_OpInferSymbolicShape(pir::Operation *op,
11351187
bool WhereOpInferSymbolicShape(pir::Operation *op,
11361188
pir::ShapeConstraintIRAnalysis *shape_analysis) {
11371189
PADDLE_THROW(phi::errors::Unimplemented(
1138-
op->name() + " DOES NOT have InferSymbolicShapeInterface!"));
1190+
op->name() + " 's InferSymbolicShape interface is NOT implemented now."));
11391191
return true;
11401192
}
11411193

@@ -1189,10 +1241,6 @@ bool SliceOpInferSymbolicShape(pir::Operation *op,
11891241
};
11901242
symbol::ShapeOrDataDimExprs shape_data{GetOutDimExprs()};
11911243

1192-
op->set_attribute(
1193-
"symbolic_shape",
1194-
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
1195-
11961244
shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data);
11971245
return true;
11981246
}
@@ -1239,10 +1287,6 @@ bool ConcatOpInferSymbolicShape(
12391287
symbol::ShapeOrDataDimExprs shape_data{
12401288
symbol::TensorShapeOrDataDimExprs(GetOutDimExprs())};
12411289

1242-
op->set_attribute(
1243-
"symbolic_shape",
1244-
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
1245-
12461290
shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data);
12471291
return true;
12481292
}
@@ -1292,9 +1336,7 @@ bool ReshapeOpInferSymbolicShape(
12921336
symbol::ShapeOrDataDimExprs shape_data{
12931337
symbol::TensorShapeOrDataDimExprs(out_dims)};
12941338
shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data);
1295-
op->set_attribute(
1296-
"symbolic_shape",
1297-
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
1339+
12981340
return true;
12991341
}
13001342

0 commit comments

Comments
 (0)