Skip to content

[GPT-3] Support Grad Merge with FP32 main grad for BF16 training of GPT-3 model #1053

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
5cf6743
Add the support of bfloat16 amp training.
Xreki Feb 25, 2023
52c7a83
Allow to use multi_precision to grad_clip.
Xreki Feb 25, 2023
840ce2f
Merge branch 'develop' into support_bf16
Xreki Feb 27, 2023
2f9a039
Applay GradScaler for bfloat16.
Xreki Feb 27, 2023
2915f57
Fix ci error.
Xreki Feb 28, 2023
59546ba
Fix typo and add print of loss_scale.
Xreki Mar 1, 2023
86b574c
Merge branch 'develop' into support_bf16
Xreki Mar 1, 2023
21760c0
Merge branch 'develop' into bf16_grad_merge
haohongxiang Mar 2, 2023
c91472d
fix format
haohongxiang Mar 2, 2023
cd14fac
Merge branch 'develop' into bf16_grad_merge
haohongxiang Mar 3, 2023
8a60a76
rename use_pure_fp16 as mix_precision.enable
haohongxiang Mar 3, 2023
6bb1e03
support main_grad
haohongxiang Mar 4, 2023
cf7fa38
update
haohongxiang Mar 4, 2023
3012033
update
haohongxiang Mar 5, 2023
75b9a9f
update for exps part1
haohongxiang Mar 5, 2023
58a337f
update for exps part2
haohongxiang Mar 5, 2023
d239077
update for exps part3
haohongxiang Mar 6, 2023
b8606bb
update for exps part4
haohongxiang Mar 6, 2023
bc951b2
update
haohongxiang Mar 8, 2023
43023a5
update
haohongxiang Mar 8, 2023
a589a35
update
haohongxiang Mar 10, 2023
65d3d17
Merge branch 'develop' into bf16_grad_merge
haohongxiang Mar 10, 2023
1ea9406
update
haohongxiang Mar 14, 2023
674e5fa
Merge branch 'develop' into bf16_grad_merge
haohongxiang Mar 14, 2023
ba932a9
update
haohongxiang Mar 14, 2023
ea77797
update
haohongxiang Mar 14, 2023
325d81c
update
haohongxiang Mar 15, 2023
fa6fa48
update
haohongxiang Mar 15, 2023
1874cb6
update
haohongxiang Mar 15, 2023
62f32a9
update
haohongxiang Mar 15, 2023
974e1d2
update
haohongxiang Mar 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ function _train(){
-o Global.micro_batch_size=${micro_batch_size} \
-o Engine.max_steps=${max_iter} \
-o Engine.eval_freq=${eval_freq} \
-o Engine.mix_precision.use_pure_fp16=${use_pure_fp16} \
-o Engine.mix_precision.enable=${use_pure_fp16} \
-o Engine.save_load.save_steps=100000 \
-o Model.hidden_size=1024 \
-o Model.num_hidden_layers=${num_layers} \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ function _train(){
-o Global.micro_batch_size=${micro_batch_size} \
-o Engine.max_steps=${max_iter} \
-o Engine.eval_freq=${eval_freq} \
-o Engine.mix_precision.use_pure_fp16=${use_pure_fp16} \
-o Engine.mix_precision.enable=${use_pure_fp16} \
-o Engine.save_load.save_steps=100000 \
-o Model.hidden_size=1024 \
-o Model.num_layers=${num_layers} \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ function _train(){
-o Global.micro_batch_size=${micro_batch_size} \
-o Engine.max_steps=${max_iter} \
-o Engine.eval_freq=${eval_freq} \
-o Engine.mix_precision.use_pure_fp16=${use_pure_fp16} \
-o Engine.mix_precision.enable=${use_pure_fp16} \
-o Engine.save_load.save_steps=100000 \
-o Model.use_recompute=${use_recompute} \
-o Distributed.dp_degree=${dp_degree} \
Expand Down
14 changes: 9 additions & 5 deletions docs/standard.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ Engine:
eval_iters: 10
test_iters:
mix_precision:
use_pure_fp16: True
enable: True
dtype: "float16"
level: "O2"
scale_loss: 32768.0
custom_black_list: ["reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div"]
custom_white_list: ["lookup_table", "lookup_table_v2"]
Expand All @@ -123,10 +125,12 @@ Engine:
| logging_freq | 训练日志打印的频率 |
| eval_freq | 模型评估间隔 |
| eval_iters | 模型评估时训练评估测试集的轮数 |
| use_pure_fp16 | 是否使用purefp16精度训练 |
| scale_loss | 使用fp16精度下,loss的放缩比例 |
| custom_black_list | 自定义算子黑名单。这个名单中的算子在支持float16计算时会被认为是数值危险的,它们的影响也可能会在下游操作中观察到。这些算子通常不会转为float16计算。 |
| custom_white_list | 自定义算子白名单。这个名单中的算子在支持float16计算时会被认为是数值安全的,并且对性能至关重要。如果设置了白名单,该名单中的算子会使用float16计算。|
| enable | 是否使用混合精度策略进行训练 |
| dtype | 混合精度训练数据类型使用float16还是bfloat16,默认为float16类型 |
| level | 混合精度训练模式,默认``O2``模式 |
| scale_loss | 使用fp16混合精度策略下,loss的放缩比例 |
| custom_black_list | 自定义算子黑名单。这个名单中的算子在支持混合精度计算时会被认为是数值危险的,它们的影响也可能会在下游操作中观察到。这些算子通常不会转为float16/bfloat16计算 |
| custom_white_list | 自定义算子白名单。这个名单中的算子在支持混合精度计算时会被认为是数值安全的,并且对性能至关重要。如果设置了白名单,该名单中的算子会使用float16/bfloat16计算 |
| save_steps | 保存模型间隔 |
| save_epoch | 保存模型epoch间隔 |
| output_dir | 指定输出文件 |
Expand Down
14 changes: 9 additions & 5 deletions examples/transformer/models/GPT/docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ cd .. # 回到 GPT 目录下
eval_iters: 10
test_iters:
mix_precision:
use_pure_fp16: True
enable: True
dtype: "float16"
level: "O2"
scale_loss: 32768.0
custom_black_list: ["reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div"]
custom_white_list: ["lookup_table", "lookup_table_v2"]
Expand All @@ -128,10 +130,12 @@ cd .. # 回到 GPT 目录下
| eval_freq | 模型评估间隔 |
| eval_iters | 模型评估时训练评估测试集的轮数 |
| test_iters | 模型测试或推理时的轮数 |
| use_pure_fp16 | 是否使用purefp16精度训练 |
| scale_loss | 使用fp16精度下,loss的放缩比例 |
| custom_black_list | 自定义算子黑名单。这个名单中的算子在支持float16计算时会被认为是数值危险的,它们的影响也可能会在下游操作中观察到。这些算子通常不会转为float16计算。 |
| custom_white_list | 自定义算子白名单。这个名单中的算子在支持float16计算时会被认为是数值安全的,并且对性能至关重要。如果设置了白名单,该名单中的算子会使用float16计算。|
| enable | 是否使用混合精度策略进行训练 |
| dtype | 混合精度训练数据类型使用float16还是bfloat16,默认为float16类型 |
| level | 混合精度训练模式,默认``O2``模式 |
| scale_loss | 使用fp16混合精度策略下,loss的放缩比例 |
| custom_black_list | 自定义算子黑名单。这个名单中的算子在支持混合精度计算时会被认为是数值危险的,它们的影响也可能会在下游操作中观察到。这些算子通常不会转为float16/bfloat16计算 |
| custom_white_list | 自定义算子白名单。这个名单中的算子在支持混合精度计算时会被认为是数值安全的,并且对性能至关重要。如果设置了白名单,该名单中的算子会使用float16/bfloat16计算 |
| save_steps | 保存模型间隔step数 |
| save_epoch | 保存模型间隔epoch数 |
| output_dir | 指定输出文件 |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Global:
logging_freq: 10
eval_freq: 1
mix_precision:
use_pure_fp16: True
enable: True
scale_loss: 32768.0
custom_black_list: ["reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div", "reduce_mean"]
custom_white_list: ["lookup_table", "lookup_table_v2"]
Expand Down
2 changes: 1 addition & 1 deletion examples/transformer/models/GPT/finetune/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def fit_impl(config, batch, forward_func, **kwargs):
def eval_impl(config, batch, model, loss_fn, eval_metric):
model.eval()

use_fp16 = config.Global.mix_precision.use_pure_fp16
use_fp16 = config.Global.mix_precision.enable
black_list = config.Global.mix_precision.custom_black_list
white_list = config.Global.mix_precision.custom_white_list

Expand Down
6 changes: 3 additions & 3 deletions examples/transformer/models/GPT/finetune/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
# build GPT model
model, tokenizer, train_loss_fn, eval_loss_fn = impls.build_model(config)

if config.Global.mix_precision.use_pure_fp16:
if config.Global.mix_precision.enable:
scaler = paddle.amp.GradScaler(
init_loss_scaling=config.Global.mix_precision.scale_loss)
# Note: Save dtype is the same as model dtype. Also can set save_dtype='float32' when
Expand Down Expand Up @@ -98,14 +98,14 @@

if 'multi_precision' in config.Optimizer:
assert config.Optimizer.pop('multi_precision') \
== config.Global.mix_precision.use_pure_fp16
== config.Global.mix_precision.enable

lr_scheduler = cpn.build_lr_scheduler(config.Optimizer.lr)
optimizer = cpn.build_optimizer(
config.Optimizer,
model,
lr_scheduler,
multi_precision=config.Global.mix_precision.use_pure_fp16)
multi_precision=config.Global.mix_precision.enable)

# call fleet wrapper
if nranks > 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Global:
eval_iters: 10
test_iters:
mix_precision:
use_pure_fp16: True
enable: True
scale_loss: 32768.0
custom_black_list: ["reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div"]
custom_white_list: ["lookup_table", "lookup_table_v2"]
Expand Down
4 changes: 2 additions & 2 deletions examples/transformer/models/GPT/generation/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@
cfg.process_configs(config)
cfg.print_config(config)

if config.Global.mix_precision.use_pure_fp16:
logger.info("NOTE: disable use_pure_fp16 in export mode")
if config.Global.mix_precision.enable:
logger.info("NOTE: disable mix_precision in export mode")

# build GPT model
model, _ = impls.build_model(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Global:
eval_iters: 10
test_iters:
mix_precision:
use_pure_fp16: True
enable: True
scale_loss: 32768.0
custom_black_list: ["reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div"]
custom_white_list: ["lookup_table", "lookup_table_v2"]
Expand Down
2 changes: 1 addition & 1 deletion examples/transformer/models/GPT/offline-eval/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def build_model(config):
def eval_impl(config, batch, model):
model.eval()

use_fp16 = config.Global.mix_precision.use_pure_fp16
use_fp16 = config.Global.mix_precision.enable
black_list = config.Global.mix_precision.custom_black_list
white_list = config.Global.mix_precision.custom_white_list

Expand Down
2 changes: 1 addition & 1 deletion examples/transformer/models/GPT/offline-eval/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
]
model, quanter = qat.compress_model(config, model, input_spec)

if config.Global.mix_precision.use_pure_fp16:
if config.Global.mix_precision.enable:
scaler = paddle.amp.GradScaler(
init_loss_scaling=config.Global.mix_precision.scale_loss)
# Note: Save dtype is the same as model dtype. Also can set save_dtype='float32' when
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Global:
eval_iters: 10
test_iters:
mix_precision:
use_pure_fp16: True
enable: True
scale_loss: 32768.0
custom_black_list: ["reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div"]
custom_white_list: ["lookup_table", "lookup_table_v2"]
Expand Down
4 changes: 2 additions & 2 deletions examples/transformer/models/GPT/pretrain/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@
cfg.process_configs(config)
cfg.print_config(config)

if config.Global.mix_precision.use_pure_fp16:
logger.info("NOTE: disable use_pure_fp16 in export mode")
if config.Global.mix_precision.enable:
logger.info("NOTE: disable mix_precision in export mode")

# build GPT model
model, _, _ = impls.build_model(config)
Expand Down
6 changes: 3 additions & 3 deletions examples/transformer/models/GPT/pretrain/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def build_model(config):

def model_forward_backward(config, batch, forward_func, **kwargs):
acc_steps = config.Global.accumulate_steps
use_fp16 = config.Global.mix_precision.use_pure_fp16
use_fp16 = config.Global.mix_precision.enable
black_list = config.Global.mix_precision.custom_black_list
white_list = config.Global.mix_precision.custom_white_list

Expand Down Expand Up @@ -165,7 +165,7 @@ def model_forward_backward(config, batch, forward_func, **kwargs):

def optim_update_params(config, **kwargs):
hcg = env.get_hcg()
use_fp16 = config.Global.mix_precision.use_pure_fp16
use_fp16 = config.Global.mix_precision.enable

dp_degree = config.Distributed.dp_degree
sharding_stage = config.Distributed.sharding.sharding_stage
Expand Down Expand Up @@ -221,7 +221,7 @@ def fit_impl(config, batch, forward_func, **kwargs):
def eval_impl(config, batch, model, loss_fn):
model.eval()

use_fp16 = config.Global.mix_precision.use_pure_fp16
use_fp16 = config.Global.mix_precision.enable
black_list = config.Global.mix_precision.custom_black_list
white_list = config.Global.mix_precision.custom_white_list

Expand Down
4 changes: 2 additions & 2 deletions examples/transformer/models/GPT/pretrain/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
]
model, quanter = qat.compress_model(config, model, input_spec)

if config.Global.mix_precision.use_pure_fp16:
if config.Global.mix_precision.enable:
scaler = paddle.amp.GradScaler(
init_loss_scaling=config.Global.mix_precision.scale_loss)
# Note: Save dtype is the same as model dtype. Also can set save_dtype='float32' when
Expand All @@ -104,7 +104,7 @@
config.Optimizer,
model,
lr_scheduler,
multi_precision=config.Global.mix_precision.use_pure_fp16)
multi_precision=config.Global.mix_precision.enable)

