Skip to content

Commit 5121808

Browse files
author
zhedong
committed
update instance with gamma
1 parent 2d8aa0f commit 5121808

File tree

3 files changed

+12
-9
lines changed

3 files changed

+12
-9
lines changed

.instance_loss.py.swp

-12 KB
Binary file not shown.

instance_loss.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,22 @@ def l2_norm(v):
1010
return v
1111

1212
class InstanceLoss(nn.Module):
13-
def __init__(self) -> None:
13+
def __init__(self, gamma = 1) -> None:
1414
super(InstanceLoss, self).__init__()
15+
self.gamma = gamma
1516

16-
def forward(self, feature, label) -> Tensor:
17+
def forward(self, feature, label = None) -> Tensor:
1718
# Dual-Path Convolutional Image-Text Embeddings with Instance Loss, ACM TOMM 2020
1819
# https://arxiv.org/abs/1711.05535
19-
# using cross-entropy loss for every class
20+
# using cross-entropy loss for every sample if label is not available. else use given label.
2021
normed_feature = l2_norm(feature)
21-
sim1 = torch.mm(normed_feature, torch.t(normed_feature))
22-
sim2 = sim1.t()
23-
#sim_label = torch.arange(sim1.size(0)).cuda().detach()
24-
_, sim_label = torch.unique(label, return_inverse=True)
25-
loss = F.cross_entropy(sim1, sim_label) + F.cross_entropy(sim2, sim_label)
22+
sim1 = torch.mm(normed_feature*self.gamma, torch.t(normed_feature))
23+
#sim2 = sim1.t()
24+
if label is None:
25+
sim_label = torch.arange(sim1.size(0)).cuda().detach()
26+
else:
27+
_, sim_label = torch.unique(label, return_inverse=True)
28+
loss = F.cross_entropy(sim1, sim_label) #+ F.cross_entropy(sim2, sim_label)
2629
return loss
2730

2831

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
216216
if opt.contrast:
217217
criterion_contrast = losses.ContrastiveLoss(pos_margin=0, neg_margin=1)
218218
if opt.instance:
219-
criterion_instance = InstanceLoss()
219+
criterion_instance = InstanceLoss(gamma=1)
220220
if opt.sphere:
221221
criterion_sphere = losses.SphereFaceLoss(num_classes=opt.nclasses, embedding_size=512, margin=4)
222222
for epoch in range(num_epochs):

0 commit comments

Comments
 (0)