diff --git a/contrib/MedicalSeg/medicalseg/cvlibs/config.py b/contrib/MedicalSeg/medicalseg/cvlibs/config.py index d80387e41b..40221470a1 100644 --- a/contrib/MedicalSeg/medicalseg/cvlibs/config.py +++ b/contrib/MedicalSeg/medicalseg/cvlibs/config.py @@ -21,7 +21,7 @@ import yaml from medicalseg.cvlibs import manager -from medicalseg.utils import logger +from medicalseg.utils import logger, get_sys_env # todo: check and edit the unnecessary components @@ -318,9 +318,13 @@ def model(self) -> paddle.nn.Layer: if not self._model: self._model = self._load_object(model_cfg) - if paddle.get_device() != 'cpu': + + env_info = get_sys_env() + if paddle.get_device() == 'gpu' and env_info['Paddle compiled with cuda'] \ + and env_info['GPUs used'] and paddle.distributed.ParallelEnv().nranks > 1: self._model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm( self._model) + logger.info("Convert bn to sync_bn") return self._model