Skip to content

Commit c303716

Browse files
committed
Add SUNRGBD dataloader
1 parent f3f95e6 commit c303716

File tree

1 file changed

+168
-0
lines changed

1 file changed

+168
-0
lines changed

ptsemseg/loader/sunrgbd_loader.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
import os
2+
import collections
3+
import torch
4+
import torchvision
5+
import numpy as np
6+
import scipy.misc as m
7+
import scipy.io as io
8+
import matplotlib.pyplot as plt
9+
10+
from torch.utils import data
11+
12+
from ptsemseg.utils import recursive_glob
13+
from ptsemseg.augmentations import *
14+
15+
16+
class SUNRGBDLoader(data.Dataset):
17+
"""SUNRGBD loader
18+
19+
Download From:
20+
http://www.doc.ic.ac.uk/~ahanda/SUNRGBD-test_images.tgz
21+
test source: http://www.doc.ic.ac.uk/~ahanda/SUNRGBD-test_images.tgz
22+
train source: http://www.doc.ic.ac.uk/~ahanda/SUNRGBD-train_images.tgz
23+
24+
first 5050 in this is test, later 5051 is train
25+
test and train labels source: https://github.com/ankurhanda/sunrgbd-meta-data/raw/master/sunrgbd_train_test_labels.tar.gz
26+
"""
27+
28+
def __init__(self, root, split="training", is_transform=False, img_size=(480, 640), augmentations=None):
29+
self.root = root
30+
self.is_transform = is_transform
31+
self.n_classes = 38
32+
self.augmentations = augmentations
33+
self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
34+
self.mean = np.array([104.00699, 116.66877, 122.67892])
35+
self.files = collections.defaultdict(list)
36+
self.anno_files = collections.defaultdict(list)
37+
self.cmap = self.color_map(normalized=False)
38+
39+
split_map = {"training": 'train',
40+
"val": 'test',}
41+
self.split = split_map[split]
42+
43+
for split in ["train", "test"]:
44+
file_list = sorted(recursive_glob(rootdir=self.root + split + '/', suffix='jpg'))
45+
self.files[split] = file_list
46+
47+
for split in ["train", "test"]:
48+
file_list = sorted(recursive_glob(rootdir=self.root + 'annotations/' + split + '/', suffix='png'))
49+
self.anno_files[split] = file_list
50+
51+
52+
def __len__(self):
53+
return len(self.files[self.split])
54+
55+
56+
def __getitem__(self, index):
57+
img_path = self.files[self.split][index].rstrip()
58+
lbl_path = self.anno_files[self.split][index].rstrip()
59+
# img_number = img_path.split('/')[-1]
60+
# lbl_path = os.path.join(self.root, 'annotations', img_number).replace('jpg', 'png')
61+
62+
img = m.imread(img_path)
63+
img = np.array(img, dtype=np.uint8)
64+
65+
lbl = m.imread(lbl_path)
66+
lbl = np.array(lbl, dtype=np.uint8)
67+
68+
if not (len(img.shape) == 3 and len(lbl.shape) == 2):
69+
return self.__getitem__(np.random.randint(0, self.__len__()))
70+
71+
if self.augmentations is not None:
72+
img, lbl = self.augmentations(img, lbl)
73+
74+
if self.is_transform:
75+
img, lbl = self.transform(img, lbl)
76+
77+
return img, lbl
78+
79+
80+
def transform(self, img, lbl):
81+
img = img[:, :, ::-1]
82+
img = img.astype(np.float64)
83+
img -= self.mean
84+
img = m.imresize(img, (self.img_size[0], self.img_size[1]))
85+
# Resize scales images from 0 to 255, thus we need
86+
# to divide by 255.0
87+
img = img.astype(float) / 255.0
88+
# NHWC -> NCWH
89+
img = img.transpose(2, 0, 1)
90+
91+
classes = np.unique(lbl)
92+
lbl = lbl.astype(float)
93+
lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), 'nearest', mode='F')
94+
lbl = lbl.astype(int)
95+
assert(np.all(classes == np.unique(lbl)))
96+
97+
img = torch.from_numpy(img).float()
98+
lbl = torch.from_numpy(lbl).long()
99+
return img, lbl
100+
101+
102+
def color_map(self, N=256, normalized=False):
103+
"""
104+
Return Color Map in PASCAL VOC format
105+
"""
106+
107+
def bitget(byteval, idx):
108+
return ((byteval & (1 << idx)) != 0)
109+
110+
dtype = 'float32' if normalized else 'uint8'
111+
cmap = np.zeros((N, 3), dtype=dtype)
112+
for i in range(N):
113+
r = g = b = 0
114+
c = i
115+
for j in range(8):
116+
r = r | (bitget(c, 0) << 7-j)
117+
g = g | (bitget(c, 1) << 7-j)
118+
b = b | (bitget(c, 2) << 7-j)
119+
c = c >> 3
120+
121+
cmap[i] = np.array([r, g, b])
122+
123+
cmap = cmap/255.0 if normalized else cmap
124+
return cmap
125+
126+
127+
def decode_segmap(self, temp):
128+
r = temp.copy()
129+
g = temp.copy()
130+
b = temp.copy()
131+
for l in range(0, self.n_classes):
132+
r[temp == l] = self.cmap[l,0]
133+
g[temp == l] = self.cmap[l,1]
134+
b[temp == l] = self.cmap[l,2]
135+
136+
rgb = np.zeros((temp.shape[0], temp.shape[1], 3))
137+
rgb[:, :, 0] = r / 255.0
138+
rgb[:, :, 1] = g / 255.0
139+
rgb[:, :, 2] = b / 255.0
140+
return rgb
141+
142+
143+
if __name__ == '__main__':
144+
import torchvision
145+
import matplotlib.pyplot as plt
146+
147+
augmentations = Compose([Scale(512),
148+
RandomRotate(10),
149+
RandomHorizontallyFlip()])
150+
151+
local_path = '/home/meet/datasets/SUNRGBD/'
152+
dst = SUNRGBDLoader(local_path, is_transform=True, augmentations=augmentations)
153+
bs = 4
154+
trainloader = data.DataLoader(dst, batch_size=bs, num_workers=0)
155+
for i, data in enumerate(trainloader):
156+
imgs, labels = data
157+
imgs = imgs.numpy()[:, ::-1, :, :]
158+
imgs = np.transpose(imgs, [0,2,3,1])
159+
f, axarr = plt.subplots(bs,2)
160+
for j in range(bs):
161+
axarr[j][0].imshow(imgs[j])
162+
axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j]))
163+
plt.show()
164+
a = raw_input()
165+
if a == 'ex':
166+
break
167+
else:
168+
plt.close()

0 commit comments

Comments
 (0)