Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/weathergen/datasets/data_reader_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ def window(self, idx: TIndex) -> DTRange:

return DTRange(t_start_win, t_end_win)

def get_n_steps(self, forecast_step: int) -> int:
return (int(self.t_window_len) * forecast_step) // int(self.t_window_step)


@dataclass
class ReaderData:
Expand Down
27 changes: 14 additions & 13 deletions src/weathergen/datasets/multi_stream_data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def reset(self):
# value in worker_workset()
self.rng = np.random.default_rng(self.data_loader_rng_seed)

fsm = (
fsm: int = (
self.forecast_steps[min(self.epoch, len(self.forecast_steps) - 1)]
if self.forecast_policy != "random"
else self.forecast_steps.max()
Expand All @@ -255,7 +255,7 @@ def reset(self):
index_range = self.time_window_handler.get_index_range()
idx_end = index_range.end
# native length of datasets, independent of epoch length that has potentially been specified
forecast_len = (self.len_hrs * (fsm + 1)) // self.step_hrs
forecast_len = self.time_window_handler.get_n_steps(fsm + 1)
idx_end -= forecast_len + self.forecast_offset
assert idx_end > 0, "dataset size too small for forecast range"
self.perms = np.arange(index_range.start, idx_end)
Expand Down Expand Up @@ -309,16 +309,19 @@ def __iter__(self):
# bidx is used to count the #batches that have been emitted
# idx_raw is used to index into the dataset; the decoupling is needed
# since there are empty batches
idx_raw = iter_start
for i, _bidx in enumerate(range(iter_start, iter_end, self.batch_size)):
# forecast_dt needs to be constant per batch (amortized through data parallel training)
forecast_dt = self.perms_forecast_dt[i]

idx_raw = iter_start # start step index
assert (iter_end - iter_start) // self.batch_size == len(self.perms_forecast_dt)
# forecast_dt needs to be constant per batch (amortized through data parallel training)
for forecast_dt in self.perms_forecast_dt: # bidx loop
# use while loop due to the scattered nature of the data in time and to
# ensure batches are not empty
batch = []
while len(batch) < self.batch_size:
idx: TIndex = self.perms[idx_raw % self.perms.shape[0]]
# TODO: identity? len(self.perms) should be most likely longer then
# idx_raw since it contains the all dataset steps (- small adjustment)
# whereas iter_end-iter_start should be smaller since it is only a subset
perm_idx = idx_raw % len(self.perms)
idx: TIndex = self.perms[perm_idx]
idx_raw += 1

time_win1 = self.time_window_handler.window(idx)
Expand Down Expand Up @@ -373,12 +376,10 @@ def __iter__(self):
for fstep in range(
self.forecast_offset, self.forecast_offset + forecast_dt + 1
):
step_forecast_dt = (
idx + (self.forecast_delta_hrs * fstep) // self.step_hrs
)
time_win2 = self.time_window_handler.window(step_forecast_dt)
forecast_idx = idx + self.time_window_handler.get_n_steps(fstep)
time_win2 = self.time_window_handler.window(forecast_idx)

rdata = ds.get_target(step_forecast_dt)
rdata = ds.get_target(forecast_idx)

sample_is_empty = rdata.is_empty()
if sample_is_empty:
Expand Down
Loading