Skip to content

Commit 0d0389f

Browse files
authored
[Other] Update example codes using download_model api (#486)
* update example codes using download_model api * add resource files to repo * add resource files to repo * fix * reduce unused lib * fix
1 parent b0a30a7 commit 0d0389f

File tree

6 files changed

+63
-12
lines changed

6 files changed

+63
-12
lines changed

examples/vision/detection/paddledetection/python/infer_ppyoloe.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
1-
import fastdeploy as fd
21
import cv2
32
import os
43

4+
import fastdeploy as fd
5+
import fastdeploy.utils
6+
57

68
def parse_arguments():
79
import argparse
810
import ast
911
parser = argparse.ArgumentParser()
1012
parser.add_argument(
1113
"--model_dir",
12-
required=True,
14+
default=None,
1315
help="Path of PaddleDetection model directory")
1416
parser.add_argument(
15-
"--image", required=True, help="Path of test image file.")
17+
"--image", default=None, help="Path of test image file.")
1618
parser.add_argument(
1719
"--device",
1820
type=str,
@@ -39,17 +41,26 @@ def build_option(args):
3941

4042
args = parse_arguments()
4143

42-
model_file = os.path.join(args.model_dir, "model.pdmodel")
43-
params_file = os.path.join(args.model_dir, "model.pdiparams")
44-
config_file = os.path.join(args.model_dir, "infer_cfg.yml")
44+
if args.model_dir is None:
45+
model_dir = fd.download_model(name='ppyoloe_crn_l_300e_coco')
46+
else:
47+
model_dir = args.model_dir
48+
49+
model_file = os.path.join(model_dir, "model.pdmodel")
50+
params_file = os.path.join(model_dir, "model.pdiparams")
51+
config_file = os.path.join(model_dir, "infer_cfg.yml")
4552

4653
# 配置runtime,加载模型
4754
runtime_option = build_option(args)
4855
model = fd.vision.detection.PPYOLOE(
4956
model_file, params_file, config_file, runtime_option=runtime_option)
5057

5158
# 预测图片检测结果
52-
im = cv2.imread(args.image)
59+
if args.image is None:
60+
image = fd.utils.get_detection_test_image()
61+
else:
62+
image = args.image
63+
im = cv2.imread(image)
5364
result = model.predict(im.copy())
5465
print(result)
5566

examples/vision/detection/yolor/python/infer.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1-
import fastdeploy as fd
21
import cv2
32

3+
import fastdeploy as fd
4+
import fastdeploy.utils
5+
46

57
def parse_arguments():
68
import argparse
79
import ast
810
parser = argparse.ArgumentParser()
911
parser.add_argument(
10-
"--model", required=True, help="Path of yolor onnx model.")
12+
"--model", default=None, help="Path of yolor onnx model.")
1113
parser.add_argument(
12-
"--image", required=True, help="Path of test image file.")
14+
"--image", default=None, help="Path of test image file.")
1315
parser.add_argument(
1416
"--device",
1517
type=str,
@@ -39,10 +41,20 @@ def build_option(args):
3941

4042
# 配置runtime,加载模型
4143
runtime_option = build_option(args)
42-
model = fd.vision.detection.YOLOR(args.model, runtime_option=runtime_option)
44+
if args.model is None:
45+
model = fd.download_model(name='YOLOR-W6')
46+
else:
47+
model = args.model
48+
49+
model = fd.vision.detection.YOLOR(model, runtime_option=runtime_option)
4350

4451
# 预测图片检测结果
45-
im = cv2.imread(args.image)
52+
if args.image is None:
53+
image = fd.utils.get_detection_test_image()
54+
else:
55+
image = args.image
56+
57+
im = cv2.imread(image)
4658
result = model.predict(im.copy())
4759
print(result)
4860

python/fastdeploy/download.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,5 +245,6 @@ def download_model(name: str,
245245
except FileExistsError:
246246
pass
247247
print('Successfully download model at path: {}'.format(fullpath))
248+
return fullpath
248249
else:
249250
print('ERROR: Could not find a model named {}'.format(name))

python/fastdeploy/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
1415
from .profile import profile
16+
from .example_resource import get_detection_test_image
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import fastdeploy.download as download
15+
import fastdeploy.utils.hub_env as hubenv
16+
17+
18+
def get_detection_test_image(path=None):
19+
if path is None:
20+
path = hubenv.RESOURCE_HOME
21+
fullpath = download.download(
22+
url='https://bj.bcebos.com/paddlehub/fastdeploy/example/detection_test_image.jpg',
23+
path=path)
24+
return fullpath

python/fastdeploy/utils/hub_env.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,4 @@ def _get_sub_home(directory):
5252
HUB_HOME = _get_hub_home()
5353
MODEL_HOME = _get_sub_home('models')
5454
CONF_HOME = _get_sub_home('conf')
55+
RESOURCE_HOME = _get_sub_home('resources')

0 commit comments

Comments
 (0)