Skip to content

Commit 4ba7ec4

Browse files
authored
Merge pull request #146 from wintermelon008/dev
[BREAKING] Rewrite the quesnet code for pretrain
2 parents be1aef1 + 06461fd commit 4ba7ec4

File tree

13 files changed

+320
-291
lines changed

13 files changed

+320
-291
lines changed

EduNLP/ModelZoo/disenqnet/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
# -*- coding: utf-8 -*-
22

3-
from .disenqnet import DisenQNet, DisenQNetForPreTraining, \
4-
DisenQNetForPropertyPrediction, DisenQNetForKnowledgePrediction
3+
from .disenqnet import *

EduNLP/ModelZoo/quesnet/quesnet.py

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ def make_batch(self, data, device, pretrain=False):
114114
ans_input = []
115115
ans_output = []
116116
false_options = [[] for i in range(3)]
117+
118+
if not isinstance(data, list):
119+
data = [data]
120+
117121
for q in data:
118122
meta = torch.zeros(len(self.stoi[self.meta])).to(device)
119123
meta[q.labels.get(self.meta) or []] = 1
@@ -192,6 +196,7 @@ def make_batch(self, data, device, pretrain=False):
192196
words = torch.cat(words, dim=0) if words else None
193197
ims = torch.cat(ims, dim=0) if ims else None
194198
metas = torch.cat(metas, dim=0) if metas else None
199+
195200
if pretrain:
196201
return (
197202
lembs, rembs, words, ims, metas, wmask, imask, mmask,
@@ -302,67 +307,70 @@ def __init__(self, _stoi=None, pretrained_embs: np.ndarray = None, pretrained_im
302307
self.config = PretrainedConfig.from_dict(self.config)
303308

304309
def forward(self, batch):
305-
left, right, words, ims, metas, wmask, imask, mmask, inputs, ans_input, ans_output, false_opt_input = batch
310+
left, right, words, ims, metas, wmask, imask, mmask, inputs, ans_input, ans_output, false_opt_input = batch[0]
306311

307312
# high-level loss
308313
outputs = self.quesnet(inputs)
309314
embeded = outputs.embeded
310315
h = outputs.hidden
311316

312317
x = ans_input.packed()
313-
y, _ = self.ans_decode(PackedSequence(self.quesnet.we(x.data), x.batch_sizes),
318+
319+
y, _ = self.ans_decode(PackedSequence(self.quesnet.we(x[0].data), x.batch_sizes),
314320
h.repeat(self.config.layers, 1, 1))
315321
floss = F.cross_entropy(self.ans_output(y.data),
316322
ans_output.packed().data)
317323
floss = floss + F.binary_cross_entropy_with_logits(self.ans_judge(y.data),
318324
torch.ones_like(self.ans_judge(y.data)))
319325
for false_opt in false_opt_input:
320326
x = false_opt.packed()
321-
y, _ = self.ans_decode(PackedSequence(self.quesnet.we(x.data), x.batch_sizes),
327+
if x == (None, None):
328+
continue
329+
y, _ = self.ans_decode(PackedSequence(self.quesnet.we(x[0].data), x.batch_sizes),
322330
h.repeat(self.config.layers, 1, 1))
323331
floss = floss + F.binary_cross_entropy_with_logits(self.ans_judge(y.data),
324332
torch.zeros_like(self.ans_judge(y.data)))
325333
loss = floss * self.lambda_loss[1]
326334
# low-level loss
327-
left_hid = self.quesnet(left).pack_embeded.data[:, :self.rnn_size]
328-
right_hid = self.quesnet(right).pack_embeded.data[:, self.rnn_size:]
335+
left_hid = self.quesnet(left).pack_embeded.data[:, :self.rnn_size].clone()
336+
right_hid = self.quesnet(right).pack_embeded.data[:, self.rnn_size:].clone()
329337

330338
wloss = iloss = mloss = None
331339

332340
if words is not None:
333-
lwfea = torch.masked_select(left_hid, wmask.unsqueeze(1).bool()) \
334-
.view(-1, self.rnn_size)
335-
lout = self.lwoutput(lwfea)
336-
rwfea = torch.masked_select(right_hid, wmask.unsqueeze(1).bool()) \
337-
.view(-1, self.rnn_size)
338-
rout = self.rwoutput(rwfea)
339-
out = self.woutput(torch.cat([lwfea, rwfea], dim=1))
341+
lwfea = torch.masked_select(left_hid.clone(), wmask.unsqueeze(1).bool()) \
342+
.view(-1, self.rnn_size).clone()
343+
lout = self.lwoutput(lwfea.clone())
344+
rwfea = torch.masked_select(right_hid.clone(), wmask.unsqueeze(1).bool()) \
345+
.view(-1, self.rnn_size).clone()
346+
rout = self.rwoutput(rwfea.clone())
347+
out = self.woutput(torch.cat([lwfea.clone(), rwfea.clone()], dim=1).clone())
340348
wloss = (F.cross_entropy(out, words) + F.cross_entropy(lout, words) + F.
341349
cross_entropy(rout, words)) * self.quesnet.lambda_input[0] / 3
342350
wloss *= self.lambda_loss[0]
343351
loss = loss + wloss
344352

345353
if ims is not None:
346-
lifea = torch.masked_select(left_hid, imask.unsqueeze(1).bool()) \
347-
.view(-1, self.rnn_size)
348-
lout = self.lioutput(lifea)
349-
rifea = torch.masked_select(right_hid, imask.unsqueeze(1).bool()) \
350-
.view(-1, self.rnn_size)
351-
rout = self.rioutput(rifea)
352-
out = self.ioutput(torch.cat([lifea, rifea], dim=1))
354+
lifea = torch.masked_select(left_hid.clone(), imask.unsqueeze(1).bool()) \
355+
.view(-1, self.rnn_size).clone()
356+
lout = self.lioutput(lifea.clone())
357+
rifea = torch.masked_select(right_hid.clone(), imask.unsqueeze(1).bool()) \
358+
.view(-1, self.rnn_size).clone()
359+
rout = self.rioutput(rifea.clone())
360+
out = self.ioutput(torch.cat([lifea.clone(), rifea.clone()], dim=1).clone())
353361
iloss = (self.quesnet.ie.loss(ims, out) + self.quesnet.ie.loss(ims, lout) + self.quesnet.ie.
354362
loss(ims, rout)) * self.quesnet.lambda_input[1] / 3
355363
iloss *= self.lambda_loss[0]
356364
loss = loss + iloss
357365

358366
if metas is not None:
359-
lmfea = torch.masked_select(left_hid, mmask.unsqueeze(1).bool()) \
360-
.view(-1, self.rnn_size)
361-
lout = self.lmoutput(lmfea)
362-
rmfea = torch.masked_select(right_hid, mmask.unsqueeze(1).bool()) \
363-
.view(-1, self.rnn_size)
364-
rout = self.rmoutput(rmfea)
365-
out = self.moutput(torch.cat([lmfea, rmfea], dim=1))
367+
lmfea = torch.masked_select(left_hid.clone(), mmask.unsqueeze(1).bool()) \
368+
.view(-1, self.rnn_size).clone()
369+
lout = self.lmoutput(lmfea.clone())
370+
rmfea = torch.masked_select(right_hid.clone(), mmask.unsqueeze(1).bool()) \
371+
.view(-1, self.rnn_size).clone()
372+
rout = self.rmoutput(rmfea.clone())
373+
out = self.moutput(torch.cat([lmfea.clone(), rmfea.clone()], dim=1).clone())
366374
mloss = (self.quesnet.me.loss(metas, out) + self.quesnet.me.loss(metas, lout) + self.quesnet.me.
367375
loss(metas, rout)) * self.quesnet.lambda_input[2] / 3
368376
mloss *= self.lambda_loss[0]

EduNLP/ModelZoo/quesnet/util.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,19 @@ def __init__(self, seqs, dtype=None, device=None):
1111
self.dtype = dtype
1212
self.device = device
1313
self.seqs = seqs
14-
self.lens = [len(x) for x in seqs]
14+
15+
if not seqs:
16+
self.lens = [0]
17+
else:
18+
self.lens = [len(x) for x in seqs]
1519

1620
self.ind = argsort(self.lens)[::-1]
1721
self.inv = argsort(self.ind)
1822
self.lens.sort(reverse=True)
1923
self._prefix = [0]
2024
self._index = {}
2125
c = 0
26+
2227
for i in range(self.lens[0]):
2328
for j in range(len(self.lens)):
2429
if self.lens[j] <= i:
@@ -28,10 +33,16 @@ def __init__(self, seqs, dtype=None, device=None):
2833

2934
def packed(self):
3035
ind = torch.tensor(self.ind, dtype=torch.long, device=self.device)
36+
if not ind.numel() or ind.max() >= self.padded()[0].size(1):
37+
return None, None
3138
padded = self.padded()[0].index_select(1, ind)
3239
return pack_padded_sequence(padded, torch.tensor(self.lens))
3340

3441
def padded(self, max_len=None, batch_first=False):
42+
if not self.seqs:
43+
return torch.empty((0, 0), dtype=self.dtype, device=self.device), \
44+
torch.empty((0, 0), dtype=torch.bool, device=self.device)
45+
3546
seqs = [torch.tensor(s, dtype=self.dtype, device=self.device)
3647
if not isinstance(s, torch.Tensor) else s
3748
for s in self.seqs]

0 commit comments

Comments
 (0)