@@ -70,9 +70,9 @@ def forward(self, cemb: torch.Tensor, wemb: torch.Tensor):
7070 emb = torch .cat ((cemb , wemb ), dim = 1 )
7171 emb = self .highway (emb )
7272
73- emb = F .relu (self .resizer (emb .transpose (1 , 2 )))
74- emb = F .dropout (emb , p = config .word_emb_dropout , training = self .training )
75- emb = self . norm ( emb ) .transpose (1 , 2 )
73+ emb = F .relu (self .norm ( self . resizer (emb .transpose (1 , 2 ) )))
74+ emb = F .dropout (emb , p = config .layer_dropout , training = self .training )
75+ emb = emb .transpose (1 , 2 )
7676
7777 return emb
7878
@@ -141,40 +141,6 @@ def mask_logits(target, mask):
141141 return target + (- 1e30 ) * (1 - mask )
142142
143143
144- class AttentionBlock (nn .Module ):
145- def __init__ (self , hidden_size , head_number ):
146- super (AttentionBlock , self ).__init__ ()
147- self .self_attention = nn .MultiheadAttention (hidden_size , head_number ,
148- dropout = config .layer_dropout )
149- self .attention_layer_norm = nn .LayerNorm (hidden_size )
150- self .feedforward_norm1 = nn .LayerNorm (hidden_size )
151- self .feedforward1 = nn .Linear (hidden_size , hidden_size )
152- self .feedforward_norm2 = nn .LayerNorm (hidden_size )
153- self .feedforward2 = nn .Linear (hidden_size , hidden_size )
154- nn .init .kaiming_normal_ (self .feedforward1 .weight , nonlinearity = 'relu' )
155- nn .init .kaiming_normal_ (self .feedforward2 .weight , nonlinearity = 'relu' )
156-
157- def forward (self , x , mask ):
158- raw = x
159- x = x .permute (2 , 0 , 1 )
160- x , _ = self .self_attention (x , x , x , key_padding_mask = (mask .bool () == False ))
161- x = x .permute (1 , 2 , 0 )
162- x = F .dropout (x , config .layer_dropout , training = self .training )
163- x = self .attention_layer_norm (raw .transpose (1 , 2 ) + x .transpose (1 , 2 )).transpose (1 , 2 )
164-
165- raw = x
166- x = self .feedforward1 (x .transpose (1 , 2 )).transpose (1 , 2 )
167- x = F .dropout (F .relu (x ), config .layer_dropout , training = self .training )
168- x = self .feedforward_norm1 (raw .transpose (1 , 2 ) + x .transpose (1 , 2 )).transpose (1 , 2 )
169-
170- raw = x
171- x = self .feedforward2 (x .transpose (1 , 2 )).transpose (1 , 2 )
172- x = F .dropout (F .relu (x ), config .layer_dropout , training = self .training )
173- x = self .feedforward_norm2 (raw .transpose (1 , 2 ) + x .transpose (1 , 2 )).transpose (1 , 2 )
174-
175- return x
176-
177-
178144class EncoderBlock (nn .Module ):
179145 """
180146 input:
@@ -184,31 +150,66 @@ class EncoderBlock(nn.Module):
184150 x: shape [batch_size, hidden_size, max length] => [8, 128, 400]
185151 """
186152
187- def __init__ (self , conv_number , hidden_size , kernel_size , head_number ):
153+ def __init__ (self , conv_num , hidden_size , kernel_size , head_number , ff_depth ):
188154 super (EncoderBlock , self ).__init__ ()
189- self .conv_number = conv_number
155+ self .conv_num = conv_num
156+ self .ff_depth = ff_depth
157+ self .total_layer = self .conv_num + self .ff_depth + 1 # one => atten
158+
190159 self .position_encoder = PositionEncoder (hidden_size )
160+
161+ self .conv_norm_list = nn .ModuleList (
162+ [nn .LayerNorm (hidden_size ) for _ in range (self .conv_num )]
163+ )
191164 self .conv_list = nn .ModuleList ([
192165 DepthwiseSeparableConv (hidden_size , hidden_size , kernel_size )
193- for _ in range (self .conv_number )
166+ for _ in range (self .conv_num )
194167 ])
195168
196- self .conv_norm_list = nn .ModuleList (
197- [nn .LayerNorm (hidden_size ) for _ in range (self .conv_number )]
169+ self .atten_layer_norm = nn .LayerNorm (hidden_size )
170+ self .self_atten = nn .MultiheadAttention (
171+ hidden_size ,
172+ head_number ,
173+ dropout = (1 - (self .conv_num + 1 ) / self .total_layer ) * config .layer_dropout
198174 )
199- self .self_attention = AttentionBlock (hidden_size , head_number )
175+
176+ self .ff_norm_list = nn .ModuleList (
177+ [nn .LayerNorm (hidden_size ) for _ in range (self .ff_depth )]
178+ )
179+ self .ff = nn .ModuleList (
180+ [nn .Linear (hidden_size , hidden_size ) for _ in range (self .ff_depth )]
181+ )
182+ for layer in self .ff :
183+ nn .init .kaiming_normal_ (layer .weight , nonlinearity = 'relu' )
200184
201185 def forward (self , x , mask ):
202186 x = self .position_encoder (x )
203- for i in range (self .conv_number ):
204- raw = x .transpose (1 , 2 )
205- x = self .conv_list [i ](x )
206- x = F .dropout (x , config .layer_dropout * (i + 1 ) / self .conv_number ,
207- training = self .training )
208- x = self .conv_norm_list [i ](x .transpose (1 , 2 ) + raw ).transpose (1 , 2 )
209-
210- x = self .self_attention (x , mask )
187+ for i in range (self .conv_num ):
188+ raw = x
189+ x = self .conv_list [i ](self .conv_norm_list [i ](x .transpose (1 , 2 )).transpose (1 , 2 ))
190+ x = F .dropout (
191+ input = x ,
192+ p = (1 - (i + 1 ) / self .total_layer ) * config .layer_dropout ,
193+ training = self .training
194+ )
195+ x = raw + x
211196
197+ raw = x
198+ x = self .atten_layer_norm (x .transpose (1 , 2 )).transpose (1 , 2 )
199+ x = x .permute (2 , 0 , 1 )
200+ x , _ = self .self_atten (x , x , x , key_padding_mask = (mask .bool () == False ))
201+ x = x .permute (1 , 2 , 0 )
202+ x = raw + x
203+
204+ for i in range (self .ff_depth ):
205+ raw = x
206+ x = F .relu (self .ff [i ](self .ff_norm_list [i ](x .transpose (1 , 2 )))).transpose (1 , 2 )
207+ x = F .dropout (
208+ input = x ,
209+ p = (1 - (self .conv_num + 1 + (i + 1 )) / self .total_layer ) * config .layer_dropout ,
210+ training = self .training
211+ )
212+ x = raw + x
212213 return x
213214
214215
@@ -254,10 +255,7 @@ def forward(self, C, Q, cmask, qmask):
254255
255256 A = torch .bmm (S_row_sofmax , Q )
256257 B = torch .bmm (torch .bmm (S_row_sofmax , S_column_softmax .transpose (1 , 2 )), C )
257- output = torch .cat ((C , A , torch .mul (C , A ), torch .mul (C , B )), dim = 2 )
258- output = F .dropout (output , p = config .layer_dropout , training = self .training )
259- output = self .resizer (output )
260- output = F .relu (output )
258+ output = F .relu (self .resizer (torch .cat ((C , A , torch .mul (C , A ), torch .mul (C , B )), dim = 2 )))
261259 output = F .dropout (output , p = config .layer_dropout , training = self .training )
262260 output = output .transpose (1 , 2 )
263261 return output
@@ -301,19 +299,21 @@ def __init__(self, word_mat, char_mat):
301299 self .char_embedding = nn .Embedding .from_pretrained (torch .tensor (char_mat ), freeze = False )
302300 self .embedding = Embedding (word_mat .shape [1 ], char_mat .shape [1 ], config .global_hidden_size )
303301 emb_encoder_block = EncoderBlock (
304- conv_number = config .emb_encoder_conv_num ,
302+ conv_num = config .emb_encoder_conv_num ,
305303 hidden_size = config .global_hidden_size ,
306304 kernel_size = config .emb_encoder_conv_kernel_size ,
307- head_number = config .attention_head_num
305+ head_number = config .attention_head_num ,
306+ ff_depth = config .emb_encoder_ff_depth ,
308307 )
309308 self .emb_encoder = nn .ModuleList (
310309 [emb_encoder_block for _ in range (config .emb_encoder_block_num )])
311310 self .cq_attention = CQAttention (hidden_size = config .global_hidden_size )
312311 output_encoder_block = EncoderBlock (
313- conv_number = config .output_encoder_conv_num ,
312+ conv_num = config .output_encoder_conv_num ,
314313 hidden_size = config .global_hidden_size ,
315314 kernel_size = config .output_encoder_conv_kernel_size ,
316- head_number = config .attention_head_num
315+ head_number = config .attention_head_num ,
316+ ff_depth = config .output_encoder_ff_depth ,
317317 )
318318 self .output_encoder = nn .ModuleList (
319319 [output_encoder_block for _ in range (config .output_encoder_block_num )])
0 commit comments