@@ -114,6 +114,10 @@ def make_batch(self, data, device, pretrain=False):
114
114
ans_input = []
115
115
ans_output = []
116
116
false_options = [[] for i in range (3 )]
117
+
118
+ if not isinstance (data , list ):
119
+ data = [data ]
120
+
117
121
for q in data :
118
122
meta = torch .zeros (len (self .stoi [self .meta ])).to (device )
119
123
meta [q .labels .get (self .meta ) or []] = 1
@@ -192,6 +196,7 @@ def make_batch(self, data, device, pretrain=False):
192
196
words = torch .cat (words , dim = 0 ) if words else None
193
197
ims = torch .cat (ims , dim = 0 ) if ims else None
194
198
metas = torch .cat (metas , dim = 0 ) if metas else None
199
+
195
200
if pretrain :
196
201
return (
197
202
lembs , rembs , words , ims , metas , wmask , imask , mmask ,
@@ -302,67 +307,70 @@ def __init__(self, _stoi=None, pretrained_embs: np.ndarray = None, pretrained_im
302
307
self .config = PretrainedConfig .from_dict (self .config )
303
308
304
309
def forward (self , batch ):
305
- left , right , words , ims , metas , wmask , imask , mmask , inputs , ans_input , ans_output , false_opt_input = batch
310
+ left , right , words , ims , metas , wmask , imask , mmask , inputs , ans_input , ans_output , false_opt_input = batch [ 0 ]
306
311
307
312
# high-level loss
308
313
outputs = self .quesnet (inputs )
309
314
embeded = outputs .embeded
310
315
h = outputs .hidden
311
316
312
317
x = ans_input .packed ()
313
- y , _ = self .ans_decode (PackedSequence (self .quesnet .we (x .data ), x .batch_sizes ),
318
+
319
+ y , _ = self .ans_decode (PackedSequence (self .quesnet .we (x [0 ].data ), x .batch_sizes ),
314
320
h .repeat (self .config .layers , 1 , 1 ))
315
321
floss = F .cross_entropy (self .ans_output (y .data ),
316
322
ans_output .packed ().data )
317
323
floss = floss + F .binary_cross_entropy_with_logits (self .ans_judge (y .data ),
318
324
torch .ones_like (self .ans_judge (y .data )))
319
325
for false_opt in false_opt_input :
320
326
x = false_opt .packed ()
321
- y , _ = self .ans_decode (PackedSequence (self .quesnet .we (x .data ), x .batch_sizes ),
327
+ if x == (None , None ):
328
+ continue
329
+ y , _ = self .ans_decode (PackedSequence (self .quesnet .we (x [0 ].data ), x .batch_sizes ),
322
330
h .repeat (self .config .layers , 1 , 1 ))
323
331
floss = floss + F .binary_cross_entropy_with_logits (self .ans_judge (y .data ),
324
332
torch .zeros_like (self .ans_judge (y .data )))
325
333
loss = floss * self .lambda_loss [1 ]
326
334
# low-level loss
327
- left_hid = self .quesnet (left ).pack_embeded .data [:, :self .rnn_size ]
328
- right_hid = self .quesnet (right ).pack_embeded .data [:, self .rnn_size :]
335
+ left_hid = self .quesnet (left ).pack_embeded .data [:, :self .rnn_size ]. clone ()
336
+ right_hid = self .quesnet (right ).pack_embeded .data [:, self .rnn_size :]. clone ()
329
337
330
338
wloss = iloss = mloss = None
331
339
332
340
if words is not None :
333
- lwfea = torch .masked_select (left_hid , wmask .unsqueeze (1 ).bool ()) \
334
- .view (- 1 , self .rnn_size )
335
- lout = self .lwoutput (lwfea )
336
- rwfea = torch .masked_select (right_hid , wmask .unsqueeze (1 ).bool ()) \
337
- .view (- 1 , self .rnn_size )
338
- rout = self .rwoutput (rwfea )
339
- out = self .woutput (torch .cat ([lwfea , rwfea ], dim = 1 ))
341
+ lwfea = torch .masked_select (left_hid . clone () , wmask .unsqueeze (1 ).bool ()) \
342
+ .view (- 1 , self .rnn_size ). clone ()
343
+ lout = self .lwoutput (lwfea . clone () )
344
+ rwfea = torch .masked_select (right_hid . clone () , wmask .unsqueeze (1 ).bool ()) \
345
+ .view (- 1 , self .rnn_size ). clone ()
346
+ rout = self .rwoutput (rwfea . clone () )
347
+ out = self .woutput (torch .cat ([lwfea . clone () , rwfea . clone () ], dim = 1 ). clone ( ))
340
348
wloss = (F .cross_entropy (out , words ) + F .cross_entropy (lout , words ) + F .
341
349
cross_entropy (rout , words )) * self .quesnet .lambda_input [0 ] / 3
342
350
wloss *= self .lambda_loss [0 ]
343
351
loss = loss + wloss
344
352
345
353
if ims is not None :
346
- lifea = torch .masked_select (left_hid , imask .unsqueeze (1 ).bool ()) \
347
- .view (- 1 , self .rnn_size )
348
- lout = self .lioutput (lifea )
349
- rifea = torch .masked_select (right_hid , imask .unsqueeze (1 ).bool ()) \
350
- .view (- 1 , self .rnn_size )
351
- rout = self .rioutput (rifea )
352
- out = self .ioutput (torch .cat ([lifea , rifea ], dim = 1 ))
354
+ lifea = torch .masked_select (left_hid . clone () , imask .unsqueeze (1 ).bool ()) \
355
+ .view (- 1 , self .rnn_size ). clone ()
356
+ lout = self .lioutput (lifea . clone () )
357
+ rifea = torch .masked_select (right_hid . clone () , imask .unsqueeze (1 ).bool ()) \
358
+ .view (- 1 , self .rnn_size ). clone ()
359
+ rout = self .rioutput (rifea . clone () )
360
+ out = self .ioutput (torch .cat ([lifea . clone () , rifea . clone () ], dim = 1 ). clone ( ))
353
361
iloss = (self .quesnet .ie .loss (ims , out ) + self .quesnet .ie .loss (ims , lout ) + self .quesnet .ie .
354
362
loss (ims , rout )) * self .quesnet .lambda_input [1 ] / 3
355
363
iloss *= self .lambda_loss [0 ]
356
364
loss = loss + iloss
357
365
358
366
if metas is not None :
359
- lmfea = torch .masked_select (left_hid , mmask .unsqueeze (1 ).bool ()) \
360
- .view (- 1 , self .rnn_size )
361
- lout = self .lmoutput (lmfea )
362
- rmfea = torch .masked_select (right_hid , mmask .unsqueeze (1 ).bool ()) \
363
- .view (- 1 , self .rnn_size )
364
- rout = self .rmoutput (rmfea )
365
- out = self .moutput (torch .cat ([lmfea , rmfea ], dim = 1 ))
367
+ lmfea = torch .masked_select (left_hid . clone () , mmask .unsqueeze (1 ).bool ()) \
368
+ .view (- 1 , self .rnn_size ). clone ()
369
+ lout = self .lmoutput (lmfea . clone () )
370
+ rmfea = torch .masked_select (right_hid . clone () , mmask .unsqueeze (1 ).bool ()) \
371
+ .view (- 1 , self .rnn_size ). clone ()
372
+ rout = self .rmoutput (rmfea . clone () )
373
+ out = self .moutput (torch .cat ([lmfea . clone () , rmfea . clone () ], dim = 1 ). clone ( ))
366
374
mloss = (self .quesnet .me .loss (metas , out ) + self .quesnet .me .loss (metas , lout ) + self .quesnet .me .
367
375
loss (metas , rout )) * self .quesnet .lambda_input [2 ] / 3
368
376
mloss *= self .lambda_loss [0 ]
0 commit comments