@@ -168,6 +168,10 @@ def __init__(
168168 self .weights .append (nn .Parameter (torch .randn (dim_in , dim_out ) / sqrt (dim_in )))
169169 self .degrees .append (degree )
170170
171+ def init_zero_ (self ):
172+ for weight in self .weights :
173+ weight .data .zero_ ()
174+
171175 def forward (self , x ):
172176 out = {}
173177
@@ -453,7 +457,8 @@ def __init__(
453457 fiber : Tuple [int , ...],
454458 fiber_out : Optional [Tuple [int , ...]] = None ,
455459 mult = 4 ,
456- include_htype_norms = True
460+ include_htype_norms = True ,
461+ init_out_zero = True
457462 ):
458463 super ().__init__ ()
459464 self .fiber = fiber
@@ -474,6 +479,9 @@ def __init__(
474479 self .gate = Gate (project_in_fiber_hidden )
475480 self .project_out = Linear (fiber_hidden , fiber_out )
476481
482+ if init_out_zero :
483+ self .project_out .init_zero_ ()
484+
477485 def forward (self , features ):
478486 outputs = self .prenorm (features )
479487
@@ -542,7 +550,8 @@ def __init__(
542550 single_headed_kv = False ,
543551 radial_hidden_dim = 64 ,
544552 splits = 4 ,
545- num_linear_attn_heads = 0
553+ num_linear_attn_heads = 0 ,
554+ init_out_zero = True
546555 ):
547556 super ().__init__ ()
548557 num_degrees = len (fiber )
@@ -580,6 +589,9 @@ def __init__(
580589
581590 self .to_out = Linear (hidden_fiber , fiber )
582591
592+ if init_out_zero :
593+ self .to_out .init_zero_ ()
594+
583595 @beartype
584596 def forward (
585597 self ,
@@ -669,6 +681,7 @@ def __init__(
669681 attn_hidden_dim_mult = 4 ,
670682 radial_hidden_dim = 16 ,
671683 num_linear_attn_heads = 0 ,
684+ init_out_zero = True ,
672685 ** kwargs
673686 ):
674687 super ().__init__ ()
@@ -738,6 +751,9 @@ def __init__(
738751
739752 self .to_out = Linear (hidden_fiber , fiber )
740753
754+ if init_out_zero :
755+ self .to_out .init_zero_ ()
756+
741757 @beartype
742758 def forward (
743759 self ,
0 commit comments