From 917b21739355b69da33428dd8974805f9be831eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Fri, 25 Jul 2025 22:00:37 +0200 Subject: [PATCH 1/5] skbase dep --- pyproject.toml | 2 +- .../models/base/_base_model.py | 8 ++-- .../temporal_fusion_transformer/tuning.py | 6 +-- .../utils/_dependencies/_dependencies.py | 38 +------------------ .../utils/_dependencies/_safe_import.py | 2 +- .../_dependencies/tests/test_safe_import.py | 7 ++-- .../utils/_maint/_show_versions.py | 2 +- tests/test_models/test_nbeats.py | 4 +- tests/test_models/test_nhits.py | 6 +-- .../test_temporal_fusion_transformer.py | 9 ++--- tests/test_models/test_tide.py | 4 +- 11 files changed, 24 insertions(+), 64 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e65dfd740..89b2dee13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "scipy >=1.8,<2.0", "pandas >=1.3.0,<3.0.0", "scikit-learn >=1.2,<2.0", + "scikit-base <0.13.0", ] [project.optional-dependencies] @@ -102,7 +103,6 @@ dev = [ "pytest-dotenv>=0.5.2,<1.0.0", "tensorboard>=2.12.1,<3.0.0", "pandoc>=2.3,<3.0.0", - "scikit-base", ] # docs - dependencies for building the documentation diff --git a/pytorch_forecasting/models/base/_base_model.py b/pytorch_forecasting/models/base/_base_model.py index 9488f8ac8..338614d40 100644 --- a/pytorch_forecasting/models/base/_base_model.py +++ b/pytorch_forecasting/models/base/_base_model.py @@ -20,6 +20,7 @@ from numpy import iterable import pandas as pd import scipy.stats +from skbase.utils._dependencies import _check_soft_dependencies import torch import torch.nn as nn from torch.nn.utils import rnn @@ -59,10 +60,7 @@ to_list, ) from pytorch_forecasting.utils._classproperty import classproperty -from pytorch_forecasting.utils._dependencies import ( - _check_matplotlib, - _get_installed_packages, -) +from pytorch_forecasting.utils._dependencies import _check_matplotlib # todo: compile models @@ -1351,7 +1349,7 @@ def configure_optimizers(self): Returns: Tuple[List]: first entry is list of optimizers and second is list of schedulers """ # noqa: E501 - ptopt_in_env = "pytorch_optimizer" in _get_installed_packages() + ptopt_in_env = _check_soft_dependencies("pytorch_optimizer", severity="none") # either set a schedule of lrs or find it dynamically if self.hparams.optimizer_params is None: optimizer_params = {} diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py b/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py index 6acb76203..9f428151e 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py @@ -13,12 +13,12 @@ from lightning.pytorch.tuner import Tuner import numpy as np import scipy._lib._util +from skbase.utils._dependencies import _check_soft_dependencies from torch.utils.data import DataLoader from pytorch_forecasting import TemporalFusionTransformer from pytorch_forecasting.data import TimeSeriesDataSet from pytorch_forecasting.metrics import QuantileLoss -from pytorch_forecasting.utils._dependencies import _get_installed_packages optuna_logger = logging.getLogger("optuna") @@ -108,9 +108,7 @@ def optimize_hyperparameters( Returns: optuna.Study: optuna study results """ # noqa : E501 - pkgs = _get_installed_packages() - - if "optuna" not in pkgs or "statsmodels" not in pkgs: + if not _check_soft_dependencies(["optuna", "statsmodels"], severity="none"): raise ImportError( "optimize_hyperparameters requires optuna and statsmodels. " "Please install these packages with `pip install optuna statsmodels`. " diff --git a/pytorch_forecasting/utils/_dependencies/_dependencies.py b/pytorch_forecasting/utils/_dependencies/_dependencies.py index a8158e681..fe50b86c4 100644 --- a/pytorch_forecasting/utils/_dependencies/_dependencies.py +++ b/pytorch_forecasting/utils/_dependencies/_dependencies.py @@ -5,39 +5,7 @@ from functools import lru_cache - -@lru_cache -def _get_installed_packages_private(): - """Get a dictionary of installed packages and their versions. - - Same as _get_installed_packages, but internal to avoid mutating the lru_cache - by accident. - """ - from importlib.metadata import distributions, version - - dists = distributions() - package_names = {dist.metadata["Name"] for dist in dists} - package_versions = {pkg_name: version(pkg_name) for pkg_name in package_names} - # developer note: - # we cannot just use distributions naively, - # because the same top level package name may appear *twice*, - # e.g., in a situation where a virtual env overrides a base env, - # such as in deployment environments like databricks. - # the "version" contract ensures we always get the version that corresponds - # to the importable distribution, i.e., the top one in the sys.path. - return package_versions - - -def _get_installed_packages(): - """Get a dictionary of installed packages and their versions. - - Returns - ------- - dict : dictionary of installed packages and their versions - keys are PEP 440 compatible package names, values are package versions - MAJOR.MINOR.PATCH version format is used for versions, e.g., "1.2.3" - """ - return _get_installed_packages_private().copy() +from skbase.utils._dependencies import _check_soft_dependencies def _check_matplotlib(ref="This feature", raise_error=True): @@ -54,9 +22,7 @@ def _check_matplotlib(ref="This feature", raise_error=True): ------- bool : whether matplotlib is installed """ - pkgs = _get_installed_packages() - - if raise_error and "matplotlib" not in pkgs: + if raise_error and not _check_soft_dependencies("matplotlib", severity="none"): raise ImportError( f"{ref} requires matplotlib." " Please install matplotlib with `pip install matplotlib`." diff --git a/pytorch_forecasting/utils/_dependencies/_safe_import.py b/pytorch_forecasting/utils/_dependencies/_safe_import.py index f11313dd4..0c7d80a17 100644 --- a/pytorch_forecasting/utils/_dependencies/_safe_import.py +++ b/pytorch_forecasting/utils/_dependencies/_safe_import.py @@ -8,7 +8,7 @@ import importlib from unittest.mock import MagicMock -from pytorch_forecasting.utils._dependencies import _get_installed_packages +from skbase.utils.dependencies._dependencies import _get_installed_packages def _safe_import(import_path, pkg_name=None): diff --git a/pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py b/pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py index e0e7e7ecb..eff80f29b 100644 --- a/pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py +++ b/pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py @@ -1,9 +1,8 @@ __author__ = ["jgyasu", "fkiraly"] -from pytorch_forecasting.utils._dependencies import ( - _get_installed_packages, - _safe_import, -) +from skbase.utils.dependencies._dependencies import _get_installed_packages + +from pytorch_forecasting.utils._dependencies import _safe_import def test_import_present_module(): diff --git a/pytorch_forecasting/utils/_maint/_show_versions.py b/pytorch_forecasting/utils/_maint/_show_versions.py index 39f4c61bc..b9b31d907 100644 --- a/pytorch_forecasting/utils/_maint/_show_versions.py +++ b/pytorch_forecasting/utils/_maint/_show_versions.py @@ -82,7 +82,7 @@ def _get_deps_info(deps=None, source="distributions"): deps = ["pytorch-forecasting"] if source == "distributions": - from pytorch_forecasting.utils._dependencies import _get_installed_packages + from skbase.utils.dependencies._dependencies import _get_installed_packages KEY_ALIAS = {"sklearn": "scikit-learn", "skbase": "scikit-base"} diff --git a/tests/test_models/test_nbeats.py b/tests/test_models/test_nbeats.py index c3379fbf1..2eb43249e 100644 --- a/tests/test_models/test_nbeats.py +++ b/tests/test_models/test_nbeats.py @@ -5,9 +5,9 @@ from lightning.pytorch.callbacks import EarlyStopping from lightning.pytorch.loggers import TensorBoardLogger import pytest +from skbase.utils.dependencies import _check_soft_dependencies from pytorch_forecasting.models import NBeats -from pytorch_forecasting.utils._dependencies import _get_installed_packages def test_integration(dataloaders_fixed_window_without_covariates, tmp_path): @@ -90,7 +90,7 @@ def test_pickle(model): @pytest.mark.skipif( - "matplotlib" not in _get_installed_packages(), + not _check_soft_dependencies("matplotlib", severity="none"), reason="skip test if required package matplotlib not installed", ) def test_interpretation(model, dataloaders_fixed_window_without_covariates): diff --git a/tests/test_models/test_nhits.py b/tests/test_models/test_nhits.py index a79e7a93f..2260a47f1 100644 --- a/tests/test_models/test_nhits.py +++ b/tests/test_models/test_nhits.py @@ -7,6 +7,7 @@ import numpy as np import pandas as pd import pytest +from skbase.utils.dependencies import _check_soft_dependencies from pytorch_forecasting.data.timeseries import TimeSeriesDataSet from pytorch_forecasting.metrics import MQF2DistributionLoss, QuantileLoss @@ -14,7 +15,6 @@ ImplicitQuantileNetworkDistributionLoss, ) from pytorch_forecasting.models import NHiTS -from pytorch_forecasting.utils._dependencies import _get_installed_packages def _integration(dataloader, tmp_path, trainer_kwargs=None, **kwargs): @@ -96,7 +96,7 @@ def _integration(dataloader, tmp_path, trainer_kwargs=None, **kwargs): "implicit-quantiles", ] -if "cpflows" in _get_installed_packages(): +if _check_soft_dependencies("cpflows", severity="none"): LOADERS += ["multivariate-quantiles"] @@ -158,7 +158,7 @@ def test_pickle(model): @pytest.mark.skipif( - "matplotlib" not in _get_installed_packages(), + not _check_soft_dependencies("matplotlib", severity="none"), reason="skip test if required package matplotlib not installed", ) def test_interpretation(model, dataloaders_with_covariates): diff --git a/tests/test_models/test_temporal_fusion_transformer.py b/tests/test_models/test_temporal_fusion_transformer.py index f0eab8671..a6b187bfb 100644 --- a/tests/test_models/test_temporal_fusion_transformer.py +++ b/tests/test_models/test_temporal_fusion_transformer.py @@ -8,6 +8,7 @@ import numpy as np import pandas as pd import pytest +from skbase.utils.dependencies import _check_soft_dependencies import torch from pytorch_forecasting import Baseline, TimeSeriesDataSet @@ -27,7 +28,6 @@ from pytorch_forecasting.models.temporal_fusion_transformer.tuning import ( optimize_hyperparameters, ) -from pytorch_forecasting.utils._dependencies import _get_installed_packages if sys.version.startswith("3.6"): # python 3.6 does not have nullcontext from contextlib import contextmanager @@ -81,7 +81,7 @@ def test_distribution_loss(data_with_covariates, tmp_path): @pytest.mark.skipif( - "cpflows" not in _get_installed_packages(), + not _check_soft_dependencies("cpflows", severity="none"), reason="Test skipped if required package cpflows not available", ) def test_mqf2_loss(data_with_covariates, tmp_path): @@ -341,7 +341,7 @@ def test_predict_dependency( @pytest.mark.skipif( - "matplotlib" not in _get_installed_packages(), + not _check_soft_dependencies("matplotlib", severity="none"), reason="skip test if required package matplotlib not installed", ) def test_actual_vs_predicted_plot(model, dataloaders_with_covariates): @@ -434,8 +434,7 @@ def test_prediction_with_dataframe(model, data_with_covariates): SKIP_HYPEPARAM_TEST = ( sys.platform.startswith("win") # Test skipped on Windows OS due to issues with ddp, see #1632" - or "optuna" not in _get_installed_packages() - or "statsmodels" not in _get_installed_packages() + or not _check_soft_dependencies(["optuna", "statsmodels"], severity="none") # Test skipped if required package optuna or statsmodels not available ) diff --git a/tests/test_models/test_tide.py b/tests/test_models/test_tide.py index 3b73ba380..f4514d92c 100644 --- a/tests/test_models/test_tide.py +++ b/tests/test_models/test_tide.py @@ -7,12 +7,12 @@ import numpy as np import pandas as pd import pytest +from skbase.utils.dependencies import _check_soft_dependencies from pytorch_forecasting.data.timeseries import TimeSeriesDataSet from pytorch_forecasting.metrics import SMAPE from pytorch_forecasting.models import TiDEModel from pytorch_forecasting.tests._conftest import make_dataloaders -from pytorch_forecasting.utils._dependencies import _get_installed_packages def _integration( @@ -192,7 +192,7 @@ def test_pickle(model): @pytest.mark.skipif( - "matplotlib" not in _get_installed_packages(), + _check_soft_dependencies("matplotlib", severity="none"), reason="skip test if required package matplotlib not installed", ) def test_prediction_visualization(model, dataloaders_with_covariates): From 14578eb3927f18a14ea65632c4bbad09c457e3de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 26 Jul 2025 21:18:06 +0200 Subject: [PATCH 2/5] Update _dependencies.py --- pytorch_forecasting/utils/_dependencies/_dependencies.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/utils/_dependencies/_dependencies.py b/pytorch_forecasting/utils/_dependencies/_dependencies.py index fe50b86c4..2bf12f4ee 100644 --- a/pytorch_forecasting/utils/_dependencies/_dependencies.py +++ b/pytorch_forecasting/utils/_dependencies/_dependencies.py @@ -22,10 +22,11 @@ def _check_matplotlib(ref="This feature", raise_error=True): ------- bool : whether matplotlib is installed """ - if raise_error and not _check_soft_dependencies("matplotlib", severity="none"): + matplotlib_present = _check_soft_dependencies("matplotlib", severity="none") + if raise_error and not matplotlib_present: raise ImportError( f"{ref} requires matplotlib." " Please install matplotlib with `pip install matplotlib`." ) - return "matplotlib" in pkgs + return matplotlib_present From 98eddb583ff8f173e2f3608e1f53bdfa45d6b9c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 27 Jul 2025 08:27:30 +0200 Subject: [PATCH 3/5] fix imports --- pytorch_forecasting/models/base/_base_model.py | 2 +- .../models/temporal_fusion_transformer/tuning.py | 2 +- pytorch_forecasting/utils/_dependencies/_dependencies.py | 4 +--- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/pytorch_forecasting/models/base/_base_model.py b/pytorch_forecasting/models/base/_base_model.py index 338614d40..8a1fe68a7 100644 --- a/pytorch_forecasting/models/base/_base_model.py +++ b/pytorch_forecasting/models/base/_base_model.py @@ -20,7 +20,7 @@ from numpy import iterable import pandas as pd import scipy.stats -from skbase.utils._dependencies import _check_soft_dependencies +from skbase.utils.dependencies._dependencies import _check_soft_dependencies import torch import torch.nn as nn from torch.nn.utils import rnn diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py b/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py index 9f428151e..e62ca18ac 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py @@ -13,7 +13,7 @@ from lightning.pytorch.tuner import Tuner import numpy as np import scipy._lib._util -from skbase.utils._dependencies import _check_soft_dependencies +from skbase.utils.dependencies._dependencies import _check_soft_dependencies from torch.utils.data import DataLoader from pytorch_forecasting import TemporalFusionTransformer diff --git a/pytorch_forecasting/utils/_dependencies/_dependencies.py b/pytorch_forecasting/utils/_dependencies/_dependencies.py index 2bf12f4ee..1059efa5e 100644 --- a/pytorch_forecasting/utils/_dependencies/_dependencies.py +++ b/pytorch_forecasting/utils/_dependencies/_dependencies.py @@ -3,9 +3,7 @@ Copied from sktime/skbase. """ -from functools import lru_cache - -from skbase.utils._dependencies import _check_soft_dependencies +from skbase.utils.dependencies._dependencies import _check_soft_dependencies def _check_matplotlib(ref="This feature", raise_error=True): From 87bcd5fe939f70501b5b749abf53ed38641a0560 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 27 Jul 2025 15:17:41 +0200 Subject: [PATCH 4/5] remove imports --- pytorch_forecasting/utils/_dependencies/__init__.py | 6 +----- .../utils/_dependencies/tests/test_safe_import.py | 9 --------- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/pytorch_forecasting/utils/_dependencies/__init__.py b/pytorch_forecasting/utils/_dependencies/__init__.py index fbf75137b..c198be6c1 100644 --- a/pytorch_forecasting/utils/_dependencies/__init__.py +++ b/pytorch_forecasting/utils/_dependencies/__init__.py @@ -1,13 +1,9 @@ """Utilities for managing dependencies.""" -from pytorch_forecasting.utils._dependencies._dependencies import ( - _check_matplotlib, - _get_installed_packages, -) +from pytorch_forecasting.utils._dependencies._dependencies import _check_matplotlib from pytorch_forecasting.utils._dependencies._safe_import import _safe_import __all__ = [ - "_get_installed_packages", "_check_matplotlib", "_safe_import", ] diff --git a/pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py b/pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py index eff80f29b..bc98d7c27 100644 --- a/pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py +++ b/pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py @@ -1,17 +1,8 @@ __author__ = ["jgyasu", "fkiraly"] -from skbase.utils.dependencies._dependencies import _get_installed_packages - from pytorch_forecasting.utils._dependencies import _safe_import -def test_import_present_module(): - """Test importing a dependency that is installed.""" - result = _safe_import("pandas") - assert result is not None - assert "pandas" in _get_installed_packages() - - def test_import_missing_module(): """Test importing a dependency that is not installed.""" result = _safe_import("nonexistent_module") From 1c8eff69ff2ef8da4a345e7fcecbadcd37cd3c31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Tue, 29 Jul 2025 12:44:50 +0200 Subject: [PATCH 5/5] Update test_tide.py --- tests/test_models/test_tide.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models/test_tide.py b/tests/test_models/test_tide.py index f4514d92c..4e7a6eddc 100644 --- a/tests/test_models/test_tide.py +++ b/tests/test_models/test_tide.py @@ -192,7 +192,7 @@ def test_pickle(model): @pytest.mark.skipif( - _check_soft_dependencies("matplotlib", severity="none"), + not _check_soft_dependencies("matplotlib", severity="none"), reason="skip test if required package matplotlib not installed", ) def test_prediction_visualization(model, dataloaders_with_covariates):