Skip to content

Commit 6700f1d

Browse files
committed
[FEATURE]basic QANet
1 parent 45a64ce commit 6700f1d

File tree

2 files changed

+62
-60
lines changed

2 files changed

+62
-60
lines changed

config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,11 @@
4747
flags.DEFINE_integer('emb_encoder_conv_num', 4, "")
4848
flags.DEFINE_integer('emb_encoder_conv_kernel_size', 7, "")
4949
flags.DEFINE_integer('emb_encoder_block_num', 1, "")
50+
flags.DEFINE_integer('emb_encoder_ff_depth', 3, "")
5051
flags.DEFINE_integer('output_encoder_conv_num', 2, "")
5152
flags.DEFINE_integer('output_encoder_conv_kernel_size', 5, "")
5253
flags.DEFINE_integer('output_encoder_block_num', 7, "")
54+
flags.DEFINE_integer('output_encoder_ff_depth', 2, "")
5355
flags.DEFINE_integer('attention_head_num', 8, "")
5456

5557
# train & test config
@@ -59,7 +61,7 @@
5961
flags.DEFINE_integer('checkpoint', 1400, "")
6062
flags.DEFINE_float('lr', 0.001, "")
6163
flags.DEFINE_integer('lr_warm_up_steps', 1000, "")
62-
flags.DEFINE_float('adam_beta1', 0.9, "")
64+
flags.DEFINE_float('adam_beta1', 0.8, "")
6365
flags.DEFINE_float('adam_beta2', 0.999, "")
6466
flags.DEFINE_float('adam_eps', 1e-7, "")
6567
flags.DEFINE_float('adam_decay', 5e-8, "")

models.py

Lines changed: 59 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ def forward(self, cemb: torch.Tensor, wemb: torch.Tensor):
7070
emb = torch.cat((cemb, wemb), dim=1)
7171
emb = self.highway(emb)
7272

73-
emb = F.relu(self.resizer(emb.transpose(1, 2)))
74-
emb = F.dropout(emb, p=config.word_emb_dropout, training=self.training)
75-
emb = self.norm(emb).transpose(1, 2)
73+
emb = F.relu(self.norm(self.resizer(emb.transpose(1, 2))))
74+
emb = F.dropout(emb, p=config.layer_dropout, training=self.training)
75+
emb = emb.transpose(1, 2)
7676

7777
return emb
7878

@@ -141,40 +141,6 @@ def mask_logits(target, mask):
141141
return target + (-1e30) * (1 - mask)
142142

143143

144-
class AttentionBlock(nn.Module):
145-
def __init__(self, hidden_size, head_number):
146-
super(AttentionBlock, self).__init__()
147-
self.self_attention = nn.MultiheadAttention(hidden_size, head_number,
148-
dropout=config.layer_dropout)
149-
self.attention_layer_norm = nn.LayerNorm(hidden_size)
150-
self.feedforward_norm1 = nn.LayerNorm(hidden_size)
151-
self.feedforward1 = nn.Linear(hidden_size, hidden_size)
152-
self.feedforward_norm2 = nn.LayerNorm(hidden_size)
153-
self.feedforward2 = nn.Linear(hidden_size, hidden_size)
154-
nn.init.kaiming_normal_(self.feedforward1.weight, nonlinearity='relu')
155-
nn.init.kaiming_normal_(self.feedforward2.weight, nonlinearity='relu')
156-
157-
def forward(self, x, mask):
158-
raw = x
159-
x = x.permute(2, 0, 1)
160-
x, _ = self.self_attention(x, x, x, key_padding_mask=(mask.bool() == False))
161-
x = x.permute(1, 2, 0)
162-
x = F.dropout(x, config.layer_dropout, training=self.training)
163-
x = self.attention_layer_norm(raw.transpose(1, 2) + x.transpose(1, 2)).transpose(1, 2)
164-
165-
raw = x
166-
x = self.feedforward1(x.transpose(1, 2)).transpose(1, 2)
167-
x = F.dropout(F.relu(x), config.layer_dropout, training=self.training)
168-
x = self.feedforward_norm1(raw.transpose(1, 2) + x.transpose(1, 2)).transpose(1, 2)
169-
170-
raw = x
171-
x = self.feedforward2(x.transpose(1, 2)).transpose(1, 2)
172-
x = F.dropout(F.relu(x), config.layer_dropout, training=self.training)
173-
x = self.feedforward_norm2(raw.transpose(1, 2) + x.transpose(1, 2)).transpose(1, 2)
174-
175-
return x
176-
177-
178144
class EncoderBlock(nn.Module):
179145
"""
180146
input:
@@ -184,31 +150,66 @@ class EncoderBlock(nn.Module):
184150
x: shape [batch_size, hidden_size, max length] => [8, 128, 400]
185151
"""
186152

187-
def __init__(self, conv_number, hidden_size, kernel_size, head_number):
153+
def __init__(self, conv_num, hidden_size, kernel_size, head_number, ff_depth):
188154
super(EncoderBlock, self).__init__()
189-
self.conv_number = conv_number
155+
self.conv_num = conv_num
156+
self.ff_depth = ff_depth
157+
self.total_layer = self.conv_num + self.ff_depth + 1 # one => atten
158+
190159
self.position_encoder = PositionEncoder(hidden_size)
160+
161+
self.conv_norm_list = nn.ModuleList(
162+
[nn.LayerNorm(hidden_size) for _ in range(self.conv_num)]
163+
)
191164
self.conv_list = nn.ModuleList([
192165
DepthwiseSeparableConv(hidden_size, hidden_size, kernel_size)
193-
for _ in range(self.conv_number)
166+
for _ in range(self.conv_num)
194167
])
195168

