Skip to content

Commit e93c3e9

Browse files
committed
Update O2 inference.
1 parent e8dd19b commit e93c3e9

File tree

5 files changed

+69
-29
lines changed

5 files changed

+69
-29
lines changed

apps/protein_folding/helixfold/gpu_infer.sh

+8-4
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,14 @@ distributed=false
2828
# 'fp32' or 'bf16'
2929
PRECISION='bf16'
3030

31-
# Disable C++ enisum, using python enisum
32-
export FLAGS_new_einsum=0
31+
# 'O1' or 'O2'
32+
AMP_LEVEL='O1'
3333

34-
# Enable bf16 optimization
35-
export FLAGS_use_autotune=1
34+
# Enable C++ enisum instead of python enisum
35+
export FLAGS_new_einsum=1
36+
37+
# Enable/Disable bf16 optimization
38+
export FLAGS_use_autotune=0
3639

3740
if [ $distributed == true ]
3841
then
@@ -79,5 +82,6 @@ $python_cmd run_helixfold.py \
7982
--preset='full_dbs' \
8083
--random_seed=0 \
8184
--precision=${PRECISION} \
85+
--amp_level=${AMP_LEVEL} \
8286
--subbatch_size=${SUBBATCH_SIZE} \
8387
${@:2}

apps/protein_folding/helixfold/gpu_infer_long.sh

+8-4
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,14 @@ distributed=true
3737
# 'fp32' or 'bf16'
3838
PRECISION='bf16'
3939

40-
# Disable C++ enisum, using python enisum
41-
export FLAGS_new_einsum=0
40+
# 'O1' or 'O2'
41+
AMP_LEVEL='O1'
4242

43-
# Enable bf16 optimization
44-
export FLAGS_use_autotune=1
43+
# Enable C++ enisum instead of python enisum
44+
export FLAGS_new_einsum=1
45+
46+
# Enable/Disable bf16 optimization
47+
export FLAGS_use_autotune=0
4548

4649
if [ $distributed == true ]
4750
then
@@ -89,5 +92,6 @@ $python_cmd run_helixfold.py \
8992
--preset='full_dbs' \
9093
--random_seed=0 \
9194
--precision=${PRECISION} \
95+
--amp_level=${AMP_LEVEL} \
9296
--subbatch_size=${SUBBATCH_SIZE} \
9397
${@:2}

apps/protein_folding/helixfold/gpu_train.sh

+31-7
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ python_bin="/opt/conda/envs/helixfold/bin/python"
99
# python_bin="python3"
1010

1111
# export NCCL_DEBUG=INFO
12+
# export LD_LIBRARY_PATH=/usr/local/cuda/compat:$LD_LIBRARY_PATH
1213
export PYTHONPATH=$root_path:$PYTHONPATH
1314
# export PADDLE_NODE_NUM=$PADDLE_TRAINERS_NUM
1415
# export PADDLE_NODE_NUM=1
@@ -17,11 +18,11 @@ LDDT_SCORE_BIN="$root_path/tools/lddt"
1718
chmod +x $TM_SCORE_BIN
1819
chmod +x $LDDT_SCORE_BIN
1920

20-
# Disable C++ enisum, using python enisum
21-
export FLAGS_new_einsum=0
21+
# Enable C++ enisum instead of python enisum
22+
export FLAGS_new_einsum=1
2223

23-
# Enable bf16 optimization
24-
export FLAGS_use_autotune=1
24+
# Enable/Disable bf16 optimization
25+
export FLAGS_use_autotune=0
2526

