From 708a2c6dd48e3638cc02ff4ac666f75eb205eccf Mon Sep 17 00:00:00 2001 From: shiyutang <1574572981@qq.com> Date: Tue, 28 Mar 2023 19:11:37 +0800 Subject: [PATCH 1/2] fix_single_card_syncbn_convert --- contrib/MedicalSeg/medicalseg/cvlibs/config.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/contrib/MedicalSeg/medicalseg/cvlibs/config.py b/contrib/MedicalSeg/medicalseg/cvlibs/config.py index d80387e41b..85068f26c0 100644 --- a/contrib/MedicalSeg/medicalseg/cvlibs/config.py +++ b/contrib/MedicalSeg/medicalseg/cvlibs/config.py @@ -21,7 +21,7 @@ import yaml from medicalseg.cvlibs import manager -from medicalseg.utils import logger +from medicalseg.utils import logger, get_sys_env # todo: check and edit the unnecessary components @@ -318,7 +318,11 @@ def model(self) -> paddle.nn.Layer: if not self._model: self._model = self._load_object(model_cfg) - if paddle.get_device() != 'cpu': + + env_info = get_sys_env() + if paddle.get_device() == 'gpu' and env_info['Paddle compiled with cuda'] \ + and env_info['GPUs used'] and paddle.distributed.ParallelEnv().nranks > 1: + self._model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm( self._model) From f4b03e5a20e6dd2128575832f1cb2c3e23246196 Mon Sep 17 00:00:00 2001 From: shiyutang <1574572981@qq.com> Date: Tue, 28 Mar 2023 19:15:01 +0800 Subject: [PATCH 2/2] fix_single_card_syncbn_convert --- contrib/MedicalSeg/medicalseg/cvlibs/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/contrib/MedicalSeg/medicalseg/cvlibs/config.py b/contrib/MedicalSeg/medicalseg/cvlibs/config.py index 85068f26c0..40221470a1 100644 --- a/contrib/MedicalSeg/medicalseg/cvlibs/config.py +++ b/contrib/MedicalSeg/medicalseg/cvlibs/config.py @@ -322,9 +322,9 @@ def model(self) -> paddle.nn.Layer: env_info = get_sys_env() if paddle.get_device() == 'gpu' and env_info['Paddle compiled with cuda'] \ and env_info['GPUs used'] and paddle.distributed.ParallelEnv().nranks > 1: - self._model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm( self._model) + logger.info("Convert bn to sync_bn") return self._model