Skip to content

Commit 5e1dba5

Browse files
committed
Fix eval for AMP-O2.
1 parent d89708d commit 5e1dba5

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

apps/protein_folding/helixfold/train.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,22 @@ def eval(args, model, eval_dataset, compute_loss, cache_dir=None):
144144
batch['feat'] = align_feat(batch['feat'], args.dap_degree)
145145
batch['label'] = align_label(batch['label'], args.dap_degree)
146146

147-
res = model(batch, compute_loss=compute_loss)
147+
if args.precision == "bf16" and args.amp_level == "O2":
148+
black_list, white_list = get_custom_amp_list()
149+
with paddle.amp.auto_cast(enable=True,
150+
custom_white_list=white_list,
151+
custom_black_list=black_list,
152+
level=args.amp_level,
153+
dtype='bfloat16'):
154+
res = model(batch, compute_loss=compute_loss)
155+
else:
156+
res = model(batch, compute_loss=compute_loss)
148157
if compute_loss:
149158
results, loss = res
159+
if loss.dtype == paddle.bfloat16:
160+
loss = loss.cast("float32").item()
161+
else:
162+
loss = loss.item()
150163
else:
151164
results, loss = res, np.zeros([1])
152165
s2 = time_me()
@@ -257,8 +270,7 @@ def _forward_with_precision(batch):
257270
ema.update()
258271
optimizer.clear_grad()
259272

260-
if args.precision == "bf16":
261-
loss = loss.cast("float32")
273+
loss = loss.cast("float32") if loss.dtype == paddle.bfloat16 else loss
262274

263275
s5 = time_me()
264276
batch_cost = s5 - s0

apps/protein_folding/helixfold/utils/metric.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,8 @@ def get_result(self):
287287
def _extract_loss_dict(self, results):
288288
"""extract value with 'loss' or 'fape' in key"""
289289
def _calc_tensor_mean(x):
290+
if x.dtype == paddle.bfloat16:
291+
x = x.cast("float32")
290292
if len(x.shape) == 0:
291293
return x.item()
292294
else:

0 commit comments

Comments
 (0)