Skip to content

Commit 893235e

Browse files
authored
[PIR] Refine get_paremter、set_parameter、shadow_output transfer (#59171)
* add interface * add interface * add code * fix * fix * fix * fix * fix * fix * fix
1 parent 617bdb9 commit 893235e

File tree

3 files changed

+109
-74
lines changed

3 files changed

+109
-74
lines changed

paddle/fluid/framework/new_executor/pir_interpreter.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1535,7 +1535,8 @@ void PirInterpreter::RunInstructionBase(InstructionBase* instr_node) {
15351535
const std::vector<std::string> op_callstack_attr =
15361536
interpreter::GetInstructionCallStack(op->name(), op->attributes());
15371537
framework::InsertCallStackInfo(op->name(), op_callstack_attr, &ex);
1538-
LOG(WARNING) << instr_node->Name() << " raises an EnforceNotMet exception "
1538+
LOG(WARNING) << " OP id:" << instr_node->Id() << " " << instr_node->Name()
1539+
<< " raises an EnforceNotMet exception "
15391540
<< platform::demangle(typeid(ex).name()) << ", " << ex.what();
15401541
exception_holder_.Catch(std::make_exception_ptr(std::move(ex)));
15411542
} catch (platform::EOFException&) {

paddle/fluid/pir/transforms/params_sync_among_devices_pass.cc

Lines changed: 16 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -39,48 +39,29 @@ class ParamsSyncAmongDevicesPass : public pir::Pass {
3939
scope_(scope) {}
4040

4141
void Run(pir::Operation* op) override {
42+
VLOG(6) << "apply ParamsSyncAmongDevicesPass";
4243
auto module_op = op->dyn_cast<pir::ModuleOp>();
4344
PADDLE_ENFORCE_NOT_NULL(
4445
module_op,
4546
phi::errors::PreconditionNotMet(
4647
"params_sync_among_devices_pass should run on module op."));
4748
auto* block = module_op.block();
4849
for (auto& inner_op : *block) {
49-
if (inner_op.attributes().count("op_name") == 0) {
50-
continue;
51-
}
52-
auto op_name = inner_op.attributes()
53-
.at("op_name")
54-
.dyn_cast<pir::StrAttribute>()
55-
.AsString();
56-
if (op_name == pir::GetParameterOp::name()) {
57-
auto use_op = pir::GetUseOpsForOutput(&inner_op, 0).front();
58-
phi::KernelKey kernel_key;
59-
if (use_op->attributes().count("kernel_key")) {
60-
kernel_key = use_op->attributes()
61-
.at("kernel_key")
62-
.dyn_cast<paddle::dialect::KernelAttribute>()
63-
.data();
64-
}
65-
// TODO(liuyuanle): When the kernel_key doesn't exist?
66-
if (use_op->attributes().count("kernel_key") &&
67-
kernel_key.backend() != phi::Backend::CPU) {
68-
std::string param_name = inner_op.attributes()
69-
.at("parameter_name")
70-
.dyn_cast<pir::StrAttribute>()
71-
.AsString();
72-
auto* param_var = scope_->FindVar(param_name);
73-
if (param_var->IsType<phi::DenseTensor>()) {
74-
auto* param_tensor = param_var->GetMutable<phi::DenseTensor>();
75-
paddle::platform::CPUPlace cpu_place;
76-
phi::DenseTensor temp_tensor;
77-
temp_tensor.Resize(param_tensor->dims());
78-
paddle::framework::TensorCopySync(
79-
*param_tensor, cpu_place, &temp_tensor);
80-
param_tensor->clear();
81-
paddle::framework::TensorCopySync(
82-
temp_tensor, place_, param_tensor);
83-
}
50+
if (inner_op.isa<pir::GetParameterOp>()) {
51+
std::string param_name = inner_op.attributes()
52+
.at("parameter_name")
53+
.dyn_cast<pir::StrAttribute>()
54+
.AsString();
55+
auto* param_var = scope_->FindVar(param_name);
56+
if (param_var->IsType<phi::DenseTensor>()) {
57+
auto* param_tensor = param_var->GetMutable<phi::DenseTensor>();
58+
paddle::platform::CPUPlace cpu_place;
59+
phi::DenseTensor temp_tensor;
60+
temp_tensor.Resize(param_tensor->dims());
61+
paddle::framework::TensorCopySync(
62+
*param_tensor, cpu_place, &temp_tensor);
63+
param_tensor->clear();
64+
paddle::framework::TensorCopySync(temp_tensor, place_, param_tensor);
8465
}
8566
}
8667
}

paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc

Lines changed: 91 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ const std::unordered_set<std::string> UnchangeOutputOps = {
7373
};
7474
const std::unordered_set<std::string> SpecialLowerOps = {
7575
pir::CombineOp::name(),
76+
pir::SetParameterOp::name(),
77+
pir::GetParameterOp::name(),
78+
pir::ShadowOutputOp::name(),
7679
pir::SliceOp::name(),
7780
pir::SplitOp::name(),
7881
pir::YieldOp::name(),
@@ -328,7 +331,6 @@ static pir::Type BuildDtypeTransferOutputType(pir::Type type,
328331

329332
static pir::Type BuildOutputType(pir::Type type,
330333
const phi::Place& place,
331-
phi::DataType data_type,
332334
pir::IrContext* ctx) {
333335
if (type.isa<DenseTensorType>()) {
334336
auto out_dtype = type.dyn_cast<DenseTensorType>().dtype();
@@ -563,6 +565,33 @@ static phi::Backend GetKernelBackendByYaml(
563565
return kernel_backend;
564566
}
565567

568+
std::unique_ptr<OpYamlInfoParser> GetOpYamlInfoParser(pir::Operation* op) {
569+
OpYamlInfoInterface op_info_interface = op->dyn_cast<OpYamlInfoInterface>();
570+
571+
std::unique_ptr<OpYamlInfoParser> op_info_parser(nullptr);
572+
if (op_info_interface) {
573+
op_info_parser = std::make_unique<OpYamlInfoParser>(
574+
op_info_interface.GetOpInfo(), IsLegacyOp(op->name()));
575+
}
576+
577+
return op_info_parser;
578+
}
579+
580+
std::string GetKernelName(const OpYamlInfoParser* op_info_parser,
581+
pir::Operation* op_item) {
582+
std::string kernel_fn_str;
583+
if (op_info_parser != nullptr) {
584+
kernel_fn_str = op_info_parser->OpRuntimeInfo().kernel_func;
585+
}
586+
587+
if (op_item->isa<AddN_Op>() || op_item->isa<AddNWithKernelOp>()) {
588+
if (op_item->result(0).type().isa<SelectedRowsType>()) {
589+
kernel_fn_str = "add_n_sr";
590+
}
591+
}
592+
return kernel_fn_str;
593+
}
594+
566595
phi::KernelKey GetKernelKey(
567596
pir::Operation* op,
568597
const phi::Place& place,
@@ -948,8 +977,10 @@ void HandleForSpecialOp(
948977
HandleForWhileOp(place, op_item, block, ctx, map_op_pair, map_value_pair);
949978
return;
950979
}
980+
951981
std::vector<pir::Value> vec_inputs;
952982
std::vector<pir::Type> op_output_types;
983+
953984
if (op_item->isa<::pir::CombineOp>()) {
954985
// Copy op inputs
955986
std::vector<pir::Type> vec_inner_types;
@@ -972,6 +1003,11 @@ void HandleForSpecialOp(
9721003
op_output_types.push_back(t1);
9731004
}
9741005

1006+
if (op_item->isa<::pir::GetParameterOp>()) {
1007+
op_output_types.push_back(
1008+
BuildOutputType(op_item->result(0).type(), place, ctx));
1009+
}
1010+
9751011
if (op_item->isa<::pir::SliceOp>()) {
9761012
if (op_item->num_operands() > 0) {
9771013
for (size_t i = 0; i < op_item->num_operands(); ++i) {
@@ -1023,7 +1059,22 @@ void HandleForSpecialOp(
10231059
}
10241060
}
10251061

1026-
if (op_item->isa<::pir::YieldOp>()) {
1062+
if (op_item->isa<::pir::YieldOp>() || op_item->isa<::pir::ShadowOutputOp>()) {
1063+
if (op_item->num_operands() > 0) {
1064+
for (size_t i = 0; i < op_item->num_operands(); ++i) {
1065+
auto cur_in = op_item->operand_source(i);
1066+
if (!cur_in) {
1067+
vec_inputs.emplace_back();
1068+
continue;
1069+
}
1070+
auto new_in = GetNewInput(
1071+
cur_in, *map_value_pair, static_cast<int>(i), op_item->name());
1072+
vec_inputs.push_back(new_in);
1073+
}
1074+
}
1075+
}
1076+
1077+
if (op_item->isa<::pir::SetParameterOp>()) {
10271078
if (op_item->num_operands() > 0) {
10281079
for (size_t i = 0; i < op_item->num_operands(); ++i) {
10291080
auto cur_in = op_item->operand_source(i);
@@ -1033,6 +1084,41 @@ void HandleForSpecialOp(
10331084
}
10341085
auto new_in = GetNewInput(
10351086
cur_in, *map_value_pair, static_cast<int>(i), op_item->name());
1087+
// NOTE(zhangbo): parameter place is equal to exe place.
1088+
if (new_in.type().isa<AllocatedDenseTensorType>()) {
1089+
auto in_place =
1090+
new_in.type().dyn_cast<AllocatedDenseTensorType>().place();
1091+
auto dst_backend = phi::TransToPhiBackend(place);
1092+
bool need_trans =
1093+
(in_place.GetType() != phi::AllocationType::UNDEFINED) &&
1094+
(paddle::experimental::NeedTransformPlace(
1095+
in_place, dst_backend, {}));
1096+
if (need_trans) {
1097+
VLOG(6) << "need trans from " << in_place << " to " << dst_backend;
1098+
// build memcopy op
1099+
auto out_place = phi::TransToPhiPlace(dst_backend);
1100+
auto new_in_alloc_type =
1101+
new_in.type().dyn_cast<AllocatedDenseTensorType>();
1102+
auto out_type =
1103+
AllocatedDenseTensorType::get(ctx,
1104+
out_place,
1105+
new_in_alloc_type.dtype(),
1106+
new_in_alloc_type.dims(),
1107+
new_in_alloc_type.data_layout(),
1108+
new_in_alloc_type.lod(),
1109+
new_in_alloc_type.offset());
1110+
auto op_info_parser = GetOpYamlInfoParser(op_item);
1111+
auto kernel_name = GetKernelName(op_info_parser.get(), op_item);
1112+
auto kernel_key = GetKernelKey(op_item,
1113+
place,
1114+
kernel_name,
1115+
*map_value_pair,
1116+
op_info_parser.get());
1117+
VLOG(6) << "kernel type " << kernel_key;
1118+
new_in = AddPlaceTransferOp(
1119+
new_in, out_type, in_place, out_place, kernel_key, block);
1120+
}
1121+
}
10361122
vec_inputs.push_back(new_in);
10371123
}
10381124
}
@@ -1077,6 +1163,7 @@ void HandleForSpecialOp(
10771163
op_output_types.push_back(new_inlet_element.type());
10781164
}
10791165
}
1166+
10801167
if (op_item->name() == "cinn_runtime.jit_kernel") {
10811168
if (op_item->num_operands() > 0) {
10821169
for (size_t i = 0; i < op_item->num_operands(); ++i) {
@@ -1136,12 +1223,9 @@ std::vector<pir::Type> BuildOutputs(pir::Operation* op_item,
11361223

11371224
for (size_t i = 0; i < op_item->num_results(); ++i) {
11381225
phi::Place out_place = phi::TransToPhiPlace(kernel_key.backend());
1139-
1140-
phi::DataType out_phi_dtype = phi::DataType::UNDEFINED;
11411226
if ((!UnchangeOutputOps.count(op_item->name())) &&
11421227
(!IsLegacyOp(op_item->name())) && phi_kernel.IsValid()) {
11431228
out_place = phi::TransToPhiPlace(output_defs[i].backend);
1144-
out_phi_dtype = output_defs[i].dtype;
11451229
}
11461230

11471231
auto result_type = op_item->result(i).type();
@@ -1150,8 +1234,7 @@ std::vector<pir::Type> BuildOutputs(pir::Operation* op_item,
11501234
} else if (result_type.isa<DenseTensorType>() ||
11511235
result_type.isa<SelectedRowsType>() ||
11521236
result_type.isa<DenseTensorArrayType>()) {
1153-
op_output_types.push_back(
1154-
BuildOutputType(result_type, out_place, out_phi_dtype, ctx));
1237+
op_output_types.push_back(BuildOutputType(result_type, out_place, ctx));
11551238
} else if (result_type.isa<pir::VectorType>()) {
11561239
std::vector<pir::Type> vec_inner_types;
11571240
auto base_types = result_type.dyn_cast<pir::VectorType>().data();
@@ -1160,7 +1243,7 @@ std::vector<pir::Type> BuildOutputs(pir::Operation* op_item,
11601243
if (base_type.isa<DenseTensorType>() ||
11611244
base_type.isa<SelectedRowsType>()) {
11621245
vec_inner_types.push_back(
1163-
BuildOutputType(base_type, out_place, out_phi_dtype, ctx));
1246+
BuildOutputType(base_type, out_place, ctx));
11641247
} else {
11651248
PADDLE_THROW(phi::errors::Unimplemented(
11661249
"only support dense tensor and selected rows in vector type "
@@ -1287,7 +1370,6 @@ std::vector<pir::Value> BuildInputs(
12871370
// [ todo need update here, support combine data transfomer]
12881371
// deal with pre combine op
12891372
auto pre_define_op = cur_in.dyn_cast<pir::OpResult>().owner();
1290-
12911373
if (pre_define_op->isa<::pir::CombineOp>()) {
12921374
std::vector<pir::Value> inner_inputs;
12931375
std::vector<pir::Type> types_in_vec;
@@ -1320,8 +1402,6 @@ std::vector<pir::Value> BuildInputs(
13201402
(paddle::experimental::NeedTransformPlace(
13211403
place, input_backend, {}));
13221404
if (need_trans) {
1323-
VLOG(6) << "need trans from " << place << " to "
1324-
<< kernel_key.backend();
13251405
// build memcopy op
13261406
auto out_place = phi::TransToPhiPlace(input_backend);
13271407
pir::Type out_type;
@@ -1528,33 +1608,6 @@ void AddShadowFeedOpForDataOrFeed(
15281608
}
15291609
}
15301610

1531-
std::unique_ptr<OpYamlInfoParser> GetOpYamlInfoParser(pir::Operation* op) {
1532-
OpYamlInfoInterface op_info_interface = op->dyn_cast<OpYamlInfoInterface>();
1533-
1534-
std::unique_ptr<OpYamlInfoParser> op_info_parser(nullptr);
1535-
if (op_info_interface) {
1536-
op_info_parser = std::make_unique<OpYamlInfoParser>(
1537-
op_info_interface.GetOpInfo(), IsLegacyOp(op->name()));
1538-
}
1539-
1540-
return op_info_parser;
1541-
}
1542-
1543-
std::string GetKernelName(const OpYamlInfoParser* op_info_parser,
1544-
pir::Operation* op_item) {
1545-
std::string kernel_fn_str;
1546-
if (op_info_parser != nullptr) {
1547-
kernel_fn_str = op_info_parser->OpRuntimeInfo().kernel_func;
1548-
}
1549-
1550-
if (op_item->isa<AddN_Op>() || op_item->isa<AddNWithKernelOp>()) {
1551-
if (op_item->result(0).type().isa<SelectedRowsType>()) {
1552-
kernel_fn_str = "add_n_sr";
1553-
}
1554-
}
1555-
return kernel_fn_str;
1556-
}
1557-
15581611
pir::Operation* BuildKernelOp(
15591612
const std::string& kernel_fn_str,
15601613
const phi::KernelKey& kernel_key,

0 commit comments

Comments
 (0)