Skip to content

Commit 01f7dbb

Browse files
committed
Optimzie the implementation of StructureModule.
1 parent 850a594 commit 01f7dbb

File tree

1 file changed

+14
-8
lines changed
  • apps/protein_folding/helixfold/alphafold_paddle/model

1 file changed

+14
-8
lines changed

apps/protein_folding/helixfold/alphafold_paddle/model/r3.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -410,22 +410,28 @@ def rots_from_two_vecs(e0_unnormalized: Vecs, e1_unnormalized: Vecs) -> Rots:
410410

411411
def rots_mul_rots(a: Rots, b: Rots) -> Rots:
412412
"""Composition of rotations 'a' and 'b'."""
413-
c0 = rots_mul_vecs(a, Vecs(b.xx, b.yx, b.zx))
414-
c1 = rots_mul_vecs(a, Vecs(b.xy, b.yy, b.zy))
415-
c2 = rots_mul_vecs(a, Vecs(b.xz, b.yz, b.zz))
416-
return Rots(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z)
413+
if a.shape == b.shape:$
414+
return Rots(paddle.matmul(a.rotation, b.rotation))
415+
else:
416+
c0 = rots_mul_vecs(a, Vecs(b.xx, b.yx, b.zx))
417+
c1 = rots_mul_vecs(a, Vecs(b.xy, b.yy, b.zy))
418+
c2 = rots_mul_vecs(a, Vecs(b.xz, b.yz, b.zz))
419+
return Rots(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z)
417420

418421

419422
def rots_mul_vecs(m: Rots, v: Vecs) -> Vecs:
420423
"""Apply rotations 'm' to vectors 'v'."""
421-
return Vecs(m.xx * v.x + m.xy * v.y + m.xz * v.z,
422-
m.yx * v.x + m.yy * v.y + m.yz * v.z,
423-
m.zx * v.x + m.zy * v.y + m.zz * v.z)
424+
v_x = v.x
425+
v_y = v.y
426+
v_z = v.z
427+
return Vecs(m.xx * v_x + m.xy * v_y + m.xz * v_z,
428+
m.yx * v_x + m.yy * v_y + m.yz * v_z,
429+
m.zx * v_x + m.zy * v_y + m.zz * v_z)
424430

425431

426432
def vecs_add(v1: Vecs, v2: Vecs) -> Vecs:
427433
"""Add two vectors 'v1' and 'v2'."""
428-
return Vecs(v1.x + v2.x, v1.y + v2.y, v1.z + v2.z)
434+
return Vecs(v1.translation + v2.translation)
429435

430436

431437
def vecs_dot_vecs(v1: Vecs, v2: Vecs) -> paddle.Tensor:

0 commit comments

Comments
 (0)