196-
self.conv_norm_list = nn.ModuleList(
197-
[nn.LayerNorm(hidden_size) for _ in range(self.conv_number)]
169+
self.atten_layer_norm = nn.LayerNorm(hidden_size)
170+
self.self_atten = nn.MultiheadAttention(
171+
hidden_size,
172+
head_number,
173+
dropout=(1 - (self.conv_num + 1) / self.total_layer) * config.layer_dropout
198174
)
199-
self.self_attention = AttentionBlock(hidden_size, head_number)
175+
176+
self.ff_norm_list = nn.ModuleList(
177+
[nn.LayerNorm(hidden_size) for _ in range(self.ff_depth)]
178+
)
179+
self.ff = nn.ModuleList(
180+
[nn.Linear(hidden_size, hidden_size) for _ in range(self.ff_depth)]
181+
)
182+
for layer in self.ff:
183+
nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')
200184

201185
def forward(self, x, mask):
202186
x = self.position_encoder(x)
203-
for i in range(self.conv_number):
204-
raw = x.transpose(1, 2)
205-
x = self.conv_list[i](x)
206-
x = F.dropout(x, config.layer_dropout * (i + 1) / self.conv_number,
207-
training=self.training)
208-
x = self.conv_norm_list[i](x.transpose(1, 2) + raw).transpose(1, 2)
209-
210-
x = self.self_attention(x, mask)
187+
for i in range(self.conv_num):
188+
raw = x
189+
x = self.conv_list[i](self.conv_norm_list[i](x.transpose(1, 2)).transpose(1, 2))
190+
x = F.dropout(
191+
input=x,
192+
p=(1 - (i + 1) / self.total_layer) * config.layer_dropout,
193+
training=self.training
194+
)
195+
x = raw + x
211196

197+
raw = x
198+
x = self.atten_layer_norm(x.transpose(1, 2)).transpose(1, 2)
199+
x = x.permute(2, 0, 1)
200+
x, _ = self.self_atten(x, x, x, key_padding_mask=(mask.bool() == False))
201+
x = x.permute(1, 2, 0)
202+
x = raw + x
203+
204+
for i in range(self.ff_depth):
205+
raw = x
206+
x = F.relu(self.ff[i](self.ff_norm_list[i](x.transpose(1, 2)))).transpose(1, 2)
207+
x = F.dropout(
208+
input=x,
209+
p=(1 - (self.conv_num + 1 + (i + 1)) / self.total_layer) * config.layer_dropout,
210+
training=self.training
211+
)
212+
x = raw + x
212213
return x
213214

214215

@@ -254,10 +255,7 @@ def forward(self, C, Q, cmask, qmask):
254255

255256
A = torch.bmm(S_row_sofmax, Q)
256257
B = torch.bmm(torch.bmm(S_row_sofmax, S_column_softmax.transpose(1, 2)), C)
257-
output = torch.cat((C, A, torch.mul(C, A), torch.mul(C, B)), dim=2)
258-
output = F.dropout(output, p=config.layer_dropout, training=self.training)
259-
output = self.resizer(output)
260-
output = F.relu(output)
258+
output = F.relu(self.resizer(torch.cat((C, A, torch.mul(C, A), torch.mul(C, B)), dim=2)))
261259
output = F.dropout(output, p=config.layer_dropout, training=self.training)
262260
output = output.transpose(1, 2)
263261
return output
@@ -301,19 +299,21 @@ def __init__(self, word_mat, char_mat):
301299
self.char_embedding = nn.Embedding.from_pretrained(torch.tensor(char_mat), freeze=False)
302300
self.embedding = Embedding(word_mat.shape[1], char_mat.shape[1], config.global_hidden_size)
303301
emb_encoder_block = EncoderBlock(
304-
conv_number=config.emb_encoder_conv_num,
302+
conv_num=config.emb_encoder_conv_num,
305303
hidden_size=config.global_hidden_size,
306304
kernel_size=config.emb_encoder_conv_kernel_size,
307-
head_number=config.attention_head_num
305+
head_number=config.attention_head_num,
306+
ff_depth=config.emb_encoder_ff_depth,
308307
)
309308
self.emb_encoder = nn.ModuleList(
310309
[emb_encoder_block for _ in range(config.emb_encoder_block_num)])
311310
self.cq_attention = CQAttention(hidden_size=config.global_hidden_size)
312311
output_encoder_block = EncoderBlock(
313-
conv_number=config.output_encoder_conv_num,
312+
conv_num=config.output_encoder_conv_num,
314313
hidden_size=config.global_hidden_size,
315314
kernel_size=config.output_encoder_conv_kernel_size,
316-
head_number=config.attention_head_num
315+
head_number=config.attention_head_num,
316+
ff_depth=config.output_encoder_ff_depth,
317317
)
318318
self.output_encoder = nn.ModuleList(
319319
[output_encoder_block for _ in range(config.output_encoder_block_num)])

0 commit comments

Comments
 (0)