Skip to content

Commit 2e33a53

Browse files
committed
test out einstein notation for indexing, using einx.get_at
1 parent f3cb662 commit 2e33a53

File tree

4 files changed

+13
-24
lines changed

4 files changed

+13
-24
lines changed

equiformer_pytorch/equiformer_pytorch.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from equiformer_pytorch.utils import (
2828
exists,
2929
default,
30-
batched_index_select,
3130
masked_mean,
3231
to_order,
3332
cast_tuple,
@@ -37,6 +36,8 @@
3736
pad_for_centering_y_to_x
3837
)
3938

39+
from einx import get_at
40+
4041
from einops import rearrange, repeat, reduce, einsum, pack, unpack
4142
from einops.layers.torch import Rearrange
4243

@@ -336,7 +337,9 @@ def forward(
336337

337338
xi, xj = source[degree_in], target[degree_in]
338339

339-
x = batched_index_select(xj, neighbor_indices, dim = 1)
340+
flattened_neighbor_indices, ps = pack_one(neighbor_indices, 'b *')
341+
x = get_at('b [i] d m, b k -> b k d m', xj, flattened_neighbor_indices)
342+
x = unpack_one(x, ps, 'b * d m')
340343

341344
if self.project_xi_xj:
342345
xi = rearrange(xi, 'b i d m -> b i 1 d m')
@@ -1215,15 +1218,16 @@ def forward(
12151218
dist_values, nearest_indices = modified_rel_dist.topk(total_neighbors, dim = -1, largest = False)
12161219
neighbor_mask = dist_values <= valid_radius
12171220

1218-
neighbor_rel_dist = batched_index_select(rel_dist, nearest_indices, dim = 2)
1219-
neighbor_rel_pos = batched_index_select(rel_pos, nearest_indices, dim = 2)
1220-
neighbor_indices = batched_index_select(indices, nearest_indices, dim = 2)
1221+
neighbor_rel_dist = get_at('b i [j], b i k -> b i k', rel_dist, nearest_indices)
1222+
neighbor_rel_pos = get_at('b i [j] c, b i k -> b i k c', rel_pos, nearest_indices)
1223+
neighbor_indices = get_at('b i [j], b i k -> b i k', indices, nearest_indices)
12211224

12221225
if exists(mask):
1223-
neighbor_mask = neighbor_mask & batched_index_select(mask, nearest_indices, dim = 2)
1226+
nearest_mask = get_at('b i [j], b i k -> b i k', mask, nearest_indices)
1227+
neighbor_mask = neighbor_mask & nearest_mask
12241228

12251229
if exists(edges):
1226-
edges = batched_index_select(edges, nearest_indices, dim = 2)
1230+
edges = get_at('b i [j] d, b i k -> b i k d', edges, nearest_indices)
12271231

12281232
# embed relative distances
12291233

equiformer_pytorch/utils.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -51,22 +51,6 @@ def safe_cat(arr, el, dim):
5151
def cast_tuple(val, depth = 1):
5252
return val if isinstance(val, tuple) else (val,) * depth
5353

54-
def batched_index_select(values, indices, dim = 1):
55-
value_dims = values.shape[(dim + 1):]
56-
values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
57-
indices = indices[(..., *((None,) * len(value_dims)))]
58-
indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
59-
value_expand_len = len(indices_shape) - (dim + 1)
60-
values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
61-
62-
value_expand_shape = [-1] * len(values.shape)
63-
expand_slice = slice(dim, (dim + value_expand_len))
64-
value_expand_shape[expand_slice] = indices.shape[expand_slice]
65-
values = values.expand(*value_expand_shape)
66-
67-
dim += value_expand_len
68-
return values.gather(dim, indices)
69-
7054
def fast_split(arr, splits, dim=0):
7155
axis_len = arr.shape[dim]
7256
splits = min(axis_len, max(splits, 1))

equiformer_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.5.1'
1+
__version__ = '0.5.2'

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
install_requires=[
2525
'beartype',
2626
'einops>=0.6',
27+
'einx',
2728
'filelock',
2829
'opt-einsum',
2930
'taylor-series-linear-attention>=0.1.4',

0 commit comments

Comments
 (0)