Skip to content

Commit 1387c46

Browse files
yuanlehomeBeingGod
authored andcommitted
1 parent 75669e3 commit 1387c46

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

paddle/fluid/framework/ir/auto_mixed_precision_pass.cc

+12-10
Original file line numberDiff line numberDiff line change
@@ -300,10 +300,10 @@ void AutoMixedPrecisionPass::ApplyImpl(Graph* graph) const {
300300
VLOG(4) << "SetVarPrecision done";
301301
ConvertWeightsData();
302302
VLOG(4) << "ConvertWeightsData done";
303-
ProcessOpWithDtypeAttr();
304-
VLOG(4) << "ProcessOpWithDtypeAttr done";
305303
InsertCastOp();
306304
VLOG(4) << "InsertCastOp done";
305+
ProcessOpWithDtypeAttr();
306+
VLOG(4) << "ProcessOpWithDtypeAttr done";
307307
RestoreOpOriginType();
308308
VLOG(4) << "RestoreOpOriginType done";
309309
LOG(INFO) << "The number of ops run at low precision ["
@@ -355,7 +355,9 @@ void AutoMixedPrecisionPass::ProcessOpWithDtypeAttr() const {
355355

356356
if (op_node->Op()->HasAttr("in_dtype")) {
357357
auto* var_node = op_node->inputs[0];
358-
auto* real_var_node = real_vars_[var_node->Var()->Name()];
358+
auto* real_var_node = real_vars_.count(var_node->Var()->Name())
359+
? real_vars_.at(var_node->Var()->Name())
360+
: var_node;
359361
if (IsFP16AndBFP16(real_var_node->Var()->GetDataType())) {
360362
op_node->Op()->SetAttr(
361363
"in_dtype",
@@ -455,7 +457,7 @@ void AutoMixedPrecisionPass::GetOpPrecision() const {
455457
// not run at low precision.
456458
for (auto* in_var_node : op_node->inputs) {
457459
CHECK_EQ(in_var_node->IsVar(), true);
458-
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
460+
auto* real_in_var_node = real_vars_.at(in_var_node->Var()->Name());
459461
if (real_in_var_node->Var()->Persistable()) continue;
460462

461463
support_low_precision =
@@ -464,7 +466,7 @@ void AutoMixedPrecisionPass::GetOpPrecision() const {
464466
}
465467
for (auto* out_var_node : op_node->outputs) {
466468
CHECK_EQ(out_var_node->IsVar(), true);
467-
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
469+
auto* real_out_var_node = real_vars_.at(out_var_node->Var()->Name());
468470
if (real_out_var_node->Var()->Persistable()) continue;
469471

470472
support_low_precision =
@@ -554,7 +556,7 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const {
554556
CHECK_EQ(in_var_node->IsVar(), true);
555557
if (!VarNodeHasDtype(in_var_node)) continue;
556558

557-
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
559+
auto* real_in_var_node = real_vars_.at(in_var_node->Var()->Name());
558560
if (real_in_var_node->Var()->Persistable()) continue;
559561

560562
if (vars_should_not_low_precision.count(
@@ -573,7 +575,7 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const {
573575
CHECK_EQ(out_var_node->IsVar(), true);
574576
if (!VarNodeHasDtype(out_var_node)) continue;
575577

576-
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
578+
auto* real_out_var_node = real_vars_.at(out_var_node->Var()->Name());
577579
if (real_out_var_node->Var()->Persistable()) continue;
578580

579581
bool not_run_low_precision = false;
@@ -742,7 +744,7 @@ void AutoMixedPrecisionPass::SetVarPrecision() const {
742744
for (auto* in_var_node : op_node->inputs) {
743745
CHECK_EQ(in_var_node->IsVar(), true);
744746

745-
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
747+
auto* real_in_var_node = real_vars_.at(in_var_node->Var()->Name());
746748
auto in_var_name = real_in_var_node->Var()->Name();
747749

748750
if (!IsFP32AndFP64(real_in_var_node->Var()->GetDataType())) continue;
@@ -761,7 +763,7 @@ void AutoMixedPrecisionPass::SetVarPrecision() const {
761763
for (auto* out_var_node : op_node->outputs) {
762764
CHECK_EQ(out_var_node->IsVar(), true);
763765

764-
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
766+
auto* real_out_var_node = real_vars_.at(out_var_node->Var()->Name());
765767
auto out_var_name = real_out_var_node->Var()->Name();
766768

767769
if (!IsFP32AndFP64(real_out_var_node->Var()->GetDataType())) continue;
@@ -877,7 +879,7 @@ void AutoMixedPrecisionPass::InsertCastOp() const {
877879
if (!VarNodeHasDtype(in_var_node)) continue;
878880
if (in_var_node->Var()->Persistable()) continue;
879881

880-
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
882+
auto* real_in_var_node = real_vars_.at(in_var_node->Var()->Name());
881883

882884
auto in_var_type = real_in_var_node->Var()->GetDataType();
883885

paddle/fluid/framework/ir/identity_op_clean_pass.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,8 @@ int IdentityOpCleanPass::CleanTwoCastOp(ir::Graph* graph) const {
201201

202202
void IdentityOpCleanPass::ApplyImpl(ir::Graph* graph) const {
203203
Init(name_scope_, graph);
204-
int found_count = CleanUselessOp(graph) + CleanTwoCastOp(graph);
204+
int found_count = CleanUselessOp(graph);
205+
found_count += CleanTwoCastOp(graph);
205206
AddStatis(found_count);
206207
}
207208

0 commit comments

Comments
 (0)