|
27 | 27 | from equiformer_pytorch.utils import ( |
28 | 28 | exists, |
29 | 29 | default, |
30 | | - batched_index_select, |
31 | 30 | masked_mean, |
32 | 31 | to_order, |
33 | 32 | cast_tuple, |
|
37 | 36 | pad_for_centering_y_to_x |
38 | 37 | ) |
39 | 38 |
|
| 39 | +from einx import get_at |
| 40 | + |
40 | 41 | from einops import rearrange, repeat, reduce, einsum, pack, unpack |
41 | 42 | from einops.layers.torch import Rearrange |
42 | 43 |
|
@@ -336,7 +337,9 @@ def forward( |
336 | 337 |
|
337 | 338 | xi, xj = source[degree_in], target[degree_in] |
338 | 339 |
|
339 | | - x = batched_index_select(xj, neighbor_indices, dim = 1) |
| 340 | + flattened_neighbor_indices, ps = pack_one(neighbor_indices, 'b *') |
| 341 | + x = get_at('b [i] d m, b k -> b k d m', xj, flattened_neighbor_indices) |
| 342 | + x = unpack_one(x, ps, 'b * d m') |
340 | 343 |
|
341 | 344 | if self.project_xi_xj: |
342 | 345 | xi = rearrange(xi, 'b i d m -> b i 1 d m') |
@@ -1215,15 +1218,16 @@ def forward( |
1215 | 1218 | dist_values, nearest_indices = modified_rel_dist.topk(total_neighbors, dim = -1, largest = False) |
1216 | 1219 | neighbor_mask = dist_values <= valid_radius |
1217 | 1220 |
|
1218 | | - neighbor_rel_dist = batched_index_select(rel_dist, nearest_indices, dim = 2) |
1219 | | - neighbor_rel_pos = batched_index_select(rel_pos, nearest_indices, dim = 2) |
1220 | | - neighbor_indices = batched_index_select(indices, nearest_indices, dim = 2) |
| 1221 | + neighbor_rel_dist = get_at('b i [j], b i k -> b i k', rel_dist, nearest_indices) |
| 1222 | + neighbor_rel_pos = get_at('b i [j] c, b i k -> b i k c', rel_pos, nearest_indices) |
| 1223 | + neighbor_indices = get_at('b i [j], b i k -> b i k', indices, nearest_indices) |
1221 | 1224 |
|
1222 | 1225 | if exists(mask): |
1223 | | - neighbor_mask = neighbor_mask & batched_index_select(mask, nearest_indices, dim = 2) |
| 1226 | + nearest_mask = get_at('b i [j], b i k -> b i k', mask, nearest_indices) |
| 1227 | + neighbor_mask = neighbor_mask & nearest_mask |
1224 | 1228 |
|
1225 | 1229 | if exists(edges): |
1226 | | - edges = batched_index_select(edges, nearest_indices, dim = 2) |
| 1230 | + edges = get_at('b i [j] d, b i k -> b i k d', edges, nearest_indices) |
1227 | 1231 |
|
1228 | 1232 | # embed relative distances |
1229 | 1233 |
|
|
0 commit comments