Skip to content

Commit beec7ed

Browse files
authored
add swin2sr_real_sr_x4 (#2085)
1 parent 36ce478 commit beec7ed

File tree

4 files changed

+1418
-0
lines changed

4 files changed

+1418
-0
lines changed
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# swin2sr_real_sr_x4
2+
3+
|模型名称|swin2sr_real_sr_x4|
4+
| :--- | :---: |
5+
|类别|图像-图像编辑|
6+
|网络|Swin2SR|
7+
|数据集|DIV2K / Flickr2K|
8+
|是否支持Fine-tuning||
9+
|模型大小|68.4MB|
10+
|指标|-|
11+
|最新更新日期|2022-10-25|
12+
13+
14+
## 一、模型基本信息
15+
16+
- ### 应用效果展示
17+
18+
- 网络结构:
19+
<p align="center">
20+
<img src="https://ai-studio-static-online.cdn.bcebos.com/884d4d4472b44bf1879606374ed64a7e8d2fec0bcf034285a5cecfc582e8cd65" hspace='10'/> <br />
21+
</p>
22+
23+
- 样例结果示例:
24+
<p align="center">
25+
<img src="https://ai-studio-static-online.cdn.bcebos.com/c5517af6c3f944c4b281aedc417a4f8c02c0a969d0dd494c9106c4ff2709fc2f" hspace='10'/>
26+
<img src="https://ai-studio-static-online.cdn.bcebos.com/183c5821029f45bbb78d1700ab8297baabba15f82ab4467e88414bbed056ccf0" hspace='10'/>
27+
</p>
28+
29+
- ### 模型介绍
30+
31+
- Swin2SR 是一个基于 Swin Transformer v2 的图像超分辨率模型。swin2sr_real_sr_x4 是基于 Swin2SR 的 4 倍现实图像超分辨率模型。
32+
33+
34+
35+
## 二、安装
36+
37+
- ### 1、环境依赖
38+
39+
- paddlepaddle >= 2.0.0
40+
41+
- paddlehub >= 2.0.0
42+
43+
- ### 2.安装
44+
45+
- ```shell
46+
$ hub install swin2sr_real_sr_x4
47+
```
48+
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
49+
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
50+
51+
## 三、模型API预测
52+
- ### 1、命令行预测
53+
54+
```shell
55+
$ hub run swin2sr_real_sr_x4 \
56+
--input_path "/PATH/TO/IMAGE" \
57+
--output_dir "swin2sr_real_sr_x4_output"
58+
```
59+
60+
- ### 2、预测代码示例
61+
62+
```python
63+
import paddlehub as hub
64+
import cv2
65+
66+
module = hub.Module(name="swin2sr_real_sr_x4")
67+
result = module.real_sr(
68+
image=cv2.imread('/PATH/TO/IMAGE'),
69+
visualization=True,
70+
output_dir='swin2sr_real_sr_x4_output'
71+
)
72+
```
73+
74+
- ### 3、API
75+
76+
```python
77+
def real_sr(
78+
image: Union[str, numpy.ndarray],
79+
visualization: bool = True,
80+
output_dir: str = "swin2sr_real_sr_x4_output"
81+
) -> numpy.ndarray
82+
```
83+
84+
- 超分辨率 API
85+
86+
- **参数**
87+
88+
* image (Union\[str, numpy.ndarray\]): 图片数据,ndarray.shape 为 \[H, W, C\],BGR格式;
89+
* visualization (bool): 是否将识别结果保存为图片文件;
90+
* output\_dir (str): 保存处理结果的文件目录。
91+
92+
- **返回**
93+
94+
* res (numpy.ndarray): 图像超分辨率结果 (BGR);
95+
96+
## 四、服务部署
97+
98+
- PaddleHub Serving 可以部署一个图像超分辨率的在线服务。
99+
100+
- ### 第一步:启动PaddleHub Serving
101+
102+
- 运行启动命令:
103+
104+
```shell
105+
$ hub serving start -m swin2sr_real_sr_x4
106+
```
107+
108+
- 这样就完成了一个图像超分辨率服务化API的部署,默认端口号为8866。
109+
110+
- ### 第二步:发送预测请求
111+
112+
- 配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
113+
114+
```python
115+
import requests
116+
import json
117+
import base64
118+
119+
import cv2
120+
import numpy as np
121+
122+
def cv2_to_base64(image):
123+
data = cv2.imencode('.jpg', image)[1]
124+
return base64.b64encode(data.tobytes()).decode('utf8')
125+
126+
def base64_to_cv2(b64str):
127+
data = base64.b64decode(b64str.encode('utf8'))
128+
data = np.frombuffer(data, np.uint8)
129+
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
130+
return data
131+
132+
# 发送HTTP请求
133+
org_im = cv2.imread('/PATH/TO/IMAGE')
134+
data = {
135+
'image': cv2_to_base64(org_im)
136+
}
137+
headers = {"Content-type": "application/json"}
138+
url = "http://127.0.0.1:8866/predict/swin2sr_real_sr_x4"
139+
r = requests.post(url=url, headers=headers, data=json.dumps(data))
140+
141+
# 结果转换
142+
results = r.json()['results']
143+
results = base64_to_cv2(results)
144+
145+
# 保存结果
146+
cv2.imwrite('output.jpg', results)
147+
```
148+
149+
## 五、参考资料
150+
151+
* 论文:[Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration](https://arxiv.org/abs/2209.11345)
152+
153+
* 官方实现:[mv-lab/swin2sr](https://github.com/mv-lab/swin2sr/)
154+
155+
## 六、更新历史
156+
157+
* 1.0.0
158+
159+
初始发布
160+
161+
```shell
162+
$ hub install swin2sr_real_sr_x4==1.0.0
163+
```
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import argparse
2+
import base64
3+
import os
4+
import time
5+
from typing import Union
6+
7+
import cv2
8+
import numpy as np
9+
import paddle
10+
import paddle.nn as nn
11+
12+
from .swin2sr import Swin2SR
13+
from paddlehub.module.module import moduleinfo
14+
from paddlehub.module.module import runnable
15+
from paddlehub.module.module import serving
16+
17+
18+
def cv2_to_base64(image):
19+
data = cv2.imencode('.jpg', image)[1]
20+
return base64.b64encode(data.tobytes()).decode('utf8')
21+
22+
23+
def base64_to_cv2(b64str):
24+
data = base64.b64decode(b64str.encode('utf8'))
25+
data = np.frombuffer(data, np.uint8)
26+
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
27+
return data
28+
29+
30+
@moduleinfo(
31+
name='swin2sr_real_sr_x4',
32+
version='1.0.0',
33+
type="CV/image_editing",
34+
author="",
35+
author_email="",
36+
summary="SwinV2 Transformer for Compressed Image Super-Resolution and Restoration.",
37+
)
38+
class SwinIRMRealSR(nn.Layer):
39+
40+
def __init__(self):
41+
super(SwinIRMRealSR, self).__init__()
42+
self.default_pretrained_model_path = os.path.join(self.directory,
43+
'Swin2SR_RealworldSR_X4_64_BSRGAN_PSNR.pdparams')
44+
self.swin2sr = Swin2SR(upscale=4,
45+
in_chans=3,
46+
img_size=64,
47+
window_size=8,
48+
img_range=1.,
49+
depths=[6, 6, 6, 6, 6, 6],
50+
embed_dim=180,
51+
num_heads=[6, 6, 6, 6, 6, 6],
52+
mlp_ratio=2,
53+
upsampler='nearest+conv',
54+
resi_connection='1conv')
55+
state_dict = paddle.load(self.default_pretrained_model_path)
56+
self.swin2sr.set_state_dict(state_dict)
57+
self.swin2sr.eval()
58+
59+
def preprocess(self, img: np.ndarray) -> np.ndarray:
60+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
61+
img = img.transpose((2, 0, 1))
62+
img = img / 255.0
63+
return img.astype(np.float32)
64+
65+
def postprocess(self, img: np.ndarray) -> np.ndarray:
66+
img = img.clip(0, 1)
67+
img = img * 255.0
68+
img = img.transpose((1, 2, 0))
69+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
70+
return img.astype(np.uint8)
71+
72+
def real_sr(self,
73+
image: Union[str, np.ndarray],
74+
visualization: bool = True,
75+
output_dir: str = "swin2sr_real_sr_x4_output") -> np.ndarray:
76+
if isinstance(image, str):
77+
_, file_name = os.path.split(image)
78+
save_name, _ = os.path.splitext(file_name)
79+
save_name = save_name + '_' + str(int(time.time())) + '.jpg'
80+
image = cv2.imdecode(np.fromfile(image, dtype=np.uint8), cv2.IMREAD_COLOR)
81+
elif isinstance(image, np.ndarray):
82+
save_name = str(int(time.time())) + '.jpg'
83+
image = image
84+
else:
85+
raise Exception("image should be a str / np.ndarray")
86+
87+
with paddle.no_grad():
88+
img_input = self.preprocess(image)
89+
img_input = paddle.to_tensor(img_input[None, ...], dtype=paddle.float32)
90+
91+
img_output = self.swin2sr(img_input)
92+
img_output = img_output.numpy()[0]
93+
img_output = self.postprocess(img_output)
94+
95+
if visualization:
96+
if not os.path.isdir(output_dir):
97+
os.makedirs(output_dir)
98+
save_path = os.path.join(output_dir, save_name)
99+
cv2.imwrite(save_path, img_output)
100+
101+
return img_output
102+
103+
@runnable
104+
def run_cmd(self, argvs):
105+
"""
106+
Run as a command.
107+
"""
108+
self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name),
109+
prog='hub run {}'.format(self.name),
110+
usage='%(prog)s',
111+
add_help=True)
112+
self.parser.add_argument('--input_path', type=str, help="Path to image.")
113+
self.parser.add_argument('--output_dir',
114+
type=str,
115+
default='swin2sr_real_sr_x4_output',
116+
help="The directory to save output images.")
117+
args = self.parser.parse_args(argvs)
118+
self.real_sr(image=args.input_path, visualization=True, output_dir=args.output_dir)
119+
return 'Results are saved in %s' % args.output_dir
120+
121+
@serving
122+
def serving_method(self, image, **kwargs):
123+
"""
124+
Run as a service.
125+
"""
126+
image = base64_to_cv2(image)
127+
img_output = self.real_sr(image=image, **kwargs)
128+
129+
return cv2_to_base64(img_output)

0 commit comments

Comments
 (0)