Skip to content

Commit a2d5290

Browse files
committed
[FEATURE]basic QANet
1 parent 45a64ce commit a2d5290

File tree

3 files changed

+68
-63
lines changed

3 files changed

+68
-63
lines changed

config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
SQUAD_VERSION = 'v1.1'
1313
flags.DEFINE_string('squad_version', SQUAD_VERSION, '')
1414
flags.DEFINE_string("mode", "debug", "train/debug/test")
15+
flags.DEFINE_string("run_name", "", "")
1516

1617
# data
1718
DATA_DIR = os.path.join(BASE_DIR, 'data', 'squad', SQUAD_VERSION)
@@ -47,9 +48,11 @@
4748
flags.DEFINE_integer('emb_encoder_conv_num', 4, "")
4849
flags.DEFINE_integer('emb_encoder_conv_kernel_size', 7, "")
4950
flags.DEFINE_integer('emb_encoder_block_num', 1, "")
51+
flags.DEFINE_integer('emb_encoder_ff_depth', 3, "")
5052
flags.DEFINE_integer('output_encoder_conv_num', 2, "")
5153
flags.DEFINE_integer('output_encoder_conv_kernel_size', 5, "")
5254
flags.DEFINE_integer('output_encoder_block_num', 7, "")
55+
flags.DEFINE_integer('output_encoder_ff_depth', 2, "")
5356
flags.DEFINE_integer('attention_head_num', 8, "")
5457

5558
# train & test config
@@ -59,7 +62,7 @@
5962
flags.DEFINE_integer('checkpoint', 1400, "")
6063
flags.DEFINE_float('lr', 0.001, "")
6164
flags.DEFINE_integer('lr_warm_up_steps', 1000, "")
62-
flags.DEFINE_float('adam_beta1', 0.9, "")
65+
flags.DEFINE_float('adam_beta1', 0.8, "")
6366
flags.DEFINE_float('adam_beta2', 0.999, "")
6467
flags.DEFINE_float('adam_eps', 1e-7, "")
6568
flags.DEFINE_float('adam_decay', 5e-8, "")
@@ -81,5 +84,3 @@
8184
os.makedirs(DATA_DIR)
8285
if not os.path.exists(RESULT_DIR):
8386
os.makedirs(RESULT_DIR)
84-
if not os.path.exists(LOG_DIR):
85-
os.makedirs(LOG_DIR)

main.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import math
2+
import os
3+
24
import numpy as np
35
import torch
46
import torch.cuda
@@ -192,7 +194,8 @@ def train_entry(config):
192194

193195

194196
def main(*args, **kwarg):
195-
# runner = Runner(loss=nn.CrossEntropyLoss())
197+
if not os.path.exists(LOG_DIR) and config.mode == 'train':
198+
os.makedirs(LOG_DIR)
196199
if config.mode == "data":
197200
preproc(config)
198201
elif config.mode == "train":
@@ -206,6 +209,7 @@ def main(*args, **kwarg):
206209
else:
207210
print("Unknown mode")
208211
exit(0)
212+
print(config.run_name)
209213

210214

211215
if __name__ == '__main__':

models.py

Lines changed: 59 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ def forward(self, cemb: torch.Tensor, wemb: torch.Tensor):
7070
emb = torch.cat((cemb, wemb), dim=1)
7171
emb = self.highway(emb)
7272

73-
emb = F.relu(self.resizer(emb.transpose(1, 2)))
74-
emb = F.dropout(emb, p=config.word_emb_dropout, training=self.training)
75-
emb = self.norm(emb).transpose(1, 2)
73+
emb = F.relu(self.norm(self.resizer(emb.transpose(1, 2))))
74+
emb = F.dropout(emb, p=config.layer_dropout, training=self.training)
75+
emb = emb.transpose(1, 2)
7676

7777
return emb
7878

@@ -141,40 +141,6 @@ def mask_logits(target, mask):
141141
return target + (-1e30) * (1 - mask)
142142

143143

