Skip to content

Commit e8dd19b

Browse files
authored
Merge pull request #248 from JamesLim-sy/support_O2
Support for BF16-O2
2 parents e20d6ea + 5e1dba5 commit e8dd19b

File tree

6 files changed

+160
-29
lines changed

6 files changed

+160
-29
lines changed

apps/protein_folding/helixfold/alphafold_paddle/model/all_atom.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -868,8 +868,8 @@ def between_residue_clash_loss(
868868
dists_mask *= (residue_index1 < residue_index2)
869869

870870
# Backbone C--N bond between subsequent residues is no clash.
871-
c_one_hot = nn.functional.one_hot(paddle.to_tensor([2]), num_classes=14)
872-
n_one_hot = nn.functional.one_hot(paddle.to_tensor([0]), num_classes=14)
871+
c_one_hot = nn.functional.one_hot(paddle.full(shape=[1], fill_value=2, dtype="int64"), num_classes=14)
872+
n_one_hot = nn.functional.one_hot(paddle.full(shape=[1], fill_value=0, dtype="int64"), num_classes=14)
873873
neighbour_mask = ((residue_index1 + 1) == residue_index2)
874874
tmp_c_one_hot = paddle.unsqueeze(c_one_hot, axis=[1,2,4])
875875
tmp_n_one_hot = paddle.unsqueeze(n_one_hot, axis=[1,2,3])
@@ -879,7 +879,7 @@ def between_residue_clash_loss(
879879

880880
# Disulfide bridge between two cysteines is no clash.
881881
cys_sg_idx = residue_constants.restype_name_to_atom14_names['CYS'].index('SG')
882-
cys_sg_one_hot = nn.functional.one_hot(paddle.to_tensor(cys_sg_idx), num_classes=14)
882+
cys_sg_one_hot = nn.functional.one_hot(paddle.full(shape=[1], fill_value=cys_sg_idx, dtype="int64"), num_classes=14)
883883
cys_sg_one_hot1 = paddle.unsqueeze(cys_sg_one_hot, axis=[1,2,4])
884884
cys_sg_one_hot2 = paddle.unsqueeze(cys_sg_one_hot, axis=[1,2,3])
885885
disulfide_bonds = (cys_sg_one_hot1 * cys_sg_one_hot2)

apps/protein_folding/helixfold/alphafold_paddle/model/modules.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ def __init__(self, channel_num, config, global_config):
290290
}
291291

292292
self.used_heads = []
293+
self.heads = []
293294
for head_name, head_config in sorted(self.config.heads.items()):
294295
if head_name not in Head_modules:
295296
continue
@@ -300,6 +301,7 @@ def __init__(self, channel_num, config, global_config):
300301

301302
head_name_ = Head_names.get(head_name, head_name)
302303
setattr(self, head_name_, module)
304+
self.heads.append(module)
303305

304306
def forward(self,
305307
ensembled_batch,
@@ -2173,7 +2175,7 @@ def forward(self, query_embedding, batch, mask_2d):
21732175
Returns:
21742176
A template embedding [N_res, N_res, c_z].
21752177
"""
2176-
assert mask_2d.dtype == query_embedding.dtype
2178+
assert mask_2d.dtype == query_embedding.dtype, f"mask_2d.dtype ({mask_2d.dtype}) is not the same with query_embedding.dtype ({query_embedding.dtype})!"
21772179
dtype = query_embedding.dtype
21782180
num_res = batch['template_aatype'].shape[1]
21792181
template_mask = batch['template_pseudo_beta_mask']

apps/protein_folding/helixfold/train.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from tensorboardX import SummaryWriter
3131

3232
from utils.utils import get_model_parameter_size, add_to_data_writer, upload_to_hadoop, csv_print
33-
from utils.utils import get_bf16_op_list
33+
from utils.utils import get_custom_amp_list
3434
from utils.metric import ResultsCollect
3535
from utils.model import RunModel
3636
from utils.exponential_moving_average import ExponentialMovingAverage, EMA
@@ -53,7 +53,8 @@ def time_me():
5353
# paddle.device.cuda.synchronize()
5454
return time.time()
5555

56-
def get_optimizer(opt_config, model):
56+
57+
def get_optimizer(args, opt_config, model):
5758
if opt_config.grad_clip == 0:
5859
grad_clip = None
5960
else:
@@ -74,11 +75,13 @@ def get_optimizer(opt_config, model):
7475

7576
parameters = get_fused_param_groups(model, args.dap_degree > 1 or args.bp_degree > 1)
7677

78+
multi_precision = (args.precision == "bf16" and args.amp_level == "O2")
7779
optimizer = paddle.optimizer.Adam(
7880
learning_rate=lr_scheduler,
7981
epsilon=1e-06,
8082
grad_clip=grad_clip,
81-
parameters = parameters
83+
parameters=parameters,
84+
multi_precision=multi_precision,
8285
)
8386
return optimizer, lr_scheduler
8487

@@ -141,9 +144,22 @@ def eval(args, model, eval_dataset, compute_loss, cache_dir=None):
141144
batch['feat'] = align_feat(batch['feat'], args.dap_degree)
142145
batch['label'] = align_label(batch['label'], args.dap_degree)
143146

144-
res = model(batch, compute_loss=compute_loss)
147+
if args.precision == "bf16" and args.amp_level == "O2":
148+
black_list, white_list = get_custom_amp_list()
149+
with paddle.amp.auto_cast(enable=True,
150+
custom_white_list=white_list,
151+
custom_black_list=black_list,
152+
level=args.amp_level,
153+
dtype='bfloat16'):
154+
res = model(batch, compute_loss=compute_loss)
155+
else:
156+
res = model(batch, compute_loss=compute_loss)
145157
if compute_loss:
146158
results, loss = res
159+
if loss.dtype == paddle.bfloat16:
160+
loss = loss.cast("float32").item()
161+
else:
162+
loss = loss.item()
147163
else:
148164
results, loss = res, np.zeros([1])
149165
s2 = time_me()
@@ -218,8 +234,12 @@ def train(args, cur_step, model, train_data_gen, distill_data_gen, train_config,
218234
# train
219235
def _forward_with_precision(batch):
220236
if args.precision == "bf16":
221-
black_list, white_list = get_bf16_op_list()
222-
with paddle.amp.auto_cast(level='O1', custom_white_list=white_list, custom_black_list=black_list, dtype='bfloat16'):
237+
black_list, white_list = get_custom_amp_list()
238+
with paddle.amp.auto_cast(enable=True,
239+
custom_white_list=white_list,
240+
custom_black_list=black_list,
241+
level=args.amp_level,
242+
dtype='bfloat16'):
223243
return model(batch)
224244
elif args.precision == "fp32":
225245
return model(batch)
@@ -250,8 +270,7 @@ def _forward_with_precision(batch):
250270
ema.update()
251271
optimizer.clear_grad()
252272

253-
if args.precision == "bf16":
254-
loss = loss.cast("float32")
273+
loss = loss.cast("float32") if loss.dtype == paddle.bfloat16 else loss
255274

256275
s5 = time_me()
257276
batch_cost = s5 - s0
@@ -283,6 +302,7 @@ def main(args):
283302
set_logging_level(args.logging_level)
284303

285304
"""main function"""
305+
print(f'>>> PaddlePaddle commit: {paddle.version.commit}')
286306
print(f'>>> args:\n{args}')
287307
data_config = ml_collections.ConfigDict(json.load(open(args.data_config, 'r')))
288308
print(f'>>> data_config:\n{data_config}')
@@ -314,7 +334,7 @@ def worker_init_fn(worker_id):
314334
model_config.model.global_config.dist_model = True
315335
if args.bp_degree > 1:
316336
model_config.model.global_config.outer_product_mean_position = 'end'
317-
# print(f'>>> model_config:\n{model_config}')
337+
print(f'>>> model_config:\n{model_config}')
318338

319339
model = RunModel(train_config, model_config)
320340

@@ -377,13 +397,22 @@ def worker_init_fn(worker_id):
377397

378398
model.alphafold.set_state_dict(pd_params)
379399

380-
optimizer, lr_scheduler = get_optimizer(train_config.optimizer, model)
400+
if args.precision == "bf16" and args.amp_level == "O2":
401+
print(f"args.amp_level : {args.amp_level}")
402+
model = paddle.amp.decorate(
403+
models=model,
404+
level=args.amp_level,
405+
dtype='bfloat16',
406+
excluded_layers=model.alphafold.alphafold_iteration.heads
407+
)
408+
409+
optimizer, lr_scheduler = get_optimizer(args, train_config.optimizer, model)
381410
args.grad_clip = train_config.optimizer.grad_clip
382411

383412
# ema = ExponentialMovingAverage(model, 0.999)
384413
ema = EMA(optimizer._param_groups, 0.999)
385414
ema.register()
386-
415+
387416
### load dataset
388417
if not args.only_test:
389418
train_dataset = AF2Dataset(
@@ -473,6 +502,7 @@ def worker_init_fn(worker_id):
473502
for _ in range(cur_step):
474503
lr_scheduler.step()
475504
logging.info('[Main] Start training.')
505+
476506
while True:
477507
# reset train log info
478508
if cur_step == 5:
@@ -484,6 +514,7 @@ def worker_init_fn(worker_id):
484514
# train
485515
train(args, cur_step, model, train_data_gen, distill_data_gen, train_config, model_config, \
486516
lr_scheduler, optimizer, res_collect, train_logger, ema)
517+
487518
if cur_step % args.log_step == 0:
488519
train_results = res_collect.get_result()
489520
train_results['lr'] = lr_scheduler.get_lr()
@@ -522,6 +553,7 @@ def worker_init_fn(worker_id):
522553
parser.add_argument("--model_name", type=str, help='used to choose model config')
523554
parser.add_argument("--init_model", type=str, default='')
524555
parser.add_argument("--precision", type=str, choices=['fp32', 'bf16'], default='fp32')
556+
parser.add_argument("--amp_level", type=str, default='O1')
525557
parser.add_argument("--start_step", type=int, default=0)
526558
parser.add_argument("--train_step", type=int, default=1000)
527559
parser.add_argument("--batch_size", type=int, default=1)

apps/protein_folding/helixfold/utils/exponential_moving_average.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ def register(self):
8989
for p in group['params']:
9090
if p.stop_gradient is True:
9191
continue
92-
self._shadow[id(p)] = paddle.zeros_like(p)
93-
self._shadow[id(p)].set_value(p)
92+
self._shadow[id(p)] = paddle.zeros_like(p, dtype="float32")
93+
self._shadow[id(p)].set_value(p.astype("float32"))
9494

9595
@paddle.no_grad()
9696
def update(self):
@@ -104,7 +104,7 @@ def update(self):
104104
continue
105105
new_val = p.detach().clone()
106106
old_val = self._shadow[id(p)]
107-
new_average = decay * old_val + (1 - decay) * new_val
107+
new_average = decay * old_val + (1 - decay) * new_val.astype("float32")
108108
self._shadow[id(p)] = new_average
109109

110110
self._update_step += 1
@@ -121,7 +121,10 @@ def apply_shadow(self):
121121
assert id(p) in self._shadow
122122

123123
self._backup[id(p)] = p.detach().clone()
124-
p.set_value(self._shadow[id(p)])
124+
if p.dtype == paddle.bfloat16:
125+
p.set_value(self._shadow[id(p)].astype(paddle.bfloat16))
126+
else:
127+
p.set_value(self._shadow[id(p)])
125128

126129
@paddle.no_grad()
127130
def restore(self):
@@ -133,4 +136,4 @@ def restore(self):
133136
continue
134137
assert id(p) in self._backup
135138
p.set_value(self._backup[id(p)])
136-
self._backup = {}
139+
self._backup = {}

apps/protein_folding/helixfold/utils/metric.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,15 @@ def get_result(self):
286286

287287
def _extract_loss_dict(self, results):
288288
"""extract value with 'loss' or 'fape' in key"""
289+
def _calc_tensor_mean(x):
290+
if x.dtype == paddle.bfloat16:
291+
x = x.cast("float32")
292+
if len(x.shape) == 0:
293+
return x.item()
294+
else:
295+
return x.numpy().mean()
296+
289297
res = tree_flatten(results)
290298
res = tree_filter(lambda k: 'loss' in k or 'fape' in k, None, res)
291-
res = tree_map(lambda x: x.numpy().mean(), res)
299+
res = tree_map(lambda x: _calc_tensor_mean(x), res)
292300
return res

apps/protein_folding/helixfold/utils/utils.py

Lines changed: 94 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,107 @@
1919
import numpy as np
2020
import paddle
2121

22-
def get_bf16_op_list():
22+
23+
def get_custom_amp_list():
2324
"""tbd."""
2425

2526
black_list = {"reduce_sum"}
26-
white_list = {"concat", "elementwise_add", "elementwise_div", "elementwise_mul", "elementwise_sub", "fill_any_like", "fill_constant", "gather", "gaussian_random",
27-
"softmax", "layer_norm", "log_softmax", "matmul_v2", "p_norm", "py_layer", "relu", "scale", "sigmoid", "slice", "softplus", "split", "sqrt", "square", "stack",
28-
"sum", "transpose2", "fused_gate_attention", "dropout_nd"}
27+
white_list = {
28+
"concat",
29+
"dropout_nd",
30+
"einsum",
31+
"elementwise_add",
32+
"elementwise_div",
33+
"elementwise_mul",
34+
"elementwise_sub",
35+
"fill_any_like",
36+
"fill_constant",
37+
"fused_gate_attention",
38+
"gather",
39+
"gaussian_random",
40+
"layer_norm",
41+
"log_softmax",
42+
"matmul_v2",
43+
"p_norm",
44+
"py_layer",
45+
"relu",
46+
"scale",
47+
"sigmoid",
48+
"slice",
49+
"softmax",
50+
"softplus",
51+
"split",
52+
"split_with_num",
53+
"sqrt",
54+
"square",
55+
"stack",
56+
"sum",
57+
"transpose2",
58+
"unsqueeze2",
59+
"unstack",
60+
"where"
61+
}
2962
return black_list, white_list
3063

64+
3165
def get_structure_module_bf16_op_list():
32-
black_list = {"reduce_sum", "elementwise_add", "elementwise_div", "elementwise_mul", "elementwise_sub", "fill_any_like", "fill_constant", "gaussian_random", "uniform_random",
33-
"softmax", "log_softmax", "p_norm", "py_layer", "scale", "sigmoid", "softplus", "sqrt", "square", "linspace", "squared_l2_norm", "reduce_mean", "reduce_min", "reduce_prod", "sum", "fused_gate_attention", "dropout_nd", "clip"}
34-
white_list = {"layer_norm", "relu", "split", "stack", "gather", "concat", "transpose2", "matmul_v2", "unsqueeze2", "squeeze2", "tile", "slice", "one_hot_v2", "reshape2", "elementwise_max", "elementwise_min", "equal", "greater_than", "less_than", "reduce_max", "eye", "bitwise_or", "abs", "reduce_max", }
66+
black_list = {
67+
"clip",
68+
"dropout_nd",
69+
"elementwise_add",
70+
"elementwise_div",
71+
"elementwise_mul",
72+
"elementwise_sub",
73+
"fill_any_like",
74+
"fill_constant",
75+
"fused_gate_attention",
76+
"gaussian_random",
77+
"linspace",
78+
"log_softmax",
79+
"p_norm",
80+
"py_layer",
81+
"reduce_mean",
82+
"reduce_min",
83+
"reduce_prod",
84+
"reduce_sum",
85+
"scale",
86+
"sigmoid",
87+
"softmax",
88+
"softplus",
89+
"sqrt",
90+
"square",
91+
"squared_l2_norm",
92+
"sum",
93+
"uniform_random",
94+
}
95+
white_list = {
96+
"abs",
97+
"bitwise_or",
98+
"concat",
99+
"elementwise_max",
100+
"elementwise_min",
101+
"equal",
102+
"eye",
103+
"gather",
104+
"greater_than",
105+
"layer_norm",
106+
"less_than",
107+
"matmul_v2",
108+
"one_hot_v2",
109+
"reduce_max",
110+
"relu",
111+
"reshape2",
112+
"slice",
113+
"split",
114+
"squeeze2",
115+
"stack",
116+
"transpose2",
117+
"unsqueeze2",
118+
"tile",
119+
}
35120
return black_list, white_list
36121

122+
37123
def get_model_parameter_size(model):
38124
"""tbd"""
39125
size = 0
@@ -119,4 +205,4 @@ def csv_print(d):
119205
keys = sorted(list(d.keys()))
120206
values = [str(d[k]) for k in keys]
121207
print(' '.join([str(x) for x in keys]))
122-
print(' '.join([str(x) for x in values]))
208+
print(' '.join([str(x) for x in values]))

0 commit comments

Comments
 (0)