Skip to content

Commit 2b72192

Browse files
authored
support to pass batch_size when create model (#2901)
1 parent d521203 commit 2b72192

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

paddlex/inference/models_new/base/predictor/basic_predictor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(
3939
model_dir: str,
4040
config: Dict[str, Any] = None,
4141
device: str = None,
42+
batch_size: int = 1,
4243
pp_option: PaddlePredictorOption = None,
4344
) -> None:
4445
"""Initializes the BasicPredictor.
@@ -47,6 +48,7 @@ def __init__(
4748
model_dir (str): The directory where the model files are stored.
4849
config (Dict[str, Any], optional): The configuration dictionary. Defaults to None.
4950
device (str, optional): The device to run the inference engine on. Defaults to None.
51+
batch_size (int, optional): The batch size to predict. Defaults to 1.
5052
pp_option (PaddlePredictorOption, optional): The inference engine options. Defaults to None.
5153
"""
5254
super().__init__(model_dir=model_dir, config=config)
@@ -63,6 +65,8 @@ def __init__(
6365
if trt_dynamic_shapes:
6466
pp_option.trt_dynamic_shapes = trt_dynamic_shapes
6567
self.pp_option = pp_option
68+
self.pp_option.batch_size = batch_size
69+
self.batch_sampler.batch_size = batch_size
6670

6771
logging.debug(f"{self.__class__.__name__}: {self.model_dir}")
6872
self.benchmark = benchmark

0 commit comments

Comments
 (0)