Skip to content

Commit 9343b0a

Browse files
committed
[fix]: dataloader of list dataset has length.
1 parent aab487a commit 9343b0a

File tree

2 files changed

+16
-17
lines changed

2 files changed

+16
-17
lines changed

pgl/utils/data/dataloader.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""dataloader
1616
"""
1717
import warnings
18+
import time
1819
import numpy as np
1920
from collections import namedtuple
2021

@@ -104,15 +105,8 @@ def __init__(self,
104105
raise ValueError("num_workers(default: 1) should be larger than 0, " \
105106
"but got [num_workers=%s] < 1." % self.num_workers)
106107

107-
def __len__(self):
108-
if not isinstance(self.dataset, StreamDataset):
109-
return len(self.sampler)
110-
else:
111-
raise "StreamDataset has no length"
112-
113-
def __iter__(self):
114-
# generating a iterable sequence for produce batch data without repetition
115108
if isinstance(self.dataset, StreamDataset): # for stream data
109+
# generating a iterable sequence for produce batch data without repetition
116110
self.sampler = StreamSampler(
117111
self.dataset,
118112
batch_size=self.batch_size,
@@ -124,6 +118,16 @@ def __iter__(self):
124118
drop_last=self.drop_last,
125119
shuffle=self.shuffle)
126120

121+
def __len__(self):
122+
if not isinstance(self.dataset, StreamDataset):
123+
return len(self.sampler)
124+
else:
125+
raise "StreamDataset has no length"
126+
127+
def __iter__(self):
128+
# random seed will be fixed when using multiprocess,
129+
# so set seed explicitly every time
130+
np.random.seed()
127131
if self.num_workers == 1:
128132
r = paddle.reader.buffered(_DataLoaderIter(self, 0), self.buf_size)
129133
else:

pgl/utils/data/sampler.py

+4-9
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,13 @@ def __init__(self, dataset, batch_size=1, drop_last=False, shuffle=False):
2626
self.drop_last = drop_last
2727
self.shuffle = shuffle
2828

29-
length = len(self.dataset)
30-
self.perm = np.arange(0, length)
31-
32-
# shuffle one time whne Sampler is created
29+
def __iter__(self):
30+
perm = np.arange(0, len(self.dataset))
3331
if self.shuffle:
34-
seed = int(float(time.time()) * 1000) % 10000007
35-
np.random.seed(seed)
36-
np.random.shuffle(self.perm)
32+
np.random.shuffle(perm)
3733

38-
def __iter__(self):
3934
batch = []
40-
for idx in self.perm:
35+
for idx in perm:
4136
batch.append(idx)
4237
if len(batch) == self.batch_size:
4338
yield batch

0 commit comments

Comments
 (0)