Skip to content

Commit acd1615

Browse files
authored
Add BasicVSR++ (PaddlePaddle#383)
* add BasicVSR++
1 parent 8c7878d commit acd1615

File tree

9 files changed

+633
-7
lines changed

9 files changed

+633
-7
lines changed

configs/basicvsr++_reds.yaml

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
total_iters: 600000
2+
output_dir: output_dir
3+
find_unused_parameters: True
4+
checkpoints_dir: checkpoints
5+
use_dataset: True
6+
# tensor range for function tensor2img
7+
min_max:
8+
(0., 1.)
9+
10+
model:
11+
name: BasicVSRModel
12+
fix_iter: 5000
13+
lr_mult: 0.25
14+
generator:
15+
name: BasicVSRPlusPlus
16+
mid_channels: 64
17+
num_blocks: 7
18+
is_low_res_input: True
19+
pixel_criterion:
20+
name: CharbonnierLoss
21+
reduction: mean
22+
23+
dataset:
24+
train:
25+
name: RepeatDataset
26+
times: 1000
27+
num_workers: 4
28+
batch_size: 2 #4 gpus
29+
dataset:
30+
name: SRREDSMultipleGTDataset
31+
mode: train
32+
lq_folder: data/REDS/train_sharp_bicubic/X4
33+
gt_folder: data/REDS/train_sharp/X4
34+
crop_size: 256
35+
interval_list: [1]
36+
random_reverse: False
37+
number_frames: 30
38+
use_flip: True
39+
use_rot: True
40+
scale: 4
41+
val_partition: REDS4
42+
43+
test:
44+
name: SRREDSMultipleGTDataset
45+
mode: test
46+
lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4
47+
gt_folder: data/REDS/REDS4_test_sharp/X4
48+
interval_list: [1]
49+
random_reverse: False
50+
number_frames: 100
51+
use_flip: False
52+
use_rot: False
53+
scale: 4
54+
val_partition: REDS4
55+
num_workers: 0
56+
batch_size: 1
57+
58+
lr_scheduler:
59+
name: CosineAnnealingRestartLR
60+
learning_rate: !!float 1e-4
61+
periods: [600000]
62+
restart_weights: [1]
63+
eta_min: !!float 1e-7
64+
65+
optimizer:
66+
name: Adam
67+
# add parameters of net_name to optim
68+
# name should in self.nets
69+
net_names:
70+
- generator
71+
beta1: 0.9
72+
beta2: 0.99
73+
74+
validate:
75+
interval: 5000
76+
save_img: false
77+
78+
metrics:
79+
psnr: # metric name, can be arbitrary
80+
name: PSNR
81+
crop_border: 0
82+
test_y_channel: False
83+
ssim:
84+
name: SSIM
85+
crop_border: 0
86+
test_y_channel: False
87+
88+
log_config:
89+
interval: 10
90+
visiual_interval: 500
91+
92+
snapshot_config:
93+
interval: 5000

configs/basicvsr_reds.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ min_max:
1010
model:
1111
name: BasicVSRModel
1212
fix_iter: 5000
13+
lr_mult: 0.125
1314
generator:
1415
name: BasicVSRNet
1516
mid_channels: 64

configs/iconvsr_reds.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ min_max:
1010
model:
1111
name: BasicVSRModel
1212
fix_iter: 5000
13+
lr_mult: 0.125
1314
generator:
1415
name: IconVSR
1516
mid_channels: 64

docs/en_US/tutorials/video_super_resolution.md

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
## 1.1 Principle
55

6-
Video super-resolution originates from image super-resolution, which aims to recover high-resolution (HR) images from one or more low resolution (LR) images. The difference between them is that the video is composed of multiple frames, so the video super-resolution usually uses the information between frames to repair. Here we provide the video super-resolution model [EDVR](https://arxiv.org/pdf/1905.02716.pdf).
6+
Video super-resolution originates from image super-resolution, which aims to recover high-resolution (HR) images from one or more low resolution (LR) images. The difference between them is that the video is composed of multiple frames, so the video super-resolution usually uses the information between frames to repair. Here we provide the video super-resolution model [EDVR](https://arxiv.org/pdf/1905.02716.pdf).[BasicVSR](https://arxiv.org/pdf/2012.02181.pdf),[IconVSR](https://arxiv.org/pdf/2012.02181.pdf),[BasicVSR++](https://arxiv.org/pdf/2104.13371v1.pdf).
77

88
[EDVR](https://arxiv.org/pdf/1905.02716.pdf) wins the champions and outperforms the second place by a large margin in all four tracks in the NTIRE19 video restoration and enhancement challenges. The main difficulties of video super-resolution from two aspects: (1) how to align multiple frames given large motions, and (2) how to effectively fuse different frames with diverse motion and blur. First, to handle large motions, EDVR devise a Pyramid, Cascading and Deformable (PCD) alignment module, in which frame alignment is done at the feature level using deformable convolutions in a coarse-to-fine manner. Second, EDVR propose a Temporal and Spatial Attention (TSA) fusion module, in which attention is applied both temporally and spatially, so as to emphasize important features for subsequent restoration.
99

@@ -79,6 +79,7 @@ The metrics are PSNR / SSIM.
7979
| EDVR_L_w_tsa_deblur | 35.1473 / 0.9526 |
8080
| BasicVSR_x4 | 31.4325 / 0.8913 |
8181
| IconVSR_x4 | 31.6882 / 0.8950 |
82+
| BasicVSR++_x4 | 32.4018 / 0.9071 |
8283

8384

8485
## 1.4 Model Download
@@ -92,6 +93,7 @@ The metrics are PSNR / SSIM.
9293
| EDVR_L_w_tsa_deblur | REDS | [EDVR_L_w_tsa_deblur](https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_deblur.pdparams)
9394
| BasicVSR_x4 | REDS | [BasicVSR_x4](https://paddlegan.bj.bcebos.com/models/BasicVSR_reds_x4.pdparams)
9495
| IconVSR_x4 | REDS | [IconVSR_x4](https://paddlegan.bj.bcebos.com/models/IconVSR_reds_x4.pdparams)
96+
| BasicVSR++_x4 | REDS | [BasicVSR++_x4](https://paddlegan.bj.bcebos.com/models/BasicVSR%2B%2B_reds_x4.pdparams)
9597

9698

9799

@@ -120,3 +122,14 @@ The metrics are PSNR / SSIM.
120122
year = {2021}
121123
}
122124
```
125+
126+
- 3. [BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment](https://arxiv.org/pdf/2104.13371v1.pdf)
127+
128+
```
129+
@article{chan2021basicvsr++,
130+
author = {Chan, Kelvin C.K. and Zhou, Shangchen and Xu, Xiangyu and Loy, Chen Change},
131+
title = {BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment},
132+
booktitle = {arXiv preprint arXiv:2104.13371},
133+
year = {2021}
134+
}
135+
```

docs/zh_CN/tutorials/video_super_resolution.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
## 1.1 原理介绍
55

6-
视频超分源于图像超分,其目的是从一个或多个低分辨率(LR)图像中恢复高分辨率(HR)图像。它们的区别也很明显,由于视频是由多个帧组成的,所以视频超分通常利用帧间的信息来进行修复。这里我们提供视频超分模型[EDVR](https://arxiv.org/pdf/1905.02716.pdf).
6+
视频超分源于图像超分,其目的是从一个或多个低分辨率(LR)图像中恢复高分辨率(HR)图像。它们的区别也很明显,由于视频是由多个帧组成的,所以视频超分通常利用帧间的信息来进行修复。这里我们提供视频超分模型[EDVR](https://arxiv.org/pdf/1905.02716.pdf),[BasicVSR](https://arxiv.org/pdf/2012.02181.pdf),[IconVSR](https://arxiv.org/pdf/2012.02181.pdf),[BasicVSR++](https://arxiv.org/pdf/2104.13371v1.pdf).
77

88
[EDVR](https://arxiv.org/pdf/1905.02716.pdf)模型在NTIRE19视频恢复和增强挑战赛的四个赛道中都赢得了冠军,并以巨大的优势超过了第二名。视频超分的主要难点在于(1)如何在给定大运动的情况下对齐多个帧;(2)如何有效地融合具有不同运动和模糊的不同帧。首先,为了处理大的运动,EDVR模型设计了一个金字塔级联的可变形(PCD)对齐模块,在该模块中,从粗到精的可变形卷积被使用来进行特征级的帧对齐。其次,EDVR使用了时空注意力(TSA)融合模块,该模块在时间和空间上同时应用注意力机制,以强调后续恢复的重要特征。
99

@@ -75,6 +75,7 @@
7575
| EDVR_L_w_tsa_deblur | 35.1473 / 0.9526 |
7676
| BasicVSR_x4 | 31.4325 / 0.8913 |
7777
| IconVSR_x4 | 31.6882 / 0.8950 |
78+
| BasicVSR++_x4 | 32.4018 / 0.9071 |
7879

7980
## 1.4 模型下载
8081
| 模型 | 数据集 | 下载地址 |
@@ -87,6 +88,7 @@
8788
| EDVR_L_w_tsa_deblur | REDS | [EDVR_L_w_tsa_deblur](https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_deblur.pdparams)
8889
| BasicVSR_x4 | REDS | [BasicVSR_x4](https://paddlegan.bj.bcebos.com/models/BasicVSR_reds_x4.pdparams)
8990
| IconVSR_x4 | REDS | [IconVSR_x4](https://paddlegan.bj.bcebos.com/models/IconVSR_reds_x4.pdparams)
91+
| BasicVSR++_x4 | REDS | [BasicVSR++_x4](https://paddlegan.bj.bcebos.com/models/BasicVSR%2B%2B_reds_x4.pdparams)
9092

9193

9294

@@ -113,3 +115,13 @@
113115
year = {2021}
114116
}
115117
```
118+
- 3. [BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment](https://arxiv.org/pdf/2104.13371v1.pdf)
119+
120+
```
121+
@article{chan2021basicvsr++,
122+
author = {Chan, Kelvin C.K. and Zhou, Shangchen and Xu, Xiangyu and Loy, Chen Change},
123+
title = {BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment},
124+
booktitle = {arXiv preprint arXiv:2104.13371},
125+
year = {2021}
126+
}
127+
```

ppgan/models/basicvsr_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class BasicVSRModel(BaseSRModel):
2929
3030
Paper: BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond, CVPR, 2021
3131
"""
32-
def __init__(self, generator, fix_iter, pixel_criterion=None):
32+
def __init__(self, generator, fix_iter, lr_mult, pixel_criterion=None):
3333
"""Initialize the BasicVSR class.
3434
3535
Args:
@@ -41,6 +41,7 @@ def __init__(self, generator, fix_iter, pixel_criterion=None):
4141
self.fix_iter = fix_iter
4242
self.current_iter = 1
4343
self.flag = True
44+
self.lr_mult = lr_mult
4445
init_basicvsr_weight(self.nets['generator'])
4546

4647
def setup_input(self, input):
@@ -65,7 +66,7 @@ def train_iter(self, optims=None):
6566
for name, param in self.nets['generator'].named_parameters():
6667
param.trainable = True
6768
if 'spynet' in name:
68-
param.optimize_attr['learning_rate'] = 0.125
69+
param.optimize_attr['learning_rate'] = self.lr_mult
6970
self.flag = False
7071
for net in self.nets.values():
7172
net.find_unused_parameters = False

ppgan/models/generators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,4 @@
3535
from .iconvsr import IconVSR
3636
from .gpen import GPEN
3737
from .pan import PAN
38+
from .basicvsr_plus_plus import BasicVSRPlusPlus

ppgan/models/generators/basicvsr.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
# Copyright (c) MMEditing Authors.
22

3-
import paddle
4-
53
import numpy as np
64

5+
import paddle
76
import paddle.nn as nn
87
import paddle.nn.functional as F
8+
from paddle.vision.ops import DeformConv2D
99
from ...utils.download import get_path_from_url
1010
from ...modules.init import kaiming_normal_, constant_
11-
1211
from .builder import GENERATORS
1312

1413

@@ -607,3 +606,69 @@ def forward(self, lrs):
607606
outputs[i] = out
608607

609608
return paddle.stack(outputs, axis=1)
609+
610+
611+
class SecondOrderDeformableAlignment(nn.Layer):
612+
"""Second-order deformable alignment module.
613+
Args:
614+
in_channels (int): Same as nn.Conv2d.
615+
out_channels (int): Same as nn.Conv2d.
616+
kernel_size (int or tuple[int]): Same as nn.Conv2d.
617+
stride (int or tuple[int]): Same as nn.Conv2d.
618+
padding (int or tuple[int]): Same as nn.Conv2d.
619+
dilation (int or tuple[int]): Same as nn.Conv2d.
620+
groups (int): Same as nn.Conv2d.
621+
deformable_groups (int).
622+
"""
623+
def __init__(self,
624+
in_channels=128,
625+
out_channels=64,
626+
kernel_size=3,
627+
stride=1,
628+
padding=1,
629+
dilation=1,
630+
groups=1,
631+
deformable_groups=16):
632+
super(SecondOrderDeformableAlignment, self).__init__()
633+
634+
self.conv_offset = nn.Sequential(
635+
nn.Conv2D(3 * out_channels + 4, out_channels, 3, 1, 1),
636+
nn.LeakyReLU(negative_slope=0.1),
637+
nn.Conv2D(out_channels, out_channels, 3, 1, 1),
638+
nn.LeakyReLU(negative_slope=0.1),
639+
nn.Conv2D(out_channels, out_channels, 3, 1, 1),
640+
nn.LeakyReLU(negative_slope=0.1),
641+
nn.Conv2D(out_channels, 27 * deformable_groups, 3, 1, 1),
642+
)
643+
self.dcn = DeformConv2D(in_channels,
644+
out_channels,
645+
kernel_size=kernel_size,
646+
stride=stride,
647+
padding=padding,
648+
dilation=dilation,
649+
deformable_groups=deformable_groups)
650+
self.init_offset()
651+
652+
def init_offset(self):
653+
constant_(self.conv_offset[-1].weight, 0)
654+
constant_(self.conv_offset[-1].bias, 0)
655+
656+
def forward(self, x, extra_feat, flow_1, flow_2):
657+
extra_feat = paddle.concat([extra_feat, flow_1, flow_2], axis=1)
658+
out = self.conv_offset(extra_feat)
659+
o1, o2, mask = paddle.chunk(out, 3, axis=1)
660+
661+
# offset
662+
offset = 10 * paddle.tanh(paddle.concat((o1, o2), axis=1))
663+
offset_1, offset_2 = paddle.chunk(offset, 2, axis=1)
664+
offset_1 = offset_1 + flow_1.flip(1).tile(
665+
[1, offset_1.shape[1] // 2, 1, 1])
666+
offset_2 = offset_2 + flow_2.flip(1).tile(
667+
[1, offset_2.shape[1] // 2, 1, 1])
668+
offset = paddle.concat([offset_1, offset_2], axis=1)
669+
670+
# mask
671+
mask = F.sigmoid(mask)
672+
673+
out = self.dcn(x, offset, mask)
674+
return out

0 commit comments

Comments
 (0)