Skip to content

Commit d5b4809

Browse files
committed
Still working on dataloader
1 parent 2fd7362 commit d5b4809

File tree

3 files changed

+29
-11
lines changed

3 files changed

+29
-11
lines changed

config/generative_config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
files:
2-
- "train-00000-of-00645-b66ac786bf6fb553.parquet"
2+
- test_data/test.parquet
3+
#- "train-00000-of-00645-b66ac786bf6fb553.parquet"
34
mlp:
45
periodicity: null
56
rescale_output: False

examples/text_to_image.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pytorch_lightning import Trainer
88
import matplotlib.pyplot as plt
99
from high_order_implicit_representation.networks import GenNet
10-
from pytorch_lightning.callbacks import LearningRateMonitor
10+
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
1111
from high_order_implicit_representation.rendering import Text2ImageSampler
1212
from high_order_implicit_representation.single_image_dataset import (
1313
image_to_dataset,
@@ -41,11 +41,15 @@ def run_implicit_images(cfg: DictConfig):
4141
filename=full_path[0], batch_size=cfg.batch_size
4242
)
4343
lr_monitor = LearningRateMonitor(logging_interval="epoch")
44+
checkpoint = ModelCheckpoint(
45+
save_top_k=-1, # Save all checkpoints
46+
every_n_train_steps=50000 # Save checkpoint every 500 steps
47+
)
4448
trainer = Trainer(
4549
max_epochs=cfg.max_epochs,
4650
devices=cfg.gpus,
4751
accelerator=cfg.accelerator,
48-
callbacks=[lr_monitor],
52+
callbacks=[lr_monitor, checkpoint],
4953
)
5054
model = GenNet(cfg)
5155
trainer.fit(model, datamodule=data_module)

high_order_implicit_representation/single_image_dataset.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -333,8 +333,9 @@ def test_dataloader(self) -> DataLoader:
333333
class PickAPic:
334334
def __init__(self, files: list[str]):
335335
self.files = files
336+
self._generator = self.data_generator()
336337

337-
def __call__(self):
338+
def data_generator(self):
338339
for file in self.files:
339340
data = pd.read_parquet(file)
340341

@@ -344,30 +345,42 @@ def __call__(self):
344345
img = Image.open(io.BytesIO(jpg_0))
345346
arr = np.copy(np.asarray(img))
346347
yield caption, torch.from_numpy(arr)
348+
347349
jpg_1 = row["jpg_1"]
348350
img = Image.open(io.BytesIO(jpg_1))
349351
arr = np.copy(np.asarray(img))
350352
yield caption, torch.from_numpy(arr)
351353

354+
def __call__(self):
355+
return self._generator
352356

353357
class Text2ImageDataset(Dataset):
354358
def __init__(self, filenames: List[str]):
355359
super().__init__()
356360
self.dataset = PickAPic(files=filenames)
357361
self.sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
358362
self.generator = self.gen_data()
363+
self._length = 0 #int(1e6)
364+
self.count=0
359365

360366
def __len__(self):
361-
return int(1e6)
367+
return self._length or int(1e12)
362368

363369
def gen_data(self):
364370

365-
caption, image = next(self.dataset())
366-
caption_embedding = self.sentence_model.encode(caption)
367-
flattened_image, flattened_position, image = simple_image_to_dataset(image)
368-
369-
for index, rgb in enumerate(flattened_image):
370-
yield caption_embedding, flattened_position[index], rgb
371+
for batch in self.dataset():
372+
print('batch', batch)
373+
caption, image = batch
374+
caption_embedding = self.sentence_model.encode(caption)
375+
print('next image')
376+
flattened_image, flattened_position, image = simple_image_to_dataset(image)
377+
if self.count==0:
378+
self._length += len(flattened_image)
379+
380+
for index, rgb in enumerate(flattened_image):
381+
yield caption_embedding, flattened_position[index], rgb
382+
383+
self.count+=1
371384

372385
def __getitem__(self, idx):
373386
# I'm totally ignoring the index

0 commit comments

Comments
 (0)