Skip to content

Commit b91a372

Browse files
committed
ci: add test for double iterating into empty queue bug
1 parent 02311d0 commit b91a372

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import multiprocessing as mp
2+
import os
3+
from queue import Queue
4+
from typing import Iterator
5+
6+
import numpy as np
7+
from lightning import Trainer
8+
from lightning.pytorch.demos.boring_classes import BoringModel
9+
from torch.utils.data import DataLoader, IterableDataset
10+
11+
class QueueDataset(IterableDataset):
12+
def __init__(self, queue: Queue) -> None:
13+
super().__init__()
14+
self.queue = queue
15+
16+
def __iter__(self) -> Iterator:
17+
for _ in range(5):
18+
tensor, _ = self.queue.get(timeout=10)
19+
yield tensor
20+
21+
def create_queue():
22+
q = mp.Queue()
23+
arr = np.random.random([1, 32]).astype(np.float32)
24+
for ind in range(10):
25+
q.put((arr, ind))
26+
return q
27+
28+
def train_model(queue, maxEpochs, ckptPath):
29+
dataloader = DataLoader(QueueDataset(queue), num_workers=1, batch_size=None, persistent_workers=True)
30+
trainer = Trainer(max_epochs=maxEpochs, enable_progress_bar=False, devices=1)
31+
trainer.fit(BoringModel(), dataloader)
32+
if os.path.exists(ckptPath):
33+
trainer.fit(BoringModel(), dataloader, ckpt_path=ckptPath)
34+
else:
35+
trainer.fit(BoringModel(), dataloader)
36+
trainer.save_checkpoint(ckptPath)
37+
return trainer
38+
39+
def test_training():
40+
queue = create_queue()
41+
42+
ckpt_path = "model.ckpt"
43+
trainer = train_model(queue, 1, ckpt_path)
44+
assert trainer is not None
45+
46+
assert os.path.exists(ckpt_path), "Checkpoint file wasn't created"
47+
48+
ckpt_size = os.path.getsize(ckpt_path)
49+
assert ckpt_size > 0, f"Checkpoint file is empty (size: {ckpt_size} bytes)"
50+
51+
trainer = train_model(queue, 1, ckpt_path)
52+
assert trainer is not None
53+
54+
if __name__ == "__main__":
55+
test_training()

0 commit comments

Comments
 (0)