Skip to content
This repository was archived by the owner on Jul 4, 2023. It is now read-only.

Commit 05172e1

Browse files
authored
Fix invariant check
1 parent 620d1c0 commit 05172e1

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

torchnlp/samplers/distributed_sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@ def __init__(self, iterable, num_replicas=None, rank=None):
2323
self.num_replicas = num_replicas
2424
self.rank = rank
2525

26-
if self.rank >= self.num_replicas:
27-
raise IndexError('`rank` must be smaller than the `num_replicas`.')
28-
2926
if num_replicas is None or rank is None: # pragma: no cover
3027
if not torch.distributed.is_initialized():
3128
raise RuntimeError('Requires `torch.distributed` to be initialized.')
@@ -34,6 +31,9 @@ def __init__(self, iterable, num_replicas=None, rank=None):
3431
torch.distributed.get_world_size() if num_replicas is None else num_replicas)
3532
self.rank = torch.distributed.get_rank() if rank is None else rank
3633

34+
if self.rank >= self.num_replicas:
35+
raise IndexError('`rank` must be smaller than the `num_replicas`.')
36+
3737
def __iter__(self):
3838
return iter(
3939
[e for i, e in enumerate(self.iterable) if (i - self.rank) % self.num_replicas == 0])

0 commit comments

Comments
 (0)