1
+ import os
2
+ import collections
3
+ import torch
4
+ import torchvision
5
+ import numpy as np
6
+ import scipy .misc as m
7
+ import scipy .io as io
8
+ import matplotlib .pyplot as plt
9
+
10
+ from torch .utils import data
11
+
12
+ from ptsemseg .utils import recursive_glob
13
+ from ptsemseg .augmentations import *
14
+
15
+
16
+ class NYUv2Loader (data .Dataset ):
17
+ """NYUv2 loader
18
+
19
+ Download From (only 13 classes):
20
+ test source: http://www.doc.ic.ac.uk/~ahanda/nyu_test_rgb.tgz
21
+ train source: http://www.doc.ic.ac.uk/~ahanda/nyu_train_rgb.tgz
22
+ test_labels source: https://github.com/ankurhanda/nyuv2-meta-data/raw/master/test_labels_13/nyuv2_test_class13.tgz
23
+ train_labels source: https://github.com/ankurhanda/nyuv2-meta-data/raw/master/train_labels_13/nyuv2_train_class13.tgz
24
+
25
+ """
26
+
27
+
28
+ def __init__ (self , root , split = "training" , is_transform = False , img_size = (480 ,640 ), augmentations = None ):
29
+ self .root = root
30
+ self .is_transform = is_transform
31
+ self .n_classes = 14
32
+ self .augmentations = augmentations
33
+ self .img_size = img_size if isinstance (img_size , tuple ) else (img_size , img_size )
34
+ self .mean = np .array ([104.00699 , 116.66877 , 122.67892 ])
35
+ self .files = collections .defaultdict (list )
36
+ self .cmap = self .color_map (normalized = False )
37
+
38
+ split_map = {"training" : 'train' ,
39
+ "val" : 'test' ,}
40
+ self .split = split_map [split ]
41
+
42
+ for split in ["train" , "test" ]:
43
+ file_list = recursive_glob (rootdir = self .root + split + '/' , suffix = 'png' )
44
+ self .files [split ] = file_list
45
+
46
+ def __len__ (self ):
47
+ return len (self .files [self .split ])
48
+
49
+
50
+ def __getitem__ (self , index ):
51
+ img_path = self .files [self .split ][index ].rstrip ()
52
+ img_number = img_path .split ('_' )[- 1 ][:4 ]
53
+ lbl_path = os .path .join (self .root , self .split + '_annot' , 'new_nyu_class13_' + img_number + '.png' )
54
+
55
+ img = m .imread (img_path )
56
+ img = np .array (img , dtype = np .uint8 )
57
+
58
+ lbl = m .imread (lbl_path )
59
+ lbl = np .array (lbl , dtype = np .uint8 )
60
+
61
+ if not (len (img .shape ) == 3 and len (lbl .shape ) == 2 ):
62
+ return self .__getitem__ (np .random .randint (0 , self .__len__ ()))
63
+
64
+ if self .augmentations is not None :
65
+ img , lbl = self .augmentations (img , lbl )
66
+
67
+ if self .is_transform :
68
+ img , lbl = self .transform (img , lbl )
69
+
70
+ return img , lbl
71
+
72
+
73
+ def transform (self , img , lbl ):
74
+ img = img [:, :, ::- 1 ]
75
+ img = img .astype (np .float64 )
76
+ img -= self .mean
77
+ img = m .imresize (img , (self .img_size [0 ], self .img_size [1 ]))
78
+ # Resize scales images from 0 to 255, thus we need
79
+ # to divide by 255.0
80
+ img = img .astype (float ) / 255.0
81
+ # NHWC -> NCWH
82
+ img = img .transpose (2 , 0 , 1 )
83
+
84
+ classes = np .unique (lbl )
85
+ lbl = lbl .astype (float )
86
+ lbl = m .imresize (lbl , (self .img_size [0 ], self .img_size [1 ]), 'nearest' , mode = 'F' )
87
+ lbl = lbl .astype (int )
88
+ assert (np .all (classes == np .unique (lbl )))
89
+
90
+ img = torch .from_numpy (img ).float ()
91
+ lbl = torch .from_numpy (lbl ).long ()
92
+ return img , lbl
93
+
94
+
95
+ def color_map (self , N = 256 , normalized = False ):
96
+ """
97
+ Return Color Map in PASCAL VOC format
98
+ """
99
+
100
+ def bitget (byteval , idx ):
101
+ return ((byteval & (1 << idx )) != 0 )
102
+
103
+ dtype = 'float32' if normalized else 'uint8'
104
+ cmap = np .zeros ((N , 3 ), dtype = dtype )
105
+ for i in range (N ):
106
+ r = g = b = 0
107
+ c = i
108
+ for j in range (8 ):
109
+ r = r | (bitget (c , 0 ) << 7 - j )
110
+ g = g | (bitget (c , 1 ) << 7 - j )
111
+ b = b | (bitget (c , 2 ) << 7 - j )
112
+ c = c >> 3
113
+
114
+ cmap [i ] = np .array ([r , g , b ])
115
+
116
+ cmap = cmap / 255.0 if normalized else cmap
117
+ return cmap
118
+
119
+
120
+ def decode_segmap (self , temp ):
121
+ r = temp .copy ()
122
+ g = temp .copy ()
123
+ b = temp .copy ()
124
+ for l in range (0 , self .n_classes ):
125
+ r [temp == l ] = self .cmap [l ,0 ]
126
+ g [temp == l ] = self .cmap [l ,1 ]
127
+ b [temp == l ] = self .cmap [l ,2 ]
128
+
129
+ rgb = np .zeros ((temp .shape [0 ], temp .shape [1 ], 3 ))
130
+ rgb [:, :, 0 ] = r / 255.0
131
+ rgb [:, :, 1 ] = g / 255.0
132
+ rgb [:, :, 2 ] = b / 255.0
133
+ return rgb
134
+
135
+
136
+ if __name__ == '__main__' :
137
+ import torchvision
138
+ import matplotlib .pyplot as plt
139
+
140
+ augmentations = Compose ([Scale (512 ),
141
+ RandomRotate (10 ),
142
+ RandomHorizontallyFlip ()])
143
+
144
+ local_path = '/home/meet/datasets/NYUv2/'
145
+ dst = NYUv2Loader (local_path , is_transform = True , augmentations = augmentations )
146
+ bs = 4
147
+ trainloader = data .DataLoader (dst , batch_size = bs , num_workers = 0 )
148
+ for i , data in enumerate (trainloader ):
149
+ imgs , labels = data
150
+ imgs = imgs .numpy ()[:, ::- 1 , :, :]
151
+ imgs = np .transpose (imgs , [0 ,2 ,3 ,1 ])
152
+ f , axarr = plt .subplots (bs ,2 )
153
+ for j in range (bs ):
154
+ axarr [j ][0 ].imshow (imgs [j ])
155
+ axarr [j ][1 ].imshow (dst .decode_segmap (labels .numpy ()[j ]))
156
+ plt .show ()
157
+ a = raw_input ()
158
+ if a == 'ex' :
159
+ break
160
+ else :
161
+ plt .close ()
0 commit comments