Skip to content

Commit cab5dcf

Browse files
committed
reduce out dimension in edges readme example
1 parent 80000f6 commit cab5dcf

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

README.md

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ model = Equiformer(
8585
edge_dim = 16, # dimension of edge embedding
8686
depth = 2,
8787
input_degrees = 1,
88-
num_degrees = 3
88+
num_degrees = 3,
89+
reduce_dim_out = True
8990
)
9091

9192
atoms = torch.randint(0, 28, (2, 32))
@@ -94,6 +95,7 @@ coors = torch.randn(2, 32, 3)
9495
mask = torch.ones(2, 32).bool()
9596

9697
out = model(atoms, coors, mask, edges = bonds)
98+
9799
out.type0 # (2, 32)
98100
out.type1 # (2, 32, 3)
99101
```
@@ -111,7 +113,7 @@ model = Equiformer(
111113
dim_head = 64,
112114
num_degrees = 2,
113115
valid_radius = 10,
114-
reversible = True,
116+
reduce_dim_out = True,
115117
attend_sparse_neighbors = True, # this must be set to true, in which case it will assert that you pass in the adjacency matrix
116118
num_neighbors = 0, # if you set this to 0, it will only consider the connected neighbors as defined by the adjacency matrix. but if you set a value greater than 0, it will continue to fetch the closest points up to this many, excluding the ones already specified by the adjacency matrix
117119
num_adj_degrees_embed = 2, # this will derive the second degree connections and embed it correctly
@@ -128,7 +130,10 @@ mask = torch.ones(1, 128).bool()
128130
i = torch.arange(128)
129131
adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))
130132

131-
out = model(feats, coors, mask, adj_mat = adj_mat) # (1, 128, 512)
133+
out = model(feats, coors, mask, adj_mat = adj_mat)
134+
135+
out.type0 # (1, 128)
136+
out.type1 # (1, 128, 3)
132137
```
133138

134139
## Appreciation

tests/test_edges.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ def test_edges_equivariance(
2323
num_degrees = 3,
2424
l2_dist_attention = l2_dist_attention,
2525
reversible = reversible,
26-
init_out_zero = False
26+
init_out_zero = False,
27+
reduce_dim_out = True
2728
)
2829

2930
atoms = torch.randint(0, 28, (2, 32))
@@ -58,7 +59,8 @@ def test_adj_mat_equivariance(
5859
num_neighbors = 0,
5960
num_adj_degrees_embed = 2,
6061
max_sparse_neighbors = 8,
61-
init_out_zero = False
62+
init_out_zero = False,
63+
reduce_dim_out = True
6264
)
6365

6466
feats = torch.randn(1, 128, 32)

0 commit comments

Comments
 (0)