Skip to content

[AMP] Allow to enable multi_precision through paddle.static.amp.decorate and add documents for some apis. #53012

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 24, 2023
101 changes: 91 additions & 10 deletions python/paddle/static/amp/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,14 @@
# limitations under the License.

import copy
import logging

import paddle
from paddle.fluid.log_helper import get_logger

_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)


class OperatorStatsUnit:
Expand Down Expand Up @@ -76,7 +82,7 @@ def _get_var_dtype_from_block(block, op, arg_name, is_input):
var = block._var_recursive(var_name)
return var.dtype
except:
print(
_logger.warning(
"Operator < {} > gets {} < {} : {} > error!".format(
op.type, "input" if is_input else "output", arg_name, var_name
)
Expand All @@ -99,7 +105,7 @@ def _extract_compute_dtype(op, block):
if _is_floating_point(compute_dtype) and _is_floating_point(
var_dtype
):
print(
_logger.warning(
"Operator < {} > has different input data types, input_names = {}, output_names = {}.".format(
op.type, op.input_names, op.output_names
)
Expand All @@ -125,7 +131,7 @@ def _extract_compute_dtype(op, block):
if _is_floating_point(compute_dtype) and _is_floating_point(
var_dtype
):
print(
_logger.warning(
"Operator < {} > has different input / output data types, input_names = {}, output_names = {}.".format(
op.type, op.input_names, op.output_names
)
Expand All @@ -145,6 +151,15 @@ def _merge_op_stats(op_stats_list):


def _get_op_stats_list(program):
def _is_special_ops_with_input_x(op_type):
# operators have input X and have inputs different dtypes.
special_op_list = ['cast', 'batch_norm', 'instance_norm', 'layer_norm']
if op_type in special_op_list:
return True
if op_type.replace("_grad", "") in special_op_list:
return True
return False

op_stats_list = []
for block in program.blocks:
block_op_stats_dict = {}
Expand All @@ -161,13 +176,7 @@ def _get_op_stats_list(program):
'create_double_buffer_reader',
]:
compute_dtype = None
elif op.type in [
'cast',
'layer_norm',
'layer_norm_grad',
'batch_norm',
'batch_norm_grad',
]:
elif _is_special_ops_with_input_x(op.type):
# Not check the input and output dtype difference for this operators.
compute_dtype = _get_var_dtype_from_block(block, op, 'X', True)
elif "Param" in op.input_names:
Expand All @@ -183,6 +192,78 @@ def _get_op_stats_list(program):


def collect_operator_stats(program=None, print_subblocks=False):
"""
Collect the number of operators for different data types through parsing
the program. The statistical data are categorized according to four data
types, namely float32, float16, bfloat16 and others.

Args:
program(Program, optional): The program to parse. Default None, and the default main_program will be parsed.
print_subblocks(bool, optional): Whether to print the operator stats for each subblock. Default False.

Examples:

.. code-block:: python

import paddle

paddle.enable_static()

class SimpleConvNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.conv = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3)
self.linear = paddle.nn.Linear(in_features=26, out_features=10)

def forward(self, x):
out = self.conv(x)
out = paddle.nn.functional.relu(out)
out = self.linear(out)
out = paddle.nn.functional.softmax(out)
return out

main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startup_program):
model = SimpleConvNet()
x = paddle.static.data(
name='input', shape=[None, 1, 28, 28], dtype='float32'
)
out = model(x)
loss = paddle.mean(out)
optimizer = paddle.optimizer.AdamW()
optimizer = paddle.static.amp.decorate(optimizer)
optimizer.minimize(loss)
paddle.static.amp.debugging.collect_operator_stats(main_program)
# <------------------------------------------------ op list of all blocks ------------------------------------------------->
# <------------------------------------------------------- op list -------------------------------------------------------->
# <--------------- Op Name ---------------- | -- FP16 Calls --- | -- BF16 Calls --- | --- FP32 Calls--- | -- Other Calls -->
# adamw | 0 | 0 | 4 | 0
# cast | 5 | 0 | 6 | 0
# check_finite_and_unscale | 0 | 0 | 1 | 0
# conv2d | 1 | 0 | 0 | 0
# conv2d_grad | 1 | 0 | 0 | 0
# elementwise_add | 2 | 0 | 0 | 0
# elementwise_add_grad | 2 | 0 | 0 | 0
# elementwise_mul | 0 | 0 | 1 | 0
# elementwise_mul_grad | 0 | 0 | 1 | 0
# fill_constant | 0 | 0 | 1 | 0
# matmul_v2 | 1 | 0 | 0 | 0
# matmul_v2_grad | 1 | 0 | 0 | 0
# memcpy | 0 | 0 | 0 | 1
# reduce_mean | 0 | 0 | 1 | 0
# reduce_mean_grad | 0 | 0 | 1 | 0
# relu | 1 | 0 | 0 | 0
# relu_grad | 1 | 0 | 0 | 0
# reshape2 | 0 | 0 | 1 | 0
# reshape2_grad | 0 | 0 | 1 | 0
# softmax | 0 | 0 | 1 | 0
# softmax_grad | 0 | 0 | 1 | 0
# update_loss_scaling | 0 | 0 | 1 | 0
# <----------------------------------------------------- op count: 22 ----------------------------------------------------->
"""

def _convert_to_list(op_stats_unit_dict):
for key, value in op_stats_unit_dict.items():
op_stats_unit_dict[key] = value.convert_to_list()
Expand Down
111 changes: 106 additions & 5 deletions python/paddle/static/amp/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@
from .function_overload import FunctionType, overload


def _set_multi_precision(optimizer, multi_precision):
if not isinstance(
optimizer,
(paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer),
):
raise RuntimeError(
"Current AMP training level is O2, optimizer is expected to be paddle.optimizer.Optimizer or paddle.fluid.optimizer.Optimizer, but receive {}.".format(
type(optimizer)
)
)

if multi_precision and hasattr(optimizer, "_multi_precision"):
optimizer._multi_precision = multi_precision


class OptimizerWithMixedPrecision:
"""
Optimizer with mixed-precision (MP) training. This is a wrapper of a common
Expand Down Expand Up @@ -767,29 +782,115 @@ def decorate(
amp_lists=None,
level='O1',
dtype='float16',
master_weight=None,
init_loss_scaling=2**15,
incr_every_n_steps=1000,
decr_every_n_nan_or_inf=2,
incr_ratio=2.0,
decr_ratio=0.8,
use_dynamic_loss_scaling=True,
use_dynamic_loss_scaling=None,
use_amp_guard=False,
use_promote=False,
):
"""
Decorate the given optimizer to adapt to the mixed-precision training.
"""
amp_dtype = check_amp_dtype(dtype)
if amp_lists is None:
amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)

Args:
optimizer(Optimizer): A common Optimizer.
amp_lists(CustomOpLists, optional): An CustomOpLists object. The default
white_list and black_list will be used for AMP training when it is
not set. Default is None.
level(str, optional): Auto mixed precision level. Accepted values are
"O1" and "O2": O1 represent mixed precision, the input data type of
each operator will be casted by white_list and black_list;
O2 represent pure FP16 / BF16 training, all operators parameters
and input data will be casted to FP16 / BF16, except operators in
black_list, don't support FP16 / BF16 kernel and batch_norm. Default is O1.
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
master_weight(bool, optinal): For level='O2', whether to use multi-precision
during weight updating. If master_weight is None, in O2 level optimizer
will use multi-precision. Default is None.
init_loss_scaling(float, optional): The initial loss scaling factor.
Default is 32768.
incr_every_n_steps(int, optional): Increases loss scaling every n
consecutive steps with finite gradients. Default is 1000.
decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n
accumulated steps with nan or inf gradients. Default is 2.
incr_ratio(float, optional): The multiplier to use when increasing the
loss scaling. Default is 2.
decr_ratio(float, optional): The less-than-one-multiplier to use when
decreasing the loss scaling. Default is 0.8.
use_dynamic_loss_scaling(bool, None): Whether to use dynamic loss
scaling. Default is None, which means True for float16, and False
for bfloat16.

Returns:
An optimizer acting like a normal one but with mixed-precision training

Examples:

.. code-block:: python

import paddle

paddle.enable_static()

class SimpleConvNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.conv = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3)
self.linear = paddle.nn.Linear(in_features=26, out_features=10)

def forward(self, x):
out = self.conv(x)
out = paddle.nn.functional.relu(out)
out = self.linear(out)
out = paddle.nn.functional.softmax(out)
return out

main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startup_program):
model = SimpleConvNet()
x = paddle.static.data(
name='input', shape=[None, 1, 28, 28], dtype='float32'
)
out = model(x)
loss = paddle.mean(out)
optimizer = paddle.optimizer.AdamW()
optimizer = paddle.static.amp.decorate(optimizer, level="O2", dtype="float16")
optimizer.minimize(loss)

if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(startup_program)

# Call `amp_init` after FP32 parameters initialization, such as `exe.run(startup_program)`,
# to convert FP32 parameters to low precision FP16 / BF16.
optimizer.amp_init(place, scope=paddle.static.global_scope())

"""
# check amp_level: O0-O2
level = level.upper()
if not (level in ['O0', 'O1', 'O2']):
raise ValueError(
"level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
)

amp_dtype = check_amp_dtype(dtype)
if amp_lists is None:
amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)

if use_dynamic_loss_scaling is None:
use_dynamic_loss_scaling = dtype == "float16"

if optimizer is not None:
# support master_weight
multi_precision = not (master_weight is False)
_set_multi_precision(optimizer, multi_precision)

mp_optimizer = OptimizerWithMixedPrecision(
optimizer,
amp_lists,
Expand Down
1 change: 0 additions & 1 deletion test/amp/amp_base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def _build_optimizer(
beta2=0.836,
epsilon=1e-4,
weight_decay=0.01,
multi_precision=True,
)
if use_amp:
optimizer = paddle.static.amp.decorate(
Expand Down
23 changes: 23 additions & 0 deletions test/amp/test_model_cast_to_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,29 @@ def test_graph_cast(self):


class TestProgramBF16(AmpTestBase):
def _check_optimizer(self, program, expected_num_mp):
optimizers = []
for block in program.blocks:
for op in block.ops:
if "Param" in op.input_names and "Grad" in op.input_names:
optimizers.append(op)

actual_num_mp = 0
for op in optimizers:
if op.has_attr("multi_precision") and op.attr("multi_precision"):
actual_num_mp += 1
self.assertEqual(
actual_num_mp,
expected_num_mp,
f"The number of optimizers with multi_precison = True is expected to be {expected_num_mp}, but recieved {actual_num_mp}.",
)

def test_amp_bf16_o1(self):
main_program, startup_program = build_embedding_model(
True, "bfloat16", "O1"
)
self.assertEqual(main_program.num_blocks, 1)
self._check_optimizer(main_program, 0)

amp.debugging.collect_operator_stats(main_program)
op_stats_list = amp.debugging._get_op_stats_list(main_program)
Expand Down Expand Up @@ -255,6 +273,11 @@ def test_amp_bf16_o2(self):
"squared_l2_norm": 2,
"adamw": 2,
}
self._check_optimizer(
main_program,
expected_bf16_calls["matmul_v2"]
+ expected_bf16_calls["elementwise_add"],
)
self._check_op_calls(op_stats_list[0], expected_bf16_calls)


Expand Down