20
20
#include " paddle/pir/core/builtin_type_interfaces.h"
21
21
#include " paddle/pir/dialect/shape/ir/shape_attribute.h"
22
22
23
+ // to make codes shorter
24
+ using ShapeOrData = symbol::ShapeOrDataDimExprs;
25
+ using TensorExprs = symbol::TensorShapeOrDataDimExprs;
26
+ using TensorListExprs = symbol::TensorListShapeOrDataDimExprs;
27
+
23
28
template <typename T>
24
29
struct AttributeTrait ;
25
30
@@ -78,9 +83,6 @@ bool SameOperandsAndResultShape(
78
83
symbol::ShapeOrDataDimExprs operand_shape_or_data =
79
84
shape_analysis->GetShapeOrDataForValue (operand_source);
80
85
81
- op->set_attribute (" symbolic_shape" ,
82
- pir::shape::SymbolAttribute::get (pir::IrContext::Instance (),
83
- operand_shape_or_data));
84
86
pir::Value res = op->result (0 );
85
87
shape_analysis->SetShapeOrDataForValue (res, operand_shape_or_data);
86
88
return true ;
@@ -143,9 +145,7 @@ bool InferSymbolicShapeElementWiseBinary(
143
145
symbol::ShapeOrDataDimExprs shape_data{
144
146
symbol::TensorShapeOrDataDimExprs (shapes)};
145
147
shape_analysis->SetShapeOrDataForValue (res, shape_data);
146
- op->set_attribute (
147
- " symbolic_shape" ,
148
- pir::shape::SymbolAttribute::get (pir::IrContext::Instance (), shape_data));
148
+
149
149
return true ;
150
150
}
151
151
@@ -184,9 +184,6 @@ bool DataOpInferSymbolicShape(pir::Operation *op,
184
184
185
185
symbol::ShapeOrDataDimExprs shape_data{
186
186
symbol::TensorShapeOrDataDimExprs (sym_dims)};
187
- op->set_attribute (
188
- " symbolic_shape" ,
189
- pir::shape::SymbolAttribute::get (pir::IrContext::Instance (), shape_data));
190
187
191
188
pir::Value res = op->result (0 );
192
189
shape_analysis->SetShapeOrDataForValue (res, shape_data);
@@ -263,9 +260,7 @@ bool ShapeOpInferSymbolicShape(pir::Operation *op,
263
260
sym_shape, operand_shape_or_data.shape ())};
264
261
265
262
shape_analysis->SetShapeOrDataForValue (res, shape_or_data);
266
- op->set_attribute (" symbolic_shape" ,
267
- pir::shape::SymbolAttribute::get (pir::IrContext::Instance (),
268
- shape_or_data));
263
+
269
264
return true ;
270
265
}
271
266
@@ -305,9 +300,6 @@ bool StackOpInferSymbolicShape(pir::Operation *op,
305
300
symbol::ShapeOrDataDimExprs shape_data (
306
301
symbol::TensorShapeOrDataDimExprs (out_dims, out_dims_data));
307
302
308
- op->set_attribute (
309
- " symbolic_shape" ,
310
- pir::shape::SymbolAttribute::get (pir::IrContext::Instance (), shape_data));
311
303
pir::Value res = op->result (0 );
312
304
shape_analysis->SetShapeOrDataForValue (res, shape_data);
313
305
return true ;
@@ -368,9 +360,7 @@ bool ReduceInferDim(pir::Operation *op,
368
360
pir::Value res = op->result (0 );
369
361
symbol::ShapeOrDataDimExprs shape_data{
370
362
symbol::TensorShapeOrDataDimExprs (shapes)};
371
- op->set_attribute (
372
- " symbolic_shape" ,
373
- pir::shape::SymbolAttribute::get (pir::IrContext::Instance (), shape_data));
363
+
374
364
shape_analysis->SetShapeOrDataForValue (res, shape_data);
375
365
return true ;
376
366
}
@@ -441,9 +431,6 @@ bool ReshapeOpInferSymbolicShape(
441
431
442
432
symbol::ShapeOrDataDimExprs shape_data{
443
433
symbol::TensorShapeOrDataDimExprs (out_dims)};
444
- op->set_attribute (
445
- " symbolic_shape" ,
446
- pir::shape::SymbolAttribute::get (pir::IrContext::Instance (), shape_data));
447
434
448
435
pir::Value res0 = op->result (0 );
449
436
pir::Value res1 = op->result (1 );
@@ -476,10 +463,6 @@ bool FullIntArrayOpInferSymbolicShape(
476
463
symbol::ShapeOrDataDimExprs shape_data{
477
464
symbol::TensorShapeOrDataDimExprs (shape, data)};
478
465
479
- op->set_attribute (
480
- " symbolic_shape" ,
481
- pir::shape::SymbolAttribute::get (pir::IrContext::Instance (), shape_data));
482
-
483
466
pir::Value res = op->result (0 );
484
467
shape_analysis->SetShapeOrDataForValue (res, shape_data);
485
468
return true ;
@@ -537,10 +520,6 @@ bool SliceOpInferSymbolicShape(pir::Operation *op,
537
520
symbol::ShapeOrDataDimExprs shape_data{
538
521
symbol::TensorShapeOrDataDimExprs (sym_shape, out_data)};
539
522
540
- op->set_attribute (
541
- " symbolic_shape" ,
542
- pir::shape::SymbolAttribute::get (pir::IrContext::Instance (), shape_data));
543
-
544
523
shape_analysis->SetShapeOrDataForValue (res, shape_data);
545
524
return true ;
546
525
}
@@ -580,10 +559,6 @@ bool FullOpInferSymbolicShape(pir::Operation *op,
580
559
symbol::TensorShapeOrDataDimExprs (sym_shape)};
581
560
shape_data.SetData (sym_data);
582
561
583
- op->set_attribute (
584
- " symbolic_shape" ,
585
- pir::shape::SymbolAttribute::get (pir::IrContext::Instance (), shape_data));
586
-
587
562
pir::Value res = op->result (0 );
588
563
shape_analysis->SetShapeOrDataForValue (res, shape_data);
589
564
return true ;
@@ -629,10 +604,6 @@ bool ConcatOpInferSymbolicShape(
629
604
symbol::ShapeOrDataDimExprs shape_data{
630
605
symbol::TensorShapeOrDataDimExprs (out_dims)};
631
606
632
- op->set_attribute (
633
- " symbolic_shape" ,
634
- pir::shape::SymbolAttribute::get (pir::IrContext::Instance (), shape_data));
635
-
636
607
pir::Value res = op->result (0 );
637
608
shape_analysis->SetShapeOrDataForValue (res, shape_data);
638
609
@@ -685,10 +656,6 @@ bool GatherNdOpInferSymbolicShape(
685
656
symbol::ShapeOrDataDimExprs shape_data{
686
657
symbol::TensorShapeOrDataDimExprs (result_sym_dims)};
687
658
688
- op->set_attribute (
689
- " symbolic_shape" ,
690
- pir::shape::SymbolAttribute::get (pir::IrContext::Instance (), shape_data));
691
-
692
659
pir::Value res = op->result (0 );
693
660
shape_analysis->SetShapeOrDataForValue (res, shape_data);
694
661
@@ -702,7 +669,7 @@ bool PowOpInferSymbolicShape(pir::Operation *op,
702
669
bool Pow_OpInferSymbolicShape (pir::Operation *op,
703
670
pir::ShapeConstraintIRAnalysis *shape_analysis) {
704
671
PADDLE_THROW (phi::errors::Unimplemented (
705
- op->name () + " DOES NOT have InferSymbolicShapeInterface! " ));
672
+ op->name () + " 's InferSymbolicShape interface is NOT implemented now. " ));
706
673
return PowOpInferSymbolicShape (op, shape_analysis);
707
674
}
708
675
@@ -812,10 +779,6 @@ bool SqueezeOpInferSymbolicShape(
812
779
symbol::ShapeOrDataDimExprs shape_data{
813
780
symbol::TensorShapeOrDataDimExprs (output_shape_sym)};
814
781
815
- op->set_attribute (
816
- " symbolic_shape" ,
817
- pir::shape::SymbolAttribute::get (pir::IrContext::Instance (), shape_data));
818
-
819
782
pir::Value res = op->result (0 );
820
783
shape_analysis->SetShapeOrDataForValue (res, shape_data);
821
784
@@ -891,10 +854,6 @@ bool UnsqueezeOpInferSymbolicShape(
891
854
symbol::ShapeOrDataDimExprs shape_data{
892
855
symbol::TensorShapeOrDataDimExprs (result_sym_dims)};
893
856
894
- op->set_attribute (
895
- " symbolic_shape" ,
896
- pir::shape::SymbolAttribute::get (pir::IrContext::Instance (), shape_data));
897
-
898
857
pir::Value res = op->result (0 );
899
858
shape_analysis->SetShapeOrDataForValue (res, shape_data);
900
859
@@ -949,10 +908,6 @@ bool TileOpInferSymbolicShape(pir::Operation *op,
949
908
symbol::ShapeOrDataDimExprs shape_data{
950
909
symbol::TensorShapeOrDataDimExprs (out_shape)};
951
910
952
- op->set_attribute (
953
- " symbolic_shape" ,
954
- pir::shape::SymbolAttribute::get (pir::IrContext::Instance (), shape_data));
955
-
956
911
pir::Value res = op->result (0 );
957
912
shape_analysis->SetShapeOrDataForValue (res, shape_data);
958
913
@@ -962,7 +917,7 @@ bool TileOpInferSymbolicShape(pir::Operation *op,
962
917
bool TransposeOpInferSymbolicShape (
963
918
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
964
919
PADDLE_THROW (phi::errors::Unimplemented (
965
- op->name () + " DOES NOT have InferSymbolicShapeInterface! " ));
920
+ op->name () + " 's InferSymbolicShape interface is NOT implemented now. " ));
966
921
return true ;
967
922
}
968
923
bool Transpose_OpInferSymbolicShape (
@@ -1095,35 +1050,132 @@ bool EmbeddingOpInferSymbolicShape(
1095
1050
bool SparseWeightEmbeddingOpInferSymbolicShape (
1096
1051
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
1097
1052
PADDLE_THROW (phi::errors::Unimplemented (
1098
- op->name () + " DOES NOT have InferSymbolicShapeInterface! " ));
1053
+ op->name () + " 's InferSymbolicShape interface is NOT implemented now. " ));
1099
1054
return true ;
1100
1055
}
1101
1056
1102
1057
bool ExpandOpInferSymbolicShape (
1103
1058
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
1104
1059
PADDLE_THROW (phi::errors::Unimplemented (
1105
- op->name () + " DOES NOT have InferSymbolicShapeInterface! " ));
1060
+ op->name () + " 's InferSymbolicShape interface is NOT implemented now. " ));
1106
1061
return true ;
1107
1062
}
1108
1063
1109
1064
bool MatmulOpInferSymbolicShape (
1110
1065
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
1111
- PADDLE_THROW (phi::errors::Unimplemented (
1112
- op->name () + " DOES NOT have InferSymbolicShapeInterface!" ));
1066
+ // x_dims can't be const or ref here, in case to be broadcasted
1067
+ std::vector<symbol::DimExpr> x_dims = [&] {
1068
+ std::vector<symbol::DimExpr> dims;
1069
+ const auto &x_shape_or_data =
1070
+ shape_analysis->GetShapeOrDataForValue (op->operand_source (0 ));
1071
+ if (x_shape_or_data.data ().has_value ()) {
1072
+ dims = x_shape_or_data.data ().value ();
1073
+ } else {
1074
+ dims = x_shape_or_data.shape ();
1075
+ }
1076
+ return dims;
1077
+ }();
1078
+
1079
+ // y_dims can't be const or ref here, in case to be broadcasted
1080
+ std::vector<symbol::DimExpr> y_dims = [&] {
1081
+ std::vector<symbol::DimExpr> dims;
1082
+ const auto y_shape_or_data =
1083
+ shape_analysis->GetShapeOrDataForValue (op->operand_source (1 ));
1084
+ if (y_shape_or_data.data ().has_value ()) {
1085
+ dims = y_shape_or_data.data ().value ();
1086
+ } else {
1087
+ dims = y_shape_or_data.shape ();
1088
+ }
1089
+ return dims;
1090
+ }();
1091
+
1092
+ size_t ndims_x = x_dims.size ();
1093
+ size_t ndims_y = y_dims.size ();
1094
+
1095
+ const bool x_broadcasted = [&] {
1096
+ bool broadcasted = false ;
1097
+ if (ndims_x == 1 ) {
1098
+ x_dims.insert (x_dims.begin (), 1 );
1099
+ ndims_x = 2 ;
1100
+ broadcasted = true ;
1101
+ }
1102
+ return broadcasted;
1103
+ }();
1104
+
1105
+ const bool y_broadcasted = [&] {
1106
+ bool broadcasted = false ;
1107
+ if (ndims_y == 1 ) {
1108
+ y_dims.emplace_back (1 );
1109
+ ndims_y = 2 ;
1110
+ broadcasted = true ;
1111
+ }
1112
+ return broadcasted;
1113
+ }();
1114
+
1115
+ std::vector<symbol::DimExpr> out_dims;
1116
+ if (ndims_x > ndims_y) {
1117
+ out_dims.assign (x_dims.begin (), x_dims.end () - 2 );
1118
+ } else if (ndims_x < ndims_y) {
1119
+ out_dims.assign (y_dims.begin (), y_dims.end () - 2 );
1120
+ } else {
1121
+ symbol::DimExprBuilder builder{nullptr };
1122
+ for (size_t i = 0 ; i < ndims_x - 2 ; ++i) {
1123
+ out_dims.emplace_back (builder.Broadcast (x_dims[i], y_dims[i]));
1124
+ }
1125
+ }
1126
+
1127
+ symbol::DimExpr out_M =
1128
+ op->attributes ().at (" transpose_x" ).dyn_cast <pir::BoolAttribute>().data ()
1129
+ ? x_dims[ndims_x - 1 ]
1130
+ : x_dims[ndims_x - 2 ];
1131
+ symbol::DimExpr out_N =
1132
+ op->attributes ().at (" transpose_y" ).dyn_cast <pir::BoolAttribute>().data ()
1133
+ ? y_dims[ndims_y - 2 ]
1134
+ : y_dims[ndims_y - 1 ];
1135
+ if (!x_broadcasted) {
1136
+ out_dims.emplace_back (out_M);
1137
+ }
1138
+ if (!y_broadcasted) {
1139
+ out_dims.emplace_back (out_N);
1140
+ }
1141
+
1142
+ shape_analysis->SetShapeOrDataForValue (op->result (0 ),
1143
+ ShapeOrData{TensorExprs (out_dims)});
1144
+
1113
1145
return true ;
1114
1146
}
1115
1147
1116
1148
bool MaxOpInferSymbolicShape (pir::Operation *op,
1117
1149
pir::ShapeConstraintIRAnalysis *shape_analysis) {
1118
- PADDLE_THROW (phi::errors::Unimplemented (
1119
- op->name () + " DOES NOT have InferSymbolicShapeInterface!" ));
1120
- return true ;
1150
+ bool keepdim =
1151
+ op->attributes ().at (" keepdim" ).dyn_cast <pir::BoolAttribute>().data ();
1152
+
1153
+ const std::vector<int64_t > axis = [&] {
1154
+ pir::Operation *axis_gen_op = op->operand_source (1 ).defining_op ();
1155
+ std::vector<int64_t > axis_vec;
1156
+ if (axis_gen_op->isa <paddle::dialect::FullIntArrayOp>()) {
1157
+ axis_vec = GetVectorAttr (
1158
+ axis_gen_op->dyn_cast <paddle::dialect::FullIntArrayOp>(), " value" );
1159
+ } else {
1160
+ // TODO(lanxianghit): there's other source: pir::VectorType,
1161
+ // paddle::dialect::DenseTensorType, but after PRIM, maybe always
1162
+ // FullIntArrayOp, to be confirmed
1163
+ PADDLE_THROW (
1164
+ phi::errors::Unimplemented (" MaxOpInferSymbolicShape: 'axis' only "
1165
+ " support FullIntArrayOp's result now." ));
1166
+ }
1167
+ return axis_vec;
1168
+ }();
1169
+
1170
+ bool reduce_all = axis.size () == 0 ? true : false ;
1171
+
1172
+ return ReduceInferDim (op, shape_analysis, axis, keepdim, reduce_all);
1121
1173
}
1122
1174
1123
1175
bool TrilOpInferSymbolicShape (pir::Operation *op,
1124
1176
pir::ShapeConstraintIRAnalysis *shape_analysis) {
1125
1177
PADDLE_THROW (phi::errors::Unimplemented (
1126
- op->name () + " DOES NOT have InferSymbolicShapeInterface! " ));
1178
+ op->name () + " 's InferSymbolicShape interface is NOT implemented now. " ));
1127
1179
return true ;
1128
1180
}
1129
1181
@@ -1135,7 +1187,7 @@ bool Tril_OpInferSymbolicShape(pir::Operation *op,
1135
1187
bool WhereOpInferSymbolicShape (pir::Operation *op,
1136
1188
pir::ShapeConstraintIRAnalysis *shape_analysis) {
1137
1189
PADDLE_THROW (phi::errors::Unimplemented (
1138
- op->name () + " DOES NOT have InferSymbolicShapeInterface! " ));
1190
+ op->name () + " 's InferSymbolicShape interface is NOT implemented now. " ));
1139
1191
return true ;
1140
1192
}
1141
1193
@@ -1189,10 +1241,6 @@ bool SliceOpInferSymbolicShape(pir::Operation *op,
1189
1241
};
1190
1242
symbol::ShapeOrDataDimExprs shape_data{GetOutDimExprs ()};
1191
1243
1192
- op->set_attribute (
1193
- " symbolic_shape" ,
1194
- pir::shape::SymbolAttribute::get (pir::IrContext::Instance (), shape_data));
1195
-
1196
1244
shape_analysis->SetShapeOrDataForValue (op->result (0 ), shape_data);
1197
1245
return true ;
1198
1246
}
@@ -1239,10 +1287,6 @@ bool ConcatOpInferSymbolicShape(
1239
1287
symbol::ShapeOrDataDimExprs shape_data{
1240
1288
symbol::TensorShapeOrDataDimExprs (GetOutDimExprs ())};
1241
1289
1242
- op->set_attribute (
1243
- " symbolic_shape" ,
1244
- pir::shape::SymbolAttribute::get (pir::IrContext::Instance (), shape_data));
1245
-
1246
1290
shape_analysis->SetShapeOrDataForValue (op->result (0 ), shape_data);
1247
1291
return true ;
1248
1292
}
@@ -1292,9 +1336,7 @@ bool ReshapeOpInferSymbolicShape(
1292
1336
symbol::ShapeOrDataDimExprs shape_data{
1293
1337
symbol::TensorShapeOrDataDimExprs (out_dims)};
1294
1338
shape_analysis->SetShapeOrDataForValue (op->result (0 ), shape_data);
1295
- op->set_attribute (
1296
- " symbolic_shape" ,
1297
- pir::shape::SymbolAttribute::get (pir::IrContext::Instance (), shape_data));
1339
+
1298
1340
return true ;
1299
1341
}
1300
1342
0 commit comments