Skip to content

Commit 53af79a

Browse files
committed
support different color spaces
1 parent ce5c55a commit 53af79a

File tree

16 files changed

+170
-63
lines changed

16 files changed

+170
-63
lines changed

basicsr/data/data_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from basicsr.utils import img2tensor, scandir
99

1010

11-
def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
11+
def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False, color_space='rgb'):
1212
"""Read a sequence of images from a given folder path.
1313
1414
Args:
@@ -30,7 +30,7 @@ def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
3030

3131
if require_mod_crop:
3232
imgs = [mod_crop(img, scale) for img in imgs]
33-
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
33+
imgs = img2tensor(imgs, color_space=color_space, float32=True)
3434
imgs = torch.stack(imgs, dim=0)
3535

3636
if return_imgname:

basicsr/data/ffhq_dataset.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torchvision.transforms.functional import normalize
66

77
from basicsr.data.transforms import augment
8-
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
8+
from basicsr.utils import ColorSpace, FileClient, get_root_logger, imfrombytes, img2tensor
99
from basicsr.utils.registry import DATASET_REGISTRY
1010

1111

@@ -70,8 +70,10 @@ def __getitem__(self, index):
7070

7171
# random horizontal flip
7272
img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)
73+
# color space transform
74+
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']
7375
# BGR to RGB, HWC to CHW, numpy to tensor
74-
img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
76+
img_gt = img2tensor(img_gt, color_space=color_space, float32=True)
7577
# normalize
7678
normalize(img_gt, self.mean, self.std, inplace=True)
7779
return {'gt': img_gt, 'gt_path': gt_path}

basicsr/data/paired_image_dataset.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
55
from basicsr.data.transforms import augment, paired_random_crop
6-
from basicsr.utils import FileClient, bgr2ycbcr, imfrombytes, img2tensor
6+
from basicsr.utils import ColorSpace, FileClient, imfrombytes, img2tensor
77
from basicsr.utils.registry import DATASET_REGISTRY
88

99

@@ -83,18 +83,16 @@ def __getitem__(self, index):
8383
# flip, rotation
8484
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
8585

86-
# color space transform
87-
if 'color' in self.opt and self.opt['color'] == 'y':
88-
img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None]
89-
img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None]
90-
9186
# crop the unmatched GT images during validation or testing, especially for SR benchmark datasets
9287
# TODO: It is better to update the datasets, rather than force to crop
9388
if self.opt['phase'] != 'train':
9489
img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :]
9590

91+
# color space transform
92+
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']
93+
9694
# BGR to RGB, HWC to CHW, numpy to tensor
97-
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
95+
img_gt, img_lq = img2tensor([img_gt, img_lq], color_space=color_space, float32=True)
9896
# normalize
9997
if self.mean is not None or self.std is not None:
10098
normalize(img_lq, self.mean, self.std, inplace=True)

basicsr/data/realesrgan_dataset.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
1212
from basicsr.data.transforms import augment
13-
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
13+
from basicsr.utils import ColorSpace, FileClient, get_root_logger, imfrombytes, img2tensor
1414
from basicsr.utils.registry import DATASET_REGISTRY
1515

1616

@@ -181,8 +181,10 @@ def __getitem__(self, index):
181181
else:
182182
sinc_kernel = self.pulse_tensor
183183

184+
# color space transform
185+
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']
184186
# BGR to RGB, HWC to CHW, numpy to tensor
185-
img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
187+
img_gt = img2tensor([img_gt], color_space=color_space, float32=True)[0]
186188
kernel = torch.FloatTensor(kernel)
187189
kernel2 = torch.FloatTensor(kernel2)
188190

basicsr/data/realesrgan_paired_dataset.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
66
from basicsr.data.transforms import augment, paired_random_crop
7-
from basicsr.utils import FileClient, imfrombytes, img2tensor
7+
from basicsr.utils import FileClient, imfrombytes, img2tensor, ColorSpace
88
from basicsr.utils.registry import DATASET_REGISTRY
99

1010

@@ -93,8 +93,11 @@ def __getitem__(self, index):
9393
# flip, rotation
9494
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
9595

96+
# color space transform
97+
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']
98+
9699
# BGR to RGB, HWC to CHW, numpy to tensor
97-
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
100+
img_gt, img_lq = img2tensor([img_gt, img_lq], color_space=color_space, float32=True)
98101
# normalize
99102
if self.mean is not None or self.std is not None:
100103
normalize(img_lq, self.mean, self.std, inplace=True)

basicsr/data/reds_dataset.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch.utils import data as data
66

77
from basicsr.data.transforms import augment, paired_random_crop
8-
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
8+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor, ColorSpace
99
from basicsr.utils.flow_util import dequantize_flow
1010
from basicsr.utils.registry import DATASET_REGISTRY
1111

@@ -182,12 +182,16 @@ def __getitem__(self, index):
182182
else:
183183
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
184184

185-
img_results = img2tensor(img_results)
185+
# color space transform
186+
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']
187+
188+
# BGR to RGB, HWC to CHW, numpy to tensor
189+
img_results = img2tensor(img_results, color_space=color_space)
186190
img_lqs = torch.stack(img_results[0:-1], dim=0)
187191
img_gt = img_results[-1]
188192

189193
if self.flow_root is not None:
190-
img_flows = img2tensor(img_flows)
194+
img_flows = img2tensor(img_flows, color_space=ColorSpace.RAW)
191195
# add the zero center flow
192196
img_flows.insert(self.num_half_frames, torch.zeros_like(img_flows[0]))
193197
img_flows = torch.stack(img_flows, dim=0)
@@ -339,7 +343,11 @@ def __getitem__(self, index):
339343
img_lqs.extend(img_gts)
340344
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
341345

342-
img_results = img2tensor(img_results)
346+
# color space transform
347+
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']
348+
349+
# BGR to RGB, HWC to CHW, numpy to tensor
350+
img_results = img2tensor(img_results, color_space=color_space)
343351
img_gts = torch.stack(img_results[len(img_lqs) // 2:], dim=0)
344352
img_lqs = torch.stack(img_results[:len(img_lqs) // 2], dim=0)
345353

basicsr/data/single_image_dataset.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torchvision.transforms.functional import normalize
44

55
from basicsr.data.data_util import paths_from_lmdb
6-
from basicsr.utils import FileClient, imfrombytes, img2tensor, rgb2ycbcr, scandir
6+
from basicsr.utils import ColorSpace, FileClient, imfrombytes, img2tensor, scandir
77
from basicsr.utils.registry import DATASET_REGISTRY
88

99

@@ -54,11 +54,10 @@ def __getitem__(self, index):
5454
img_lq = imfrombytes(img_bytes, float32=True)
5555

5656
# color space transform
57-
if 'color' in self.opt and self.opt['color'] == 'y':
58-
img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None]
57+
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']
5958

6059
# BGR to RGB, HWC to CHW, numpy to tensor
61-
img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
60+
img_lq = img2tensor(img_lq, color_space=color_space, float32=True)
6261
# normalize
6362
if self.mean is not None or self.std is not None:
6463
normalize(img_lq, self.mean, self.std, inplace=True)

basicsr/data/vimeo90k_dataset.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch.utils import data as data
55

66
from basicsr.data.transforms import augment, paired_random_crop
7-
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
7+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor, ColorSpace
88
from basicsr.utils.registry import DATASET_REGISTRY
99

1010

@@ -120,7 +120,11 @@ def __getitem__(self, index):
120120
img_lqs.append(img_gt)
121121
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
122122

123-
img_results = img2tensor(img_results)
123+
# color space transform
124+
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']
125+
126+
# BGR to RGB, HWC to CHW, numpy to tensor
127+
img_results = img2tensor(img_results, color_space=color_space)
124128
img_lqs = torch.stack(img_results[0:-1], dim=0)
125129
img_gt = img_results[-1]
126130

@@ -182,7 +186,11 @@ def __getitem__(self, index):
182186
img_lqs.extend(img_gts)
183187
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
184188

185-
img_results = img2tensor(img_results)
189+
# color space transform
190+
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']
191+
192+
# BGR to RGB, HWC to CHW, numpy to tensor
193+
img_results = img2tensor(img_results, color_space=color_space)
186194
img_lqs = torch.stack(img_results[:7], dim=0)
187195
img_gts = torch.stack(img_results[7:], dim=0)
188196

basicsr/metrics/test_metrics/test_psnr_ssim.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from basicsr.metrics import calculate_psnr, calculate_ssim
55
from basicsr.metrics.psnr_ssim import calculate_psnr_pt, calculate_ssim_pt
6-
from basicsr.utils import img2tensor
6+
from basicsr.utils import img2tensor, ColorSpace
77

88

99
def test(img_path, img_path2, crop_border, test_y_channel=False):
@@ -16,8 +16,8 @@ def test(img_path, img_path2, crop_border, test_y_channel=False):
1616
print(f'\tNumpy\tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}')
1717

1818
# --------------------- PyTorch (CPU) ---------------------
19-
img = img2tensor(img / 255., bgr2rgb=True, float32=True).unsqueeze_(0)
20-
img2 = img2tensor(img2 / 255., bgr2rgb=True, float32=True).unsqueeze_(0)
19+
img = img2tensor(img / 255., color_space=ColorSpace.RGB, float32=True).unsqueeze_(0)
20+
img2 = img2tensor(img2 / 255., color_space=ColorSpace.RGB, float32=True).unsqueeze_(0)
2121

2222
psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel)
2323
ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel)

basicsr/models/sr_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
183183

184184
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
185185
dataset_name = dataloader.dataset.opt['name']
186+
color_space = dataloader.dataset.opt['color'] if 'color' in dataloader.dataset.opt else 'rgb'
186187
with_metrics = self.opt['val'].get('metrics') is not None
187188
use_pbar = self.opt['val'].get('pbar', False)
188189

@@ -205,10 +206,10 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
205206
self.test()
206207

207208
visuals = self.get_current_visuals()
208-
sr_img = tensor2img([visuals['result']])
209+
sr_img = tensor2img([visuals['result']], color_space=color_space)
209210
metric_data['img'] = sr_img
210211
if 'gt' in visuals:
211-
gt_img = tensor2img([visuals['gt']])
212+
gt_img = tensor2img([visuals['gt']], color_space=color_space)
212213
metric_data['img2'] = gt_img
213214
del self.gt
214215

0 commit comments

Comments
 (0)