From d58f90e3e36eac9576ba2b6eca35790c8be60463 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Sat, 20 Sep 2025 10:32:01 +0530 Subject: [PATCH] [ENH] Added metadata Testing for models and data modules for ptf-v2 --- .../models/base/_base_object.py | 53 +++++++++++++++++++ .../models/dlinear/_dlinear_pkg_v2.py | 11 +++- .../models/samformer/_samformer_v2_pkg.py | 11 +++- .../_tft_pkg_v2.py | 7 +++ .../models/tide/_tide_dsipts/_tide_v2_pkg.py | 11 +++- .../models/timexer/_timexer_pkg_v2.py | 11 +++- .../tests/test_all_estimators_v2.py | 15 ++++++ 7 files changed, 115 insertions(+), 4 deletions(-) diff --git a/pytorch_forecasting/models/base/_base_object.py b/pytorch_forecasting/models/base/_base_object.py index 6aba7802e..17b81c579 100644 --- a/pytorch_forecasting/models/base/_base_object.py +++ b/pytorch_forecasting/models/base/_base_object.py @@ -118,3 +118,56 @@ class _BasePtForecasterV2(_BasePtForecaster_Common): _tags = { "object_type": "forecaster_pytorch_v2", } + + +class _EncoderDecoderConfigBase(_BasePtForecasterV2): + def _check_metadata(self, metadata): + assert isinstance(metadata, dict) + required_keys = [ + "encoder_cat", + "encoder_cont", + "decoder_cat", + "decoder_cont", + "target", + "max_encoder_length", + "min_encoder_length", + "max_prediction_length", + "min_prediction_length", + "static_categorical_features", + "static_continuous_features", + ] + + for key in required_keys: + assert key in metadata, f"Key {key} missing in metadata" + + assert metadata["encoder_cat"] >= 0 + assert metadata["encoder_cont"] >= 0 + assert metadata["decoder_cat"] >= 0 + assert metadata["decoder_cont"] >= 0 + assert metadata["target"] > 0 + + +class _TSlibConfigBase(_BasePtForecasterV2): + def _check_metadata(self, metadata): + assert isinstance(metadata, dict) + required_keys = [ + "feature_names", + "feature_indices", + "n_features", + "context_length", + "prediction_length", + "freq", + "features", + ] + + for key in required_keys: + assert key in metadata, f"Key {key} missing in metadata" + + assert ( + metadata["n_features"] + == len(metadata["feature_names"]) + == len(metadata["feature_indices"]) + ) + assert metadata["context_length"] > 0 + assert metadata["prediction_length"] > 0 + assert metadata["freq"] is not None diff --git a/pytorch_forecasting/models/dlinear/_dlinear_pkg_v2.py b/pytorch_forecasting/models/dlinear/_dlinear_pkg_v2.py index bf4fffce5..81ea7349e 100644 --- a/pytorch_forecasting/models/dlinear/_dlinear_pkg_v2.py +++ b/pytorch_forecasting/models/dlinear/_dlinear_pkg_v2.py @@ -2,7 +2,10 @@ Packages container for DLinear model. """ -from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2 +from pytorch_forecasting.models.base._base_object import ( + _BasePtForecasterV2, + _TSlibConfigBase, +) class DLinear_pkg_v2(_BasePtForecasterV2): @@ -125,3 +128,9 @@ def get_test_train_params(cls): logging_metrics=[SMAPE()], ), ] + + +class DLinear_pkg_v2_metadata(_TSlibConfigBase): + @classmethod + def _check_metadata_dlinear(self, metadata): + super()._check_metadata(metadata) diff --git a/pytorch_forecasting/models/samformer/_samformer_v2_pkg.py b/pytorch_forecasting/models/samformer/_samformer_v2_pkg.py index 36db9340a..80c517a39 100644 --- a/pytorch_forecasting/models/samformer/_samformer_v2_pkg.py +++ b/pytorch_forecasting/models/samformer/_samformer_v2_pkg.py @@ -2,7 +2,10 @@ Samformer package container. """ -from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2 +from pytorch_forecasting.models.base._base_object import ( + _BasePtForecasterV2, + _EncoderDecoderConfigBase, +) class Samformer_pkg_v2(_BasePtForecasterV2): @@ -134,3 +137,9 @@ def get_test_train_params(cls): "use_revin": False, }, ] + + +class Samformer_pkg_v2_metadata(_EncoderDecoderConfigBase): + @classmethod + def _check_metadata_samformer(self, metadata): + super()._check_metadata(metadata) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_pkg_v2.py b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_pkg_v2.py index 8c95daa6b..a57674bde 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_pkg_v2.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_pkg_v2.py @@ -1,6 +1,7 @@ """TFT package container.""" from pytorch_forecasting.models.base import _BasePtForecasterV2 +from pytorch_forecasting.models.base._base_object import _EncoderDecoderConfigBase class TFT_pkg_v2(_BasePtForecasterV2): @@ -137,3 +138,9 @@ def get_test_train_params(cls): ), dict(attention_head_size=2), ] + + +class TFT_pkg_v2_metadata(_EncoderDecoderConfigBase): + @classmethod + def _check_metadata_tft(self, metadata): + super()._check_metadata(metadata) diff --git a/pytorch_forecasting/models/tide/_tide_dsipts/_tide_v2_pkg.py b/pytorch_forecasting/models/tide/_tide_dsipts/_tide_v2_pkg.py index d3cf70454..ed9c50216 100644 --- a/pytorch_forecasting/models/tide/_tide_dsipts/_tide_v2_pkg.py +++ b/pytorch_forecasting/models/tide/_tide_dsipts/_tide_v2_pkg.py @@ -1,6 +1,9 @@ """TIDE package container.""" -from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2 +from pytorch_forecasting.models.base._base_object import ( + _BasePtForecasterV2, + _EncoderDecoderConfigBase, +) class TIDE_pkg_v2(_BasePtForecasterV2): @@ -138,3 +141,9 @@ def get_test_train_params(cls): loss=MAPE(), ), ] + + +class TIDE_pkg_v2_metadata(_EncoderDecoderConfigBase): + @classmethod + def _check_metadata_tide(self, metadata): + super()._check_metadata(metadata) diff --git a/pytorch_forecasting/models/timexer/_timexer_pkg_v2.py b/pytorch_forecasting/models/timexer/_timexer_pkg_v2.py index a0e4b8aa7..bd593384e 100644 --- a/pytorch_forecasting/models/timexer/_timexer_pkg_v2.py +++ b/pytorch_forecasting/models/timexer/_timexer_pkg_v2.py @@ -2,7 +2,10 @@ Metadata container for TimeXer v2. """ -from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2 +from pytorch_forecasting.models.base._base_object import ( + _BasePtForecasterV2, + _TSlibConfigBase, +) class TimeXer_pkg_v2(_BasePtForecasterV2): @@ -163,3 +166,9 @@ def get_test_train_params(cls): loss=QuantileLoss(quantiles=[0.1, 0.5, 0.9]), ), ] + + +class TimeXer_pkg_v2_metadata(_TSlibConfigBase): + @classmethod + def _check_metadata_timexer(self, metadata): + super()._check_metadata(metadata) diff --git a/pytorch_forecasting/tests/test_all_estimators_v2.py b/pytorch_forecasting/tests/test_all_estimators_v2.py index 9c28c5d0a..2b6151d92 100644 --- a/pytorch_forecasting/tests/test_all_estimators_v2.py +++ b/pytorch_forecasting/tests/test_all_estimators_v2.py @@ -135,3 +135,18 @@ def test_pkg_linkage(self, object_pkg, object_class): f"{object_class.__name__}_pkg." ) assert object_pkg.__name__ == object_class.__name__ + "_pkg_v2", msg + + def test_d2_metadata(self, object_pkg, trainer_kwargs): + object_class = object_pkg.get_cls() + dataloaders = object_pkg._get_test_datamodule_from(trainer_kwargs) + data_module = dataloaders.get("data_module") + metadata = data_module.metadata + + model_kwargs = dict(trainer_kwargs) + model_kwargs.pop("data_loader_kwargs", None) + + model_name = object_class.__name__ + + check_method_name = f"_check_metadata_{model_name.lower()}" + if hasattr(object_pkg, check_method_name): + getattr(object_pkg, check_method_name)(metadata)