Skip to content

Commit 5518fe7

Browse files
committed
1 parent 0522b58 commit 5518fe7

File tree

2 files changed

+55
-7
lines changed

2 files changed

+55
-7
lines changed

asr/models/las/network.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,61 @@ def forward(self, s, h, len_mask=None):
177177
return c, a
178178

179179

180+
def split_last(x, shape):
181+
"split the last dimension to given shape"
182+
shape = list(shape)
183+
assert shape.count(-1) <= 1
184+
if -1 in shape:
185+
shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape))
186+
return x.view(*x.size()[:-1], *shape)
187+
188+
189+
def merge_last(x, n_dims):
190+
"merge the last n_dims to a dimension"
191+
s = x.size()
192+
assert n_dims > 1 and n_dims < len(s)
193+
return x.view(*s[:-n_dims], -1)
194+
195+
196+
class MultiHeadedSelfAttention(nn.Module):
197+
""" Multi-Headed Dot Product Attention """
198+
def __init__(self, state_vec_size, listen_vec_size, proj_hidden_size=512, num_heads=1, dropout=0.1):
199+
super().__init__()
200+
self.proj_q = nn.Linear(state_vec_size, proj_hidden_size)
201+
self.proj_k = nn.Linear(listen_vec_size, proj_hidden_size)
202+
self.proj_v = nn.Linear(listen_vec_size, proj_hidden_size)
203+
self.drop = nn.Dropout(dropout)
204+
self.scores = None # for visualization
205+
self.n_heads = num_heads
206+
207+
def forward(self, q, k, mask):
208+
"""
209+
x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim))
210+
mask : (B(batch_size) x S(seq_len))
211+
* split D(dim) into (H(n_heads), W(width of head)) ; D = H * W
212+
"""
213+
# (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
214+
q, k, v = self.proj_q(q), self.proj_k(k), self.proj_v(k)
215+
q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v])
216+
# (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S)
217+
scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1))
218+
if mask is not None:
219+
mask = mask[:, None, None, :].float()
220+
scores -= 10000.0 * (1.0 - mask)
221+
scores = self.drop(F.softmax(scores, dim=-1))
222+
# (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W)
223+
h = (scores @ v).transpose(1, 2).contiguous()
224+
# -merge-> (B, S, D)
225+
h = merge_last(h, 2)
226+
self.scores = scores
227+
return h
228+
229+
180230
class Speller(nn.Module):
181231

182232
def __init__(self, listen_vec_size, label_vec_size, max_seq_lens=256, sos=None, eos=None,
183233
rnn_type=nn.LSTM, rnn_hidden_size=512, rnn_num_layers=2,
184-
apply_attend_proj=False, proj_hidden_size=256, num_attend_heads=1,
185-
masked_attend=True):
234+
proj_hidden_size=256, num_attend_heads=1, masked_attend=True):
186235
super().__init__()
187236

188237
assert sos is not None and 0 <= sos < label_vec_size
@@ -204,8 +253,7 @@ def __init__(self, listen_vec_size, label_vec_size, max_seq_lens=256, sos=None,
204253
self.norm = nn.LayerNorm(Hs, elementwise_affine=False)
205254

206255
self.attention = Attention(state_vec_size=Hs, listen_vec_size=Hc,
207-
apply_proj=apply_attend_proj, proj_hidden_size=proj_hidden_size,
208-
num_heads=num_attend_heads)
256+
proj_hidden_size=proj_hidden_size, num_heads=num_attend_heads)
209257

210258
self.masked_attend = masked_attend
211259

@@ -330,7 +378,7 @@ def forward(self, x):
330378
class ListenAttendSpell(nn.Module):
331379

332380
def __init__(self, label_vec_size=p.NUM_CTC_LABELS, listen_vec_size=256,
333-
state_vec_size=256, num_attend_heads=1, input_folding=2, smoothing=0.001):
381+
state_vec_size=256, num_attend_heads=4, input_folding=2, smoothing=0.001):
334382
super().__init__()
335383

336384
self.label_vec_size = label_vec_size + 2 # to add <sos>, <eos>
@@ -347,7 +395,7 @@ def __init__(self, label_vec_size=p.NUM_CTC_LABELS, listen_vec_size=256,
347395
self.spell = Speller(listen_vec_size=listen_vec_size, label_vec_size=self.label_vec_size,
348396
sos=self.sos, eos=self.eos, max_seq_lens=256,
349397
rnn_hidden_size=state_vec_size, rnn_num_layers=2,
350-
apply_attend_proj=True, proj_hidden_size=128, num_attend_heads=num_attend_heads)
398+
proj_hidden_size=256, num_attend_heads=num_attend_heads)
351399

352400
self.attentions = None
353401
self.regions = None

asr/utils/logger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def plot_heatmap(ax, tensor, drawbox=None):
276276
else:
277277
fig, axs = plt.subplots(tensor.size(0), sharex=True)
278278
for i, ax in enumerate(axs):
279-
plot_heatmap(ax, tensor[i], drawbox[i])
279+
plot_heatmap(ax, tensor[i], drawbox)
280280
fig.subplots_adjust(hspace=2)
281281

282282
fig.patch.set_color('white')

0 commit comments

Comments
 (0)