Skip to content

Commit 9d9f69f

Browse files
authored
fix sr docs and add div2k process script (#154)
1 parent 95e5f4f commit 9d9f69f

9 files changed

+367
-64
lines changed

configs/drn_psnr_x4_div2k.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ dataset:
5353
keys: [image, image, image]
5454
test:
5555
name: SRDataset
56-
gt_folder: data/DIV2K/val_set14/Set14
57-
lq_folder: data/DIV2K/val_set14/Set14_bicLRx4
56+
gt_folder: data/Set14/GTmod12
57+
lq_folder: data/Set14/LRbicx4
5858
scale: 4
5959
preprocess:
6060
- name: LoadImageFromFile

configs/esrgan_psnr_x4_div2k.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ model:
1818
dataset:
1919
train:
2020
name: SRDataset
21-
gt_folder: data/DIV2K/DIV2K_train_HR_sub
22-
lq_folder: data/DIV2K/DIV2K_train_LR_bicubic/X4_sub
21+
gt_folder: data/DIV2K/DIV2K_train_HR
22+
lq_folder: data/DIV2K/DIV2K_train_LR_bicubic/X4
2323
num_workers: 4
2424
batch_size: 16
2525
scale: 4
@@ -49,8 +49,8 @@ dataset:
4949
keys: [image, image]
5050
test:
5151
name: SRDataset
52-
gt_folder: data/DIV2K/val_set14/Set14
53-
lq_folder: data/DIV2K/val_set14/Set14_bicLRx4
52+
gt_folder: data/Set14/GTmod12
53+
lq_folder: data/Set14/LRbicx4
5454
scale: 4
5555
preprocess:
5656
- name: LoadImageFromFile

configs/esrgan_x4_div2k.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ dataset:
6565
keys: [image, image]
6666
test:
6767
name: SRDataset
68-
gt_folder: data/DIV2K/val_set14/Set14
69-
lq_folder: data/DIV2K/val_set14/Set14_bicLRx4
68+
gt_folder: data/Set14/GTmod12
69+
lq_folder: data/Set14/LRbicx4
7070
scale: 4
7171
preprocess:
7272
- name: LoadImageFromFile

configs/lesrcnn_psnr_x4_div2k.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ dataset:
4545
keys: [image, image]
4646
test:
4747
name: SRDataset
48-
gt_folder: data/DIV2K/val_set14/Set14
49-
lq_folder: data/DIV2K/val_set14/Set14_bicLRx4
48+
gt_folder: data/Set14/GTmod12
49+
lq_folder: data/Set14/LRbicx4
5050
scale: 4
5151
preprocess:
5252
- name: LoadImageFromFile

configs/realsr_bicubic_noise_x4_df2k.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ dataset:
6969
keys: [image]
7070
test:
7171
name: SRDataset
72-
gt_folder: data/DIV2K/val_set14/Set14
73-
lq_folder: data/DIV2K/val_set14/Set14_bicLRx4
72+
gt_folder: data/Set14/GTmod12
73+
lq_folder: data/Set14/LRbicx4
7474
scale: 4
7575
preprocess:
7676
- name: LoadImageFromFile

configs/realsr_kernel_noise_x4_dped.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ dataset:
6969
keys: [image]
7070
test:
7171
name: SRDataset
72-
gt_folder: data/DIV2K/val_set14/Set14
73-
lq_folder: data/DIV2K/val_set14/Set14_bicLRx4
72+
gt_folder: data/Set14/GTmod12
73+
lq_folder: data/Set14/LRbicx4
7474
scale: 4
7575
preprocess:
7676
- name: LoadImageFromFile

data/process_div2k_data.py

Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
import os
2+
import re
3+
import sys
4+
import cv2
5+
import argparse
6+
import numpy as np
7+
import os.path as osp
8+
9+
from time import time
10+
from multiprocessing import Pool
11+
from shutil import get_terminal_size
12+
from ppgan.datasets.base_dataset import scandir
13+
14+
15+
class Timer:
16+
"""A flexible Timer class."""
17+
def __init__(self, start=True, print_tmpl=None):
18+
self._is_running = False
19+
self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}'
20+
if start:
21+
self.start()
22+
23+
@property
24+
def is_running(self):
25+
"""bool: indicate whether the timer is running"""
26+
return self._is_running
27+
28+
def __enter__(self):
29+
self.start()
30+
return self
31+
32+
def __exit__(self, type, value, traceback):
33+
print(self.print_tmpl.format(self.since_last_check()))
34+
self._is_running = False
35+
36+
def start(self):
37+
"""Start the timer."""
38+
if not self._is_running:
39+
self._t_start = time()
40+
self._is_running = True
41+
self._t_last = time()
42+
43+
def since_start(self):
44+
"""Total time since the timer is started.
45+
46+
Returns (float): Time in seconds.
47+
"""
48+
if not self._is_running:
49+
raise ValueError('timer is not running')
50+
self._t_last = time()
51+
return self._t_last - self._t_start
52+
53+
def since_last_check(self):
54+
"""Time since the last checking.
55+
56+
Either :func:`since_start` or :func:`since_last_check` is a checking
57+
operation.
58+
59+
Returns (float): Time in seconds.
60+
"""
61+
if not self._is_running:
62+
raise ValueError('timer is not running')
63+
dur = time() - self._t_last
64+
self._t_last = time()
65+
return dur
66+
67+
68+
class ProgressBar:
69+
"""A progress bar which can print the progress."""
70+
def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
71+
self.task_num = task_num
72+
self.bar_width = bar_width
73+
self.completed = 0
74+
self.file = file
75+
if start:
76+
self.start()
77+
78+
@property
79+
def terminal_width(self):
80+
width, _ = get_terminal_size()
81+
return width
82+
83+
def start(self):
84+
if self.task_num > 0:
85+
self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, '
86+
'elapsed: 0s, ETA:')
87+
else:
88+
self.file.write('completed: 0, elapsed: 0s')
89+
self.file.flush()
90+
self.timer = Timer()
91+
92+
def update(self, num_tasks=1):
93+
assert num_tasks > 0
94+
self.completed += num_tasks
95+
elapsed = self.timer.since_start()
96+
if elapsed > 0:
97+
fps = self.completed / elapsed
98+
else:
99+
fps = float('inf')
100+
if self.task_num > 0:
101+
percentage = self.completed / float(self.task_num)
102+
eta = int(elapsed * (1 - percentage) / percentage + 0.5)
103+
msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \
104+
f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \
105+
f'ETA: {eta:5}s'
106+
107+
bar_width = min(self.bar_width,
108+
int(self.terminal_width - len(msg)) + 2,
109+
int(self.terminal_width * 0.6))
110+
bar_width = max(2, bar_width)
111+
mark_width = int(bar_width * percentage)
112+
bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width)
113+
self.file.write(msg.format(bar_chars))
114+
else:
115+
self.file.write(
116+
f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,'
117+
f' {fps:.1f} tasks/s')
118+
self.file.flush()
119+
120+
121+
def main_extract_subimages(args):
122+
"""A multi-thread tool to crop large images to sub-images for faster IO.
123+
124+
It is used for DIV2K dataset.
125+
126+
args (dict): Configuration dict. It contains:
127+
n_thread (int): Thread number.
128+
compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9.
129+
A higher value means a smaller size and longer compression time.
130+
Use 0 for faster CPU decompression. Default: 3, same in cv2.
131+
132+
input_folder (str): Path to the input folder.
133+
save_folder (str): Path to save folder.
134+
crop_size (int): Crop size.
135+
step (int): Step for overlapped sliding window.
136+
thresh_size (int): Threshold size. Patches whose size is lower
137+
than thresh_size will be dropped.
138+
139+
Usage:
140+
For each folder, run this script.
141+
Typically, there are four folders to be processed for DIV2K dataset.
142+
DIV2K_train_HR
143+
DIV2K_train_LR_bicubic/X2
144+
DIV2K_train_LR_bicubic/X3
145+
DIV2K_train_LR_bicubic/X4
146+
After process, each sub_folder should have the same number of
147+
subimages.
148+
Remember to modify opt configurations according to your settings.
149+
"""
150+
151+
opt = {}
152+
opt['n_thread'] = args.n_thread
153+
opt['compression_level'] = args.compression_level
154+
155+
# HR images
156+
opt['input_folder'] = osp.join(args.data_root, 'DIV2K_train_HR')
157+
opt['save_folder'] = osp.join(args.data_root, 'DIV2K_train_HR_sub')
158+
opt['crop_size'] = args.crop_size
159+
opt['step'] = args.step
160+
opt['thresh_size'] = args.thresh_size
161+
extract_subimages(opt)
162+
163+
for scale in [2, 3, 4]:
164+
opt['input_folder'] = osp.join(args.data_root,
165+
f'DIV2K_train_LR_bicubic/X{scale}')
166+
opt['save_folder'] = osp.join(args.data_root,
167+
f'DIV2K_train_LR_bicubic/X{scale}_sub')
168+
opt['crop_size'] = args.crop_size // scale
169+
opt['step'] = args.step // scale
170+
opt['thresh_size'] = args.thresh_size // scale
171+
extract_subimages(opt)
172+
173+
174+
def extract_subimages(opt):
175+
"""Crop images to subimages.
176+
177+
Args:
178+
opt (dict): Configuration dict. It contains:
179+
input_folder (str): Path to the input folder.
180+
save_folder (str): Path to save folder.
181+
n_thread (int): Thread number.
182+
"""
183+
input_folder = opt['input_folder']
184+
save_folder = opt['save_folder']
185+
if not osp.exists(save_folder):
186+
os.makedirs(save_folder)
187+
print(f'mkdir {save_folder} ...')
188+
else:
189+
print(f'Folder {save_folder} already exists. Exit.')
190+
sys.exit(1)
191+
192+
img_list = list(scandir(input_folder))
193+
img_list = [osp.join(input_folder, v) for v in img_list]
194+
195+
prog_bar = ProgressBar(len(img_list))
196+
pool = Pool(opt['n_thread'])
197+
for path in img_list:
198+
pool.apply_async(worker,
199+
args=(path, opt),
200+
callback=lambda arg: prog_bar.update())
201+
pool.close()
202+
pool.join()
203+
print('All processes done.')
204+
205+
206+
def worker(path, opt):
207+
"""Worker for each process.
208+
209+
Args:
210+
path (str): Image path.
211+
opt (dict): Configuration dict. It contains:
212+
crop_size (int): Crop size.
213+
step (int): Step for overlapped sliding window.
214+
thresh_size (int): Threshold size. Patches whose size is smaller
215+
than thresh_size will be dropped.
216+
save_folder (str): Path to save folder.
217+
compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
218+
219+
Returns:
220+
process_info (str): Process information displayed in progress bar.
221+
"""
222+
crop_size = opt['crop_size']
223+
step = opt['step']
224+
thresh_size = opt['thresh_size']
225+
img_name, extension = osp.splitext(osp.basename(path))
226+
227+
# remove the x2, x3, x4 and x8 in the filename for DIV2K
228+
img_name = re.sub('x[2348]', '', img_name)
229+
230+
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
231+
232+
if img.ndim == 2 or img.ndim == 3:
233+
h, w = img.shape[:2]
234+
else:
235+
raise ValueError(f'Image ndim should be 2 or 3, but got {img.ndim}')
236+
237+
h_space = np.arange(0, h - crop_size + 1, step)
238+
if h - (h_space[-1] + crop_size) > thresh_size:
239+
h_space = np.append(h_space, h - crop_size)
240+
w_space = np.arange(0, w - crop_size + 1, step)
241+
if w - (w_space[-1] + crop_size) > thresh_size:
242+
w_space = np.append(w_space, w - crop_size)
243+
244+
index = 0
245+
for x in h_space:
246+
for y in w_space:
247+
index += 1
248+
cropped_img = img[x:x + crop_size, y:y + crop_size, ...]
249+
cv2.imwrite(
250+
osp.join(opt['save_folder'],
251+
f'{img_name}_s{index:03d}{extension}'), cropped_img,
252+
[cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
253+
process_info = f'Processing {img_name} ...'
254+
return process_info
255+
256+
257+
def parse_args():
258+
parser = argparse.ArgumentParser(
259+
description='Prepare DIV2K dataset',
260+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
261+
parser.add_argument('--data-root', help='dataset root')
262+
parser.add_argument('--crop-size',
263+
nargs='?',
264+
default=480,
265+
help='cropped size for HR images')
266+
parser.add_argument('--step',
267+
nargs='?',
268+
default=240,
269+
help='step size for HR images')
270+
parser.add_argument('--thresh-size',
271+
nargs='?',
272+
default=0,
273+
help='threshold size for HR images')
274+
parser.add_argument('--compression-level',
275+
nargs='?',
276+
default=3,
277+
help='compression level when save png images')
278+
parser.add_argument('--n-thread',
279+
nargs='?',
280+
default=20,
281+
help='thread number when using multiprocessing')
282+
283+
args = parser.parse_args()
284+
return args
285+
286+
287+
if __name__ == '__main__':
288+
args = parse_args()
289+
# extract subimages
290+
main_extract_subimages(args)

0 commit comments

Comments
 (0)