@@ -300,10 +300,10 @@ void AutoMixedPrecisionPass::ApplyImpl(Graph* graph) const {
300
300
VLOG (4 ) << " SetVarPrecision done" ;
301
301
ConvertWeightsData ();
302
302
VLOG (4 ) << " ConvertWeightsData done" ;
303
- ProcessOpWithDtypeAttr ();
304
- VLOG (4 ) << " ProcessOpWithDtypeAttr done" ;
305
303
InsertCastOp ();
306
304
VLOG (4 ) << " InsertCastOp done" ;
305
+ ProcessOpWithDtypeAttr ();
306
+ VLOG (4 ) << " ProcessOpWithDtypeAttr done" ;
307
307
RestoreOpOriginType ();
308
308
VLOG (4 ) << " RestoreOpOriginType done" ;
309
309
LOG (INFO) << " The number of ops run at low precision ["
@@ -355,7 +355,9 @@ void AutoMixedPrecisionPass::ProcessOpWithDtypeAttr() const {
355
355
356
356
if (op_node->Op ()->HasAttr (" in_dtype" )) {
357
357
auto * var_node = op_node->inputs [0 ];
358
- auto * real_var_node = real_vars_[var_node->Var ()->Name ()];
358
+ auto * real_var_node = real_vars_.count (var_node->Var ()->Name ())
359
+ ? real_vars_.at (var_node->Var ()->Name ())
360
+ : var_node;
359
361
if (IsFP16AndBFP16 (real_var_node->Var ()->GetDataType ())) {
360
362
op_node->Op ()->SetAttr (
361
363
" in_dtype" ,
@@ -455,7 +457,7 @@ void AutoMixedPrecisionPass::GetOpPrecision() const {
455
457
// not run at low precision.
456
458
for (auto * in_var_node : op_node->inputs ) {
457
459
CHECK_EQ (in_var_node->IsVar (), true );
458
- auto * real_in_var_node = real_vars_[ in_var_node->Var ()->Name ()] ;
460
+ auto * real_in_var_node = real_vars_. at ( in_var_node->Var ()->Name ()) ;
459
461
if (real_in_var_node->Var ()->Persistable ()) continue ;
460
462
461
463
support_low_precision =
@@ -464,7 +466,7 @@ void AutoMixedPrecisionPass::GetOpPrecision() const {
464
466
}
465
467
for (auto * out_var_node : op_node->outputs ) {
466
468
CHECK_EQ (out_var_node->IsVar (), true );
467
- auto * real_out_var_node = real_vars_[ out_var_node->Var ()->Name ()] ;
469
+ auto * real_out_var_node = real_vars_. at ( out_var_node->Var ()->Name ()) ;
468
470
if (real_out_var_node->Var ()->Persistable ()) continue ;
469
471
470
472
support_low_precision =
@@ -554,7 +556,7 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const {
554
556
CHECK_EQ (in_var_node->IsVar (), true );
555
557
if (!VarNodeHasDtype (in_var_node)) continue ;
556
558
557
- auto * real_in_var_node = real_vars_[ in_var_node->Var ()->Name ()] ;
559
+ auto * real_in_var_node = real_vars_. at ( in_var_node->Var ()->Name ()) ;
558
560
if (real_in_var_node->Var ()->Persistable ()) continue ;
559
561
560
562
if (vars_should_not_low_precision.count (
@@ -573,7 +575,7 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const {
573
575
CHECK_EQ (out_var_node->IsVar (), true );
574
576
if (!VarNodeHasDtype (out_var_node)) continue ;
575
577
576
- auto * real_out_var_node = real_vars_[ out_var_node->Var ()->Name ()] ;
578
+ auto * real_out_var_node = real_vars_. at ( out_var_node->Var ()->Name ()) ;
577
579
if (real_out_var_node->Var ()->Persistable ()) continue ;
578
580
579
581
bool not_run_low_precision = false ;
@@ -742,7 +744,7 @@ void AutoMixedPrecisionPass::SetVarPrecision() const {
742
744
for (auto * in_var_node : op_node->inputs ) {
743
745
CHECK_EQ (in_var_node->IsVar (), true );
744
746
745
- auto * real_in_var_node = real_vars_[ in_var_node->Var ()->Name ()] ;
747
+ auto * real_in_var_node = real_vars_. at ( in_var_node->Var ()->Name ()) ;
746
748
auto in_var_name = real_in_var_node->Var ()->Name ();
747
749
748
750
if (!IsFP32AndFP64 (real_in_var_node->Var ()->GetDataType ())) continue ;
@@ -761,7 +763,7 @@ void AutoMixedPrecisionPass::SetVarPrecision() const {
761
763
for (auto * out_var_node : op_node->outputs ) {
762
764
CHECK_EQ (out_var_node->IsVar (), true );
763
765
764
- auto * real_out_var_node = real_vars_[ out_var_node->Var ()->Name ()] ;
766
+ auto * real_out_var_node = real_vars_. at ( out_var_node->Var ()->Name ()) ;
765
767
auto out_var_name = real_out_var_node->Var ()->Name ();
766
768
767
769
if (!IsFP32AndFP64 (real_out_var_node->Var ()->GetDataType ())) continue ;
@@ -877,7 +879,7 @@ void AutoMixedPrecisionPass::InsertCastOp() const {
877
879
if (!VarNodeHasDtype (in_var_node)) continue ;
878
880
if (in_var_node->Var ()->Persistable ()) continue ;
879
881
880
- auto * real_in_var_node = real_vars_[ in_var_node->Var ()->Name ()] ;
882
+ auto * real_in_var_node = real_vars_. at ( in_var_node->Var ()->Name ()) ;
881
883
882
884
auto in_var_type = real_in_var_node->Var ()->GetDataType ();
883
885
0 commit comments