@@ -73,6 +73,9 @@ const std::unordered_set<std::string> UnchangeOutputOps = {
73
73
};
74
74
const std::unordered_set<std::string> SpecialLowerOps = {
75
75
pir::CombineOp::name (),
76
+ pir::SetParameterOp::name (),
77
+ pir::GetParameterOp::name (),
78
+ pir::ShadowOutputOp::name (),
76
79
pir::SliceOp::name (),
77
80
pir::SplitOp::name (),
78
81
pir::YieldOp::name (),
@@ -328,7 +331,6 @@ static pir::Type BuildDtypeTransferOutputType(pir::Type type,
328
331
329
332
static pir::Type BuildOutputType (pir::Type type,
330
333
const phi::Place& place,
331
- phi::DataType data_type,
332
334
pir::IrContext* ctx) {
333
335
if (type.isa <DenseTensorType>()) {
334
336
auto out_dtype = type.dyn_cast <DenseTensorType>().dtype ();
@@ -563,6 +565,33 @@ static phi::Backend GetKernelBackendByYaml(
563
565
return kernel_backend;
564
566
}
565
567
568
+ std::unique_ptr<OpYamlInfoParser> GetOpYamlInfoParser (pir::Operation* op) {
569
+ OpYamlInfoInterface op_info_interface = op->dyn_cast <OpYamlInfoInterface>();
570
+
571
+ std::unique_ptr<OpYamlInfoParser> op_info_parser (nullptr );
572
+ if (op_info_interface) {
573
+ op_info_parser = std::make_unique<OpYamlInfoParser>(
574
+ op_info_interface.GetOpInfo (), IsLegacyOp (op->name ()));
575
+ }
576
+
577
+ return op_info_parser;
578
+ }
579
+
580
+ std::string GetKernelName (const OpYamlInfoParser* op_info_parser,
581
+ pir::Operation* op_item) {
582
+ std::string kernel_fn_str;
583
+ if (op_info_parser != nullptr ) {
584
+ kernel_fn_str = op_info_parser->OpRuntimeInfo ().kernel_func ;
585
+ }
586
+
587
+ if (op_item->isa <AddN_Op>() || op_item->isa <AddNWithKernelOp>()) {
588
+ if (op_item->result (0 ).type ().isa <SelectedRowsType>()) {
589
+ kernel_fn_str = " add_n_sr" ;
590
+ }
591
+ }
592
+ return kernel_fn_str;
593
+ }
594
+
566
595
phi::KernelKey GetKernelKey (
567
596
pir::Operation* op,
568
597
const phi::Place& place,
@@ -948,8 +977,10 @@ void HandleForSpecialOp(
948
977
HandleForWhileOp (place, op_item, block, ctx, map_op_pair, map_value_pair);
949
978
return ;
950
979
}
980
+
951
981
std::vector<pir::Value> vec_inputs;
952
982
std::vector<pir::Type> op_output_types;
983
+
953
984
if (op_item->isa <::pir::CombineOp>()) {
954
985
// Copy op inputs
955
986
std::vector<pir::Type> vec_inner_types;
@@ -972,6 +1003,11 @@ void HandleForSpecialOp(
972
1003
op_output_types.push_back (t1);
973
1004
}
974
1005
1006
+ if (op_item->isa <::pir::GetParameterOp>()) {
1007
+ op_output_types.push_back (
1008
+ BuildOutputType (op_item->result (0 ).type (), place, ctx));
1009
+ }
1010
+
975
1011
if (op_item->isa <::pir::SliceOp>()) {
976
1012
if (op_item->num_operands () > 0 ) {
977
1013
for (size_t i = 0 ; i < op_item->num_operands (); ++i) {
@@ -1023,7 +1059,22 @@ void HandleForSpecialOp(
1023
1059
}
1024
1060
}
1025
1061
1026
- if (op_item->isa <::pir::YieldOp>()) {
1062
+ if (op_item->isa <::pir::YieldOp>() || op_item->isa <::pir::ShadowOutputOp>()) {
1063
+ if (op_item->num_operands () > 0 ) {
1064
+ for (size_t i = 0 ; i < op_item->num_operands (); ++i) {
1065
+ auto cur_in = op_item->operand_source (i);
1066
+ if (!cur_in) {
1067
+ vec_inputs.emplace_back ();
1068
+ continue ;
1069
+ }
1070
+ auto new_in = GetNewInput (
1071
+ cur_in, *map_value_pair, static_cast <int >(i), op_item->name ());
1072
+ vec_inputs.push_back (new_in);
1073
+ }
1074
+ }
1075
+ }
1076
+
1077
+ if (op_item->isa <::pir::SetParameterOp>()) {
1027
1078
if (op_item->num_operands () > 0 ) {
1028
1079
for (size_t i = 0 ; i < op_item->num_operands (); ++i) {
1029
1080
auto cur_in = op_item->operand_source (i);
@@ -1033,6 +1084,41 @@ void HandleForSpecialOp(
1033
1084
}
1034
1085
auto new_in = GetNewInput (
1035
1086
cur_in, *map_value_pair, static_cast <int >(i), op_item->name ());
1087
+ // NOTE(zhangbo): parameter place is equal to exe place.
1088
+ if (new_in.type ().isa <AllocatedDenseTensorType>()) {
1089
+ auto in_place =
1090
+ new_in.type ().dyn_cast <AllocatedDenseTensorType>().place ();
1091
+ auto dst_backend = phi::TransToPhiBackend (place);
1092
+ bool need_trans =
1093
+ (in_place.GetType () != phi::AllocationType::UNDEFINED) &&
1094
+ (paddle::experimental::NeedTransformPlace (
1095
+ in_place, dst_backend, {}));
1096
+ if (need_trans) {
1097
+ VLOG (6 ) << " need trans from " << in_place << " to " << dst_backend;
1098
+ // build memcopy op
1099
+ auto out_place = phi::TransToPhiPlace (dst_backend);
1100
+ auto new_in_alloc_type =
1101
+ new_in.type ().dyn_cast <AllocatedDenseTensorType>();
1102
+ auto out_type =
1103
+ AllocatedDenseTensorType::get (ctx,
1104
+ out_place,
1105
+ new_in_alloc_type.dtype (),
1106
+ new_in_alloc_type.dims (),
1107
+ new_in_alloc_type.data_layout (),
1108
+ new_in_alloc_type.lod (),
1109
+ new_in_alloc_type.offset ());
1110
+ auto op_info_parser = GetOpYamlInfoParser (op_item);
1111
+ auto kernel_name = GetKernelName (op_info_parser.get (), op_item);
1112
+ auto kernel_key = GetKernelKey (op_item,
1113
+ place,
1114
+ kernel_name,
1115
+ *map_value_pair,
1116
+ op_info_parser.get ());
1117
+ VLOG (6 ) << " kernel type " << kernel_key;
1118
+ new_in = AddPlaceTransferOp (
1119
+ new_in, out_type, in_place, out_place, kernel_key, block);
1120
+ }
1121
+ }
1036
1122
vec_inputs.push_back (new_in);
1037
1123
}
1038
1124
}
@@ -1077,6 +1163,7 @@ void HandleForSpecialOp(
1077
1163
op_output_types.push_back (new_inlet_element.type ());
1078
1164
}
1079
1165
}
1166
+
1080
1167
if (op_item->name () == " cinn_runtime.jit_kernel" ) {
1081
1168
if (op_item->num_operands () > 0 ) {
1082
1169
for (size_t i = 0 ; i < op_item->num_operands (); ++i) {
@@ -1136,12 +1223,9 @@ std::vector<pir::Type> BuildOutputs(pir::Operation* op_item,
1136
1223
1137
1224
for (size_t i = 0 ; i < op_item->num_results (); ++i) {
1138
1225
phi::Place out_place = phi::TransToPhiPlace (kernel_key.backend ());
1139
-
1140
- phi::DataType out_phi_dtype = phi::DataType::UNDEFINED;
1141
1226
if ((!UnchangeOutputOps.count (op_item->name ())) &&
1142
1227
(!IsLegacyOp (op_item->name ())) && phi_kernel.IsValid ()) {
1143
1228
out_place = phi::TransToPhiPlace (output_defs[i].backend );
1144
- out_phi_dtype = output_defs[i].dtype ;
1145
1229
}
1146
1230
1147
1231
auto result_type = op_item->result (i).type ();
@@ -1150,8 +1234,7 @@ std::vector<pir::Type> BuildOutputs(pir::Operation* op_item,
1150
1234
} else if (result_type.isa <DenseTensorType>() ||
1151
1235
result_type.isa <SelectedRowsType>() ||
1152
1236
result_type.isa <DenseTensorArrayType>()) {
1153
- op_output_types.push_back (
1154
- BuildOutputType (result_type, out_place, out_phi_dtype, ctx));
1237
+ op_output_types.push_back (BuildOutputType (result_type, out_place, ctx));
1155
1238
} else if (result_type.isa <pir::VectorType>()) {
1156
1239
std::vector<pir::Type> vec_inner_types;
1157
1240
auto base_types = result_type.dyn_cast <pir::VectorType>().data ();
@@ -1160,7 +1243,7 @@ std::vector<pir::Type> BuildOutputs(pir::Operation* op_item,
1160
1243
if (base_type.isa <DenseTensorType>() ||
1161
1244
base_type.isa <SelectedRowsType>()) {
1162
1245
vec_inner_types.push_back (
1163
- BuildOutputType (base_type, out_place, out_phi_dtype, ctx));
1246
+ BuildOutputType (base_type, out_place, ctx));
1164
1247
} else {
1165
1248
PADDLE_THROW (phi::errors::Unimplemented (
1166
1249
" only support dense tensor and selected rows in vector type "
@@ -1287,7 +1370,6 @@ std::vector<pir::Value> BuildInputs(
1287
1370
// [ todo need update here, support combine data transfomer]
1288
1371
// deal with pre combine op
1289
1372
auto pre_define_op = cur_in.dyn_cast <pir::OpResult>().owner ();
1290
-
1291
1373
if (pre_define_op->isa <::pir::CombineOp>()) {
1292
1374
std::vector<pir::Value> inner_inputs;
1293
1375
std::vector<pir::Type> types_in_vec;
@@ -1320,8 +1402,6 @@ std::vector<pir::Value> BuildInputs(
1320
1402
(paddle::experimental::NeedTransformPlace (
1321
1403
place, input_backend, {}));
1322
1404
if (need_trans) {
1323
- VLOG (6 ) << " need trans from " << place << " to "
1324
- << kernel_key.backend ();
1325
1405
// build memcopy op
1326
1406
auto out_place = phi::TransToPhiPlace (input_backend);
1327
1407
pir::Type out_type;
@@ -1528,33 +1608,6 @@ void AddShadowFeedOpForDataOrFeed(
1528
1608
}
1529
1609
}
1530
1610
1531
- std::unique_ptr<OpYamlInfoParser> GetOpYamlInfoParser (pir::Operation* op) {
1532
- OpYamlInfoInterface op_info_interface = op->dyn_cast <OpYamlInfoInterface>();
1533
-
1534
- std::unique_ptr<OpYamlInfoParser> op_info_parser (nullptr );
1535
- if (op_info_interface) {
1536
- op_info_parser = std::make_unique<OpYamlInfoParser>(
1537
- op_info_interface.GetOpInfo (), IsLegacyOp (op->name ()));
1538
- }
1539
-
1540
- return op_info_parser;
1541
- }
1542
-
1543
- std::string GetKernelName (const OpYamlInfoParser* op_info_parser,
1544
- pir::Operation* op_item) {
1545
- std::string kernel_fn_str;
1546
- if (op_info_parser != nullptr ) {
1547
- kernel_fn_str = op_info_parser->OpRuntimeInfo ().kernel_func ;
1548
- }
1549
-
1550
- if (op_item->isa <AddN_Op>() || op_item->isa <AddNWithKernelOp>()) {
1551
- if (op_item->result (0 ).type ().isa <SelectedRowsType>()) {
1552
- kernel_fn_str = " add_n_sr" ;
1553
- }
1554
- }
1555
- return kernel_fn_str;
1556
- }
1557
-
1558
1611
pir::Operation* BuildKernelOp (
1559
1612
const std::string& kernel_fn_str,
1560
1613
const phi::KernelKey& kernel_key,
0 commit comments