Skip to content

Commit c0aed4f

Browse files
authored
fix lapstyle runtime (#502)
1 parent 22208bf commit c0aed4f

File tree

4 files changed

+25
-33
lines changed

4 files changed

+25
-33
lines changed

applications/tools/lapstyle.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
if __name__ == "__main__":
1010
parser = argparse.ArgumentParser()
11-
parser.add_argument("--content_img", type=str, help="path to content image")
11+
parser.add_argument("--content_img_path",
12+
type=str,
13+
help="path to content image")
1214

1315
parser.add_argument("--output_path",
1416
type=str,
@@ -31,7 +33,7 @@
3133
parser.add_argument("--style_image_path",
3234
type=str,
3335
default=None,
34-
help="if weight_path is not None, path to style image")
36+
help="path to style image")
3537

3638
parser.add_argument("--cpu",
3739
dest="cpu",
@@ -45,6 +47,5 @@
4547

4648
predictor = LapStylePredictor(output=args.output_path,
4749
style=args.style,
48-
weight_path=args.weight_path,
49-
style_image_path=args.style_image_path)
50-
predictor.run(args.content_img)
50+
weight_path=args.weight_path)
51+
predictor.run(args.content_img_path, args.style_image_path)

docs/en_US/tutorials/lap_style.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,19 @@ Artistic style transfer aims at migrating the style from an example image to a c
1414

1515

1616
## 2 Quick experience
17+
Here four style images:
18+
| [StarryNew](https://user-images.githubusercontent.com/79366697/118655415-1ec8c000-b81c-11eb-8002-90bf8d477860.png) | [Stars](https://user-images.githubusercontent.com/79366697/118655423-20928380-b81c-11eb-92bd-0deeb320ff14.png) | [Ocean](https://user-images.githubusercontent.com/79366697/118655407-1c666600-b81c-11eb-83a6-300ee1952415.png) | [Circuit](https://user-images.githubusercontent.com/79366697/118655399-196b7580-b81c-11eb-8bc5-d5ece80c18ba.jpg)|
19+
1720
```
18-
python applications/tools/lapstyle.py --content_img ${PATH_OF_CONTENT_IMG}
21+
python applications/tools/lapstyle.py --content_img_path ${PATH_OF_CONTENT_IMG} --style_image_path ${PATH_OF_STYLE_IMG}
1922
```
2023
### Parameters
2124

22-
- `--content_img (str)`: path to content image.
25+
- `--content_img_path (str)`: path to content image.
26+
- `--style_image_path (str)`: path to style image.
2327
- `--output_path (str)`: path to output image dir, default value:`output_dir`.
2428
- `--weight_path (str)`: path to model weight path, if `weight_path` is `None`, the pre-training model will be downloaded automatically, default value:`None`.
2529
- `--style (str)`: style of output image, if `weight_path` is `None`, `style` can be chosen in `starrynew`, `circuit`, `ocean` and `stars`, default value:`starrynew`.
26-
- `--style_image_path (str)`: path to style image, it need to input when `weight_path` is not `None`, default value:`None`.
2730

2831
## 3 How to use
2932

docs/zh_CN/tutorials/lap_style.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,21 @@ PaddleGAN为大家提供了四种不同艺术风格的预训练模型,风格
2323
| :----------------------------------------------------------: | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
2424
| <img src='https://user-images.githubusercontent.com/48054808/130388598-1e2b27e7-be66-49df-84d5-57b4dc7730d6.png' width='300'/> | <img src='https://user-images.githubusercontent.com/48054808/130388606-78a3a682-2ae4-4753-a07c-671a46930de8.png' width='300'/> | <img src='https://user-images.githubusercontent.com/48054808/130388615-b04197b3-2fdf-4494-ad17-490afe0fd1cd.png' width='300'/> | <img src='https://user-images.githubusercontent.com/48054808/130388623-2eec0cca-fee1-47f0-8398-cae0171aa7a5.png' width='300'/> | <img src='https://user-images.githubusercontent.com/48054808/130388624-f27d0712-ba71-42b2-ada4-44bf60e36512.png' width='300'/> |
2525

26+
4个风格图像下载地址如下:
27+
| [StarryNew](https://user-images.githubusercontent.com/79366697/118655415-1ec8c000-b81c-11eb-8002-90bf8d477860.png) | [Stars](https://user-images.githubusercontent.com/79366697/118655423-20928380-b81c-11eb-92bd-0deeb320ff14.png) | [Ocean](https://user-images.githubusercontent.com/79366697/118655407-1c666600-b81c-11eb-83a6-300ee1952415.png) | [Circuit](https://user-images.githubusercontent.com/79366697/118655399-196b7580-b81c-11eb-8bc5-d5ece80c18ba.jpg)|
28+
2629
只需运行下面的代码即可迁移至指定风格:
2730

2831
```
29-
python applications/tools/lapstyle.py --content_img ${PATH_OF_CONTENT_IMG}
32+
python applications/tools/lapstyle.py --content_img_path ${PATH_OF_CONTENT_IMG} --style_image_path ${PATH_OF_STYLE_IMG}
3033
```
3134
### **参数**
3235

33-
- `--content_img (str)`: 输入的内容图像路径。
36+
- `--content_img_path (str)`: 输入的内容图像路径。
37+
- `--style_image_path (str)`: 输入的风格图像路径。
3438
- `--output_path (str)`: 输出的图像路径,默认为`output_dir`
3539
- `--weight_path (str)`: 模型权重路径,设置`None`时会自行下载预训练模型,默认为`None`
3640
- `--style (str)`: 生成图像风格,当`weight_path``None`时,可以在`starrynew`, `circuit`, `ocean``stars`中选择,默认为`starrynew`
37-
- `--style_image_path (str)`: 输入的风格图像路径,当`weight_path`不为`None`时需要输入,默认为`None`
3841

3942
## 3. 模型训练
4043

ppgan/apps/lapstyle_predictor.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,9 @@
2828
from .base_predictor import BasePredictor
2929

3030
LapStyle_circuit_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_circuit.pdparams'
31-
circuit_IMG_URL = 'https://user-images.githubusercontent.com/79366697/118655399-196b7580-b81c-11eb-8bc5-d5ece80c18ba.jpg'
3231
LapStyle_ocean_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_ocean.pdparams'
33-
ocean_IMG_URL = 'https://user-images.githubusercontent.com/79366697/118655407-1c666600-b81c-11eb-83a6-300ee1952415.png'
3432
LapStyle_starrynew_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_starrynew.pdparams'
35-
starrynew_IMG_URL = 'https://user-images.githubusercontent.com/79366697/118655415-1ec8c000-b81c-11eb-8002-90bf8d477860.png'
3633
LapStyle_stars_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_stars.pdparams'
37-
stars_IMG_URL = 'https://user-images.githubusercontent.com/79366697/118655423-20928380-b81c-11eb-92bd-0deeb320ff14.png'
3834

3935

4036
def img(img):
@@ -117,8 +113,7 @@ class LapStylePredictor(BasePredictor):
117113
def __init__(self,
118114
output='output_dir',
119115
style='starrynew',
120-
weight_path=None,
121-
style_image_path=None):
116+
weight_path=None):
122117
self.input = input
123118
self.output = os.path.join(output, 'LapStyle')
124119
if not os.path.exists(self.output):
@@ -127,31 +122,18 @@ def __init__(self,
127122
self.net_dec = DecoderNet()
128123
self.net_rev = RevisionNet()
129124
self.net_rev_2 = RevisionNet()
125+
130126
if weight_path is None:
131-
self.style_image_path = os.path.join(self.output, 'style.png')
132127
if style == 'starrynew':
133128
weight_path = get_path_from_url(LapStyle_starrynew_WEIGHT_URL)
134-
urllib.request.urlretrieve(starrynew_IMG_URL,
135-
filename=self.style_image_path)
136129
elif style == 'circuit':
137130
weight_path = get_path_from_url(LapStyle_circuit_WEIGHT_URL)
138-
urllib.request.urlretrieve(circuit_IMG_URL,
139-
filename=self.style_image_path)
140131
elif style == 'ocean':
141132
weight_path = get_path_from_url(LapStyle_ocean_WEIGHT_URL)
142-
urllib.request.urlretrieve(ocean_IMG_URL,
143-
filename=self.style_image_path)
144133
elif style == 'stars':
145134
weight_path = get_path_from_url(LapStyle_stars_WEIGHT_URL)
146-
urllib.request.urlretrieve(stars_IMG_URL,
147-
filename=self.style_image_path)
148135
else:
149136
raise Exception(f'has not implemented {style}.')
150-
else:
151-
if style_image_path is None:
152-
raise Exception('style_image_path can not be None.')
153-
else:
154-
self.style_image_path = style_image_path
155137
self.net_enc.set_dict(paddle.load(weight_path)['net_enc'])
156138
self.net_enc.eval()
157139
self.net_dec.set_dict(paddle.load(weight_path)['net_dec'])
@@ -161,12 +143,15 @@ def __init__(self,
161143
self.net_rev_2.set_dict(paddle.load(weight_path)['net_rev_2'])
162144
self.net_rev_2.eval()
163145

164-
def run(self, content_img_path):
146+
def run(self, content_img_path, style_image_path):
165147
content_img, style_img, h, w = img_read(content_img_path,
166-
self.style_image_path)
148+
style_image_path)
167149
content_img_visual = tensor2img(content_img, min_max=(0., 1.))
168150
content_img_visual = cv.cvtColor(content_img_visual, cv.COLOR_RGB2BGR)
169151
cv.imwrite(os.path.join(self.output, 'content.png'), content_img_visual)
152+
style_img_visual = tensor2img(style_img, min_max=(0., 1.))
153+
style_img_visual = cv.cvtColor(style_img_visual, cv.COLOR_RGB2BGR)
154+
cv.imwrite(os.path.join(self.output, 'style.png'), style_img_visual)
170155
pyr_ci = make_laplace_pyramid(content_img, 2)
171156
pyr_si = make_laplace_pyramid(style_img, 2)
172157
pyr_ci.append(content_img)

0 commit comments

Comments
 (0)