Skip to content

Commit a685387

Browse files
HydrogenSulfatefangfangssj
authored andcommitted
Update pop/push instruction for nulltype (PaddlePaddle#71133)
* update pop/push instruction * delete std::cout
1 parent db6f726 commit a685387

File tree

3 files changed

+31
-8
lines changed

3 files changed

+31
-8
lines changed

paddle/fluid/framework/new_executor/instruction/control_flow/tuple_pop_instruction.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,12 @@ TuplePopInstruction::TuplePopInstruction(size_t id,
4141
std::unordered_map<pir::Value, std::vector<int>> outputs;
4242
for (size_t i = 0; i < tuple_pop_op_.tuple_size(); ++i) {
4343
auto outlet_element_value = tuple_pop_op_.outlet_element(i);
44-
outputs.emplace(outlet_element_value,
45-
GetValueIds(outlet_element_value, *value_exe_info_));
44+
if (outlet_element_value.type()) {
45+
outputs.emplace(outlet_element_value,
46+
GetValueIds(outlet_element_value, *value_exe_info_));
47+
} else {
48+
outputs.emplace(outlet_element_value, std::vector<int>{});
49+
}
4650
}
4751

4852
// NOTE(zhangbo): TuplePop will change the variables corresponding to the
@@ -136,6 +140,9 @@ void TuplePopInstruction::Run() {
136140
VLOG(6) << "pop back var: " << front_var;
137141
auto outlet_element_value = tuple_pop_op_.outlet_element(i);
138142
auto grad_var = value_exe_info_->GetVarByValue(outlet_element_value);
143+
if (front_var == nullptr && grad_var == nullptr) {
144+
continue;
145+
}
139146
ShareVarData(front_var, grad_var);
140147
Variable* gc_front_var = const_cast<Variable*>(front_var);
141148
AddEagerGCVar(gc_front_var);

paddle/fluid/framework/new_executor/instruction/control_flow/tuple_push_instruction.cc

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,12 @@ TuplePushInstruction::TuplePushInstruction(size_t id,
6565
std::unordered_map<pir::Value, std::vector<int>> inputs;
6666
for (size_t i = 0; i < tuple_push_op_.tuple_size(); ++i) {
6767
auto inlet_element_value = tuple_push_op_.inlet_element(i);
68-
inputs.emplace(inlet_element_value,
69-
GetValueIds(inlet_element_value, *value_exe_info_));
68+
if (inlet_element_value.type()) {
69+
inputs.emplace(inlet_element_value,
70+
GetValueIds(inlet_element_value, *value_exe_info_));
71+
} else {
72+
inputs.emplace(inlet_element_value, std::vector<int>{});
73+
}
7074
}
7175
SetInputs(inputs);
7276

@@ -99,14 +103,22 @@ void TuplePushInstruction::Run() {
99103
for (size_t i = 0; i < tuple_push_op_.tuple_size(); i++) {
100104
auto inlet_element_value = tuple_push_op_.inlet_element(i);
101105
Variable* var = value_exe_info_->GetVarByValue(inlet_element_value);
102-
103-
auto var_name = value_2_var_name.at(inlet_element_value);
106+
bool is_optional = (inlet_element_value.impl() == nullptr ||
107+
!inlet_element_value.type());
104108
auto num_str = std::to_string(stack_element_var_array_->size());
109+
if (!value_2_var_name.count(inlet_element_value)) {
110+
if (!is_optional) {
111+
PADDLE_THROW(common::errors::PermissionDenied(
112+
"Cannot find corresbonding DenseTensor"));
113+
}
114+
stack_element_var_array_->emplace_back(nullptr);
115+
continue;
116+
}
117+
std::string var_name = value_2_var_name.at(inlet_element_value);
105118
std::string new_name = var_name + "_copied_" + num_str + "_in_tuple_" +
106119
std::to_string(op_->id());
107120
auto* copy_var = value_exe_info_->GetScope()->Var(new_name);
108-
bool is_optional = (inlet_element_value.impl() == nullptr ||
109-
!inlet_element_value.type());
121+
110122
DeepCopyVariable(var,
111123
&copy_var,
112124
value_exe_info_,

paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ pir::Type ConvertOpTypeToKernelType(pir::IrContext* ctx,
101101
ConvertOpTypeToKernelType(ctx, vec_type[i], place));
102102
}
103103
return pir::VectorType::get(ctx, vec_target_type);
104+
} else if (!op_type) {
105+
return pir::Type();
104106
}
105107
PADDLE_THROW(common::errors::Unimplemented(
106108
"Not support op type %s in ConvertOpTypeToKernelType.", op_type));
@@ -1731,6 +1733,8 @@ void AddShadowFeedForValue(
17311733
block->push_back(shadow_tensors_op);
17321734
(*map_op_pair)[op_item] = shadow_tensors_op;
17331735
(*map_value_pair)[op_item->result(index)] = shadow_tensors_op->result(0);
1736+
} else if (!op_item->result(index).type()) {
1737+
return;
17341738
} else {
17351739
PADDLE_THROW(
17361740
common::errors::Unimplemented("AddShadowFeed for value only support "

0 commit comments

Comments
 (0)