@@ -64,7 +64,8 @@ def get_chi_atom_indices():
64
64
atom_indices .append (
65
65
[residue_constants .atom_order [atom ] for atom in chi_angle ])
66
66
for _ in range (4 - len (atom_indices )):
67
- atom_indices .append ([0 , 0 , 0 , 0 ]) # For chi angles not defined on the AA.
67
+ atom_indices .append (
68
+ [0 , 0 , 0 , 0 ]) # For chi angles not defined on the AA.
68
69
chi_atom_indices .append (atom_indices )
69
70
70
71
chi_atom_indices .append ([[0 , 0 , 0 , 0 ]] * 4 ) # For UNKNOWN residue.
@@ -274,8 +275,7 @@ def atom37_to_torsion_angles(
274
275
aatype : paddle .Tensor , # (B, T, N)
275
276
all_atom_pos : paddle .Tensor , # (B, T, N, 37, 3)
276
277
all_atom_mask : paddle .Tensor , # (B, T, N, 37)
277
- placeholder_for_undefined = False ,
278
- ) -> Dict [str , paddle .Tensor ]:
278
+ placeholder_for_undefined = False , ) -> Dict [str , paddle .Tensor ]:
279
279
"""Computes the 7 torsion angles (in sin, cos encoding) for each residue.
280
280
281
281
The 7 torsion angles are in the order
@@ -300,44 +300,57 @@ def atom37_to_torsion_angles(
300
300
"""
301
301
302
302
# Map aatype > 20 to 'Unknown' (20).
303
- aatype = paddle .minimum (aatype .astype ('int' ), paddle .to_tensor ([20 ]).astype ('int' ))
304
-
303
+ aatype = paddle .minimum (
304
+ aatype .astype ('int' ), paddle .full (shape = [1 ], fill_value = 20 , dtype = "int" ))
305
+
305
306
num_batch , num_temp , num_res = aatype .shape
306
307
307
308
# Compute the backbone angles.
308
309
pad = paddle .zeros ([num_batch , num_temp , 1 , 37 , 3 ])
309
- prev_all_atom_pos = paddle .concat ([pad , all_atom_pos [..., :- 1 , :, :]], axis = - 3 )
310
+ prev_all_atom_pos = paddle .concat (
311
+ [pad , all_atom_pos [..., :- 1 , :, :]], axis = - 3 )
310
312
311
313
pad = paddle .zeros ([num_batch , num_temp , 1 , 37 ])
312
- prev_all_atom_mask = paddle .concat ([pad , all_atom_mask [..., :- 1 , :]], axis = - 2 )
314
+ prev_all_atom_mask = paddle .concat (
315
+ [pad , all_atom_mask [..., :- 1 , :]], axis = - 2 )
313
316
314
317
# For each torsion angle collect the 4 atom positions that define this angle.
315
318
# shape (B, T, N, atoms=4, xyz=3)
316
319
pre_omega_atom_pos = paddle .concat (
317
- [prev_all_atom_pos [..., 1 :3 , :], # prev CA, C
318
- all_atom_pos [..., 0 :2 , :] # this N, CA
319
- ], axis = - 2 )
320
+ [
321
+ prev_all_atom_pos [..., 1 :3 , :], # prev CA, C
322
+ all_atom_pos [..., 0 :2 , :] # this N, CA
323
+ ],
324
+ axis = - 2 )
320
325
321
326
phi_atom_pos = paddle .concat (
322
- [prev_all_atom_pos [..., 2 :3 , :], # prev C
323
- all_atom_pos [..., 0 :3 , :] # this N, CA, C
324
- ], axis = - 2 )
327
+ [
328
+ prev_all_atom_pos [..., 2 :3 , :], # prev C
329
+ all_atom_pos [..., 0 :3 , :] # this N, CA, C
330
+ ],
331
+ axis = - 2 )
325
332
326
333
psi_atom_pos = paddle .concat (
327
- [all_atom_pos [..., 0 :3 , :], # this N, CA, C
328
- all_atom_pos [..., 4 :5 , :] # this O
329
- ], axis = - 2 )
334
+ [
335
+ all_atom_pos [..., 0 :3 , :], # this N, CA, C
336
+ all_atom_pos [..., 4 :5 , :] # this O
337
+ ],
338
+ axis = - 2 )
330
339
331
340
# Collect the masks from these atoms.
332
341
# Shape [batch, n_temp, num_res]
333
342
pre_omega_mask = (
334
- paddle .prod (prev_all_atom_mask [..., 1 :3 ], axis = - 1 ) # prev CA, C
335
- * paddle .prod (all_atom_mask [..., 0 :2 ], axis = - 1 )) # this N, CA
343
+ paddle .prod (
344
+ prev_all_atom_mask [..., 1 :3 ], axis = - 1 ) # prev CA, C
345
+ * paddle .prod (
346
+ all_atom_mask [..., 0 :2 ], axis = - 1 )) # this N, CA
336
347
phi_mask = (
337
348
prev_all_atom_mask [..., 2 ] # prev C
338
- * paddle .prod (all_atom_mask [..., 0 :3 ], axis = - 1 )) # this N, CA, C
349
+ * paddle .prod (
350
+ all_atom_mask [..., 0 :3 ], axis = - 1 )) # this N, CA, C
339
351
psi_mask = (
340
- paddle .prod (all_atom_mask [..., 0 :3 ], axis = - 1 ) * # this N, CA, C
352
+ paddle .prod (
353
+ all_atom_mask [..., 0 :3 ], axis = - 1 ) * # this N, CA, C
341
354
all_atom_mask [..., 4 ]) # this O
342
355
343
356
# Collect the atoms for the chi-angles.
@@ -375,18 +388,18 @@ def atom37_to_torsion_angles(
375
388
# Stack all torsion angle atom positions.
376
389
# Shape (B, T, N, torsions=7, atoms=4, xyz=3)
377
390
torsions_atom_pos = paddle .concat (
378
- [pre_omega_atom_pos [:, :, :, None , :, :],
379
- phi_atom_pos [:, :, :, None , :, :],
380
- psi_atom_pos [:, :, :, None , :, :],
391
+ [pre_omega_atom_pos . unsqueeze ( axis = - 3 ), # [:, :, :, None, :, :]
392
+ phi_atom_pos . unsqueeze ( axis = - 3 ), # [:, :, :, None, :, :]
393
+ psi_atom_pos . unsqueeze ( axis = - 3 ), # [:, :, :, None, :, :]
381
394
chis_atom_pos
382
395
], axis = 3 )
383
396
384
397
# Stack up masks for all torsion angles.
385
398
# shape (B, T, N, torsions=7)
386
399
torsion_angles_mask = paddle .concat (
387
- [pre_omega_mask [..., None ],
388
- phi_mask [..., None ],
389
- psi_mask [..., None ],
400
+ [pre_omega_mask . unsqueeze ( axis = - 1 ), # [..., None]
401
+ phi_mask . unsqueeze ( axis = - 1 ), # [..., None]
402
+ psi_mask . unsqueeze ( axis = - 1 ), # [..., None]
390
403
chis_mask
391
404
], axis = - 1 )
392
405
@@ -417,7 +430,7 @@ def atom37_to_torsion_angles(
417
430
418
431
# Mirror psi, because we computed it from the Oxygen-atom.
419
432
torsion_angles_sin_cos *= paddle .to_tensor (
420
- [1. , 1. , - 1. , 1. , 1. , 1. , 1. ])[None , None , None , :, None ]
433
+ [1. , 1. , - 1. , 1. , 1. , 1. , 1. ]). reshape ([ 1 , 1 , 1 , 7 , 1 ]) # [None, None, None, :, None]
421
434
422
435
# Create alternative angles for ambiguous atom names.
423
436
chi_is_ambiguous = utils .batched_gather (
@@ -428,7 +441,7 @@ def atom37_to_torsion_angles(
428
441
1.0 - 2.0 * chi_is_ambiguous ], axis = - 1 )
429
442
# mirror_torsion_angles (B, T, N, torsions=7)
430
443
alt_torsion_angles_sin_cos = (
431
- torsion_angles_sin_cos * mirror_torsion_angles [:, :, :, :, None ] )
444
+ torsion_angles_sin_cos * mirror_torsion_angles . unsqueeze ( axis = - 1 ) )
432
445
433
446
if placeholder_for_undefined :
434
447
# Add placeholder torsions in place of undefined torsion angles
@@ -437,10 +450,8 @@ def atom37_to_torsion_angles(
437
450
paddle .ones (torsion_angles_sin_cos .shape [:- 1 ]),
438
451
paddle .zeros (torsion_angles_sin_cos .shape [:- 1 ])
439
452
], axis = - 1 )
440
- torsion_angles_sin_cos = torsion_angles_sin_cos * torsion_angles_mask [
441
- ..., None ] + placeholder_torsions * (1 - torsion_angles_mask [..., None ])
442
- alt_torsion_angles_sin_cos = alt_torsion_angles_sin_cos * torsion_angles_mask [
443
- ..., None ] + placeholder_torsions * (1 - torsion_angles_mask [..., None ])
453
+ torsion_angles_sin_cos = torsion_angles_sin_cos * torsion_angles_mask .unsqueeze (axis = - 1 ) + placeholder_torsions * (1 - torsion_angles_mask .unsqueeze (axis = - 1 ))
454
+ alt_torsion_angles_sin_cos = alt_torsion_angles_sin_cos * torsion_angles_mask .unsqueeze (axis = - 1 ) + placeholder_torsions * (1 - torsion_angles_mask .unsqueeze (axis = - 1 ))
444
455
445
456
return {
446
457
'torsion_angles_sin_cos' : torsion_angles_sin_cos , # (B, T, N, 7, 2)
@@ -579,7 +590,7 @@ def frames_and_literature_positions_to_atom14_pos(
579
590
"""
580
591
# Pick the appropriate transform for every atom.
581
592
restype_atom14_to_rigid_group = paddle .to_tensor (
582
- residue_constants .restype_atom14_to_rigid_group )[ None , ...]
593
+ residue_constants .restype_atom14_to_rigid_group ). unsqueeze ( axis = 0 )
583
594
584
595
# [1, 21, 14] -> # [n_batch, 21, 14]
585
596
n_batch = aatype .shape [0 ]
@@ -612,7 +623,7 @@ def _convert(x, y):
612
623
# Gather the literature atom positions for each residue.
613
624
# r3.Vecs with shape (B, N, 14)
614
625
restype_atom14_rigid_group_positions = paddle .to_tensor (
615
- residue_constants .restype_atom14_rigid_group_positions )[ None , ...]
626
+ residue_constants .restype_atom14_rigid_group_positions ). unsqueeze ( axis = 0 )
616
627
# [1, 21, 14, 3] -> [B, 21, 14, 3]
617
628
if n_batch > 1 :
618
629
restype_atom14_rigid_group_positions = paddle .tile (
@@ -629,7 +640,7 @@ def _convert(x, y):
629
640
630
641
# Mask out non-existing atoms.
631
642
restype_atom14_mask = paddle .to_tensor (
632
- residue_constants .restype_atom14_mask )[ None , ...]
643
+ residue_constants .restype_atom14_mask ). unsqueeze ( axis = 0 )
633
644
# [1, 21, 14] -> [B, 21, 14]
634
645
if n_batch > 1 :
635
646
restype_atom14_mask = paddle .tile (
0 commit comments