We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 598ca1b commit 43b22a9Copy full SHA for 43b22a9
ppfleetx/data/sampler/batch_sampler.py
@@ -136,13 +136,12 @@ def __iter__(self):
136
self.batch_size_times_rank_size = self.batch_size * self.nranks
137
138
num_samples = len(self.dataset)
139
- indices = np.arange(num_samples).tolist()
140
- indices += indices[:(self.total_size - len(indices))]
141
- assert len(indices) == self.total_size
142
-
143
batch_indices = []
144
for idx in range(self.consumed_samples, self.total_size):
145
- batch_indices.append(indices[idx])
+ if idx >= num_samples:
+ batch_indices.append(idx - num_samples)
+ else:
+ batch_indices.append(idx)
146
if len(batch_indices) == self.batch_size_times_rank_size:
147
start_idx, end_idx = self.get_start_end_idx()
148
yield batch_indices[start_idx:end_idx]
0 commit comments