Skip to content

Commit 8e47674

Browse files
committed
follow the way unet is conditioned with the relative distances to tensor products
1 parent b74579d commit 8e47674

File tree

2 files changed

+49
-22
lines changed

2 files changed

+49
-22
lines changed

equiformer_pytorch/equiformer_pytorch.py

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def __init__(
191191
project_out = True, # whether to do a project out after the "tensor product"
192192
pool = True,
193193
edge_dim = 0,
194+
radial_hidden_dim = 16,
194195
splits = 4
195196
):
196197
super().__init__()
@@ -209,7 +210,7 @@ def __init__(
209210
self.kernel_unary = nn.ModuleDict()
210211

211212
for (di, mi), (do, mo) in fiber_product(self.fiber_in, self.fiber_out):
212-
self.kernel_unary[f'({di},{do})'] = PairwiseTP(di, mi, do, mo, edge_dim = edge_dim)
213+
self.kernel_unary[f'({di},{do})'] = PairwiseTP(di, mi, do, mo, radial_hidden_dim = radial_hidden_dim, edge_dim = edge_dim)
213214

214215
if self_interaction:
215216
self.self_interact = Linear(fiber_in, fiber_out)
@@ -227,7 +228,6 @@ def forward(
227228
):
228229
splits = self.splits
229230
neighbor_indices, neighbor_masks, edges = edge_info
230-
rel_dist = rearrange(rel_dist, 'b m n -> b m n 1')
231231

232232
kernels = {}
233233
outputs = {}
@@ -309,19 +309,17 @@ def __init__(
309309
in_dim,
310310
out_dim,
311311
edge_dim = None,
312-
mid_dim = 128
312+
mid_dim = 64,
313313
):
314314
super().__init__()
315-
self.num_freq = num_freq
316315
self.in_dim = in_dim
317316
self.mid_dim = mid_dim
318317
self.out_dim = out_dim
319318

319+
edge_dim = default(edge_dim, 0)
320+
320321
self.net = nn.Sequential(
321-
nn.Linear(default(edge_dim, 0) + 1, mid_dim),
322-
nn.SiLU(),
323-
LayerNorm(mid_dim),
324-
nn.Linear(mid_dim, mid_dim),
322+
nn.Linear(edge_dim + mid_dim, mid_dim),
325323
nn.SiLU(),
326324
LayerNorm(mid_dim),
327325
nn.Linear(mid_dim, num_freq * in_dim * out_dim)
@@ -338,7 +336,8 @@ def __init__(
338336
nc_in,
339337
degree_out,
340338
nc_out,
341-
edge_dim = 0
339+
edge_dim = 0,
340+
radial_hidden_dim = 16
342341
):
343342
super().__init__()
344343
self.degree_in = degree_in
@@ -350,7 +349,7 @@ def __init__(
350349
self.d_out = to_order(degree_out)
351350
self.edge_dim = edge_dim
352351

353-
self.rp = RadialFunc(self.num_freq, nc_in, nc_out, edge_dim)
352+
self.rp = RadialFunc(self.num_freq, nc_in, nc_out, mid_dim = radial_hidden_dim, edge_dim = edge_dim)
354353

355354
def forward(self, feat, basis):
356355
R = self.rp(feat)
@@ -426,6 +425,7 @@ def __init__(
426425
attend_self = False,
427426
edge_dim = None,
428427
single_headed_kv = False,
428+
radial_hidden_dim = 16,
429429
splits = 4
430430
):
431431
super().__init__()
@@ -450,7 +450,7 @@ def __init__(
450450
self.prenorm = Norm(fiber)
451451

452452
self.to_q = Linear(fiber, hidden_fiber)
453-
self.to_kv = TP(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = attend_self, splits = splits)
453+
self.to_kv = TP(fiber, kv_hidden_fiber, radial_hidden_dim = radial_hidden_dim, edge_dim = edge_dim, pool = False, self_interaction = attend_self, splits = splits)
454454

455455
self.to_out = Linear(hidden_fiber, fiber)
456456

@@ -521,6 +521,7 @@ def __init__(
521521
single_headed_kv = False,
522522
attn_leakyrelu_slope = 0.1,
523523
attn_hidden_dim_mult = 4,
524+
radial_hidden_dim = 16,
524525
**kwargs
525526
):
526527
super().__init__()
@@ -558,7 +559,7 @@ def __init__(
558559

559560
# main branch tensor product
560561

561-
self.to_attn_and_v = TP(fiber, intermediate_fiber, edge_dim = edge_dim, pool = False, self_interaction = attend_self, splits = splits)
562+
self.to_attn_and_v = TP(fiber, intermediate_fiber, radial_hidden_dim = radial_hidden_dim, edge_dim = edge_dim, pool = False, self_interaction = attend_self, splits = splits)
562563

563564
# non-linear projection of attention branch into the attention logits
564565

@@ -654,6 +655,7 @@ def __init__(
654655
valid_radius = 1e5,
655656
num_neighbors = float('inf'),
656657
reduce_dim_out = False,
658+
radial_hidden_dim = 64,
657659
num_tokens = None,
658660
num_positions = None,
659661
num_edge_tokens = None,
@@ -709,12 +711,21 @@ def __init__(
709711

710712
# edges
711713

712-
assert not (exists(num_edge_tokens) and not exists(edge_dim)), 'edge dimension (edge_dim) must be supplied if SE3 transformer is to have edge tokens'
714+
assert not (exists(num_edge_tokens) and not exists(edge_dim)), 'edge dimension (edge_dim) must be supplied if equiformer is to have edge tokens'
713715

714716
self.edge_emb = nn.Embedding(num_edge_tokens, edge_dim) if exists(num_edge_tokens) else None
715717
self.has_edges = exists(edge_dim) and edge_dim > 0
716718

717-
# whether to differentiate through basis, needed for alphafold2
719+
# initial MLP of relative distances to intermediate representation, akin to time conditioning in ddpm unets
720+
721+
self.to_rel_dist_hidden = nn.Sequential(
722+
nn.Linear(1, radial_hidden_dim),
723+
nn.SiLU(),
724+
LayerNorm(radial_hidden_dim),
725+
nn.Linear(radial_hidden_dim, radial_hidden_dim),
726+
)
727+
728+
# whether to differentiate through basis, needed gradients for iterative refinement
718729

719730
self.differentiable_coors = differentiable_coors
720731

@@ -723,13 +734,15 @@ def __init__(
723734
self.valid_radius = valid_radius
724735
self.num_neighbors = num_neighbors
725736

726-
# define fibers and dimensionality
727-
728-
tp_kwargs = dict(edge_dim = edge_dim, splits = splits)
729-
730737
# main network
731738

732-
self.tp_in = TP(self.dim_in, self.dim, **tp_kwargs)
739+
self.tp_in = TP(
740+
self.dim_in,
741+
self.dim,
742+
edge_dim = edge_dim,
743+
radial_hidden_dim = radial_hidden_dim,
744+
splits = splits
745+
)
733746

734747
# trunk
735748

@@ -739,7 +752,17 @@ def __init__(
739752

740753
for ind in range(depth):
741754
self.layers.append(nn.ModuleList([
742-
attention_klass(self.dim, heads = heads, dim_head = dim_head, attend_self = attend_self, edge_dim = edge_dim, splits = splits, single_headed_kv = single_headed_kv, **kwargs),
755+
attention_klass(
756+
self.dim,
757+
heads = heads,
758+
dim_head = dim_head,
759+
attend_self = attend_self,
760+
edge_dim = edge_dim,
761+
splits = splits,
762+
single_headed_kv = single_headed_kv,
763+
radial_hidden_dim = radial_hidden_dim,
764+
**kwargs
765+
),
743766
FeedForward(self.dim, include_htype_norms = ff_include_htype_norms)
744767
]))
745768

@@ -784,7 +807,7 @@ def forward(
784807

785808
num_degrees, neighbors, valid_radius = self.num_degrees, self.num_neighbors, self.valid_radius
786809

787-
# se3 transformer by default cannot have a node attend to itself
810+
# cannot have a node attend to itself
788811

789812
exclude_self_mask = rearrange(~torch.eye(n, dtype = torch.bool, device = device), 'i j -> 1 i j')
790813
remove_self = lambda t: t.masked_select(exclude_self_mask).reshape(b, n, n - 1)
@@ -853,6 +876,10 @@ def forward(
853876
if exists(edges):
854877
edges = batched_index_select(edges, nearest_indices, dim = 2)
855878

879+
# embed relative distances
880+
881+
neighbor_rel_dist = self.to_rel_dist_hidden(rearrange(neighbor_rel_dist, '... -> ... 1'))
882+
856883
# calculate basis
857884

858885
basis = get_basis(neighbor_rel_pos, num_degrees - 1, differentiable = self.differentiable_coors)

equiformer_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = '0.0.22'
1+
__version__ = '0.0.23'
22

33
__cuda_pkg_name__ = f'equiformer_pytorch_cuda_{__version__.replace(".", "_")}'

0 commit comments

Comments
 (0)