diff --git a/src/weathergen/datasets/data_reader_base.py b/src/weathergen/datasets/data_reader_base.py index 66c36c39d..e9ddcb255 100644 --- a/src/weathergen/datasets/data_reader_base.py +++ b/src/weathergen/datasets/data_reader_base.py @@ -188,6 +188,9 @@ def window(self, idx: TIndex) -> DTRange: return DTRange(t_start_win, t_end_win) + def get_n_steps(self, forecast_step: int) -> int: + return (int(self.t_window_len) * forecast_step) // int(self.t_window_step) + @dataclass class ReaderData: diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 83adb8128..9a6c9cd84 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -243,7 +243,7 @@ def reset(self): # value in worker_workset() self.rng = np.random.default_rng(self.data_loader_rng_seed) - fsm = ( + fsm: int = ( self.forecast_steps[min(self.epoch, len(self.forecast_steps) - 1)] if self.forecast_policy != "random" else self.forecast_steps.max() @@ -255,7 +255,7 @@ def reset(self): index_range = self.time_window_handler.get_index_range() idx_end = index_range.end # native length of datasets, independent of epoch length that has potentially been specified - forecast_len = (self.len_hrs * (fsm + 1)) // self.step_hrs + forecast_len = self.time_window_handler.get_n_steps(fsm + 1) idx_end -= forecast_len + self.forecast_offset assert idx_end > 0, "dataset size too small for forecast range" self.perms = np.arange(index_range.start, idx_end) @@ -309,16 +309,19 @@ def __iter__(self): # bidx is used to count the #batches that have been emitted # idx_raw is used to index into the dataset; the decoupling is needed # since there are empty batches - idx_raw = iter_start - for i, _bidx in enumerate(range(iter_start, iter_end, self.batch_size)): - # forecast_dt needs to be constant per batch (amortized through data parallel training) - forecast_dt = self.perms_forecast_dt[i] - + idx_raw = iter_start # start step index + assert (iter_end - iter_start) // self.batch_size == len(self.perms_forecast_dt) + # forecast_dt needs to be constant per batch (amortized through data parallel training) + for forecast_dt in self.perms_forecast_dt: # bidx loop # use while loop due to the scattered nature of the data in time and to # ensure batches are not empty batch = [] while len(batch) < self.batch_size: - idx: TIndex = self.perms[idx_raw % self.perms.shape[0]] + # TODO: identity? len(self.perms) should be most likely longer then + # idx_raw since it contains the all dataset steps (- small adjustment) + # whereas iter_end-iter_start should be smaller since it is only a subset + perm_idx = idx_raw % len(self.perms) + idx: TIndex = self.perms[perm_idx] idx_raw += 1 time_win1 = self.time_window_handler.window(idx) @@ -373,12 +376,10 @@ def __iter__(self): for fstep in range( self.forecast_offset, self.forecast_offset + forecast_dt + 1 ): - step_forecast_dt = ( - idx + (self.forecast_delta_hrs * fstep) // self.step_hrs - ) - time_win2 = self.time_window_handler.window(step_forecast_dt) + forecast_idx = idx + self.time_window_handler.get_n_steps(fstep) + time_win2 = self.time_window_handler.window(forecast_idx) - rdata = ds.get_target(step_forecast_dt) + rdata = ds.get_target(forecast_idx) sample_is_empty = rdata.is_empty() if sample_is_empty: