Skip to content

Commit 2d8aa0f

Browse files
author
zhedong
committed
add instance loss
1 parent 4ebfb0b commit 2d8aa0f

File tree

3 files changed

+44
-2
lines changed

3 files changed

+44
-2
lines changed

.instance_loss.py.swp

12 KB
Binary file not shown.

instance_loss.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from typing import Tuple
2+
3+
import torch
4+
from torch import nn, Tensor
5+
import torch.nn.functional as F
6+
7+
def l2_norm(v):
8+
fnorm = torch.norm(v, p=2, dim=1, keepdim=True) + 1e-6
9+
v = v.div(fnorm.expand_as(v))
10+
return v
11+
12+
class InstanceLoss(nn.Module):
13+
def __init__(self) -> None:
14+
super(InstanceLoss, self).__init__()
15+
16+
def forward(self, feature, label) -> Tensor:
17+
# Dual-Path Convolutional Image-Text Embeddings with Instance Loss, ACM TOMM 2020
18+
# https://arxiv.org/abs/1711.05535
19+
# using cross-entropy loss for every class
20+
normed_feature = l2_norm(feature)
21+
sim1 = torch.mm(normed_feature, torch.t(normed_feature))
22+
sim2 = sim1.t()
23+
#sim_label = torch.arange(sim1.size(0)).cuda().detach()
24+
_, sim_label = torch.unique(label, return_inverse=True)
25+
loss = F.cross_entropy(sim1, sim_label) + F.cross_entropy(sim2, sim_label)
26+
return loss
27+
28+
29+
if __name__ == "__main__":
30+
feat = nn.functional.normalize(torch.rand(256, 64, requires_grad=True))
31+
lbl = torch.randint(high=10, size=(256,))
32+
33+
criterion = InstanceLoss()
34+
instance_loss = criterion(feat, lbl)
35+
36+
print(instance_loss)

train.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import yaml
2323
from shutil import copyfile
2424
from circle_loss import CircleLoss, convert_label_to_similarity
25+
from instance_loss import InstanceLoss
2526

2627
version = torch.__version__
2728
#fp16
@@ -61,6 +62,7 @@
6162
parser.add_argument('--circle', action='store_true', help='use Circle loss' )
6263
parser.add_argument('--cosface', action='store_true', help='use CosFace loss' )
6364
parser.add_argument('--contrast', action='store_true', help='use contrast loss' )
65+
parser.add_argument('--instance', action='store_true', help='use instance loss' )
6466
parser.add_argument('--triplet', action='store_true', help='use triplet loss' )
6567
parser.add_argument('--lifted', action='store_true', help='use lifted loss' )
6668
parser.add_argument('--sphere', action='store_true', help='use sphere loss' )
@@ -213,6 +215,8 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
213215
criterion_lifted = losses.GeneralizedLiftedStructureLoss(neg_margin=1, pos_margin=0)
214216
if opt.contrast:
215217
criterion_contrast = losses.ContrastiveLoss(pos_margin=0, neg_margin=1)
218+
if opt.instance:
219+
criterion_instance = InstanceLoss()
216220
if opt.sphere:
217221
criterion_sphere = losses.SphereFaceLoss(num_classes=opt.nclasses, embedding_size=512, margin=4)
218222
for epoch in range(num_epochs):
@@ -258,7 +262,7 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
258262

259263
sm = nn.Softmax(dim=1)
260264
log_sm = nn.LogSoftmax(dim=1)
261-
return_feature = opt.arcface or opt.cosface or opt.circle or opt.triplet or opt.contrast or opt.lifted or opt.sphere
265+
return_feature = opt.arcface or opt.cosface or opt.circle or opt.triplet or opt.contrast or opt.instance or opt.lifted or opt.sphere
262266
if return_feature:
263267
logits, ff = outputs
264268
fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
@@ -278,6 +282,8 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
278282
loss += criterion_lifted(ff, labels) #/now_batch_size
279283
if opt.contrast:
280284
loss += criterion_contrast(ff, labels) #/now_batch_size
285+
if opt.instance:
286+
loss += criterion_instance(ff, labels)
281287
if opt.sphere:
282288
loss += criterion_sphere(ff, labels)/now_batch_size
283289
elif opt.PCB: # PCB
@@ -421,7 +427,7 @@ def save_network(network, epoch_label):
421427
# Load a pretrainied model and reset final fully connected layer.
422428
#
423429

424-
return_feature = opt.arcface or opt.cosface or opt.circle or opt.triplet or opt.contrast or opt.lifted or opt.sphere
430+
return_feature = opt.arcface or opt.cosface or opt.circle or opt.triplet or opt.contrast or opt.instance or opt.lifted or opt.sphere
425431

426432
if opt.use_dense:
427433
model = ft_net_dense(len(class_names), opt.droprate, circle = return_feature)

0 commit comments

Comments
 (0)