Skip to content

Commit 3d5a1b5

Browse files
committed
fix get_at usage
1 parent f6d36da commit 3d5a1b5

File tree

2 files changed

+2
-4
lines changed

2 files changed

+2
-4
lines changed

equiformer_pytorch/equiformer_pytorch.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,9 +337,7 @@ def forward(
337337

338338
xi, xj = source[degree_in], target[degree_in]
339339

340-
flattened_neighbor_indices, ps = pack_one(neighbor_indices, 'b *')
341-
x = get_at('b [i] d m, b k -> b k d m', xj, flattened_neighbor_indices)
342-
x = unpack_one(x, ps, 'b * d m')
340+
x = get_at('b [i] d m, b j k -> b j k d m', xj, neighbor_indices)
343341

344342
if self.project_xi_xj:
345343
xi = rearrange(xi, 'b i d m -> b i 1 d m')

equiformer_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.5.2'
1+
__version__ = '0.5.3'

0 commit comments

Comments
 (0)