Skip to content

Commit 43b22a9

Browse files
authored
update (#1040)
1 parent 598ca1b commit 43b22a9

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

ppfleetx/data/sampler/batch_sampler.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,12 @@ def __iter__(self):
136136
self.batch_size_times_rank_size = self.batch_size * self.nranks
137137

138138
num_samples = len(self.dataset)
139-
indices = np.arange(num_samples).tolist()
140-
indices += indices[:(self.total_size - len(indices))]
141-
assert len(indices) == self.total_size
142-
143139
batch_indices = []
144140
for idx in range(self.consumed_samples, self.total_size):
145-
batch_indices.append(indices[idx])
141+
if idx >= num_samples:
142+
batch_indices.append(idx - num_samples)
143+
else:
144+
batch_indices.append(idx)
146145
if len(batch_indices) == self.batch_size_times_rank_size:
147146
start_idx, end_idx = self.get_start_end_idx()
148147
yield batch_indices[start_idx:end_idx]

0 commit comments

Comments
 (0)