144-
class AttentionBlock(nn.Module):
145-
def __init__(self, hidden_size, head_number):
146-
super(AttentionBlock, self).__init__()
147-
self.self_attention = nn.MultiheadAttention(hidden_size, head_number,
148-
dropout=config.layer_dropout)
149-
self.attention_layer_norm = nn.LayerNorm(hidden_size)
150-
self.feedforward_norm1 = nn.LayerNorm(hidden_size)
151-
self.feedforward1 = nn.Linear(hidden_size, hidden_size)
152-
self.feedforward_norm2 = nn.LayerNorm(hidden_size)
153-
self.feedforward2 = nn.Linear(hidden_size, hidden_size)
154-
nn.init.kaiming_normal_(self.feedforward1.weight, nonlinearity='relu')
155-
nn.init.kaiming_normal_(self.feedforward2.weight, nonlinearity='relu')
156-
157-
def forward(self, x, mask):
158-
raw = x
159-
x = x.permute(2, 0, 1)
160-
x, _ = self.self_attention(x, x, x, key_padding_mask=(mask.bool() == False))
161-
x = x.permute(1, 2, 0)
162-
x = F.dropout(x, config.layer_dropout, training=self.training)
163-
x = self.attention_layer_norm(raw.transpose(1, 2) + x.transpose(1, 2)).transpose(1, 2)
164-
165-
raw = x
166-
x = self.feedforward1(x.transpose(1, 2)).transpose(1, 2)
167-
x = F.dropout(F.relu(x), config.layer_dropout, training=self.training)
168-
x = self.feedforward_norm1(raw.transpose(1, 2) + x.transpose(1, 2)).transpose(1, 2)
169-
170-
raw = x
171-
x = self.feedforward2(x.transpose(1, 2)).transpose(1, 2)
172-
x = F.dropout(F.relu(x), config.layer_dropout, training=self.training)
173-
x = self.feedforward_norm2(raw.transpose(1, 2) + x.transpose(1, 2)).transpose(1, 2)
174-
175-
return x
176-
177-
178144
class EncoderBlock(nn.Module):
179145
"""
180146
input:
@@ -184,31 +150,66 @@ class EncoderBlock(nn.Module):
184150
x: shape [batch_size, hidden_size, max length] => [8, 128, 400]
185151
"""
186152

187-
def __init__(self, conv_number, hidden_size, kernel_size, head_number):
153+
def __init__(self, conv_num, hidden_size, kernel_size, head_number, ff_depth):
188154
super(EncoderBlock, self).__init__()
189-
self.conv_number = conv_number
155+
self.conv_num = conv_num
156+
self.ff_depth = ff_depth
157+
self.total_layer = self.conv_num + self.ff_depth + 1 # one => atten
158+
190159
self.position_encoder = PositionEncoder(hidden_size)
160+
161+
self.conv_norm_list = nn.ModuleList(
162+
[nn.LayerNorm(hidden_size) for _ in range(self.conv_num)]
163+
)
191164
self.conv_list = nn.ModuleList([
192165
DepthwiseSeparableConv(hidden_size, hidden_size, kernel_size)
193-
for _ in range(self.conv_number)
166+
for _ in range(self.conv_num)
194167
])
195168