2627
train_af2_single() {
2728
start_step=0
@@ -37,6 +38,7 @@ train_af2_single() {
3738
--start_step=${start_step} \
3839
--train_step=${train_step} \
3940
--precision=${precision} \
41+
--amp_level=${amp_level} \
4042
--num_workers 6 \
4143
--seed 2022 \
4244
--batch_size=$batch_size \
@@ -66,6 +68,7 @@ train_af2_distributed() {
6668
--start_step=${start_step} \
6769
--train_step=${train_step} \
6870
--precision=${precision} \
71+
--amp_level=${amp_level} \
6972
--num_workers 6 \
7073
--seed 2022 \
7174
--batch_size=$batch_size \
@@ -95,6 +98,8 @@ mkdir -p debug_log debug_models
9598
model_name="initial"
9699
precision="bf16"
97100
# precision="fp32"
101+
# amp_level="O1"
102+
amp_level="O2"
98103
log_step="--log_step=20"
99104
eval_step="--eval_step=1000"
100105
save_step="--save_step=1000"
@@ -116,10 +121,13 @@ mkdir -p debug_log debug_models
116121
model_name="finetune"
117122
precision="bf16"
118123
# precision="fp32"
124+
# amp_level="O1"
125+
amp_level="O2"
119126
log_step="--log_step=20"
120127
eval_step="--eval_step=1000"
121128
save_step="--save_step=1000"
122-
# init_model="$root_path/data/pd_params/model_5.pdparams"
129+
# init_model="$root_path/data/params/params_model_1.npz"
130+
# init_model="$root_path/data/pd_params/model_1.pdparams"
123131
train_af2_single
124132
fi
125133
}
@@ -139,6 +147,8 @@ mkdir -p debug_log debug_models
139147
model_name="initial"
140148
precision="bf16"
141149
# precision="fp32"
150+
# amp_level="O1"
151+
amp_level="O2"
142152
log_step="--log_step=20"
143153
eval_step="--eval_step=1000"
144154
save_step="--save_step=1000"
@@ -163,10 +173,13 @@ mkdir -p debug_log debug_models
163173
model_name="finetune"
164174
precision="bf16"
165175
# precision="fp32"
176+
# amp_level="O1"
177+
amp_level="O2"
166178
log_step="--log_step=20"
167179
eval_step="--eval_step=1000"
168180
save_step="--save_step=1000"
169-
# init_model="$root_path/data/pd_params/model_5.pdparams"
181+
# init_model="$root_path/data/params/params_model_1.npz"
182+
# init_model="$root_path/data/pd_params/model_1.pdparams"
170183
train_af2_distributed
171184
fi
172185
}
@@ -186,6 +199,8 @@ mkdir -p debug_log debug_models
186199
model_name="initial"
187200
precision="bf16"
188201
# precision="fp32"
202+
# amp_level="O1"
203+
amp_level="O2"
189204
log_step="--log_step=20"
190205
eval_step="--eval_step=1000"
191206
save_step="--save_step=1000"
@@ -210,10 +225,13 @@ mkdir -p debug_log debug_models
210225
model_name="finetune"
211226
precision="bf16"
212227
# precision="fp32"
228+
# amp_level="O1"
229+
amp_level="O2"
213230
log_step="--log_step=20"
214231
eval_step="--eval_step=1000"
215232
save_step="--save_step=1000"
216-
# init_model="$root_path/data/pd_params/model_5.pdparams"
233+
# init_model="$root_path/data/params/params_model_1.npz"
234+
# init_model="$root_path/data/pd_params/model_1.pdparams"
217235
train_af2_distributed
218236
fi
219237
}
@@ -232,6 +250,8 @@ mkdir -p debug_log debug_models
232250
model_name="initial"
233251
precision="bf16"
234252
# precision="fp32"
253+
# amp_level="O1"
254+
amp_level="O2"
235255
log_step="--log_step=20"
236256
eval_step="--eval_step=1000"
237257
save_step="--save_step=1000"
@@ -253,6 +273,8 @@ mkdir -p debug_log debug_models
253273
model_name="initial"
254274
precision="bf16"
255275
# precision="fp32"
276+
# amp_level="O1"
277+
amp_level="O2"
256278
log_step="--log_step=20"
257279
eval_step="--eval_step=1000"
258280
save_step="--save_step=1000"
@@ -274,6 +296,8 @@ mkdir -p debug_log debug_models
274296
model_name="initial"
275297
precision="bf16"
276298
# precision="fp32"
299+
# amp_level="O1"
300+
amp_level="O2"
277301
log_step="--log_step=20"
278302
eval_step="--eval_step=1000"
279303
save_step="--save_step=1000"

