3
3
import torch .nn as nn
4
4
5
5
from math import ceil
6
+ from torch .autograd import Variable
6
7
7
8
from ptsemseg import caffe_pb2
8
9
from ptsemseg .models .utils import *
9
10
11
+ pspnet_specs = {
12
+ 'pascalvoc' :
13
+ {
14
+ 'n_classes' : 21 ,
15
+ 'input_size' : (473 , 473 ),
16
+ 'block_config' : [3 , 4 , 23 , 3 ],
17
+ },
18
+
19
+ 'cityscapes' :
20
+ {
21
+ 'n_classes' : 19 ,
22
+ 'input_size' : (713 , 713 ),
23
+ 'block_config' : [3 , 4 , 23 , 3 ],
24
+ },
25
+
26
+ 'ade20k' :
27
+ {
28
+ 'n_classes' : 150 ,
29
+ 'input_size' : (473 , 473 ),
30
+ 'block_config' : [3 , 4 , 6 , 3 ],
31
+ },
32
+ }
10
33
11
34
class pspnet (nn .Module ):
12
35
@@ -23,12 +46,18 @@ class pspnet(nn.Module):
23
46
24
47
"""
25
48
26
- def __init__ (self , n_classes = 21 , block_config = [3 , 4 , 23 , 3 ]):
49
+ def __init__ (self ,
50
+ n_classes = 21 ,
51
+ block_config = [3 , 4 , 23 , 3 ],
52
+ input_size = (473 ,473 ),
53
+ version = None ):
54
+
27
55
super (pspnet , self ).__init__ ()
28
56
29
- self .block_config = block_config
30
- self .n_classes = n_classes
31
-
57
+ self .block_config = pspnet_specs [version ]['block_config' ] if version is not None else block_config
58
+ self .n_classes = pspnet_specs [version ]['n_classes' ] if version is not None else n_classes
59
+ self .input_size = pspnet_specs [version ]['input_size' ] if version is not None else input_size
60
+
32
61
# Encoder
33
62
self .convbnrelu1_1 = conv2DBatchNormRelu (in_channels = 3 , k_size = 3 , n_filters = 64 ,
34
63
padding = 1 , stride = 2 , bias = False )
@@ -51,7 +80,7 @@ def __init__(self, n_classes=21, block_config=[3, 4, 23, 3]):
51
80
# Final conv layers
52
81
self .cbr_final = conv2DBatchNormRelu (4096 , 512 , 3 , 1 , 1 , False )
53
82
self .dropout = nn .Dropout2d (p = 0.1 , inplace = True )
54
- self .classification = nn .Conv2d (512 , n_classes , 1 , 1 , 0 )
83
+ self .classification = nn .Conv2d (512 , self . n_classes , 1 , 1 , 0 )
55
84
56
85
def forward (self , x ):
57
86
inp_shape = x .shape [2 :]
@@ -224,7 +253,7 @@ def _transfer_residual(prefix, block):
224
253
_transfer_residual (k , v )
225
254
226
255
227
- def tile_predict (self , img , side = 713 , n_classes = 19 ):
256
+ def tile_predict (self , img ):
228
257
"""
229
258
Predict by takin overlapping tiles from the image.
230
259
@@ -236,6 +265,8 @@ def tile_predict(self, img, side=713, n_classes=19):
236
265
:param n_classes: int with number of classes in seg output.
237
266
"""
238
267
268
+ side = self .input_size [0 ]
269
+ n_classes = self .n_classes
239
270
h , w = img .shape [1 :]
240
271
n = int (max (h ,w ) / float (side ) + 1 )
241
272
stride_x = ( h - side ) / float (n )
@@ -280,16 +311,16 @@ def tile_predict(self, img, side=713, n_classes=19):
280
311
if __name__ == '__main__' :
281
312
cd = 0
282
313
from torch .autograd import Variable
314
+ import matplotlib .pyplot as plt
283
315
import scipy .misc as m
284
316
from ptsemseg .loader .cityscapes_loader import cityscapesLoader as cl
285
- psp = pspnet (n_classes = 19 )
317
+ psp = pspnet (version = 'ade20k' )
286
318
287
319
# Just need to do this one time
288
320
#psp.load_pretrained_model(model_path='/home/meet/models/pspnet101_cityscapes.caffemodel')
321
+ psp .load_pretrained_model (model_path = '/home/meet/models/pspnet50_ADE20K.caffemodel' )
289
322
290
- torch .save (psp .state_dict (), "psp.pth" )
291
- psp .load_state_dict (torch .load ('psp.pth' ))
292
-
323
+ # psp.load_state_dict(torch.load('psp.pth'))
293
324
294
325
psp .float ()
295
326
psp .cuda (cd )
@@ -303,17 +334,13 @@ def tile_predict(self, img, side=713, n_classes=19):
303
334
img = img .astype (np .float64 )
304
335
img -= np .array ([123.68 , 116.779 , 103.939 ])[:, None , None ]
305
336
img = np .copy (img [::- 1 , :, :])
306
- flp = np .copy (img [:, :, ::- 1 ])
307
-
308
-
309
- # Warmup model
310
- #warmup = Variable(torch.unsqueeze(torch.from_numpy(flp).float(), 0).cuda(cd))
311
- #for i in range(5):
312
- # _ = psp(warmup[:,:,300:300+713,300:300+713])
337
+ flp = np .copy (img [:, :, ::- 1 ])
313
338
314
339
out = psp .tile_predict (img )
315
340
pred = np .argmax (out , axis = 0 )
316
- decoded = dst .decode_segmap (pred )
317
- m .imsave ('frankfurt_tiled.png' , decoded )
341
+ #decoded = dst.decode_segmap(pred)
342
+ # m.imsave('ade20k_sttutgart_tiled.png', decoded)
343
+ m .imsave ('ade20k_sttutgart_tiled.png' , pred )
318
344
345
+ torch .save (psp .state_dict (), "psp_ade20k.pth" )
319
346
print ("Output Shape {} \t Input Shape {}" .format (out .shape , img .shape ))
0 commit comments