@@ -10,19 +10,22 @@ def l2_norm(v):
10
10
return v
11
11
12
12
class InstanceLoss (nn .Module ):
13
- def __init__ (self ) -> None :
13
+ def __init__ (self , gamma = 1 ) -> None :
14
14
super (InstanceLoss , self ).__init__ ()
15
+ self .gamma = gamma
15
16
16
- def forward (self , feature , label ) -> Tensor :
17
+ def forward (self , feature , label = None ) -> Tensor :
17
18
# Dual-Path Convolutional Image-Text Embeddings with Instance Loss, ACM TOMM 2020
18
19
# https://arxiv.org/abs/1711.05535
19
- # using cross-entropy loss for every class
20
+ # using cross-entropy loss for every sample if label is not available. else use given label.
20
21
normed_feature = l2_norm (feature )
21
- sim1 = torch .mm (normed_feature , torch .t (normed_feature ))
22
- sim2 = sim1 .t ()
23
- #sim_label = torch.arange(sim1.size(0)).cuda().detach()
24
- _ , sim_label = torch .unique (label , return_inverse = True )
25
- loss = F .cross_entropy (sim1 , sim_label ) + F .cross_entropy (sim2 , sim_label )
22
+ sim1 = torch .mm (normed_feature * self .gamma , torch .t (normed_feature ))
23
+ #sim2 = sim1.t()
24
+ if label is None :
25
+ sim_label = torch .arange (sim1 .size (0 )).cuda ().detach ()
26
+ else :
27
+ _ , sim_label = torch .unique (label , return_inverse = True )
28
+ loss = F .cross_entropy (sim1 , sim_label ) #+ F.cross_entropy(sim2, sim_label)
26
29
return loss
27
30
28
31
0 commit comments