# call fleet wrapper
if nranks > 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Global:
max_steps: 20000
logging_freq: 10
mix_precision:
use_pure_fp16: True
enable: True

Data:
Train:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Global:
eval_iters: 10
test_iters:
mix_precision:
use_pure_fp16: True
enable: True
scale_loss: 32768.0
custom_black_list: ["reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div"]
custom_white_list: ["lookup_table", "lookup_table_v2"]
Expand Down
8 changes: 4 additions & 4 deletions examples/transformer/models/GPT/pretrain_moe/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _get_model_size(l, h, v, s, ne, ei):
# gate
P += (h * nei + nei)
# experts
P += nei * (8 * h * h + 5 * h)
P += nei * (8 * h * h + 5 * h)
# FFN Layer
else:
P += 8 * h * h + 5 * h
Expand Down Expand Up @@ -120,7 +120,7 @@ def build_model(config):

def model_forward_backward(config, batch, forward_func, **kwargs):
acc_steps = config.Global.accumulate_steps
use_fp16 = config.Global.mix_precision.use_pure_fp16
use_fp16 = config.Global.mix_precision.enable
black_list = config.Global.mix_precision.custom_black_list
white_list = config.Global.mix_precision.custom_white_list

Expand Down Expand Up @@ -199,7 +199,7 @@ def model_forward_backward(config, batch, forward_func, **kwargs):

