Skip to content

Commit 64fd484

Browse files
authored
[medicalseg]Fix syncbn (#3109)
* fix_single_card_syncbn_convert
1 parent 57ad510 commit 64fd484

File tree

1 file changed

+6
-2
lines changed
  • contrib/MedicalSeg/medicalseg/cvlibs

1 file changed

+6
-2
lines changed

contrib/MedicalSeg/medicalseg/cvlibs/config.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import yaml
2222

2323
from medicalseg.cvlibs import manager
24-
from medicalseg.utils import logger
24+
from medicalseg.utils import logger, get_sys_env
2525

2626
# todo: check and edit the unnecessary components
2727

@@ -318,9 +318,13 @@ def model(self) -> paddle.nn.Layer:
318318

319319
if not self._model:
320320
self._model = self._load_object(model_cfg)
321-
if paddle.get_device() != 'cpu':
321+
322+
env_info = get_sys_env()
323+
if paddle.get_device() == 'gpu' and env_info['Paddle compiled with cuda'] \
324+
and env_info['GPUs used'] and paddle.distributed.ParallelEnv().nranks > 1:
322325
self._model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(
323326
self._model)
327+
logger.info("Convert bn to sync_bn")
324328

325329
return self._model
326330

0 commit comments

Comments
 (0)