apps/protein_folding/helixfold/run_helixfold.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
from utils.init_env import init_seed, init_distributed_env
3838
from ppfleetx.distributed.protein_folding import dp, dap, bp
39-
from utils.utils import get_bf16_op_list
39+
from utils.utils import get_custom_amp_list
4040

4141
logging.basicConfig()
4242
logger = logging.getLogger(__file__)
@@ -120,8 +120,8 @@ def predict_structure(
120120

121121
def _forward_with_precision(processed_feature_dict):
122122
if args.precision == "bf16":
123-
black_list, white_list = get_bf16_op_list()
124-
with paddle.amp.auto_cast(level='O1', custom_white_list=white_list, custom_black_list=black_list, dtype='bfloat16'):
123+
black_list, white_list = get_custom_amp_list()
124+
with paddle.amp.auto_cast(enable=True, custom_white_list=white_list, custom_black_list=black_list, level=args.amp_level, dtype='bfloat16'):
125125
return model_runner.predict(
126126
processed_feature_dict,
127127
ensemble_representations=True,
@@ -235,7 +235,7 @@ def main(args):
235235

236236
data_dir = pathlib.Path(args.data_dir)
237237
params = f'params_{model_name}'
238-
model_params = data_dir.joinpath('params', f'{params}.pd')
238+
model_params = data_dir.joinpath('params', f'{params}.pdparams')
239239
if not model_params.exists():
240240
model_params = data_dir.joinpath('params', f'{params}.npz')
241241

@@ -356,6 +356,7 @@ def main(args):
356356
help='The random seed for the data pipeline. '
357357
'By default, this is randomly generated.')
358358
parser.add_argument("--precision", type=str, choices=['fp32', 'bf16'], default='fp32')
359+
parser.add_argument("--amp_level", type=str, default='O1')
359360
parser.add_argument('--distributed',
360361
action='store_true', default=False,
361362
help='Whether to use distributed DAP inference.')

apps/protein_folding/helixfold/train.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,23 @@ def eval(args, model, eval_dataset, compute_loss, cache_dir=None):
144144
batch['feat'] = align_feat(batch['feat'], args.dap_degree)
145145
batch['label'] = align_label(batch['label'], args.dap_degree)
146146

147-
if args.precision == "bf16" and args.amp_level == "O2":
148-
black_list, white_list = get_custom_amp_list()
149-
with paddle.amp.auto_cast(enable=True,
150-
custom_white_list=white_list,
151-
custom_black_list=black_list,
152-
level=args.amp_level,
153-
dtype='bfloat16'):
154-
res = model(batch, compute_loss=compute_loss)
155-
else:
156-
res = model(batch, compute_loss=compute_loss)
147+
# inference
148+
def _forward_with_precision(batch):
149+
if args.precision == "bf16":
150+
black_list, white_list = get_custom_amp_list()
151+
with paddle.amp.auto_cast(enable=True,
152+
custom_white_list=white_list,
153+
custom_black_list=black_list,
154+
level=args.amp_level,
155+
dtype='bfloat16'):
156+
return model(batch, compute_loss=compute_loss)
157+
elif args.precision == "fp32":
158+
return model(batch, compute_loss=compute_loss)
159+
else:
160+
raise ValueError("Please choose precision from bf16 and fp32! ")
161+
162+
# res = model(batch, compute_loss=compute_loss)
163+
res = _forward_with_precision(batch)
157164
if compute_loss:
158165
results, loss = res
159166
if loss.dtype == paddle.bfloat16:

0 commit comments

Comments
 (0)