|
| 1 | +import argparse |
| 2 | +import base64 |
| 3 | +import os |
| 4 | +import time |
| 5 | +from typing import Union |
| 6 | + |
| 7 | +import cv2 |
| 8 | +import numpy as np |
| 9 | +import paddle |
| 10 | +import paddle.nn as nn |
| 11 | + |
| 12 | +from .swin2sr import Swin2SR |
| 13 | +from paddlehub.module.module import moduleinfo |
| 14 | +from paddlehub.module.module import runnable |
| 15 | +from paddlehub.module.module import serving |
| 16 | + |
| 17 | + |
| 18 | +def cv2_to_base64(image): |
| 19 | + data = cv2.imencode('.jpg', image)[1] |
| 20 | + return base64.b64encode(data.tobytes()).decode('utf8') |
| 21 | + |
| 22 | + |
| 23 | +def base64_to_cv2(b64str): |
| 24 | + data = base64.b64decode(b64str.encode('utf8')) |
| 25 | + data = np.frombuffer(data, np.uint8) |
| 26 | + data = cv2.imdecode(data, cv2.IMREAD_COLOR) |
| 27 | + return data |
| 28 | + |
| 29 | + |
| 30 | +@moduleinfo( |
| 31 | + name='swin2sr_real_sr_x4', |
| 32 | + version='1.0.0', |
| 33 | + type="CV/image_editing", |
| 34 | + author="", |
| 35 | + author_email="", |
| 36 | + summary="SwinV2 Transformer for Compressed Image Super-Resolution and Restoration.", |
| 37 | +) |
| 38 | +class SwinIRMRealSR(nn.Layer): |
| 39 | + |
| 40 | + def __init__(self): |
| 41 | + super(SwinIRMRealSR, self).__init__() |
| 42 | + self.default_pretrained_model_path = os.path.join(self.directory, |
| 43 | + 'Swin2SR_RealworldSR_X4_64_BSRGAN_PSNR.pdparams') |
| 44 | + self.swin2sr = Swin2SR(upscale=4, |
| 45 | + in_chans=3, |
| 46 | + img_size=64, |
| 47 | + window_size=8, |
| 48 | + img_range=1., |
| 49 | + depths=[6, 6, 6, 6, 6, 6], |
| 50 | + embed_dim=180, |
| 51 | + num_heads=[6, 6, 6, 6, 6, 6], |
| 52 | + mlp_ratio=2, |
| 53 | + upsampler='nearest+conv', |
| 54 | + resi_connection='1conv') |
| 55 | + state_dict = paddle.load(self.default_pretrained_model_path) |
| 56 | + self.swin2sr.set_state_dict(state_dict) |
| 57 | + self.swin2sr.eval() |
| 58 | + |
| 59 | + def preprocess(self, img: np.ndarray) -> np.ndarray: |
| 60 | + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| 61 | + img = img.transpose((2, 0, 1)) |
| 62 | + img = img / 255.0 |
| 63 | + return img.astype(np.float32) |
| 64 | + |
| 65 | + def postprocess(self, img: np.ndarray) -> np.ndarray: |
| 66 | + img = img.clip(0, 1) |
| 67 | + img = img * 255.0 |
| 68 | + img = img.transpose((1, 2, 0)) |
| 69 | + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
| 70 | + return img.astype(np.uint8) |
| 71 | + |
| 72 | + def real_sr(self, |
| 73 | + image: Union[str, np.ndarray], |
| 74 | + visualization: bool = True, |
| 75 | + output_dir: str = "swin2sr_real_sr_x4_output") -> np.ndarray: |
| 76 | + if isinstance(image, str): |
| 77 | + _, file_name = os.path.split(image) |
| 78 | + save_name, _ = os.path.splitext(file_name) |
| 79 | + save_name = save_name + '_' + str(int(time.time())) + '.jpg' |
| 80 | + image = cv2.imdecode(np.fromfile(image, dtype=np.uint8), cv2.IMREAD_COLOR) |
| 81 | + elif isinstance(image, np.ndarray): |
| 82 | + save_name = str(int(time.time())) + '.jpg' |
| 83 | + image = image |
| 84 | + else: |
| 85 | + raise Exception("image should be a str / np.ndarray") |
| 86 | + |
| 87 | + with paddle.no_grad(): |
| 88 | + img_input = self.preprocess(image) |
| 89 | + img_input = paddle.to_tensor(img_input[None, ...], dtype=paddle.float32) |
| 90 | + |
| 91 | + img_output = self.swin2sr(img_input) |
| 92 | + img_output = img_output.numpy()[0] |
| 93 | + img_output = self.postprocess(img_output) |
| 94 | + |
| 95 | + if visualization: |
| 96 | + if not os.path.isdir(output_dir): |
| 97 | + os.makedirs(output_dir) |
| 98 | + save_path = os.path.join(output_dir, save_name) |
| 99 | + cv2.imwrite(save_path, img_output) |
| 100 | + |
| 101 | + return img_output |
| 102 | + |
| 103 | + @runnable |
| 104 | + def run_cmd(self, argvs): |
| 105 | + """ |
| 106 | + Run as a command. |
| 107 | + """ |
| 108 | + self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name), |
| 109 | + prog='hub run {}'.format(self.name), |
| 110 | + usage='%(prog)s', |
| 111 | + add_help=True) |
| 112 | + self.parser.add_argument('--input_path', type=str, help="Path to image.") |
| 113 | + self.parser.add_argument('--output_dir', |
| 114 | + type=str, |
| 115 | + default='swin2sr_real_sr_x4_output', |
| 116 | + help="The directory to save output images.") |
| 117 | + args = self.parser.parse_args(argvs) |
| 118 | + self.real_sr(image=args.input_path, visualization=True, output_dir=args.output_dir) |
| 119 | + return 'Results are saved in %s' % args.output_dir |
| 120 | + |
| 121 | + @serving |
| 122 | + def serving_method(self, image, **kwargs): |
| 123 | + """ |
| 124 | + Run as a service. |
| 125 | + """ |
| 126 | + image = base64_to_cv2(image) |
| 127 | + img_output = self.real_sr(image=image, **kwargs) |
| 128 | + |
| 129 | + return cv2_to_base64(img_output) |
0 commit comments