Skip to content

Commit 1d82f2f

Browse files
committed
some cleanup
1 parent 874fc0e commit 1d82f2f

File tree

4 files changed

+46
-30
lines changed

4 files changed

+46
-30
lines changed

equiformer_pytorch/equiformer_pytorch.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from beartype import beartype
88

99
import torch
10-
from torch import nn
10+
from torch import nn, is_tensor, Tensor
1111
import torch.nn.functional as F
1212

1313
from opt_einsum import contract as opt_einsum
@@ -985,48 +985,53 @@ def basis(self, basis):
985985
def device(self):
986986
return next(self.parameters()).device
987987

988+
@beartype
988989
def forward(
989990
self,
990-
feats: Union[torch.Tensor, dict[int, torch.Tensor]],
991-
coors: torch.Tensor,
991+
inputs: Union[Tensor, dict[int, Tensor]],
992+
coors: Tensor,
992993
mask = None,
993994
adj_mat = None,
994995
edges = None,
995996
return_pooled = False,
996997
neighbor_mask = None,
997998
):
998-
_mask = mask
999+
_mask, device = mask, self.device
9991000

10001001
# apply token embedding and positional embedding to type-0 features
10011002
# (if type-0 feats are passed as a tensor they are expected to be of a flattened shape (batch, seq, n_feats)
10021003
# but if they are passed in a dict (fiber) they are expected to be of a unified shape (batch, seq, n_feats, 1=2*0+1))
10031004

1004-
if torch.is_tensor(feats):
1005-
type0_feats = feats
1006-
else:
1007-
type0_feats = rearrange(feats[0], '... 1 -> ...')
1005+
if is_tensor(inputs):
1006+
inputs = {0: inputs}
1007+
1008+
feats = inputs[0]
1009+
1010+
if feats.ndim == 4:
1011+
feats = rearrange(feats, '... 1 -> ...')
10081012

10091013
if exists(self.token_emb):
1010-
type0_feats = self.token_emb(type0_feats)
1014+
assert feats.ndim == 2
1015+
feats = self.token_emb(feats)
10111016

10121017
if exists(self.pos_emb):
1013-
assert type0_feats.shape[1] <= self.num_positions, 'feature sequence length must be less than the number of positions given at init'
1014-
type0_feats = type0_feats + self.pos_emb(torch.arange(type0_feats.shape[1], device = type0_feats.device))
1018+
seq_len = feats.shape[1]
1019+
assert seq_len <= self.num_positions, 'feature sequence length must be less than the number of positions given at init'
10151020

1016-
type0_feats = self.embedding_grad_frac * type0_feats + (1 - self.embedding_grad_frac) * type0_feats.detach()
1021+
feats = feats + self.pos_emb(torch.arange(seq_len, device = device))
1022+
1023+
feats = self.embedding_grad_frac * feats + (1 - self.embedding_grad_frac) * feats.detach()
10171024

10181025
assert not (self.has_edges and not exists(edges)), 'edge embedding (num_edge_tokens & edge_dim) must be supplied if one were to train on edge types'
10191026

1020-
type0_feats = rearrange(type0_feats, '... -> ... 1')
1021-
if torch.is_tensor(feats):
1022-
feats = {0: type0_feats}
1023-
else:
1024-
feats[0] = type0_feats
1027+
b, n, d = feats.shape
1028+
1029+
feats = rearrange(feats, 'b n d -> b n d 1')
10251030

1026-
b, n, d, *_, device = *feats[0].shape, feats[0].device
1031+
inputs[0] = feats
10271032

10281033
assert d == self.type0_feat_dim, f'feature dimension {d} must be equal to dimension given at init {self.type0_feat_dim}'
1029-
assert set(map(int, feats.keys())) == set(range(self.input_degrees)), f'input must have {self.input_degrees} degree'
1034+
assert set(map(int, inputs.keys())) == set(range(self.input_degrees)), f'input must have {self.input_degrees} degree'
10301035

10311036
num_degrees, neighbors, max_sparse_neighbors, valid_radius = self.num_degrees, self.num_neighbors, self.max_sparse_neighbors, self.valid_radius
10321037

@@ -1159,7 +1164,7 @@ def forward(
11591164

11601165
edge_info = EdgeInfo(neighbor_indices, neighbor_mask, edges)
11611166

1162-
x = feats
1167+
x = inputs
11631168

11641169
# project in
11651170

equiformer_pytorch/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def safe_cat(arr, el, dim):
4848
return el
4949
return torch.cat((arr, el), dim = dim)
5050

51-
def cast_tuple(val, depth):
51+
def cast_tuple(val, depth = 1):
5252
return val if isinstance(val, tuple) else (val,) * depth
5353

5454
def batched_index_select(values, indices, dim = 1):

equiformer_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.3.5'
1+
__version__ = '0.3.6'

tests/test_equivariance.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
import torch
44
from equiformer_pytorch.equiformer_pytorch import Equiformer
55
from equiformer_pytorch.irr_repr import rot
6-
from equiformer_pytorch.utils import torch_default_dtype
6+
7+
from equiformer_pytorch.utils import (
8+
torch_default_dtype,
9+
cast_tuple,
10+
to_order,
11+
exists
12+
)
713

814
# test output shape
915

@@ -35,11 +41,12 @@ def test_equivariance(
3541
l2_dist_attention,
3642
reversible
3743
):
44+
dim_in = cast_tuple(dim_in)
3845

3946
model = Equiformer(
4047
dim = dim,
41-
dim_in=dim_in,
42-
input_degrees = len(dim_in) if isinstance(dim_in, tuple) else 1,
48+
dim_in = dim_in,
49+
input_degrees = len(dim_in),
4350
depth = 2,
4451
l2_dist_attention = l2_dist_attention,
4552
reversible = reversible,
@@ -48,16 +55,20 @@ def test_equivariance(
4855
init_out_zero = False
4956
)
5057

51-
if isinstance(dim_in, tuple):
52-
feats = {deg: torch.randn(1, 32, dim, 2*deg + 1) for deg, dim in enumerate(dim_in)}
53-
else:
54-
feats = torch.randn(1, 32, dim_in)
58+
feats = {deg: torch.randn(1, 32, dim, to_order(deg)) for deg, dim in enumerate(dim_in)}
59+
type0, type1 = feats[0], feats.get(1, None)
5560

5661
coors = torch.randn(1, 32, 3)
5762
mask = torch.ones(1, 32).bool()
5863

5964
R = rot(*torch.randn(3))
60-
_, out1 = model({0: feats[0], 1: feats[1] @ R} if isinstance(feats, dict) else feats, coors @ R, mask)
65+
66+
maybe_rotated_feats = {0: type0}
67+
68+
if exists(type1):
69+
maybe_rotated_feats[1] = type1 @ R
70+
71+
_, out1 = model(maybe_rotated_feats, coors @ R, mask)
6172
out2 = model(feats, coors, mask)[1] @ R
6273

6374
assert torch.allclose(out1, out2, atol = 1e-4), 'is not equivariant'

0 commit comments

Comments
 (0)