|
5 | 5 | from torch.utils import data as data |
6 | 6 |
|
7 | 7 | from basicsr.data.transforms import augment, paired_random_crop |
8 | | -from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor |
| 8 | +from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor, ColorSpace |
9 | 9 | from basicsr.utils.flow_util import dequantize_flow |
10 | 10 | from basicsr.utils.registry import DATASET_REGISTRY |
11 | 11 |
|
@@ -182,12 +182,16 @@ def __getitem__(self, index): |
182 | 182 | else: |
183 | 183 | img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot']) |
184 | 184 |
|
185 | | - img_results = img2tensor(img_results) |
| 185 | + # color space transform |
| 186 | + color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color'] |
| 187 | + |
| 188 | + # BGR to RGB, HWC to CHW, numpy to tensor |
| 189 | + img_results = img2tensor(img_results, color_space=color_space) |
186 | 190 | img_lqs = torch.stack(img_results[0:-1], dim=0) |
187 | 191 | img_gt = img_results[-1] |
188 | 192 |
|
189 | 193 | if self.flow_root is not None: |
190 | | - img_flows = img2tensor(img_flows) |
| 194 | + img_flows = img2tensor(img_flows, color_space=ColorSpace.RAW) |
191 | 195 | # add the zero center flow |
192 | 196 | img_flows.insert(self.num_half_frames, torch.zeros_like(img_flows[0])) |
193 | 197 | img_flows = torch.stack(img_flows, dim=0) |
@@ -339,7 +343,11 @@ def __getitem__(self, index): |
339 | 343 | img_lqs.extend(img_gts) |
340 | 344 | img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot']) |
341 | 345 |
|
342 | | - img_results = img2tensor(img_results) |
| 346 | + # color space transform |
| 347 | + color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color'] |
| 348 | + |
| 349 | + # BGR to RGB, HWC to CHW, numpy to tensor |
| 350 | + img_results = img2tensor(img_results, color_space=color_space) |
343 | 351 | img_gts = torch.stack(img_results[len(img_lqs) // 2:], dim=0) |
344 | 352 | img_lqs = torch.stack(img_results[:len(img_lqs) // 2], dim=0) |
345 | 353 |
|
|
0 commit comments