Skip to content

Commit f3f95e6

Browse files
committed
Add NYUv2 dataloader
1 parent b59de34 commit f3f95e6

File tree

1 file changed

+161
-0
lines changed

1 file changed

+161
-0
lines changed

ptsemseg/loader/nyuv2_loader.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import os
2+
import collections
3+
import torch
4+
import torchvision
5+
import numpy as np
6+
import scipy.misc as m
7+
import scipy.io as io
8+
import matplotlib.pyplot as plt
9+
10+
from torch.utils import data
11+
12+
from ptsemseg.utils import recursive_glob
13+
from ptsemseg.augmentations import *
14+
15+
16+
class NYUv2Loader(data.Dataset):
17+
"""NYUv2 loader
18+
19+
Download From (only 13 classes):
20+
test source: http://www.doc.ic.ac.uk/~ahanda/nyu_test_rgb.tgz
21+
train source: http://www.doc.ic.ac.uk/~ahanda/nyu_train_rgb.tgz
22+
test_labels source: https://github.com/ankurhanda/nyuv2-meta-data/raw/master/test_labels_13/nyuv2_test_class13.tgz
23+
train_labels source: https://github.com/ankurhanda/nyuv2-meta-data/raw/master/train_labels_13/nyuv2_train_class13.tgz
24+
25+
"""
26+
27+
28+
def __init__(self, root, split="training", is_transform=False, img_size=(480,640), augmentations=None):
29+
self.root = root
30+
self.is_transform = is_transform
31+
self.n_classes = 14
32+
self.augmentations = augmentations
33+
self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
34+
self.mean = np.array([104.00699, 116.66877, 122.67892])
35+
self.files = collections.defaultdict(list)
36+
self.cmap = self.color_map(normalized=False)
37+
38+
split_map = {"training": 'train',
39+
"val": 'test',}
40+
self.split = split_map[split]
41+
42+
for split in ["train", "test"]:
43+
file_list = recursive_glob(rootdir=self.root + split + '/', suffix='png')
44+
self.files[split] = file_list
45+
46+
def __len__(self):
47+
return len(self.files[self.split])
48+
49+
50+
def __getitem__(self, index):
51+
img_path = self.files[self.split][index].rstrip()
52+
img_number = img_path.split('_')[-1][:4]
53+
lbl_path = os.path.join(self.root, self.split + '_annot', 'new_nyu_class13_' + img_number + '.png')
54+
55+
img = m.imread(img_path)
56+
img = np.array(img, dtype=np.uint8)
57+
58+
lbl = m.imread(lbl_path)
59+
lbl = np.array(lbl, dtype=np.uint8)
60+
61+
if not (len(img.shape) == 3 and len(lbl.shape) == 2):
62+
return self.__getitem__(np.random.randint(0, self.__len__()))
63+
64+
if self.augmentations is not None:
65+
img, lbl = self.augmentations(img, lbl)
66+
67+
if self.is_transform:
68+
img, lbl = self.transform(img, lbl)
69+
70+
return img, lbl
71+
72+
73+
def transform(self, img, lbl):
74+
img = img[:, :, ::-1]
75+
img = img.astype(np.float64)
76+
img -= self.mean
77+
img = m.imresize(img, (self.img_size[0], self.img_size[1]))
78+
# Resize scales images from 0 to 255, thus we need
79+
# to divide by 255.0
80+
img = img.astype(float) / 255.0
81+
# NHWC -> NCWH
82+
img = img.transpose(2, 0, 1)
83+
84+
classes = np.unique(lbl)
85+
lbl = lbl.astype(float)
86+
lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), 'nearest', mode='F')
87+
lbl = lbl.astype(int)
88+
assert(np.all(classes == np.unique(lbl)))
89+
90+
img = torch.from_numpy(img).float()
91+
lbl = torch.from_numpy(lbl).long()
92+
return img, lbl
93+
94+
95+
def color_map(self, N=256, normalized=False):
96+
"""
97+
Return Color Map in PASCAL VOC format
98+
"""
99+
100+
def bitget(byteval, idx):
101+
return ((byteval & (1 << idx)) != 0)
102+
103+
dtype = 'float32' if normalized else 'uint8'
104+
cmap = np.zeros((N, 3), dtype=dtype)
105+
for i in range(N):
106+
r = g = b = 0
107+
c = i
108+
for j in range(8):
109+
r = r | (bitget(c, 0) << 7-j)
110+
g = g | (bitget(c, 1) << 7-j)
111+
b = b | (bitget(c, 2) << 7-j)
112+
c = c >> 3
113+
114+
cmap[i] = np.array([r, g, b])
115+
116+
cmap = cmap/255.0 if normalized else cmap
117+
return cmap
118+
119+
120+
def decode_segmap(self, temp):
121+
r = temp.copy()
122+
g = temp.copy()
123+
b = temp.copy()
124+
for l in range(0, self.n_classes):
125+
r[temp == l] = self.cmap[l,0]
126+
g[temp == l] = self.cmap[l,1]
127+
b[temp == l] = self.cmap[l,2]
128+
129+
rgb = np.zeros((temp.shape[0], temp.shape[1], 3))
130+
rgb[:, :, 0] = r / 255.0
131+
rgb[:, :, 1] = g / 255.0
132+
rgb[:, :, 2] = b / 255.0
133+
return rgb
134+
135+
136+
if __name__ == '__main__':
137+
import torchvision
138+
import matplotlib.pyplot as plt
139+
140+
augmentations = Compose([Scale(512),
141+
RandomRotate(10),
142+
RandomHorizontallyFlip()])
143+
144+
local_path = '/home/meet/datasets/NYUv2/'
145+
dst = NYUv2Loader(local_path, is_transform=True, augmentations=augmentations)
146+
bs = 4
147+
trainloader = data.DataLoader(dst, batch_size=bs, num_workers=0)
148+
for i, data in enumerate(trainloader):
149+
imgs, labels = data
150+
imgs = imgs.numpy()[:, ::-1, :, :]
151+
imgs = np.transpose(imgs, [0,2,3,1])
152+
f, axarr = plt.subplots(bs,2)
153+
for j in range(bs):
154+
axarr[j][0].imshow(imgs[j])
155+
axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j]))
156+
plt.show()
157+
a = raw_input()
158+
if a == 'ex':
159+
break
160+
else:
161+
plt.close()

0 commit comments

Comments
 (0)