Skip to content

Commit c910bf8

Browse files
authored
Merge pull request #5222 from WenmuZhou/update_whl_dygraph
rm model type check
2 parents 5a537bb + c0ce890 commit c910bf8

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

paddleocr.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@
4747
BASE_DIR = os.path.expanduser("~/.paddleocr/")
4848

4949
DEFAULT_OCR_MODEL_VERSION = 'PP-OCR'
50+
SUPPORT_OCR_MODEL_VERSION = ['PP-OCR', 'PP-OCRv2']
5051
DEFAULT_STRUCTURE_MODEL_VERSION = 'STRUCTURE'
52+
SUPPORT_STRUCTURE_MODEL_VERSION = ['STRUCTURE']
5153
MODEL_URLS = {
5254
'OCR': {
5355
'PP-OCRv2': {
@@ -190,6 +192,7 @@ def parse_args(mMain=True):
190192
parser.add_argument(
191193
"--ocr_version",
192194
type=str,
195+
choices=SUPPORT_OCR_MODEL_VERSION,
193196
default='PP-OCRv2',
194197
help='OCR Model version, the current model support list is as follows: '
195198
'1. PP-OCRv2 Support Chinese detection and recognition model. '
@@ -198,6 +201,7 @@ def parse_args(mMain=True):
198201
parser.add_argument(
199202
"--structure_version",
200203
type=str,
204+
choices=SUPPORT_STRUCTURE_MODEL_VERSION,
201205
default='STRUCTURE',
202206
help='Model version, the current model support list is as follows:'
203207
' 1. STRUCTURE Support en table structure model.')
@@ -257,26 +261,20 @@ def get_model_config(type, version, model_type, lang):
257261
DEFAULT_MODEL_VERSION = DEFAULT_STRUCTURE_MODEL_VERSION
258262
else:
259263
raise NotImplementedError
264+
260265
model_urls = MODEL_URLS[type]
261266
if version not in model_urls:
262-
logger.warning('version {} not in {}, auto switch to version {}'.format(
263-
version, model_urls.keys(), DEFAULT_MODEL_VERSION))
264267
version = DEFAULT_MODEL_VERSION
265268
if model_type not in model_urls[version]:
266269
if model_type in model_urls[DEFAULT_MODEL_VERSION]:
267-
logger.warning(
268-
'version {} not support {} models, auto switch to version {}'.
269-
format(version, model_type, DEFAULT_MODEL_VERSION))
270270
version = DEFAULT_MODEL_VERSION
271271
else:
272272
logger.error('{} models is not support, we only support {}'.format(
273273
model_type, model_urls[DEFAULT_MODEL_VERSION].keys()))
274274
sys.exit(-1)
275+
275276
if lang not in model_urls[version][model_type]:
276277
if lang in model_urls[DEFAULT_MODEL_VERSION][model_type]:
277-
logger.warning(
278-
'lang {} is not support in {}, auto switch to version {}'.
279-
format(lang, version, DEFAULT_MODEL_VERSION))
280278
version = DEFAULT_MODEL_VERSION
281279
else:
282280
logger.error(
@@ -296,6 +294,8 @@ def __init__(self, **kwargs):
296294
"""
297295
params = parse_args(mMain=False)
298296
params.__dict__.update(**kwargs)
297+
assert params.ocr_version in SUPPORT_OCR_MODEL_VERSION, "ocr_version must in {}, but get {}".format(
298+
SUPPORT_OCR_MODEL_VERSION, params.ocr_version)
299299
params.use_gpu = check_gpu(params.use_gpu)
300300

301301
if not params.show_log:
@@ -398,6 +398,8 @@ class PPStructure(OCRSystem):
398398
def __init__(self, **kwargs):
399399
params = parse_args(mMain=False)
400400
params.__dict__.update(**kwargs)
401+
assert params.structure_version in SUPPORT_STRUCTURE_MODEL_VERSION, "structure_version must in {}, but get {}".format(
402+
SUPPORT_STRUCTURE_MODEL_VERSION, params.structure_version)
401403
params.use_gpu = check_gpu(params.use_gpu)
402404

403405
if not params.show_log:

0 commit comments

Comments
 (0)