@@ -410,22 +410,28 @@ def rots_from_two_vecs(e0_unnormalized: Vecs, e1_unnormalized: Vecs) -> Rots:
410
410
411
411
def rots_mul_rots (a : Rots , b : Rots ) -> Rots :
412
412
"""Composition of rotations 'a' and 'b'."""
413
- c0 = rots_mul_vecs (a , Vecs (b .xx , b .yx , b .zx ))
414
- c1 = rots_mul_vecs (a , Vecs (b .xy , b .yy , b .zy ))
415
- c2 = rots_mul_vecs (a , Vecs (b .xz , b .yz , b .zz ))
416
- return Rots (c0 .x , c1 .x , c2 .x , c0 .y , c1 .y , c2 .y , c0 .z , c1 .z , c2 .z )
413
+ if a .shape == b .shape :$
414
+ return Rots (paddle .matmul (a .rotation , b .rotation ))
415
+ else :
416
+ c0 = rots_mul_vecs (a , Vecs (b .xx , b .yx , b .zx ))
417
+ c1 = rots_mul_vecs (a , Vecs (b .xy , b .yy , b .zy ))
418
+ c2 = rots_mul_vecs (a , Vecs (b .xz , b .yz , b .zz ))
419
+ return Rots (c0 .x , c1 .x , c2 .x , c0 .y , c1 .y , c2 .y , c0 .z , c1 .z , c2 .z )
417
420
418
421
419
422
def rots_mul_vecs (m : Rots , v : Vecs ) -> Vecs :
420
423
"""Apply rotations 'm' to vectors 'v'."""
421
- return Vecs (m .xx * v .x + m .xy * v .y + m .xz * v .z ,
422
- m .yx * v .x + m .yy * v .y + m .yz * v .z ,
423
- m .zx * v .x + m .zy * v .y + m .zz * v .z )
424
+ v_x = v .x
425
+ v_y = v .y
426
+ v_z = v .z
427
+ return Vecs (m .xx * v_x + m .xy * v_y + m .xz * v_z ,
428
+ m .yx * v_x + m .yy * v_y + m .yz * v_z ,
429
+ m .zx * v_x + m .zy * v_y + m .zz * v_z )
424
430
425
431
426
432
def vecs_add (v1 : Vecs , v2 : Vecs ) -> Vecs :
427
433
"""Add two vectors 'v1' and 'v2'."""
428
- return Vecs (v1 .x + v2 .x , v1 . y + v2 . y , v1 . z + v2 . z )
434
+ return Vecs (v1 .translation + v2 .translation )
429
435
430
436
431
437
def vecs_dot_vecs (v1 : Vecs , v2 : Vecs ) -> paddle .Tensor :
0 commit comments