Skip to content

【AutoParallel】Unify the fp16 and bf16 in auto-parallel #60514

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/paddle/amp/amp_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
'max_pool2d_with_index',
'mul',
'fused_gemm_epilogue',
"fused_rotary_position_embedding",
"flash_attn",
}

# The set of ops that support fp16, and bf16 was unsupported.
Expand Down
99 changes: 29 additions & 70 deletions python/paddle/distributed/passes/auto_parallel_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@
)
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
from paddle.framework import core
from paddle.static.amp.bf16.amp_utils import (
AutoMixedPrecisionListsBF16,
_is_in_fp32_varnames,
)
from paddle.static.amp.fp16_utils import (
AutoMixedPrecisionLists,
_is_in_black_varnames,
Expand Down Expand Up @@ -88,48 +84,26 @@ def __init__(
black_varnames=None,
dtype="float16",
):
self._amp_list = None
if dtype == "float16":
self._amp_list = AutoMixedPrecisionLists(
set(white_list), set(black_list), set(black_varnames)
)
elif dtype == "bfloat16":
self._amp_list = AutoMixedPrecisionListsBF16(
set(white_list), set(black_list), set(black_varnames)
)

assert self._amp_list is not None
self._amp_list = AutoMixedPrecisionLists(
set(white_list), set(black_list), set(black_varnames), dtype=dtype
)
self._dtype = dtype
self._is_float16 = dtype == "float16"

@property
def white_list(self):
if self._is_float16:
return self._amp_list.white_list
else:
return self._amp_list.bf16_list
return self._amp_list.white_list

@property
def black_list(self):
if self._is_float16:
return self._amp_list.black_list
else:
return self._amp_list.fp32_list
return self._amp_list.black_list

@property
def gray_list(self):
return self._amp_list.gray_list

@property
def black_varnames(self):
if self._is_float16:
return self._amp_list.black_varnames
else:
return self._amp_list.fp32_varnames

@property
def is_fp16(self):
return self._is_float16
return self._amp_list.black_varnames

@property
def dtype(self):
Expand All @@ -140,36 +114,17 @@ def amp_list(self):
return self._amp_list

def _is_in_black_fp32_varnames(self, op):
if self._is_float16:
return _is_in_black_varnames(op, self._amp_list)
else:
return _is_in_fp32_varnames(op, self._amp_list)
return _is_in_black_varnames(op, self._amp_list)

def _op_keep_fp32_input(self, op, in_name):
if not op.amp_options.enable:
return True
if self._is_float16:
return _keep_fp32_input(op, in_name)
else:
if op.type in ['batch_norm', 'layer_norm']:
return in_name != 'X'
if op.type == 'fused_bn_add_activation':
return in_name not in {'X', 'Z'}
return False
return _keep_fp32_input(op, in_name)

def _op_keep_fp32_output(self, op, out_name):
if not op.amp_options.enable:
return True
if self._is_float16:
return _keep_fp32_output(op, out_name)
else:
if op.type in [
'batch_norm',
'fused_bn_add_activation',
'layer_norm',
]:
return out_name != 'Y'
return False
return _keep_fp32_output(op, out_name)


class AMPState:
Expand Down Expand Up @@ -324,12 +279,12 @@ def _cast_block(self, block):
self.dist_context,
)
elif self._is_fp16_op(op.desc.original_id()) is True:
if self.amp_dtype == "bfloat16":
if (
op.has_attr('dtype')
and op.attr('dtype') == core.VarDesc.VarType.FP32
Copy link
Contributor Author

@heavyrain-lzy heavyrain-lzy Jan 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分逻辑来自函数rewrite_program_bf16

rewrite_program_bf16(self._train_program, self._amp_lists)

在静态图单机AMP中有更新后的函数_insert_cast_op中已经包含了这段逻辑,
for attr_name in ['in_dtype', 'out_dtype', 'dtype']:

目前看这里不应该区分bfloat16float16

):
op._set_attr('dtype', core.VarDesc.VarType.BF16)
# deal with op with attribute 'dtype', such as 'fill_constant'
if (
op.has_attr('dtype')
and op.attr('dtype') == core.VarDesc.VarType.FP32
):
op._set_attr('dtype', _str_to_dtype(self.amp_dtype))
num_cast_ops = self._insert_cast_op_forward(
block,
op,
Expand Down Expand Up @@ -362,16 +317,13 @@ def _cast_block(self, block):
self.dist_context,
appended_grad_times,
)
elif (
self._is_fp16_op(op.desc.original_id()) is True
): # fp16/bf16
if self.amp_dtype == "bfloat16":
if (
op.has_attr('dtype')
and op.attr('dtype')
== core.VarDesc.VarType.FP32
):
op._set_attr('dtype', core.VarDesc.VarType.BF16)
elif self._is_fp16_op(op.desc.original_id()) is True:
# deal with op with attribute 'dtype', such as 'fill_constant'
if (
op.has_attr('dtype')
and op.attr('dtype') == core.VarDesc.VarType.FP32
):
op._set_attr('dtype', _str_to_dtype(self.amp_dtype))
num_cast_ops = self._insert_cast_op_backward(
block,
op,
Expand Down Expand Up @@ -522,6 +474,7 @@ def _insert_cast_op_forward(
op._set_attr(
'out_dtype', _str_to_dtype(self.amp_dtype)
)

return num_cast_ops

def _insert_cast_op_backward(
Expand Down Expand Up @@ -699,6 +652,12 @@ def _keep_fp32_output(op, out_name):
else:
assert out_var.dtype == dst_dtype

if (
op.has_attr('dtype')
and op.attr('dtype') == core.VarDesc.VarType.FP32
):
op._set_attr('dtype', _str_to_dtype(self.amp_dtype))

return num_cast_ops


Expand Down
11 changes: 3 additions & 8 deletions python/paddle/distributed/passes/auto_parallel_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from collections import defaultdict

import paddle
import paddle.static.amp.fp16_utils as amp_utils
from paddle.common_ops_import import check_type, check_variable_and_dtype
from paddle.distributed.auto_parallel.static.dist_attribute import (
OperatorDistAttr,
Expand Down Expand Up @@ -831,19 +832,12 @@ def _apply_single_impl(self, main_program, startup_program, context):
if self.use_optimizer_fp16 is None:
self.use_optimizer_fp16 = self.get_attr("level", None) == "o3"

AMPList = amp_utils.AutoMixedPrecisionLists
# swith enviroment for fp16 / bf16.
if self.target_dtype == "float16":
import paddle.static.amp.fp16_utils as amp_utils

AMPList = amp_utils.AutoMixedPrecisionLists
__target_dtype = core.VarDesc.VarType.FP16

elif self.target_dtype == "bfloat16":
from paddle.static.amp.bf16 import amp_utils

AMPList = amp_utils.AutoMixedPrecisionListsBF16
__target_dtype = core.VarDesc.VarType.BF16

else:
raise NotImplementedError(
f"target dtype [{self.target_dtype}] is for amp o2 not supported yet."
Expand All @@ -856,6 +850,7 @@ def _apply_single_impl(self, main_program, startup_program, context):
set(self.get_attr("custom_white_list")),
set(self.get_attr("custom_black_list")),
None,
dtype=self.target_dtype,
)

# NOTE don't not change input data dtype, since it is controled by dataloader
Expand Down