From 2b4dbde312fb8ff3105bc4d2487f6a9e16b22021 Mon Sep 17 00:00:00 2001 From: lzydev <1528794076@qq.com> Date: Tue, 2 Jan 2024 22:36:42 +0800 Subject: [PATCH 1/4] unify the fp16 and bf16 --- .../distributed/passes/auto_parallel_amp.py | 76 +++---------------- .../distributed/passes/auto_parallel_fp16.py | 10 +-- 2 files changed, 11 insertions(+), 75 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index ac533db098c619..986b3fa41b1316 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -26,10 +26,6 @@ ) from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole from paddle.framework import core -from paddle.static.amp.bf16.amp_utils import ( - AutoMixedPrecisionListsBF16, - _is_in_fp32_varnames, -) from paddle.static.amp.fp16_utils import ( AutoMixedPrecisionLists, _is_in_black_varnames, @@ -88,33 +84,18 @@ def __init__( black_varnames=None, dtype="float16", ): - self._amp_list = None - if dtype == "float16": - self._amp_list = AutoMixedPrecisionLists( - set(white_list), set(black_list), set(black_varnames) - ) - elif dtype == "bfloat16": - self._amp_list = AutoMixedPrecisionListsBF16( - set(white_list), set(black_list), set(black_varnames) - ) - - assert self._amp_list is not None + self._amp_list = AutoMixedPrecisionLists( + set(white_list), set(black_list), set(black_varnames) + ) self._dtype = dtype - self._is_float16 = dtype == "float16" @property def white_list(self): - if self._is_float16: - return self._amp_list.white_list - else: - return self._amp_list.bf16_list + return self._amp_list.white_list @property def black_list(self): - if self._is_float16: - return self._amp_list.black_list - else: - return self._amp_list.fp32_list + return self._amp_list.black_list @property def gray_list(self): @@ -122,14 +103,7 @@ def gray_list(self): @property def black_varnames(self): - if self._is_float16: - return self._amp_list.black_varnames - else: - return self._amp_list.fp32_varnames - - @property - def is_fp16(self): - return self._is_float16 + return self._amp_list.black_varnames @property def dtype(self): @@ -140,36 +114,17 @@ def amp_list(self): return self._amp_list def _is_in_black_fp32_varnames(self, op): - if self._is_float16: - return _is_in_black_varnames(op, self._amp_list) - else: - return _is_in_fp32_varnames(op, self._amp_list) + return _is_in_black_varnames(op, self._amp_list) def _op_keep_fp32_input(self, op, in_name): if not op.amp_options.enable: return True - if self._is_float16: - return _keep_fp32_input(op, in_name) - else: - if op.type in ['batch_norm', 'layer_norm']: - return in_name != 'X' - if op.type == 'fused_bn_add_activation': - return in_name not in {'X', 'Z'} - return False + return _keep_fp32_input(op, in_name) def _op_keep_fp32_output(self, op, out_name): if not op.amp_options.enable: return True - if self._is_float16: - return _keep_fp32_output(op, out_name) - else: - if op.type in [ - 'batch_norm', - 'fused_bn_add_activation', - 'layer_norm', - ]: - return out_name != 'Y' - return False + return _keep_fp32_output(op, out_name) class AMPState: @@ -319,12 +274,6 @@ def _cast_block(self, block): self.dist_context, ) elif self._is_fp16_op(op.desc.original_id()) is True: - if self.amp_dtype == "bfloat16": - if ( - op.has_attr('dtype') - and op.attr('dtype') == core.VarDesc.VarType.FP32 - ): - op._set_attr('dtype', core.VarDesc.VarType.BF16) num_cast_ops = self._insert_cast_op_forward( block, op, @@ -360,13 +309,6 @@ def _cast_block(self, block): elif ( self._is_fp16_op(op.desc.original_id()) is True ): # fp16/bf16 - if self.amp_dtype == "bfloat16": - if ( - op.has_attr('dtype') - and op.attr('dtype') - == core.VarDesc.VarType.FP32 - ): - op._set_attr('dtype', core.VarDesc.VarType.BF16) num_cast_ops = self._insert_cast_op_backward( block, op, diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index 2985b4da290f40..7e182c4326e103 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -16,6 +16,7 @@ from collections import defaultdict import paddle +import paddle.static.amp.fp16_utils as amp_utils from paddle.common_ops_import import check_type, check_variable_and_dtype from paddle.distributed.auto_parallel.static.dist_attribute import ( OperatorDistAttr, @@ -815,19 +816,12 @@ def _apply_single_impl(self, main_program, startup_program, context): if self.use_optimizer_fp16 is None: self.use_optimizer_fp16 = self.get_attr("level", None) == "o3" + AMPList = amp_utils.AutoMixedPrecisionLists # swith enviroment for fp16 / bf16. if self.target_dtype == "float16": - import paddle.static.amp.fp16_utils as amp_utils - - AMPList = amp_utils.AutoMixedPrecisionLists __target_dtype = core.VarDesc.VarType.FP16 - elif self.target_dtype == "bfloat16": - from paddle.static.amp.bf16 import amp_utils - - AMPList = amp_utils.AutoMixedPrecisionListsBF16 __target_dtype = core.VarDesc.VarType.BF16 - else: raise NotImplementedError( f"target dtype [{self.target_dtype}] is for amp o2 not supported yet." From 089a89aa59119027db643a422d3a7390f11b2d8d Mon Sep 17 00:00:00 2001 From: lzydev <1528794076@qq.com> Date: Wed, 3 Jan 2024 21:02:07 +0800 Subject: [PATCH 2/4] change white_list in AMP --- python/paddle/amp/amp_lists.py | 2 ++ python/paddle/distributed/passes/auto_parallel_amp.py | 2 +- python/paddle/distributed/passes/auto_parallel_fp16.py | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/python/paddle/amp/amp_lists.py b/python/paddle/amp/amp_lists.py index b4b4fc95cb0499..a1687b33ac711a 100644 --- a/python/paddle/amp/amp_lists.py +++ b/python/paddle/amp/amp_lists.py @@ -22,6 +22,8 @@ 'max_pool2d_with_index', 'mul', 'fused_gemm_epilogue', + "fused_rotary_position_embedding", + "flash_attn", } # The set of ops that support fp16, and bf16 was unsupported. diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index 986b3fa41b1316..ad005ac43965bc 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -85,7 +85,7 @@ def __init__( dtype="float16", ): self._amp_list = AutoMixedPrecisionLists( - set(white_list), set(black_list), set(black_varnames) + set(white_list), set(black_list), set(black_varnames), dtype=dtype ) self._dtype = dtype diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index 7e182c4326e103..448f916b2de352 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -834,6 +834,7 @@ def _apply_single_impl(self, main_program, startup_program, context): set(self.get_attr("custom_white_list")), set(self.get_attr("custom_black_list")), None, + dtype=self.target_dtype, ) # NOTE don't not change input data dtype, since it is controled by dataloader From c9b8be325df1ad256b5ea39dc94d74067d7cf005 Mon Sep 17 00:00:00 2001 From: lzydev <1528794076@qq.com> Date: Sun, 7 Jan 2024 15:58:29 +0800 Subject: [PATCH 3/4] add dtype support --- .../paddle/distributed/passes/auto_parallel_amp.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index 4b43d1a9add0af..325b05bf919c30 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -464,6 +464,13 @@ def _insert_cast_op_forward( op._set_attr( 'out_dtype', _str_to_dtype(self.amp_dtype) ) + + if ( + op.has_attr('dtype') + and op.attr('dtype') == core.VarDesc.VarType.FP32 + ): + op._set_attr('dtype', _str_to_dtype(self.amp_dtype)) + return num_cast_ops def _insert_cast_op_backward( @@ -641,6 +648,12 @@ def _keep_fp32_output(op, out_name): else: assert out_var.dtype == dst_dtype + if ( + op.has_attr('dtype') + and op.attr('dtype') == core.VarDesc.VarType.FP32 + ): + op._set_attr('dtype', _str_to_dtype(self.amp_dtype)) + return num_cast_ops From 112553b7f70b654dc39096ddd0d066973a257a23 Mon Sep 17 00:00:00 2001 From: lzydev <1528794076@qq.com> Date: Mon, 8 Jan 2024 11:25:01 +0800 Subject: [PATCH 4/4] fix bug in dtype --- .../distributed/passes/auto_parallel_amp.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index 325b05bf919c30..36f5116a5870e6 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -279,6 +279,12 @@ def _cast_block(self, block): self.dist_context, ) elif self._is_fp16_op(op.desc.original_id()) is True: + # deal with op with attribute 'dtype', such as 'fill_constant' + if ( + op.has_attr('dtype') + and op.attr('dtype') == core.VarDesc.VarType.FP32 + ): + op._set_attr('dtype', _str_to_dtype(self.amp_dtype)) num_cast_ops = self._insert_cast_op_forward( block, op, @@ -311,9 +317,13 @@ def _cast_block(self, block): self.dist_context, appended_grad_times, ) - elif ( - self._is_fp16_op(op.desc.original_id()) is True - ): # fp16/bf16 + elif self._is_fp16_op(op.desc.original_id()) is True: + # deal with op with attribute 'dtype', such as 'fill_constant' + if ( + op.has_attr('dtype') + and op.attr('dtype') == core.VarDesc.VarType.FP32 + ): + op._set_attr('dtype', _str_to_dtype(self.amp_dtype)) num_cast_ops = self._insert_cast_op_backward( block, op, @@ -465,12 +475,6 @@ def _insert_cast_op_forward( 'out_dtype', _str_to_dtype(self.amp_dtype) ) - if ( - op.has_attr('dtype') - and op.attr('dtype') == core.VarDesc.VarType.FP32 - ): - op._set_attr('dtype', _str_to_dtype(self.amp_dtype)) - return num_cast_ops def _insert_cast_op_backward(