Skip to content

Commit 15747a0

Browse files
authored
update dist api (#2444)
1 parent f49036b commit 15747a0

File tree

7 files changed

+16
-21
lines changed

7 files changed

+16
-21
lines changed

ppdet/engine/callbacks.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import numpy as np
2424

2525
import paddle
26-
from paddle.distributed import ParallelEnv
26+
import paddle.distributed as dist
2727

2828
from ppdet.utils.checkpoint import save_model
2929
from ppdet.optimizer import ModelEMA
@@ -81,7 +81,7 @@ def __init__(self, model):
8181
super(LogPrinter, self).__init__(model)
8282

8383
def on_step_end(self, status):
84-
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
84+
if dist.get_world_size() < 2 or dist.get_rank() == 0:
8585
mode = status['mode']
8686
if mode == 'train':
8787
epoch_id = status['epoch_id']
@@ -129,7 +129,7 @@ def on_step_end(self, status):
129129
logger.info("Eval iter: {}".format(step_id))
130130

131131
def on_epoch_end(self, status):
132-
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
132+
if dist.get_world_size() < 2 or dist.get_rank() == 0:
133133
mode = status['mode']
134134
if mode == 'eval':
135135
sample_num = status['sample_num']
@@ -160,7 +160,7 @@ def on_epoch_end(self, status):
160160
epoch_id = status['epoch_id']
161161
weight = None
162162
save_name = None
163-
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
163+
if dist.get_world_size() < 2 or dist.get_rank() == 0:
164164
if mode == 'train':
165165
end_epoch = self.model.cfg.epoch
166166
if epoch_id % self.model.cfg.snapshot_epoch == 0 or epoch_id == end_epoch - 1:
@@ -224,7 +224,7 @@ def __init__(self, model):
224224

225225
def on_step_end(self, status):
226226
mode = status['mode']
227-
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
227+
if dist.get_world_size() < 2 or dist.get_rank() == 0:
228228
if mode == 'train':
229229
training_staus = status['training_staus']
230230
for loss_name, loss_value in training_staus.get().items():
@@ -248,7 +248,7 @@ def on_step_end(self, status):
248248

249249
def on_epoch_end(self, status):
250250
mode = status['mode']
251-
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
251+
if dist.get_world_size() < 2 or dist.get_rank() == 0:
252252
if mode == 'eval':
253253
for metric in self.model._metrics:
254254
for key, map_value in metric.get_results().items():

ppdet/engine/env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import numpy as np
2222

2323
import paddle
24-
from paddle.distributed import ParallelEnv, fleet
24+
from paddle.distributed import fleet
2525

2626
__all__ = ['init_parallel_env', 'set_random_seed', 'init_fleet_env']
2727

ppdet/engine/trainer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
from PIL import Image
2525

2626
import paddle
27-
from paddle.distributed import ParallelEnv, fleet
27+
import paddle.distributed as dist
28+
from paddle.distributed import fleet
2829
from paddle import amp
2930
from paddle.static import InputSpec
3031

@@ -83,8 +84,8 @@ def __init__(self, cfg, mode='train'):
8384
self.optimizer = create('OptimizerBuilder')(self.lr,
8485
self.model.parameters())
8586

86-
self._nranks = ParallelEnv().nranks
87-
self._local_rank = ParallelEnv().local_rank
87+
self._nranks = dist.get_world_size()
88+
self._local_rank = dist.get_rank()
8889

8990
self.status = {}
9091

ppdet/utils/logger.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import os
1818
import sys
1919

20-
from paddle.distributed import ParallelEnv
20+
import paddle.distributed as dist
2121

2222
__all__ = ['setup_logger']
2323

@@ -47,7 +47,7 @@ def setup_logger(name="ppdet", output=None):
4747
"[%(asctime)s] %(name)s %(levelname)s: %(message)s",
4848
datefmt="%m/%d %H:%M:%S")
4949
# stdout logging: master only
50-
local_rank = ParallelEnv().local_rank
50+
local_rank = dist.get_rank()
5151
if local_rank == 0:
5252
ch = logging.StreamHandler(stream=sys.stdout)
5353
ch.setLevel(logging.DEBUG)

tools/eval.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
warnings.filterwarnings('ignore')
2828

2929
import paddle
30-
from paddle.distributed import ParallelEnv
3130

3231
from ppdet.core.workspace import load_config, merge_config
3332
from ppdet.utils.check import check_gpu, check_version, check_config
@@ -115,8 +114,7 @@ def main():
115114
check_gpu(cfg.use_gpu)
116115
check_version()
117116

118-
place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu'
119-
place = paddle.set_device(place)
117+
place = paddle.set_device('gpu' if cfg.use_gpu else 'cpu')
120118

121119
run(FLAGS, cfg)
122120

tools/infer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import glob
2828

2929
import paddle
30-
from paddle.distributed import ParallelEnv
3130
from ppdet.core.workspace import load_config, merge_config
3231
from ppdet.engine import Trainer
3332
from ppdet.utils.check import check_gpu, check_version, check_config
@@ -140,8 +139,7 @@ def main():
140139
check_gpu(cfg.use_gpu)
141140
check_version()
142141

143-
place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu'
144-
place = paddle.set_device(place)
142+
place = paddle.set_device('gpu' if cfg.use_gpu else 'cpu')
145143
run(FLAGS, cfg)
146144

147145

tools/train.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import numpy as np
3030

3131
import paddle
32-
from paddle.distributed import ParallelEnv
3332

3433
from ppdet.core.workspace import load_config, merge_config, create
3534
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
@@ -122,8 +121,7 @@ def main():
122121
check.check_gpu(cfg.use_gpu)
123122
check.check_version()
124123

125-
place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu'
126-
place = paddle.set_device(place)
124+
place = paddle.set_device('gpu' if cfg.use_gpu else 'cpu')
127125

128126
run(FLAGS, cfg)
129127

0 commit comments

Comments
 (0)