def optim_update_params(config, **kwargs):
hcg = env.get_hcg()
use_fp16 = config.Global.mix_precision.use_pure_fp16
use_fp16 = config.Global.mix_precision.enable

dp_degree = config.Distributed.dp_degree
sharding_stage = config.Distributed.sharding.sharding_stage
Expand Down Expand Up @@ -255,7 +255,7 @@ def fit_impl(config, batch, forward_func, **kwargs):
def eval_impl(config, batch, model, loss_fn):
model.eval()

use_fp16 = config.Global.mix_precision.use_pure_fp16
use_fp16 = config.Global.mix_precision.enable
black_list = config.Global.mix_precision.custom_black_list
white_list = config.Global.mix_precision.custom_white_list

Expand Down
4 changes: 2 additions & 2 deletions examples/transformer/models/GPT/pretrain_moe/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
]
model, quanter = qat.compress_model(config, model, input_spec)

if config.Global.mix_precision.use_pure_fp16:
if config.Global.mix_precision.enable:
scaler = paddle.amp.GradScaler(
init_loss_scaling=config.Global.mix_precision.scale_loss)
# Note: Save dtype is the same as model dtype. Also can set save_dtype='float32' when
Expand All @@ -105,7 +105,7 @@
config.Optimizer,
model,
lr_scheduler,
multi_precision=config.Global.mix_precision.use_pure_fp16)
multi_precision=config.Global.mix_precision.enable)