196-
self.conv_norm_list = nn.ModuleList(
197-
[nn.LayerNorm(hidden_size) for _ in range(self.conv_number)]
169+
self.atten_layer_norm = nn.LayerNorm(hidden_size)
170+
self.self_atten = nn.MultiheadAttention(
171+
hidden_size,
172+
head_number,
173+
dropout=(1 - (self.conv_num + 1) / self.total_layer) * config.layer_dropout
198174
)
199-
self.self_attention = AttentionBlock(hidden_size, head_number)
175+
176+
self.ff_norm_list = nn.ModuleList(
177+
[nn.LayerNorm(hidden_size) for _ in range(self.ff_depth)]
178+
)
179+
self.ff = nn.ModuleList(
180+
[nn.Linear(hidden_size, hidden_size) for _ in range(self.ff_depth)]
181+
)
182+
for layer in self.ff:
183+
nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')
200184

201185
def forward(self, x, mask):
202186
x = self.position_encoder(x)
203-
for i in range(self.conv_number):
204-
raw = x.transpose(1, 2)
205-
x = self.conv_list[i](x)
206-
x = F.dropout(x, config.layer_dropout * (i + 1) / self.conv_number,
207-
training=self.training)
208-
x = self.conv_norm_list[i](x.transpose(1, 2) + raw).transpose(1, 2)
209-
210-
x = self.self_attention(x, mask)
187+
for i in range(self.conv_num):
188+
raw = x
189+
x = self.conv_list[i](self.conv_norm_list[i](x.transpose(1, 2)).transpose(1, 2))
190+
x = F.dropout(
191+
input=x,
192+
p=(1 - (i + 1) / self.total_layer) * config.layer_dropout,
193+
training=self.training
194+
)
195+
x = raw + x
211196

197+
raw = x
198+
x = self.atten_layer_norm(x.transpose(1, 2)).transpose(1, 2)
199+
x = x.permute(2, 0, 1)
200+
x, _ = self.self_atten(x, x, x, key_padding_mask=(mask.bool() == False))
201+
x = x.permute(1, 2, 0)
202+
x = raw + x
203+
204+
for i in range(self.ff_depth):
205+
raw = x
206+
x = F.relu(self.ff[i](self.ff_norm_list[i](x.transpose(1, 2)))).transpose(1, 2)
207+
x = F.dropout(
208+
input=x,
209+
p=(1 - (self.conv_num + 1 + (i + 1)) / self.total_layer) * config.layer_dropout,
210+
training=self.training
211+
)
212+
x = raw + x
212213
return x
213214

214215

@@ -254,10 +255,7 @@ def forward(self, C, Q, cmask, qmask):
254255

255256
A = torch.bmm(S_row_sofmax, Q)
256257
B = torch.bmm(torch.bmm(S_row_sofmax, S_column_softmax.transpose(1, 2)), C)
257-
output = torch.cat((C, A, torch.mul(C, A), torch.mul(C, B)), dim=2)
258-
output = F.dropout(output, p=config.layer_dropout, training=self.training)
259-
output = self.resizer(output)
260-
output = F.relu(output)
258+
output = F.relu(self.resizer(torch.cat((C, A, torch.mul(C, A), torch.mul(C, B)), dim=2)))
261259
output = F.dropout(output, p=config.layer_dropout, training=self.training)
262260
output = output.transpose(1, 2)
263261
return output
@@ -301,19 +299,21 @@ def __init__(self, word_mat, char_mat):
301299
self.char_embedding = nn.Embedding.from_pretrained(torch.tensor(char_mat), freeze=False)
302300
self.embedding = Embedding(word_mat.shape[1], char_mat.shape[1], config.global_hidden_size)
303301
emb_encoder_block = EncoderBlock(
304-
conv_number=config.emb_encoder_conv_num,
302+
conv_num=config.emb_encoder_conv_num,
305303
hidden_size=config.global_hidden_size,
306304
kernel_size=config.emb_encoder_conv_kernel_size,
307-
head_number=config.attention_head_num
305+
head_number=config.attention_head_num,
306+
ff_depth=config.emb_encoder_ff_depth,
308307
)
309308
self.emb_encoder = nn.ModuleList(
310309
[emb_encoder_block for _ in range(config.emb_encoder_block_num)])
311310
self.cq_attention = CQAttention(hidden_size=config.global_hidden_size)
312311
output_encoder_block = EncoderBlock(
313-
conv_number=config.output_encoder_conv_num,
312+
conv_num=config.output_encoder_conv_num,
314313
hidden_size=config.global_hidden_size,
315314
kernel_size=config.output_encoder_conv_kernel_size,
316-
head_number=config.attention_head_num
315+
head_number=config.attention_head_num,
316+
ff_depth=config.output_encoder_ff_depth,
317317
)
318318
self.output_encoder = nn.ModuleList(
319319
[output_encoder_block for _ in range(config.output_encoder_block_num)])

0 commit comments

Comments
 (0)