File tree Expand file tree Collapse file tree 1 file changed +4
-0
lines changed
paddlex/inference/models_new/base/predictor Expand file tree Collapse file tree 1 file changed +4
-0
lines changed Original file line number Diff line number Diff line change @@ -39,6 +39,7 @@ def __init__(
39
39
model_dir : str ,
40
40
config : Dict [str , Any ] = None ,
41
41
device : str = None ,
42
+ batch_size : int = 1 ,
42
43
pp_option : PaddlePredictorOption = None ,
43
44
) -> None :
44
45
"""Initializes the BasicPredictor.
@@ -47,6 +48,7 @@ def __init__(
47
48
model_dir (str): The directory where the model files are stored.
48
49
config (Dict[str, Any], optional): The configuration dictionary. Defaults to None.
49
50
device (str, optional): The device to run the inference engine on. Defaults to None.
51
+ batch_size (int, optional): The batch size to predict. Defaults to 1.
50
52
pp_option (PaddlePredictorOption, optional): The inference engine options. Defaults to None.
51
53
"""
52
54
super ().__init__ (model_dir = model_dir , config = config )
@@ -63,6 +65,8 @@ def __init__(
63
65
if trt_dynamic_shapes :
64
66
pp_option .trt_dynamic_shapes = trt_dynamic_shapes
65
67
self .pp_option = pp_option
68
+ self .pp_option .batch_size = batch_size
69
+ self .batch_sampler .batch_size = batch_size
66
70
67
71
logging .debug (f"{ self .__class__ .__name__ } : { self .model_dir } " )
68
72
self .benchmark = benchmark
You can’t perform that action at this time.
0 commit comments