Skip to content

Commit 5810879

Browse files
authored
Merge pull request #251 from Xreki/opt/einsum_slice
Optimize the use of getitem (slice) and einsum.
2 parents e93c3e9 + c88d6c5 commit 5810879

File tree

8 files changed

+206
-133
lines changed

8 files changed

+206
-133
lines changed

apps/protein_folding/helixfold/alphafold_paddle/model/all_atom.py

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ def get_chi_atom_indices():
6464
atom_indices.append(
6565
[residue_constants.atom_order[atom] for atom in chi_angle])
6666
for _ in range(4 - len(atom_indices)):
67-
atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA.
67+
atom_indices.append(
68+
[0, 0, 0, 0]) # For chi angles not defined on the AA.
6869
chi_atom_indices.append(atom_indices)
6970

7071
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
@@ -274,8 +275,7 @@ def atom37_to_torsion_angles(
274275
aatype: paddle.Tensor, # (B, T, N)
275276
all_atom_pos: paddle.Tensor, # (B, T, N, 37, 3)
276277
all_atom_mask: paddle.Tensor, # (B, T, N, 37)
277-
placeholder_for_undefined=False,
278-
) -> Dict[str, paddle.Tensor]:
278+
placeholder_for_undefined=False, ) -> Dict[str, paddle.Tensor]:
279279
"""Computes the 7 torsion angles (in sin, cos encoding) for each residue.
280280
281281
The 7 torsion angles are in the order
@@ -300,44 +300,57 @@ def atom37_to_torsion_angles(
300300
"""
301301

302302
# Map aatype > 20 to 'Unknown' (20).
303-
aatype = paddle.minimum(aatype.astype('int'), paddle.to_tensor([20]).astype('int'))
304-
303+
aatype = paddle.minimum(
304+
aatype.astype('int'), paddle.full(shape=[1], fill_value=20, dtype="int"))
305+
305306
num_batch, num_temp, num_res = aatype.shape
306307

307308
# Compute the backbone angles.
308309
pad = paddle.zeros([num_batch, num_temp, 1, 37, 3])
309-
prev_all_atom_pos = paddle.concat([pad, all_atom_pos[..., :-1, :, :]], axis=-3)
310+
prev_all_atom_pos = paddle.concat(
311+
[pad, all_atom_pos[..., :-1, :, :]], axis=-3)
310312

311313
pad = paddle.zeros([num_batch, num_temp, 1, 37])
312-
prev_all_atom_mask = paddle.concat([pad, all_atom_mask[..., :-1, :]], axis=-2)
314+
prev_all_atom_mask = paddle.concat(
315+
[pad, all_atom_mask[..., :-1, :]], axis=-2)
313316

314317
# For each torsion angle collect the 4 atom positions that define this angle.
315318
# shape (B, T, N, atoms=4, xyz=3)
316319
pre_omega_atom_pos = paddle.concat(
317-
[prev_all_atom_pos[..., 1:3, :], # prev CA, C
318-
all_atom_pos[..., 0:2, :] # this N, CA
319-
], axis=-2)
320+
[
321+
prev_all_atom_pos[..., 1:3, :], # prev CA, C
322+
all_atom_pos[..., 0:2, :] # this N, CA
323+
],
324+
axis=-2)
320325

321326
phi_atom_pos = paddle.concat(
322-
[prev_all_atom_pos[..., 2:3, :], # prev C
323-
all_atom_pos[..., 0:3, :] # this N, CA, C
324-
], axis=-2)
327+
[
328+
prev_all_atom_pos[..., 2:3, :], # prev C
329+
all_atom_pos[..., 0:3, :] # this N, CA, C
330+
],
331+
axis=-2)
325332

326333
psi_atom_pos = paddle.concat(
327-
[all_atom_pos[..., 0:3, :], # this N, CA, C
328-
all_atom_pos[..., 4:5, :] # this O
329-
], axis=-2)
334+
[
335+
all_atom_pos[..., 0:3, :], # this N, CA, C
336+
all_atom_pos[..., 4:5, :] # this O
337+
],
338+
axis=-2)
330339

331340
# Collect the masks from these atoms.
332341
# Shape [batch, n_temp, num_res]
333342
pre_omega_mask = (
334-
paddle.prod(prev_all_atom_mask[..., 1:3], axis=-1) # prev CA, C
335-
* paddle.prod(all_atom_mask[..., 0:2], axis=-1)) # this N, CA
343+
paddle.prod(
344+
prev_all_atom_mask[..., 1:3], axis=-1) # prev CA, C
345+
* paddle.prod(
346+
all_atom_mask[..., 0:2], axis=-1)) # this N, CA
336347
phi_mask = (
337348
prev_all_atom_mask[..., 2] # prev C
338-
* paddle.prod(all_atom_mask[..., 0:3], axis=-1)) # this N, CA, C
349+
* paddle.prod(
350+
all_atom_mask[..., 0:3], axis=-1)) # this N, CA, C
339351
psi_mask = (
340-
paddle.prod(all_atom_mask[..., 0:3], axis=-1) * # this N, CA, C
352+
paddle.prod(
353+
all_atom_mask[..., 0:3], axis=-1) * # this N, CA, C
341354
all_atom_mask[..., 4]) # this O
342355

