Skip to content

Commit 3c913c7

Browse files
committed
small cleanup
1 parent cab5dcf commit 3c913c7

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

equiformer_pytorch/equiformer_pytorch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
pad_for_centering_y_to_x
3636
)
3737

38-
from einops import rearrange, repeat, einsum, pack, unpack
38+
from einops import rearrange, repeat, reduce, einsum, pack, unpack
3939
from einops.layers.torch import Rearrange
4040

4141
# constants
@@ -1039,7 +1039,7 @@ def forward(
10391039
degree = ind + 2
10401040

10411041
next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0
1042-
next_degree_mask = (next_degree_adj_mat.float() - adj_mat.float()).bool()
1042+
next_degree_mask = next_degree_adj_mat & ~adj_mat
10431043
adj_indices = adj_indices.masked_fill(next_degree_mask, degree)
10441044
adj_mat = next_degree_adj_mat.clone()
10451045

@@ -1056,7 +1056,7 @@ def forward(
10561056
adj_mat = remove_self(adj_mat)
10571057

10581058
adj_mat_values = adj_mat.float()
1059-
adj_mat_max_neighbors = adj_mat_values.sum(dim = -1).max().item()
1059+
adj_mat_max_neighbors = reduce(adj_mat_values, '... i j -> ... i', 'sum').amax().item()
10601060

10611061
if max_sparse_neighbors < adj_mat_max_neighbors:
10621062
eps = 1e-2

equiformer_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.3.3'
1+
__version__ = '0.3.4'

0 commit comments

Comments
 (0)