Skip to content

Commit bfefb99

Browse files
committed
Add ADE20k and PascalVOC modes in pspnet.py
1 parent a11e766 commit bfefb99

File tree

2 files changed

+48
-20
lines changed

2 files changed

+48
-20
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ __pycache__/
55

66
# Torch Models
77
*.pkl
8+
*.pth
89
current_train.py
9-
video_test.py
10+
video_test*.py
1011

1112
# C extensions
1213
*.so

ptsemseg/models/pspnet.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,33 @@
33
import torch.nn as nn
44

55
from math import ceil
6+
from torch.autograd import Variable
67

78
from ptsemseg import caffe_pb2
89
from ptsemseg.models.utils import *
910

11+
pspnet_specs = {
12+
'pascalvoc':
13+
{
14+
'n_classes': 21,
15+
'input_size': (473, 473),
16+
'block_config': [3, 4, 23, 3],
17+
},
18+
19+
'cityscapes':
20+
{
21+
'n_classes': 19,
22+
'input_size': (713, 713),
23+
'block_config': [3, 4, 23, 3],
24+
},
25+
26+
'ade20k':
27+
{
28+
'n_classes': 150,
29+
'input_size': (473, 473),
30+
'block_config': [3, 4, 6, 3],
31+
},
32+
}
1033

1134
class pspnet(nn.Module):
1235

@@ -23,12 +46,18 @@ class pspnet(nn.Module):
2346
2447
"""
2548

26-
def __init__(self, n_classes=21, block_config=[3, 4, 23, 3]):
49+
def __init__(self,
50+
n_classes=21,
51+
block_config=[3, 4, 23, 3],
52+
input_size=(473,473),
53+
version=None):
54+
2755
super(pspnet, self).__init__()
2856

29-
self.block_config = block_config
30-
self.n_classes = n_classes
31-
57+
self.block_config = pspnet_specs[version]['block_config'] if version is not None else block_config
58+
self.n_classes = pspnet_specs[version]['n_classes'] if version is not None else n_classes
59+
self.input_size = pspnet_specs[version]['input_size'] if version is not None else input_size
60+
3261
# Encoder
3362
self.convbnrelu1_1 = conv2DBatchNormRelu(in_channels=3, k_size=3, n_filters=64,
3463
padding=1, stride=2, bias=False)
@@ -51,7 +80,7 @@ def __init__(self, n_classes=21, block_config=[3, 4, 23, 3]):
5180
# Final conv layers
5281
self.cbr_final = conv2DBatchNormRelu(4096, 512, 3, 1, 1, False)
5382
self.dropout = nn.Dropout2d(p=0.1, inplace=True)
54-
self.classification = nn.Conv2d(512, n_classes, 1, 1, 0)
83+
self.classification = nn.Conv2d(512, self.n_classes, 1, 1, 0)
5584

5685
def forward(self, x):
5786
inp_shape = x.shape[2:]
@@ -224,7 +253,7 @@ def _transfer_residual(prefix, block):
224253
_transfer_residual(k, v)
225254

226255

227-
def tile_predict(self, img, side=713, n_classes=19):
256+
def tile_predict(self, img):
228257
"""
229258
Predict by takin overlapping tiles from the image.
230259
@@ -236,6 +265,8 @@ def tile_predict(self, img, side=713, n_classes=19):
236265
:param n_classes: int with number of classes in seg output.
237266
"""
238267

268+
side = self.input_size[0]
269+
n_classes = self.n_classes
239270
h, w = img.shape[1:]
240271
n = int(max(h,w) / float(side) + 1)
241272
stride_x = ( h - side ) / float(n)
@@ -280,16 +311,16 @@ def tile_predict(self, img, side=713, n_classes=19):
280311
if __name__ == '__main__':
281312
cd = 0
282313
from torch.autograd import Variable
314+
import matplotlib.pyplot as plt
283315
import scipy.misc as m
284316
from ptsemseg.loader.cityscapes_loader import cityscapesLoader as cl
285-
psp = pspnet(n_classes=19)
317+
psp = pspnet(version='ade20k')
286318

287319
# Just need to do this one time
288320
#psp.load_pretrained_model(model_path='/home/meet/models/pspnet101_cityscapes.caffemodel')
321+
psp.load_pretrained_model(model_path='/home/meet/models/pspnet50_ADE20K.caffemodel')
289322

290-
torch.save(psp.state_dict(), "psp.pth")
291-
psp.load_state_dict(torch.load('psp.pth'))
292-
323+
# psp.load_state_dict(torch.load('psp.pth'))
293324

294325
psp.float()
295326
psp.cuda(cd)
@@ -303,17 +334,13 @@ def tile_predict(self, img, side=713, n_classes=19):
303334
img = img.astype(np.float64)
304335
img -= np.array([123.68, 116.779, 103.939])[:, None, None]
305336
img = np.copy(img[::-1, :, :])
306-
flp = np.copy(img[:, :, ::-1])
307-
308-
309-
# Warmup model
310-
#warmup = Variable(torch.unsqueeze(torch.from_numpy(flp).float(), 0).cuda(cd))
311-
#for i in range(5):
312-
# _ = psp(warmup[:,:,300:300+713,300:300+713])
337+
flp = np.copy(img[:, :, ::-1])
313338

314339
out = psp.tile_predict(img)
315340
pred = np.argmax(out, axis=0)
316-
decoded = dst.decode_segmap(pred)
317-
m.imsave('frankfurt_tiled.png', decoded)
341+
#decoded = dst.decode_segmap(pred)
342+
# m.imsave('ade20k_sttutgart_tiled.png', decoded)
343+
m.imsave('ade20k_sttutgart_tiled.png', pred)
318344

345+
torch.save(psp.state_dict(), "psp_ade20k.pth")
319346
print("Output Shape {} \t Input Shape {}".format(out.shape, img.shape))

0 commit comments

Comments
 (0)