343356
# Collect the atoms for the chi-angles.
@@ -375,18 +388,18 @@ def atom37_to_torsion_angles(
375388
# Stack all torsion angle atom positions.
376389
# Shape (B, T, N, torsions=7, atoms=4, xyz=3)
377390
torsions_atom_pos = paddle.concat(
378-
[pre_omega_atom_pos[:, :, :, None, :, :],
379-
phi_atom_pos[:, :, :, None, :, :],
380-
psi_atom_pos[:, :, :, None, :, :],
391+
[pre_omega_atom_pos.unsqueeze(axis=-3), # [:, :, :, None, :, :]
392+
phi_atom_pos.unsqueeze(axis=-3), # [:, :, :, None, :, :]
393+
psi_atom_pos.unsqueeze(axis=-3), # [:, :, :, None, :, :]
381394
chis_atom_pos
382395
], axis=3)
383396

384397
# Stack up masks for all torsion angles.
385398
# shape (B, T, N, torsions=7)
386399
torsion_angles_mask = paddle.concat(
387-
[pre_omega_mask[..., None],
388-
phi_mask[..., None],
389-
psi_mask[..., None],
400+
[pre_omega_mask.unsqueeze(axis=-1), # [..., None]
401+
phi_mask.unsqueeze(axis=-1), # [..., None]
402+
psi_mask.unsqueeze(axis=-1), # [..., None]
390403
chis_mask
391404
], axis=-1)
392405

@@ -417,7 +430,7 @@ def atom37_to_torsion_angles(
417430

418431
# Mirror psi, because we computed it from the Oxygen-atom.
419432
torsion_angles_sin_cos *= paddle.to_tensor(
420-
[1., 1., -1., 1., 1., 1., 1.])[None, None, None, :, None]
433+
[1., 1., -1., 1., 1., 1., 1.]).reshape([1, 1, 1, 7, 1]) # [None, None, None, :, None]
421434

422435
# Create alternative angles for ambiguous atom names.
423436
chi_is_ambiguous = utils.batched_gather(
@@ -428,7 +441,7 @@ def atom37_to_torsion_angles(
428441
1.0 - 2.0 * chi_is_ambiguous], axis=-1)
429442
# mirror_torsion_angles (B, T, N, torsions=7)
430443
alt_torsion_angles_sin_cos = (
431-
torsion_angles_sin_cos * mirror_torsion_angles[:, :, :, :, None])
444+
torsion_angles_sin_cos * mirror_torsion_angles.unsqueeze(axis=-1))
432445

433446
if placeholder_for_undefined:
434447
# Add placeholder torsions in place of undefined torsion angles
@@ -437,10 +450,8 @@ def atom37_to_torsion_angles(
437450
paddle.ones(torsion_angles_sin_cos.shape[:-1]),
438451
paddle.zeros(torsion_angles_sin_cos.shape[:-1])
439452
], axis=-1)
440-
torsion_angles_sin_cos = torsion_angles_sin_cos * torsion_angles_mask[
441-
..., None] + placeholder_torsions * (1 - torsion_angles_mask[..., None])
442-
alt_torsion_angles_sin_cos = alt_torsion_angles_sin_cos * torsion_angles_mask[
443-
..., None] + placeholder_torsions * (1 - torsion_angles_mask[..., None])
453+
torsion_angles_sin_cos = torsion_angles_sin_cos * torsion_angles_mask.unsqueeze(axis=-1) + placeholder_torsions * (1 - torsion_angles_mask.unsqueeze(axis=-1))
454+
alt_torsion_angles_sin_cos = alt_torsion_angles_sin_cos * torsion_angles_mask.unsqueeze(axis=-1) + placeholder_torsions * (1 - torsion_angles_mask.unsqueeze(axis=-1))
444455

445456
return {
446457
'torsion_angles_sin_cos': torsion_angles_sin_cos, # (B, T, N, 7, 2)
@@ -579,7 +590,7 @@ def frames_and_literature_positions_to_atom14_pos(
579590
"""
580591
# Pick the appropriate transform for every atom.
581592
restype_atom14_to_rigid_group = paddle.to_tensor(
582-
residue_constants.restype_atom14_to_rigid_group)[None, ...]
593+
residue_constants.restype_atom14_to_rigid_group).unsqueeze(axis=0)
583594

584595
# [1, 21, 14] -> # [n_batch, 21, 14]
585596
n_batch = aatype.shape[0]
@@ -612,7 +623,7 @@ def _convert(x, y):
612623
# Gather the literature atom positions for each residue.
613624
# r3.Vecs with shape (B, N, 14)
614625
restype_atom14_rigid_group_positions = paddle.to_tensor(
615-
residue_constants.restype_atom14_rigid_group_positions)[None, ...]
626+
residue_constants.restype_atom14_rigid_group_positions).unsqueeze(axis=0)
616627
# [1, 21, 14, 3] -> [B, 21, 14, 3]
617628
if n_batch > 1:
618629
restype_atom14_rigid_group_positions = paddle.tile(
@@ -629,7 +640,7 @@ def _convert(x, y):
629640

630641
# Mask out non-existing atoms.
631642
restype_atom14_mask = paddle.to_tensor(
632-
residue_constants.restype_atom14_mask)[None, ...]
643+
residue_constants.restype_atom14_mask).unsqueeze(axis=0)
633644
# [1, 21, 14] -> [B, 21, 14]
634645
if n_batch > 1:
635646
restype_atom14_mask = paddle.tile(

apps/protein_folding/helixfold/alphafold_paddle/model/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,9 @@ def model_config(name: str) -> ml_collections.ConfigDict:
404404
'use_remat': False,
405405
'zero_init': True,
406406
'low_memory': False,
407+
'fuse_linear': False,
407408
'fuse_attention': True,
409+
'use_flash_attn': True,
408410
'use_dropout_nd': True,
409411
'outer_product_mean_position': 'origin', # 'origin' or 'middle', 'first', 'end', set 'end' if use BP
410412
},

apps/protein_folding/helixfold/alphafold_paddle/model/folding.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def __init__(self, channel_num, config, global_config,
5151
self.global_config = global_config
5252
self.dist_epsilon = dist_epsilon
5353

54+
Linear = paddle.incubate.nn.FusedLinear if self.global_config.fuse_linear else paddle.nn.Linear
55+
5456
num_head = self.config.num_head
5557
num_scalar_qk = self.config.num_scalar_qk
5658
num_point_qk = self.config.num_point_qk
@@ -62,15 +64,15 @@ def __init__(self, channel_num, config, global_config,
6264
assert num_point_qk > 0
6365
assert num_point_v > 0
6466

65-
self.q_scalar = nn.Linear(
67+
self.q_scalar = Linear(
6668
channel_num['seq_channel'], num_head * num_scalar_qk)
67-
self.kv_scalar = nn.Linear(
69+
self.kv_scalar = Linear(
6870
channel_num['seq_channel'],
6971
num_head * (num_scalar_v + num_scalar_qk))
7072

71-
self.q_point_local = nn.Linear(
73+
self.q_point_local = Linear(
7274
channel_num['seq_channel'], num_head * 3 * num_point_qk)
73-
self.kv_point_local = nn.Linear(
75+
self.kv_point_local = Linear(
7476
channel_num['seq_channel'],
7577
num_head * 3 * (num_point_qk + num_point_v))
7678

@@ -79,15 +81,15 @@ def __init__(self, channel_num, config, global_config,
7981
[num_head], 'float32',
8082
default_initializer=nn.initializer.Constant(tpw))
8183

82-
self.attention_2d = nn.Linear(channel_num['pair_channel'], num_head)
84+
self.attention_2d = Linear(channel_num['pair_channel'], num_head)
8385

8486
if self.global_config.zero_init:
8587
init_w = nn.initializer.Constant(value=0.0)
8688
else:
8789
init_w = nn.initializer.XavierUniform()
8890

8991
c = num_scalar_v + num_point_v * 4 + channel_num['pair_channel']
90-
self.output_projection = nn.Linear(
92+
self.output_projection = Linear(
9193
num_head * c, num_output,
9294
weight_attr=paddle.ParamAttr(initializer=init_w))
9395

@@ -257,6 +259,8 @@ def __init__(self, channel_num, config, global_config):
257259
self.config = config
258260
self.global_config = global_config
259261

262+
Linear = paddle.incubate.nn.FusedLinear if self.global_config.fuse_linear else paddle.nn.Linear
263+
260264
self.invariant_point_attention = InvariantPointAttention(
261265
channel_num, config, global_config)
262266
self.attention_layer_norm = nn.LayerNorm(channel_num['seq_channel'])
@@ -273,7 +277,7 @@ def __init__(self, channel_num, config, global_config):
273277
if i > 0:
274278
layer_name, c_in = f'transition_{i}', self.config.num_channel
275279

276-
setattr(self, layer_name, nn.Linear(
280+
setattr(self, layer_name, Linear(
277281
c_in, self.config.num_channel,
278282
weight_attr=paddle.ParamAttr(initializer=init_w)))
279283

@@ -287,7 +291,7 @@ def __init__(self, channel_num, config, global_config):
287291
last_init_w = nn.initializer.XavierUniform()
288292

289293
# Jumper et al. (2021) Alg. 23 "Backbone update"
290-
self.affine_update = nn.Linear(
294+
self.affine_update = Linear(
291295
self.config.num_channel, 6,
292296
weight_attr=paddle.ParamAttr(initializer=last_init_w))
293297

@@ -349,8 +353,10 @@ def __init__(self, channel_num, config, global_config):
349353
self.config = config
350354
self.global_config = global_config
351355

356+
Linear = paddle.incubate.nn.FusedLinear if self.global_config.fuse_linear else paddle.nn.Linear
357+
352358
self.single_layer_norm = nn.LayerNorm(channel_num['seq_channel'])
353-
self.initial_projection = nn.Linear(
359+
self.initial_projection = Linear(
354360
channel_num['seq_channel'], config.num_channel)
355361
self.pair_layer_norm = nn.LayerNorm(channel_num['pair_channel'])
356362

@@ -861,10 +867,10 @@ def supervised_chi_loss(ret, batch, value, config):
861867

862868
residue_type_one_hot = paddle.nn.functional.one_hot(batch['aatype_index'],
863869
num_classes=residue_constants.restype_num + 1)
864-
chi_pi_periodic = paddle.einsum('nijk,nkl->nijl', residue_type_one_hot[:, None, ...],
870+
chi_pi_periodic = paddle.einsum('nijk,nkl->nijl', residue_type_one_hot.unsqueeze(axis=1),
865871
paddle.to_tensor(residue_constants.chi_pi_periodic)[None])
866872

867-
sin_cos_true_chi = batch['chi_angles_sin_cos'][:, None, ...]
873+
sin_cos_true_chi = batch['chi_angles_sin_cos'].unsqueeze(axis=1) # [:, None, ...]
868874

869875
# This is -1 if chi is pi-periodic and +1 if it's 2pi-periodic
870876
shifted_mask = (1 - 2 * chi_pi_periodic)[..., None]
@@ -913,7 +919,7 @@ def l2_normalize(x, axis=-1, epsilon=1e-12):
913919
return x / paddle.sqrt(
914920
paddle.maximum(
915921
paddle.sum(paddle.square(x), axis=axis, keepdim=True),
916-
paddle.to_tensor([epsilon], dtype='float32')))
922+
paddle.full(shape=[1], fill_value=epsilon, dtype='float32')))
917923

918924

919925
class MultiRigidSidechain(nn.Layer):
@@ -925,9 +931,11 @@ def __init__(self, channel_num, config, global_config):
925931
self.config = config
926932
self.global_config = global_config
927933

934+
Linear = paddle.incubate.nn.FusedLinear if self.global_config.fuse_linear else paddle.nn.Linear
935+
928936
c = self.config.num_channel
929-
self.input_projection = nn.Linear(channel_num['seq_channel'], c)
930-
self.input_projection_1 = nn.Linear(channel_num['seq_channel'], c)
937+
self.input_projection = Linear(channel_num['seq_channel'], c)
938+
self.input_projection_1 = Linear(channel_num['seq_channel'], c)
931939

932940
for i in range(self.config.num_residual_block):
933941
l1, l2 = 'resblock1', 'resblock2'
@@ -940,12 +948,12 @@ def __init__(self, channel_num, config, global_config):
940948
else:
941949
init_w_2 = nn.initializer.XavierUniform()
942950

943-
setattr(self, l1, nn.Linear(
951+
setattr(self, l1, Linear(
944952
c, c, weight_attr=paddle.ParamAttr(initializer=init_w_1)))
945-
setattr(self, l2, nn.Linear(
953+
setattr(self, l2, Linear(
946954
c, c, weight_attr=paddle.ParamAttr(initializer=init_w_2)))
947955

948-
self.unnormalized_angles = nn.Linear(c, 14)
956+
self.unnormalized_angles = Linear(c, 14)
949957

950958
def forward(self, affine, single_act, init_single_act, aatype):
951959
single_act = self.input_projection(nn.functional.relu(single_act))
@@ -997,4 +1005,4 @@ def forward(self, affine, single_act, init_single_act, aatype):
9971005
# 'frames': all_frames_to_global, # (B, N, 8, 3, 3)
9981006
# })
9991007

1000-
return outputs
1008+
return outputs

0 commit comments

Comments
 (0)