Skip to content

Commit 8df0c7e

Browse files
Xrekizhangting2020
authored andcommitted
[AMP] Allow to enable multi_precision through paddle.static.amp.decorate and add documents for some apis. (PaddlePaddle#53012)
* Add document for some apis. test=docs_preview * Allow to set master_weight in paddle.static.amp.decorate. * Polish codes and add unittest. * Refine docs. * Remove the repetitive function.
1 parent bcbb959 commit 8df0c7e

File tree

4 files changed

+220
-16
lines changed

4 files changed

+220
-16
lines changed

python/paddle/static/amp/debugging.py

Lines changed: 91 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,14 @@
1313
# limitations under the License.
1414

1515
import copy
16+
import logging
1617

1718
import paddle
19+
from paddle.fluid.log_helper import get_logger
20+
21+
_logger = get_logger(
22+
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
23+
)
1824

1925

2026
class OperatorStatsUnit:
@@ -76,7 +82,7 @@ def _get_var_dtype_from_block(block, op, arg_name, is_input):
7682
var = block._var_recursive(var_name)
7783
return var.dtype
7884
except:
79-
print(
85+
_logger.warning(
8086
"Operator < {} > gets {} < {} : {} > error!".format(
8187
op.type, "input" if is_input else "output", arg_name, var_name
8288
)
@@ -99,7 +105,7 @@ def _extract_compute_dtype(op, block):
99105
if _is_floating_point(compute_dtype) and _is_floating_point(
100106
var_dtype
101107
):
102-
print(
108+
_logger.warning(
103109
"Operator < {} > has different input data types, input_names = {}, output_names = {}.".format(
104110
op.type, op.input_names, op.output_names
105111
)
@@ -125,7 +131,7 @@ def _extract_compute_dtype(op, block):
125131
if _is_floating_point(compute_dtype) and _is_floating_point(
126132
var_dtype
127133
):
128-
print(
134+
_logger.warning(
129135
"Operator < {} > has different input / output data types, input_names = {}, output_names = {}.".format(
130136
op.type, op.input_names, op.output_names
131137
)
@@ -145,6 +151,15 @@ def _merge_op_stats(op_stats_list):
145151

146152

147153
def _get_op_stats_list(program):
154+
def _is_special_ops_with_input_x(op_type):
155+
# operators have input X and have inputs different dtypes.
156+
special_op_list = ['cast', 'batch_norm', 'instance_norm', 'layer_norm']
157+
if op_type in special_op_list:
158+
return True
159+
if op_type.replace("_grad", "") in special_op_list:
160+
return True
161+
return False
162+
148163
op_stats_list = []
149164
for block in program.blocks:
150165
block_op_stats_dict = {}
@@ -161,13 +176,7 @@ def _get_op_stats_list(program):
161176
'create_double_buffer_reader',
162177
]:
163178
compute_dtype = None
164-
elif op.type in [
165-
'cast',
166-
'layer_norm',
167-
'layer_norm_grad',
168-
'batch_norm',
169-
'batch_norm_grad',
170-
]:
179+
elif _is_special_ops_with_input_x(op.type):
171180
# Not check the input and output dtype difference for this operators.
172181
compute_dtype = _get_var_dtype_from_block(block, op, 'X', True)
173182
elif "Param" in op.input_names:
@@ -183,6 +192,78 @@ def _get_op_stats_list(program):
183192

184193

185194
def collect_operator_stats(program=None, print_subblocks=False):
195+
"""
196+
Collect the number of operators for different data types through parsing
197+
the program. The statistical data are categorized according to four data
198+
types, namely float32, float16, bfloat16 and others.
199+
200+
Args:
201+
program(Program, optional): The program to parse. Default None, and the default main_program will be parsed.
202+
print_subblocks(bool, optional): Whether to print the operator stats for each subblock. Default False.
203+
204+
Examples:
205+
206+
.. code-block:: python
207+
208+
import paddle
209+
210+
paddle.enable_static()
211+
212+
class SimpleConvNet(paddle.nn.Layer):
213+
def __init__(self):
214+
super().__init__()
215+
self.conv = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3)
216+
self.linear = paddle.nn.Linear(in_features=26, out_features=10)
217+
218+
def forward(self, x):
219+
out = self.conv(x)
220+
out = paddle.nn.functional.relu(out)
221+
out = self.linear(out)
222+
out = paddle.nn.functional.softmax(out)
223+
return out
224+
225+
main_program = paddle.static.Program()
226+
startup_program = paddle.static.Program()
227+
with paddle.utils.unique_name.guard():
228+
with paddle.static.program_guard(main_program, startup_program):
229+
model = SimpleConvNet()
230+
x = paddle.static.data(
231+
name='input', shape=[None, 1, 28, 28], dtype='float32'
232+
)
233+
out = model(x)
234+
loss = paddle.mean(out)
235+
optimizer = paddle.optimizer.AdamW()
236+
optimizer = paddle.static.amp.decorate(optimizer)
237+
optimizer.minimize(loss)
238+
paddle.static.amp.debugging.collect_operator_stats(main_program)
239+
# <------------------------------------------------ op list of all blocks ------------------------------------------------->
240+
# <------------------------------------------------------- op list -------------------------------------------------------->
241+
# <--------------- Op Name ---------------- | -- FP16 Calls --- | -- BF16 Calls --- | --- FP32 Calls--- | -- Other Calls -->
242+
# adamw | 0 | 0 | 4 | 0
243+
# cast | 5 | 0 | 6 | 0
244+
# check_finite_and_unscale | 0 | 0 | 1 | 0
245+
# conv2d | 1 | 0 | 0 | 0
246+
# conv2d_grad | 1 | 0 | 0 | 0
247+
# elementwise_add | 2 | 0 | 0 | 0
248+
# elementwise_add_grad | 2 | 0 | 0 | 0
249+
# elementwise_mul | 0 | 0 | 1 | 0
250+
# elementwise_mul_grad | 0 | 0 | 1 | 0
251+
# fill_constant | 0 | 0 | 1 | 0
252+
# matmul_v2 | 1 | 0 | 0 | 0
253+
# matmul_v2_grad | 1 | 0 | 0 | 0
254+
# memcpy | 0 | 0 | 0 | 1
255+
# reduce_mean | 0 | 0 | 1 | 0
256+
# reduce_mean_grad | 0 | 0 | 1 | 0
257+
# relu | 1 | 0 | 0 | 0
258+
# relu_grad | 1 | 0 | 0 | 0
259+
# reshape2 | 0 | 0 | 1 | 0
260+
# reshape2_grad | 0 | 0 | 1 | 0
261+
# softmax | 0 | 0 | 1 | 0
262+
# softmax_grad | 0 | 0 | 1 | 0
263+
# update_loss_scaling | 0 | 0 | 1 | 0
264+
# <----------------------------------------------------- op count: 22 ----------------------------------------------------->
265+
"""
266+
186267
def _convert_to_list(op_stats_unit_dict):
187268
for key, value in op_stats_unit_dict.items():
188269
op_stats_unit_dict[key] = value.convert_to_list()

python/paddle/static/amp/decorator.py

Lines changed: 106 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,21 @@
3434
from .function_overload import FunctionType, overload
3535

3636

37+
def _set_multi_precision(optimizer, multi_precision):
38+
if not isinstance(
39+
optimizer,
40+
(paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer),
41+
):
42+
raise RuntimeError(
43+
"Current AMP training level is O2, optimizer is expected to be paddle.optimizer.Optimizer or paddle.fluid.optimizer.Optimizer, but receive {}.".format(
44+
type(optimizer)
45+
)
46+
)
47+
48+
if multi_precision and hasattr(optimizer, "_multi_precision"):
49+
optimizer._multi_precision = multi_precision
50+
51+
3752
class OptimizerWithMixedPrecision:
3853
"""
3954
Optimizer with mixed-precision (MP) training. This is a wrapper of a common
@@ -767,29 +782,115 @@ def decorate(
767782
amp_lists=None,
768783
level='O1',
769784
dtype='float16',
785+
master_weight=None,
770786
init_loss_scaling=2**15,
771787
incr_every_n_steps=1000,
772788
decr_every_n_nan_or_inf=2,
773789
incr_ratio=2.0,
774790
decr_ratio=0.8,
775-
use_dynamic_loss_scaling=True,
791+
use_dynamic_loss_scaling=None,
776792
use_amp_guard=False,
777793
use_promote=False,
778794
):
779795
"""
780796
Decorate the given optimizer to adapt to the mixed-precision training.
781-
"""
782-
amp_dtype = check_amp_dtype(dtype)
783-
if amp_lists is None:
784-
amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)
785797
798+
Args:
799+
optimizer(Optimizer): A common Optimizer.
800+
amp_lists(CustomOpLists, optional): An CustomOpLists object. The default
801+
white_list and black_list will be used for AMP training when it is
802+
not set. Default is None.
803+
level(str, optional): Auto mixed precision level. Accepted values are
804+
"O1" and "O2": O1 represent mixed precision, the input data type of
805+
each operator will be casted by white_list and black_list;
806+
O2 represent pure FP16 / BF16 training, all operators parameters
807+
and input data will be casted to FP16 / BF16, except operators in
808+
black_list, don't support FP16 / BF16 kernel and batch_norm. Default is O1.
809+
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
810+
master_weight(bool, optinal): For level='O2', whether to use multi-precision
811+
during weight updating. If master_weight is None, in O2 level optimizer
812+
will use multi-precision. Default is None.
813+
init_loss_scaling(float, optional): The initial loss scaling factor.
814+
Default is 32768.
815+
incr_every_n_steps(int, optional): Increases loss scaling every n
816+
consecutive steps with finite gradients. Default is 1000.
817+
decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n
818+
accumulated steps with nan or inf gradients. Default is 2.
819+
incr_ratio(float, optional): The multiplier to use when increasing the
820+
loss scaling. Default is 2.
821+
decr_ratio(float, optional): The less-than-one-multiplier to use when
822+
decreasing the loss scaling. Default is 0.8.
823+
use_dynamic_loss_scaling(bool, None): Whether to use dynamic loss
824+
scaling. Default is None, which means True for float16, and False
825+
for bfloat16.
826+
827+
Returns:
828+
An optimizer acting like a normal one but with mixed-precision training
829+
830+
Examples:
831+
832+
.. code-block:: python
833+
834+
import paddle
835+
836+
paddle.enable_static()
837+
838+
class SimpleConvNet(paddle.nn.Layer):
839+
def __init__(self):
840+
super().__init__()
841+
self.conv = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3)
842+
self.linear = paddle.nn.Linear(in_features=26, out_features=10)
843+
844+
def forward(self, x):
845+
out = self.conv(x)
846+
out = paddle.nn.functional.relu(out)
847+
out = self.linear(out)
848+
out = paddle.nn.functional.softmax(out)
849+
return out
850+
851+
main_program = paddle.static.Program()
852+
startup_program = paddle.static.Program()
853+
with paddle.utils.unique_name.guard():
854+
with paddle.static.program_guard(main_program, startup_program):
855+
model = SimpleConvNet()
856+
x = paddle.static.data(
857+
name='input', shape=[None, 1, 28, 28], dtype='float32'
858+
)
859+
out = model(x)
860+
loss = paddle.mean(out)
861+
optimizer = paddle.optimizer.AdamW()
862+
optimizer = paddle.static.amp.decorate(optimizer, level="O2", dtype="float16")
863+
optimizer.minimize(loss)
864+
865+
if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
866+
place = paddle.CUDAPlace(0)
867+
exe = paddle.static.Executor(place)
868+
exe.run(startup_program)
869+
870+
# Call `amp_init` after FP32 parameters initialization, such as `exe.run(startup_program)`,
871+
# to convert FP32 parameters to low precision FP16 / BF16.
872+
optimizer.amp_init(place, scope=paddle.static.global_scope())
873+
874+
"""
786875
# check amp_level: O0-O2
787876
level = level.upper()
788877
if not (level in ['O0', 'O1', 'O2']):
789878
raise ValueError(
790879
"level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
791880
)
792881

882+
amp_dtype = check_amp_dtype(dtype)
883+
if amp_lists is None:
884+
amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)
885+
886+
if use_dynamic_loss_scaling is None:
887+
use_dynamic_loss_scaling = dtype == "float16"
888+
889+
if optimizer is not None:
890+
# support master_weight
891+
multi_precision = not (master_weight is False)
892+
_set_multi_precision(optimizer, multi_precision)
893+
793894
mp_optimizer = OptimizerWithMixedPrecision(
794895
optimizer,
795896
amp_lists,

test/amp/amp_base_models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def _build_optimizer(
4242
beta2=0.836,
4343
epsilon=1e-4,
4444
weight_decay=0.01,
45-
multi_precision=True,
4645
)
4746
if use_amp:
4847
optimizer = paddle.static.amp.decorate(

test/amp/test_model_cast_to_bf16.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,11 +221,29 @@ def test_graph_cast(self):
221221

222222

223223
class TestProgramBF16(AmpTestBase):
224+
def _check_optimizer(self, program, expected_num_mp):
225+
optimizers = []
226+
for block in program.blocks:
227+
for op in block.ops:
228+
if "Param" in op.input_names and "Grad" in op.input_names:
229+
optimizers.append(op)
230+
231+
actual_num_mp = 0
232+
for op in optimizers:
233+
if op.has_attr("multi_precision") and op.attr("multi_precision"):
234+
actual_num_mp += 1
235+
self.assertEqual(
236+
actual_num_mp,
237+
expected_num_mp,
238+
f"The number of optimizers with multi_precison = True is expected to be {expected_num_mp}, but recieved {actual_num_mp}.",
239+
)
240+
224241
def test_amp_bf16_o1(self):
225242
main_program, startup_program = build_embedding_model(
226243
True, "bfloat16", "O1"
227244
)
228245
self.assertEqual(main_program.num_blocks, 1)
246+
self._check_optimizer(main_program, 0)
229247

230248
amp.debugging.collect_operator_stats(main_program)
231249
op_stats_list = amp.debugging._get_op_stats_list(main_program)
@@ -255,6 +273,11 @@ def test_amp_bf16_o2(self):
255273
"squared_l2_norm": 2,
256274
"adamw": 2,
257275
}
276+
self._check_optimizer(
277+
main_program,
278+
expected_bf16_calls["matmul_v2"]
279+
+ expected_bf16_calls["elementwise_add"],
280+
)
258281
self._check_op_calls(op_stats_list[0], expected_bf16_calls)
259282

260283

0 commit comments

Comments
 (0)