@@ -82,20 +82,10 @@ class AutoMixedPrecisionPass : public pir::Pass {
82
82
for (size_t i = 0 ; i < op->num_regions (); ++i) {
83
83
auto & region = op->region (i);
84
84
for (auto & block : region) {
85
- VLOG (6 ) << " ===========Get Op Precision============" << std::endl;
86
85
GetOpPrecision (&block);
87
- VLOG (6 ) << " ===========Update Op Precision============" << std::endl;
88
86
UpdateOpPrecision (&block);
89
-
90
- VLOG (6 ) << " ===========" << op_run_low_precision_.size () << " of "
91
- << block.size () << " ops"
92
- << " run low precision" << std::endl;
93
87
pir::Builder builder = pir::Builder (context_, &block);
94
- VLOG (6 ) << " ===========Process Op Precision============" << std::endl;
95
-
96
88
ProcessBlock (&block, builder);
97
- VLOG (6 ) << " ===========Insert Cast Op Num : " << insert_cast_op_num_
98
- << " ============" << std::endl;
99
89
}
100
90
}
101
91
}
@@ -144,7 +134,6 @@ class AutoMixedPrecisionPass : public pir::Pass {
144
134
void GetOpPrecision (pir::Block* block) {
145
135
for (auto & op_item : *block) {
146
136
auto op = &op_item;
147
- VLOG (6 ) << " op name " << op->name ();
148
137
auto op_name = op->name ();
149
138
bool support_low_precision = true ;
150
139
if (black_list_.count (op_name)) {
@@ -167,10 +156,6 @@ class AutoMixedPrecisionPass : public pir::Pass {
167
156
}
168
157
if (support_low_precision) {
169
158
op_run_low_precision_.insert (op);
170
- VLOG (6 ) << " op " << op->name () << " support low precision" << std::endl;
171
- } else {
172
- VLOG (6 ) << " op " << op->name () << " doesn't support low precision"
173
- << std::endl;
174
159
}
175
160
}
176
161
}
@@ -235,8 +220,6 @@ class AutoMixedPrecisionPass : public pir::Pass {
235
220
}
236
221
if (!OpRunLowPrecision (op)) continue ;
237
222
if (CheckOutputIsScalarAttribute (op)) { // Output is ScalarAttribute
238
- VLOG (6 ) << " op " << op->name () << " output is ScalarAttribute"
239
- << std::endl;
240
223
op_run_low_precision_.erase (op);
241
224
precision_updated = true ;
242
225
}
@@ -261,21 +244,10 @@ class AutoMixedPrecisionPass : public pir::Pass {
261
244
}
262
245
}
263
246
} while (precision_updated);
264
- for (auto & op_item : *block) {
265
- auto op = &op_item;
266
- if (op_should_not_handle_.count (op)) {
267
- VLOG (6 ) << " op " << op->name () << " should not be handled" << std::endl;
268
- } else if (op_run_low_precision_.count (op)) {
269
- VLOG (6 ) << " op " << op->name () << " run low precision" << std::endl;
270
- } else {
271
- VLOG (6 ) << " op " << op->name () << " run high precision" << std::endl;
272
- }
273
- }
274
247
}
275
248
276
249
void RewriteOp (pir::Operation* op,
277
250
pir::Builder& builder) { // NOLINT
278
- VLOG (6 ) << " Rewrite op " << op->name () << std::endl;
279
251
if (IsBuiltinOp (op)) {
280
252
RewriteBuiltinOp (op, builder);
281
253
return ;
@@ -318,7 +290,6 @@ class AutoMixedPrecisionPass : public pir::Pass {
318
290
phi::DataType precision,
319
291
phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) const {
320
292
auto & phi_op_type = op_type;
321
- VLOG (6 ) << " phi_op_type = " << phi_op_type << std::endl;
322
293
323
294
bool support =
324
295
PhiKernelSupportPrecision (phi_op_type, backend, precision, layout);
@@ -419,8 +390,8 @@ class AutoMixedPrecisionPass : public pir::Pass {
419
390
auto new_vec_type = pir::VectorType::get (context, results_type);
420
391
result.set_type (new_vec_type);
421
392
} else {
422
- VLOG ( 6 ) << " result type is not DenseTensorType or VectorType "
423
- << std::endl ;
393
+ PADDLE_THROW ( phi::errors::Unimplemented (
394
+ " result type is not DenseTensorType or VectorType " )) ;
424
395
}
425
396
}
426
397
@@ -452,7 +423,6 @@ class AutoMixedPrecisionPass : public pir::Pass {
452
423
IsVectorTypeFloat (result.type ().dyn_cast <pir::VectorType>())) {
453
424
}
454
425
}
455
- VLOG (6 ) << " op " << op->name () << " doesn't have float result" << std::endl;
456
426
return false ;
457
427
}
458
428
@@ -517,10 +487,8 @@ class AutoMixedPrecisionPass : public pir::Pass {
517
487
518
488
void RewriteBuiltinOp (pir::Operation* op,
519
489
pir::Builder& builder) { // NOLINT
520
- VLOG (6 ) << " Rewrite builtin op " << op->name () << std::endl;
521
490
// Rewrite CombineOp
522
491
if (op->isa <pir::CombineOp>()) {
523
- // auto vec_type = op->result(0).type().dyn_cast<pir::VectorType>();
524
492
auto input_num = op->num_operands ();
525
493
if (OpRunLowPrecision (op)) {
526
494
for (size_t i = 0 ; i < input_num; ++i) {
@@ -572,10 +540,8 @@ class AutoMixedPrecisionPass : public pir::Pass {
572
540
573
541
void RewritePdOp (pir::Operation* op,
574
542
pir::Builder& builder) { // NOLINT
575
- VLOG (6 ) << " Rewrite pd op " << op->name () << std::endl;
576
- phi::Backend backend = ConvertPlaceToBackend (place_);
577
543
std::string op_type = op->name ().substr (op->name ().find (" ." ) + 1 );
578
-
544
+ phi::Backend backend = ConvertPlaceToBackend (place_);
579
545
// Rewrite FetchOp
580
546
if (op->isa <paddle::dialect::FetchOp>()) {
581
547
auto fetch_operand = op->operand (0 );
@@ -587,7 +553,6 @@ class AutoMixedPrecisionPass : public pir::Pass {
587
553
auto result_dtype = paddle::dialect::TransToPhiDataType (
588
554
pir::GetDataTypeFromValue (op->result (0 )));
589
555
if (fetch_operand_phi_dtype != result_dtype) {
590
- VLOG (6 ) << " Insert CastOp for FetchOp" << std::endl;
591
556
DoInsertCastOp (op, fetch_operand, result_dtype, builder);
592
557
}
593
558
return ;
@@ -607,9 +572,6 @@ class AutoMixedPrecisionPass : public pir::Pass {
607
572
// Other pd ops
608
573
if (OpRunLowPrecision (op)) {
609
574
// change result's dtype to low precision
610
- VLOG (6 ) << " Change result's dtype to low precision " << op->name ()
611
- << std::endl;
612
-
613
575
if (op->HasAttribute (" dtype" ) &&
614
576
IsPhiDataTypeFloat (
615
577
op->attribute <paddle::dialect::DataTypeAttribute>(" dtype" )
@@ -644,8 +606,6 @@ class AutoMixedPrecisionPass : public pir::Pass {
644
606
auto result = op->result (i);
645
607
if (!result.type ()) continue ;
646
608
phi::DataType out_phi_dtype = output_defs[i].dtype ;
647
- VLOG (6 ) << " result dtype = " << phi::DataTypeToString (out_phi_dtype)
648
- << std::endl;
649
609
if (out_phi_dtype == phi::DataType::UNDEFINED)
650
610
out_phi_dtype = precision_mode_;
651
611
if (!IsPhiDataTypeFloat (out_phi_dtype))
@@ -663,8 +623,6 @@ class AutoMixedPrecisionPass : public pir::Pass {
663
623
auto operand_phi_dtype = GetPhiDataTypeFromOpOperand (operand);
664
624
if (IsPhiDataTypeFloat (operand_phi_dtype) &&
665
625
operand_phi_dtype != in_phi_dtype) {
666
- VLOG (6 ) << " Support low precision, insert CastOp for " << op->name ()
667
- << " operand " << i << std::endl;
668
626
DoInsertCastOp (op, operand, in_phi_dtype, builder);
669
627
}
670
628
}
@@ -677,8 +635,6 @@ class AutoMixedPrecisionPass : public pir::Pass {
677
635
auto operand_phi_dtype = GetPhiDataTypeFromOpOperand (operand);
678
636
if (IsPhiDataTypeFloat (operand_phi_dtype) &&
679
637
operand_phi_dtype == precision_mode_) {
680
- VLOG (6 ) << " Not support low precision, insert CastOp for "
681
- << op->name () << " operand " << i << std::endl;
682
638
DoInsertCastOp (op, operand, phi_dtype, builder);
683
639
}
684
640
}
0 commit comments