Skip to content

Commit e6464f3

Browse files
fix dtype missmatch error (#53712) (#53764)
Pcard-70458 cherry-pick #53712
1 parent 20fbafe commit e6464f3

File tree

1 file changed

+24
-3
lines changed

1 file changed

+24
-3
lines changed

python/paddle/static/amp/fp16_utils.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -629,12 +629,15 @@ def cast_model_to_fp16(
629629

630630
def need_process(op):
631631
need_process = True
632-
if op.type in ["cast", "create_py_reader", "read"]:
632+
if op.type in ["create_py_reader", "read"]:
633633
need_process = False
634634
else:
635635
for attr_name in ['out_dtype', 'dtype']:
636-
if op.has_attr(attr_name) and is_float_dtype(
637-
op.attr(attr_name)
636+
# output type of some operators such as fill_constant will be determined by the attribute value.
637+
#
638+
if not op.has_attr('in_dtype') and (
639+
op.has_attr(attr_name)
640+
and is_float_dtype(op.attr(attr_name))
638641
):
639642
need_process = False
640643

@@ -667,6 +670,24 @@ def need_process(op):
667670
"---- Add into keep_fp16_ops because the op in white_list ----"
668671
)
669672
else:
673+
# if cast in orgin program, we only modifiy attr and output's dtype to avoid dtype mismatch errors.
674+
if op.type == 'cast':
675+
in_var = block._find_var_recursive(op.input('X')[0])
676+
out_var = block._find_var_recursive(op.output('Out')[0])
677+
op._set_attr('in_dtype', in_var.dtype)
678+
out_var.desc.set_dtype(paddle.dtype(op.attr('out_dtype')))
679+
_logger.debug(
680+
"---- op type: {}, in var [name: {} dtype: {}], out var [name: {} dtype: {}], attr [in_dtype {} out_dtype {}] ----".format(
681+
op.type,
682+
op.input('X')[0],
683+
in_var.dtype,
684+
op.output('Out')[0],
685+
out_var.dtype,
686+
op.attr('in_dtype'),
687+
op.attr('out_dtype'),
688+
)
689+
)
690+
continue
670691
# divide others ops into fp16/fp32 sets according to promoting principle.
671692
dst_dtype = dest_type
672693
if not use_promote:

0 commit comments

Comments
 (0)