|
7 | 7 | from beartype import beartype |
8 | 8 |
|
9 | 9 | import torch |
10 | | -from torch import nn |
| 10 | +from torch import nn, is_tensor, Tensor |
11 | 11 | import torch.nn.functional as F |
12 | 12 |
|
13 | 13 | from opt_einsum import contract as opt_einsum |
@@ -985,48 +985,53 @@ def basis(self, basis): |
985 | 985 | def device(self): |
986 | 986 | return next(self.parameters()).device |
987 | 987 |
|
| 988 | + @beartype |
988 | 989 | def forward( |
989 | 990 | self, |
990 | | - feats: Union[torch.Tensor, dict[int, torch.Tensor]], |
991 | | - coors: torch.Tensor, |
| 991 | + inputs: Union[Tensor, dict[int, Tensor]], |
| 992 | + coors: Tensor, |
992 | 993 | mask = None, |
993 | 994 | adj_mat = None, |
994 | 995 | edges = None, |
995 | 996 | return_pooled = False, |
996 | 997 | neighbor_mask = None, |
997 | 998 | ): |
998 | | - _mask = mask |
| 999 | + _mask, device = mask, self.device |
999 | 1000 |
|
1000 | 1001 | # apply token embedding and positional embedding to type-0 features |
1001 | 1002 | # (if type-0 feats are passed as a tensor they are expected to be of a flattened shape (batch, seq, n_feats) |
1002 | 1003 | # but if they are passed in a dict (fiber) they are expected to be of a unified shape (batch, seq, n_feats, 1=2*0+1)) |
1003 | 1004 |
|
1004 | | - if torch.is_tensor(feats): |
1005 | | - type0_feats = feats |
1006 | | - else: |
1007 | | - type0_feats = rearrange(feats[0], '... 1 -> ...') |
| 1005 | + if is_tensor(inputs): |
| 1006 | + inputs = {0: inputs} |
| 1007 | + |
| 1008 | + feats = inputs[0] |
| 1009 | + |
| 1010 | + if feats.ndim == 4: |
| 1011 | + feats = rearrange(feats, '... 1 -> ...') |
1008 | 1012 |
|
1009 | 1013 | if exists(self.token_emb): |
1010 | | - type0_feats = self.token_emb(type0_feats) |
| 1014 | + assert feats.ndim == 2 |
| 1015 | + feats = self.token_emb(feats) |
1011 | 1016 |
|
1012 | 1017 | if exists(self.pos_emb): |
1013 | | - assert type0_feats.shape[1] <= self.num_positions, 'feature sequence length must be less than the number of positions given at init' |
1014 | | - type0_feats = type0_feats + self.pos_emb(torch.arange(type0_feats.shape[1], device = type0_feats.device)) |
| 1018 | + seq_len = feats.shape[1] |
| 1019 | + assert seq_len <= self.num_positions, 'feature sequence length must be less than the number of positions given at init' |
1015 | 1020 |
|
1016 | | - type0_feats = self.embedding_grad_frac * type0_feats + (1 - self.embedding_grad_frac) * type0_feats.detach() |
| 1021 | + feats = feats + self.pos_emb(torch.arange(seq_len, device = device)) |
| 1022 | + |
| 1023 | + feats = self.embedding_grad_frac * feats + (1 - self.embedding_grad_frac) * feats.detach() |
1017 | 1024 |
|
1018 | 1025 | assert not (self.has_edges and not exists(edges)), 'edge embedding (num_edge_tokens & edge_dim) must be supplied if one were to train on edge types' |
1019 | 1026 |
|
1020 | | - type0_feats = rearrange(type0_feats, '... -> ... 1') |
1021 | | - if torch.is_tensor(feats): |
1022 | | - feats = {0: type0_feats} |
1023 | | - else: |
1024 | | - feats[0] = type0_feats |
| 1027 | + b, n, d = feats.shape |
| 1028 | + |
| 1029 | + feats = rearrange(feats, 'b n d -> b n d 1') |
1025 | 1030 |
|
1026 | | - b, n, d, *_, device = *feats[0].shape, feats[0].device |
| 1031 | + inputs[0] = feats |
1027 | 1032 |
|
1028 | 1033 | assert d == self.type0_feat_dim, f'feature dimension {d} must be equal to dimension given at init {self.type0_feat_dim}' |
1029 | | - assert set(map(int, feats.keys())) == set(range(self.input_degrees)), f'input must have {self.input_degrees} degree' |
| 1034 | + assert set(map(int, inputs.keys())) == set(range(self.input_degrees)), f'input must have {self.input_degrees} degree' |
1030 | 1035 |
|
1031 | 1036 | num_degrees, neighbors, max_sparse_neighbors, valid_radius = self.num_degrees, self.num_neighbors, self.max_sparse_neighbors, self.valid_radius |
1032 | 1037 |
|
@@ -1159,7 +1164,7 @@ def forward( |
1159 | 1164 |
|
1160 | 1165 | edge_info = EdgeInfo(neighbor_indices, neighbor_mask, edges) |
1161 | 1166 |
|
1162 | | - x = feats |
| 1167 | + x = inputs |
1163 | 1168 |
|
1164 | 1169 | # project in |
1165 | 1170 |
|
|
0 commit comments