47
47
BASE_DIR = os .path .expanduser ("~/.paddleocr/" )
48
48
49
49
DEFAULT_OCR_MODEL_VERSION = 'PP-OCR'
50
+ SUPPORT_OCR_MODEL_VERSION = ['PP-OCR' , 'PP-OCRv2' ]
50
51
DEFAULT_STRUCTURE_MODEL_VERSION = 'STRUCTURE'
52
+ SUPPORT_STRUCTURE_MODEL_VERSION = ['STRUCTURE' ]
51
53
MODEL_URLS = {
52
54
'OCR' : {
53
55
'PP-OCRv2' : {
@@ -190,6 +192,7 @@ def parse_args(mMain=True):
190
192
parser .add_argument (
191
193
"--ocr_version" ,
192
194
type = str ,
195
+ choices = SUPPORT_OCR_MODEL_VERSION ,
193
196
default = 'PP-OCRv2' ,
194
197
help = 'OCR Model version, the current model support list is as follows: '
195
198
'1. PP-OCRv2 Support Chinese detection and recognition model. '
@@ -198,6 +201,7 @@ def parse_args(mMain=True):
198
201
parser .add_argument (
199
202
"--structure_version" ,
200
203
type = str ,
204
+ choices = SUPPORT_STRUCTURE_MODEL_VERSION ,
201
205
default = 'STRUCTURE' ,
202
206
help = 'Model version, the current model support list is as follows:'
203
207
' 1. STRUCTURE Support en table structure model.' )
@@ -257,26 +261,20 @@ def get_model_config(type, version, model_type, lang):
257
261
DEFAULT_MODEL_VERSION = DEFAULT_STRUCTURE_MODEL_VERSION
258
262
else :
259
263
raise NotImplementedError
264
+
260
265
model_urls = MODEL_URLS [type ]
261
266
if version not in model_urls :
262
- logger .warning ('version {} not in {}, auto switch to version {}' .format (
263
- version , model_urls .keys (), DEFAULT_MODEL_VERSION ))
264
267
version = DEFAULT_MODEL_VERSION
265
268
if model_type not in model_urls [version ]:
266
269
if model_type in model_urls [DEFAULT_MODEL_VERSION ]:
267
- logger .warning (
268
- 'version {} not support {} models, auto switch to version {}' .
269
- format (version , model_type , DEFAULT_MODEL_VERSION ))
270
270
version = DEFAULT_MODEL_VERSION
271
271
else :
272
272
logger .error ('{} models is not support, we only support {}' .format (
273
273
model_type , model_urls [DEFAULT_MODEL_VERSION ].keys ()))
274
274
sys .exit (- 1 )
275
+
275
276
if lang not in model_urls [version ][model_type ]:
276
277
if lang in model_urls [DEFAULT_MODEL_VERSION ][model_type ]:
277
- logger .warning (
278
- 'lang {} is not support in {}, auto switch to version {}' .
279
- format (lang , version , DEFAULT_MODEL_VERSION ))
280
278
version = DEFAULT_MODEL_VERSION
281
279
else :
282
280
logger .error (
@@ -296,6 +294,8 @@ def __init__(self, **kwargs):
296
294
"""
297
295
params = parse_args (mMain = False )
298
296
params .__dict__ .update (** kwargs )
297
+ assert params .ocr_version in SUPPORT_OCR_MODEL_VERSION , "ocr_version must in {}, but get {}" .format (
298
+ SUPPORT_OCR_MODEL_VERSION , params .ocr_version )
299
299
params .use_gpu = check_gpu (params .use_gpu )
300
300
301
301
if not params .show_log :
@@ -398,6 +398,8 @@ class PPStructure(OCRSystem):
398
398
def __init__ (self , ** kwargs ):
399
399
params = parse_args (mMain = False )
400
400
params .__dict__ .update (** kwargs )
401
+ assert params .structure_version in SUPPORT_STRUCTURE_MODEL_VERSION , "structure_version must in {}, but get {}" .format (
402
+ SUPPORT_STRUCTURE_MODEL_VERSION , params .structure_version )
401
403
params .use_gpu = check_gpu (params .use_gpu )
402
404
403
405
if not params .show_log :
0 commit comments