1
+ import os
2
+ import collections
3
+ import torch
4
+ import torchvision
5
+ import numpy as np
6
+ import scipy .misc as m
7
+ import scipy .io as io
8
+ import matplotlib .pyplot as plt
9
+
10
+ from torch .utils import data
11
+
12
+ from ptsemseg .utils import recursive_glob
13
+ from ptsemseg .augmentations import *
14
+
15
+
16
+ class SUNRGBDLoader (data .Dataset ):
17
+ """SUNRGBD loader
18
+
19
+ Download From:
20
+ http://www.doc.ic.ac.uk/~ahanda/SUNRGBD-test_images.tgz
21
+ test source: http://www.doc.ic.ac.uk/~ahanda/SUNRGBD-test_images.tgz
22
+ train source: http://www.doc.ic.ac.uk/~ahanda/SUNRGBD-train_images.tgz
23
+
24
+ first 5050 in this is test, later 5051 is train
25
+ test and train labels source: https://github.com/ankurhanda/sunrgbd-meta-data/raw/master/sunrgbd_train_test_labels.tar.gz
26
+ """
27
+
28
+ def __init__ (self , root , split = "training" , is_transform = False , img_size = (480 , 640 ), augmentations = None ):
29
+ self .root = root
30
+ self .is_transform = is_transform
31
+ self .n_classes = 38
32
+ self .augmentations = augmentations
33
+ self .img_size = img_size if isinstance (img_size , tuple ) else (img_size , img_size )
34
+ self .mean = np .array ([104.00699 , 116.66877 , 122.67892 ])
35
+ self .files = collections .defaultdict (list )
36
+ self .anno_files = collections .defaultdict (list )
37
+ self .cmap = self .color_map (normalized = False )
38
+
39
+ split_map = {"training" : 'train' ,
40
+ "val" : 'test' ,}
41
+ self .split = split_map [split ]
42
+
43
+ for split in ["train" , "test" ]:
44
+ file_list = sorted (recursive_glob (rootdir = self .root + split + '/' , suffix = 'jpg' ))
45
+ self .files [split ] = file_list
46
+
47
+ for split in ["train" , "test" ]:
48
+ file_list = sorted (recursive_glob (rootdir = self .root + 'annotations/' + split + '/' , suffix = 'png' ))
49
+ self .anno_files [split ] = file_list
50
+
51
+
52
+ def __len__ (self ):
53
+ return len (self .files [self .split ])
54
+
55
+
56
+ def __getitem__ (self , index ):
57
+ img_path = self .files [self .split ][index ].rstrip ()
58
+ lbl_path = self .anno_files [self .split ][index ].rstrip ()
59
+ # img_number = img_path.split('/')[-1]
60
+ # lbl_path = os.path.join(self.root, 'annotations', img_number).replace('jpg', 'png')
61
+
62
+ img = m .imread (img_path )
63
+ img = np .array (img , dtype = np .uint8 )
64
+
65
+ lbl = m .imread (lbl_path )
66
+ lbl = np .array (lbl , dtype = np .uint8 )
67
+
68
+ if not (len (img .shape ) == 3 and len (lbl .shape ) == 2 ):
69
+ return self .__getitem__ (np .random .randint (0 , self .__len__ ()))
70
+
71
+ if self .augmentations is not None :
72
+ img , lbl = self .augmentations (img , lbl )
73
+
74
+ if self .is_transform :
75
+ img , lbl = self .transform (img , lbl )
76
+
77
+ return img , lbl
78
+
79
+
80
+ def transform (self , img , lbl ):
81
+ img = img [:, :, ::- 1 ]
82
+ img = img .astype (np .float64 )
83
+ img -= self .mean
84
+ img = m .imresize (img , (self .img_size [0 ], self .img_size [1 ]))
85
+ # Resize scales images from 0 to 255, thus we need
86
+ # to divide by 255.0
87
+ img = img .astype (float ) / 255.0
88
+ # NHWC -> NCWH
89
+ img = img .transpose (2 , 0 , 1 )
90
+
91
+ classes = np .unique (lbl )
92
+ lbl = lbl .astype (float )
93
+ lbl = m .imresize (lbl , (self .img_size [0 ], self .img_size [1 ]), 'nearest' , mode = 'F' )
94
+ lbl = lbl .astype (int )
95
+ assert (np .all (classes == np .unique (lbl )))
96
+
97
+ img = torch .from_numpy (img ).float ()
98
+ lbl = torch .from_numpy (lbl ).long ()
99
+ return img , lbl
100
+
101
+
102
+ def color_map (self , N = 256 , normalized = False ):
103
+ """
104
+ Return Color Map in PASCAL VOC format
105
+ """
106
+
107
+ def bitget (byteval , idx ):
108
+ return ((byteval & (1 << idx )) != 0 )
109
+
110
+ dtype = 'float32' if normalized else 'uint8'
111
+ cmap = np .zeros ((N , 3 ), dtype = dtype )
112
+ for i in range (N ):
113
+ r = g = b = 0
114
+ c = i
115
+ for j in range (8 ):
116
+ r = r | (bitget (c , 0 ) << 7 - j )
117
+ g = g | (bitget (c , 1 ) << 7 - j )
118
+ b = b | (bitget (c , 2 ) << 7 - j )
119
+ c = c >> 3
120
+
121
+ cmap [i ] = np .array ([r , g , b ])
122
+
123
+ cmap = cmap / 255.0 if normalized else cmap
124
+ return cmap
125
+
126
+
127
+ def decode_segmap (self , temp ):
128
+ r = temp .copy ()
129
+ g = temp .copy ()
130
+ b = temp .copy ()
131
+ for l in range (0 , self .n_classes ):
132
+ r [temp == l ] = self .cmap [l ,0 ]
133
+ g [temp == l ] = self .cmap [l ,1 ]
134
+ b [temp == l ] = self .cmap [l ,2 ]
135
+
136
+ rgb = np .zeros ((temp .shape [0 ], temp .shape [1 ], 3 ))
137
+ rgb [:, :, 0 ] = r / 255.0
138
+ rgb [:, :, 1 ] = g / 255.0
139
+ rgb [:, :, 2 ] = b / 255.0
140
+ return rgb
141
+
142
+
143
+ if __name__ == '__main__' :
144
+ import torchvision
145
+ import matplotlib .pyplot as plt
146
+
147
+ augmentations = Compose ([Scale (512 ),
148
+ RandomRotate (10 ),
149
+ RandomHorizontallyFlip ()])
150
+
151
+ local_path = '/home/meet/datasets/SUNRGBD/'
152
+ dst = SUNRGBDLoader (local_path , is_transform = True , augmentations = augmentations )
153
+ bs = 4
154
+ trainloader = data .DataLoader (dst , batch_size = bs , num_workers = 0 )
155
+ for i , data in enumerate (trainloader ):
156
+ imgs , labels = data
157
+ imgs = imgs .numpy ()[:, ::- 1 , :, :]
158
+ imgs = np .transpose (imgs , [0 ,2 ,3 ,1 ])
159
+ f , axarr = plt .subplots (bs ,2 )
160
+ for j in range (bs ):
161
+ axarr [j ][0 ].imshow (imgs [j ])
162
+ axarr [j ][1 ].imshow (dst .decode_segmap (labels .numpy ()[j ]))
163
+ plt .show ()
164
+ a = raw_input ()
165
+ if a == 'ex' :
166
+ break
167
+ else :
168
+ plt .close ()
0 commit comments