@@ -177,12 +177,61 @@ def forward(self, s, h, len_mask=None):
177177 return c , a
178178
179179
180+ def split_last (x , shape ):
181+ "split the last dimension to given shape"
182+ shape = list (shape )
183+ assert shape .count (- 1 ) <= 1
184+ if - 1 in shape :
185+ shape [shape .index (- 1 )] = int (x .size (- 1 ) / - np .prod (shape ))
186+ return x .view (* x .size ()[:- 1 ], * shape )
187+
188+
189+ def merge_last (x , n_dims ):
190+ "merge the last n_dims to a dimension"
191+ s = x .size ()
192+ assert n_dims > 1 and n_dims < len (s )
193+ return x .view (* s [:- n_dims ], - 1 )
194+
195+
196+ class MultiHeadedSelfAttention (nn .Module ):
197+ """ Multi-Headed Dot Product Attention """
198+ def __init__ (self , state_vec_size , listen_vec_size , proj_hidden_size = 512 , num_heads = 1 , dropout = 0.1 ):
199+ super ().__init__ ()
200+ self .proj_q = nn .Linear (state_vec_size , proj_hidden_size )
201+ self .proj_k = nn .Linear (listen_vec_size , proj_hidden_size )
202+ self .proj_v = nn .Linear (listen_vec_size , proj_hidden_size )
203+ self .drop = nn .Dropout (dropout )
204+ self .scores = None # for visualization
205+ self .n_heads = num_heads
206+
207+ def forward (self , q , k , mask ):
208+ """
209+ x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim))
210+ mask : (B(batch_size) x S(seq_len))
211+ * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W
212+ """
213+ # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
214+ q , k , v = self .proj_q (q ), self .proj_k (k ), self .proj_v (k )
215+ q , k , v = (split_last (x , (self .n_heads , - 1 )).transpose (1 , 2 ) for x in [q , k , v ])
216+ # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S)
217+ scores = q @ k .transpose (- 2 , - 1 ) / np .sqrt (k .size (- 1 ))
218+ if mask is not None :
219+ mask = mask [:, None , None , :].float ()
220+ scores -= 10000.0 * (1.0 - mask )
221+ scores = self .drop (F .softmax (scores , dim = - 1 ))
222+ # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W)
223+ h = (scores @ v ).transpose (1 , 2 ).contiguous ()
224+ # -merge-> (B, S, D)
225+ h = merge_last (h , 2 )
226+ self .scores = scores
227+ return h
228+
229+
180230class Speller (nn .Module ):
181231
182232 def __init__ (self , listen_vec_size , label_vec_size , max_seq_lens = 256 , sos = None , eos = None ,
183233 rnn_type = nn .LSTM , rnn_hidden_size = 512 , rnn_num_layers = 2 ,
184- apply_attend_proj = False , proj_hidden_size = 256 , num_attend_heads = 1 ,
185- masked_attend = True ):
234+ proj_hidden_size = 256 , num_attend_heads = 1 , masked_attend = True ):
186235 super ().__init__ ()
187236
188237 assert sos is not None and 0 <= sos < label_vec_size
@@ -204,8 +253,7 @@ def __init__(self, listen_vec_size, label_vec_size, max_seq_lens=256, sos=None,
204253 self .norm = nn .LayerNorm (Hs , elementwise_affine = False )
205254
206255 self .attention = Attention (state_vec_size = Hs , listen_vec_size = Hc ,
207- apply_proj = apply_attend_proj , proj_hidden_size = proj_hidden_size ,
208- num_heads = num_attend_heads )
256+ proj_hidden_size = proj_hidden_size , num_heads = num_attend_heads )
209257
210258 self .masked_attend = masked_attend
211259
@@ -330,7 +378,7 @@ def forward(self, x):
330378class ListenAttendSpell (nn .Module ):
331379
332380 def __init__ (self , label_vec_size = p .NUM_CTC_LABELS , listen_vec_size = 256 ,
333- state_vec_size = 256 , num_attend_heads = 1 , input_folding = 2 , smoothing = 0.001 ):
381+ state_vec_size = 256 , num_attend_heads = 4 , input_folding = 2 , smoothing = 0.001 ):
334382 super ().__init__ ()
335383
336384 self .label_vec_size = label_vec_size + 2 # to add <sos>, <eos>
@@ -347,7 +395,7 @@ def __init__(self, label_vec_size=p.NUM_CTC_LABELS, listen_vec_size=256,
347395 self .spell = Speller (listen_vec_size = listen_vec_size , label_vec_size = self .label_vec_size ,
348396 sos = self .sos , eos = self .eos , max_seq_lens = 256 ,
349397 rnn_hidden_size = state_vec_size , rnn_num_layers = 2 ,
350- apply_attend_proj = True , proj_hidden_size = 128 , num_attend_heads = num_attend_heads )
398+ proj_hidden_size = 256 , num_attend_heads = num_attend_heads )
351399
352400 self .attentions = None
353401 self .regions = None
0 commit comments