# call fleet wrapper
if nranks > 1:
Expand Down
2 changes: 1 addition & 1 deletion examples/transformer/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def process_global_configs(config):
global_cfg['mix_precision'] = global_cfg.get('mix_precision', {})
amp_cfg = global_cfg.mix_precision

amp_cfg['use_pure_fp16'] = amp_cfg.get('use_pure_fp16', False)
amp_cfg['enable'] = amp_cfg.get('enable', False)
amp_cfg['scale_loss'] = amp_cfg.get('scale_loss', 32768)
amp_cfg['custom_black_list'] = amp_cfg.get('custom_black_list', None)
amp_cfg['custom_white_list'] = amp_cfg.get('custom_white_list', None)
Expand Down
2 changes: 1 addition & 1 deletion ppfleetx/configs/multimodal/imagen/imagen_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Engine:
eval_freq: 10000000
eval_iters: 10000000
mix_precision:
use_pure_fp16: False
enable: False
scale_loss: 32768.0
custom_black_list: ["reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div"]
custom_white_list: ["lookup_table", "lookup_table_v2"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Engine:
eval_freq: 10000000
eval_iters: 10000000
mix_precision:
use_pure_fp16: False
enable: False
scale_loss: 32768.0
custom_black_list: ["reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div"]
custom_white_list: ["lookup_table", "lookup_table_v2"]
Expand Down
2 changes: 1 addition & 1 deletion ppfleetx/configs/nlp/ernie/auto/pretrain_ernie_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Engine:
eval_iters: 10
test_iters: -1
mix_precision:
use_pure_fp16: False
enable: False
scale_loss: 32768.0
custom_black_list: ["reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div"]
custom_white_list: ["lookup_table", "lookup_table_v2"]
Expand Down
2 changes: 1 addition & 1 deletion ppfleetx/configs/nlp/ernie/finetune_ernie_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Engine:
eval_iters: 10
test_iters: -1
mix_precision:
use_pure_fp16: False
enable: False
scale_loss: 32768.0
custom_black_list: ["reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div"]
custom_white_list: ["lookup_table", "lookup_table_v2"]
Expand Down
2 changes: 1 addition & 1 deletion ppfleetx/configs/nlp/ernie/pretrain_ernie_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Engine:
eval_iters: 10
test_iters: -1
mix_precision:
use_pure_fp16: False
enable: False
scale_loss: 32768.0
custom_black_list: ["reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div"]
custom_white_list: ["lookup_table", "lookup_table_v2"]
Expand Down
2 changes: 1 addition & 1 deletion ppfleetx/configs/nlp/ernie/qat_ernie_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Engine:
eval_iters: 10
test_iters: -1
mix_precision:
use_pure_fp16: False
enable: False
scale_loss: 32768.0
custom_black_list: ["reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div"]
custom_white_list: ["lookup_table", "lookup_table_v2"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Engine:
logging_freq: 10
eval_freq: 1
mix_precision:
use_pure_fp16: True
enable: True
scale_loss: 32768.0
custom_black_list: ["reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div", "reduce_mean"]
custom_white_list: ["lookup_table", "lookup_table_v2"]
Expand Down
Loading