Skip to content

Commit 09bb3fe

Browse files
committed
zero init the attention and feedforward branch outputs by default, but turn it off for testing
1 parent 3826906 commit 09bb3fe

File tree

3 files changed

+23
-5
lines changed

3 files changed

+23
-5
lines changed

equiformer_pytorch/equiformer_pytorch.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ def __init__(
168168
self.weights.append(nn.Parameter(torch.randn(dim_in, dim_out) / sqrt(dim_in)))
169169
self.degrees.append(degree)
170170

171+
def init_zero_(self):
172+
for weight in self.weights:
173+
weight.data.zero_()
174+
171175
def forward(self, x):
172176
out = {}
173177

@@ -453,7 +457,8 @@ def __init__(
453457
fiber: Tuple[int, ...],
454458
fiber_out: Optional[Tuple[int, ...]] = None,
455459
mult = 4,
456-
include_htype_norms = True
460+
include_htype_norms = True,
461+
init_out_zero = True
457462
):
458463
super().__init__()
459464
self.fiber = fiber
@@ -474,6 +479,9 @@ def __init__(
474479
self.gate = Gate(project_in_fiber_hidden)
475480
self.project_out = Linear(fiber_hidden, fiber_out)
476481

482+
if init_out_zero:
483+
self.project_out.init_zero_()
484+
477485
def forward(self, features):
478486
outputs = self.prenorm(features)
479487

@@ -542,7 +550,8 @@ def __init__(
542550
single_headed_kv = False,
543551
radial_hidden_dim = 64,
544552
splits = 4,
545-
num_linear_attn_heads = 0
553+
num_linear_attn_heads = 0,
554+
init_out_zero = True
546555
):
547556
super().__init__()
548557
num_degrees = len(fiber)
@@ -580,6 +589,9 @@ def __init__(
580589

581590
self.to_out = Linear(hidden_fiber, fiber)
582591

592+
if init_out_zero:
593+
self.to_out.init_zero_()
594+
583595
@beartype
584596
def forward(
585597
self,
@@ -669,6 +681,7 @@ def __init__(
669681
attn_hidden_dim_mult = 4,
670682
radial_hidden_dim = 16,
671683
num_linear_attn_heads = 0,
684+
init_out_zero = True,
672685
**kwargs
673686
):
674687
super().__init__()
@@ -738,6 +751,9 @@ def __init__(
738751

739752
self.to_out = Linear(hidden_fiber, fiber)
740753

754+
if init_out_zero:
755+
self.to_out.init_zero_()
756+
741757
@beartype
742758
def forward(
743759
self,

equiformer_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.3.0'
1+
__version__ = '0.3.1'

tests/test_equivariance.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ def test_transformer(dim):
1212
model = Equiformer(
1313
dim = dim,
1414
depth = 2,
15-
num_degrees = 3
15+
num_degrees = 3,
16+
init_out_zero = False
1617
)
1718

1819
feats = torch.randn(1, 32, dim)
@@ -39,7 +40,8 @@ def test_equivariance(
3940
l2_dist_attention = l2_dist_attention,
4041
reversible = reversible,
4142
num_degrees = 3,
42-
reduce_dim_out = True
43+
reduce_dim_out = True,
44+
init_out_zero = False
4345
)
4446

4547
feat_dim = dim if not isinstance(dim, tuple) else dim[0]

0 commit comments

Comments
 (0)