Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dependencies = [
"scipy >=1.8,<2.0",
"pandas >=1.3.0,<3.0.0",
"scikit-learn >=1.2,<2.0",
"scikit-base <0.13.0",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -102,7 +103,6 @@ dev = [
"pytest-dotenv>=0.5.2,<1.0.0",
"tensorboard>=2.12.1,<3.0.0",
"pandoc>=2.3,<3.0.0",
"scikit-base",
]

# docs - dependencies for building the documentation
Expand Down
8 changes: 3 additions & 5 deletions pytorch_forecasting/models/base/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from numpy import iterable
import pandas as pd
import scipy.stats
from skbase.utils.dependencies._dependencies import _check_soft_dependencies
import torch
import torch.nn as nn
from torch.nn.utils import rnn
Expand Down Expand Up @@ -61,10 +62,7 @@
to_list,
)
from pytorch_forecasting.utils._classproperty import classproperty
from pytorch_forecasting.utils._dependencies import (
_check_matplotlib,
_get_installed_packages,
)
from pytorch_forecasting.utils._dependencies import _check_matplotlib

# todo: compile models

Expand Down Expand Up @@ -1355,7 +1353,7 @@ def configure_optimizers(self):
Returns:
Tuple[List]: first entry is list of optimizers and second is list of schedulers
""" # noqa: E501
ptopt_in_env = "pytorch_optimizer" in _get_installed_packages()
ptopt_in_env = _check_soft_dependencies("pytorch_optimizer", severity="none")
# either set a schedule of lrs or find it dynamically
if self.hparams.optimizer_params is None:
optimizer_params = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
from lightning.pytorch.tuner import Tuner
import numpy as np
import scipy._lib._util
from skbase.utils.dependencies._dependencies import _check_soft_dependencies
from torch.utils.data import DataLoader

from pytorch_forecasting import TemporalFusionTransformer
from pytorch_forecasting.data import TimeSeriesDataSet
from pytorch_forecasting.metrics import QuantileLoss
from pytorch_forecasting.utils._dependencies import _get_installed_packages

optuna_logger = logging.getLogger("optuna")

Expand Down Expand Up @@ -108,9 +108,7 @@ def optimize_hyperparameters(
Returns:
optuna.Study: optuna study results
""" # noqa : E501
pkgs = _get_installed_packages()

if "optuna" not in pkgs or "statsmodels" not in pkgs:
if not _check_soft_dependencies(["optuna", "statsmodels"], severity="none"):
raise ImportError(
"optimize_hyperparameters requires optuna and statsmodels. "
"Please install these packages with `pip install optuna statsmodels`. "
Expand Down
6 changes: 1 addition & 5 deletions pytorch_forecasting/utils/_dependencies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
"""Utilities for managing dependencies."""

from pytorch_forecasting.utils._dependencies._dependencies import (
_check_matplotlib,
_get_installed_packages,
)
from pytorch_forecasting.utils._dependencies._dependencies import _check_matplotlib
from pytorch_forecasting.utils._dependencies._safe_import import _safe_import

__all__ = [
"_get_installed_packages",
"_check_matplotlib",
"_safe_import",
]
47 changes: 5 additions & 42 deletions pytorch_forecasting/utils/_dependencies/_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,9 @@
Copied from sktime/skbase.
"""

from functools import lru_cache
from skbase.utils.dependencies._dependencies import _check_soft_dependencies


@lru_cache
def _get_installed_packages_private():
"""Get a dictionary of installed packages and their versions.

Same as _get_installed_packages, but internal to avoid mutating the lru_cache
by accident.
"""
from importlib.metadata import distributions, version

dists = distributions()
package_names = {
dist.metadata["Name"]
for dist in dists
if dist.metadata and "Name" in dist.metadata
}
package_versions = {pkg_name: version(pkg_name) for pkg_name in package_names}
# developer note:
# we cannot just use distributions naively,
# because the same top level package name may appear *twice*,
# e.g., in a situation where a virtual env overrides a base env,
# such as in deployment environments like databricks.
# the "version" contract ensures we always get the version that corresponds
# to the importable distribution, i.e., the top one in the sys.path.
return package_versions


def _get_installed_packages():
"""Get a dictionary of installed packages and their versions.

Returns
-------
dict : dictionary of installed packages and their versions
keys are PEP 440 compatible package names, values are package versions
MAJOR.MINOR.PATCH version format is used for versions, e.g., "1.2.3"
"""
return _get_installed_packages_private().copy()
__all__ = ["_check_soft_dependencies", "_check_matplotlib"]


def _check_matplotlib(ref="This feature", raise_error=True):
Expand All @@ -58,12 +22,11 @@ def _check_matplotlib(ref="This feature", raise_error=True):
-------
bool : whether matplotlib is installed
"""
pkgs = _get_installed_packages()

if raise_error and "matplotlib" not in pkgs:
matplotlib_present = _check_soft_dependencies("matplotlib", severity="none")
if raise_error and not matplotlib_present:
raise ImportError(
f"{ref} requires matplotlib."
" Please install matplotlib with `pip install matplotlib`."
)

return "matplotlib" in pkgs
return matplotlib_present
2 changes: 1 addition & 1 deletion pytorch_forecasting/utils/_dependencies/_safe_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import importlib
from unittest.mock import MagicMock

from pytorch_forecasting.utils._dependencies import _get_installed_packages
from skbase.utils.dependencies._dependencies import _get_installed_packages


def _safe_import(import_path, pkg_name=None):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
__author__ = ["jgyasu", "fkiraly"]

from pytorch_forecasting.utils._dependencies import (
_get_installed_packages,
_safe_import,
)


def test_import_present_module():
"""Test importing a dependency that is installed."""
result = _safe_import("pandas")
assert result is not None
assert "pandas" in _get_installed_packages()
from pytorch_forecasting.utils._dependencies import _safe_import


def test_import_missing_module():
Expand Down
2 changes: 1 addition & 1 deletion pytorch_forecasting/utils/_maint/_show_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _get_deps_info(deps=None, source="distributions"):
deps = ["pytorch-forecasting"]

if source == "distributions":
from pytorch_forecasting.utils._dependencies import _get_installed_packages
from skbase.utils.dependencies._dependencies import _get_installed_packages

KEY_ALIAS = {"sklearn": "scikit-learn", "skbase": "scikit-base"}

Expand Down
4 changes: 2 additions & 2 deletions tests/test_models/test_nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger
import pytest
from skbase.utils.dependencies import _check_soft_dependencies

from pytorch_forecasting.models import NBeats
from pytorch_forecasting.utils._dependencies import _get_installed_packages


def test_integration(dataloaders_fixed_window_without_covariates, tmp_path):
Expand Down Expand Up @@ -90,7 +90,7 @@ def test_pickle(model):


@pytest.mark.skipif(
"matplotlib" not in _get_installed_packages(),
not _check_soft_dependencies("matplotlib", severity="none"),
reason="skip test if required package matplotlib not installed",
)
def test_interpretation(model, dataloaders_fixed_window_without_covariates):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_models/test_nhits.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
import numpy as np
import pandas as pd
import pytest
from skbase.utils.dependencies import _check_soft_dependencies

from pytorch_forecasting.data.timeseries import TimeSeriesDataSet
from pytorch_forecasting.metrics import MQF2DistributionLoss, QuantileLoss
from pytorch_forecasting.metrics.distributions import (
ImplicitQuantileNetworkDistributionLoss,
)
from pytorch_forecasting.models import NHiTS
from pytorch_forecasting.utils._dependencies import _get_installed_packages


def _integration(dataloader, tmp_path, trainer_kwargs=None, **kwargs):
Expand Down Expand Up @@ -96,7 +96,7 @@ def _integration(dataloader, tmp_path, trainer_kwargs=None, **kwargs):
"implicit-quantiles",
]

if "cpflows" in _get_installed_packages():
if _check_soft_dependencies("cpflows", severity="none"):
LOADERS += ["multivariate-quantiles"]


Expand Down Expand Up @@ -158,7 +158,7 @@ def test_pickle(model):


@pytest.mark.skipif(
"matplotlib" not in _get_installed_packages(),
not _check_soft_dependencies("matplotlib", severity="none"),
reason="skip test if required package matplotlib not installed",
)
def test_interpretation(model, dataloaders_with_covariates):
Expand Down
9 changes: 4 additions & 5 deletions tests/test_models/test_temporal_fusion_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import pandas as pd
import pytest
from skbase.utils.dependencies import _check_soft_dependencies
from test_models.conftest import make_dataloaders
import torch

Expand All @@ -29,7 +30,6 @@
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import (
optimize_hyperparameters,
)
from pytorch_forecasting.utils._dependencies import _get_installed_packages


def test_integration(multiple_dataloaders_with_covariates, tmp_path):
Expand Down Expand Up @@ -71,7 +71,7 @@ def test_distribution_loss(data_with_covariates, tmp_path):


@pytest.mark.skipif(
"cpflows" not in _get_installed_packages(),
not _check_soft_dependencies("cpflows", severity="none"),
reason="Test skipped if required package cpflows not available",
)
def test_mqf2_loss(data_with_covariates, tmp_path):
Expand Down Expand Up @@ -331,7 +331,7 @@ def test_predict_dependency(


@pytest.mark.skipif(
"matplotlib" not in _get_installed_packages(),
not _check_soft_dependencies("matplotlib", severity="none"),
reason="skip test if required package matplotlib not installed",
)
def test_actual_vs_predicted_plot(model, dataloaders_with_covariates):
Expand Down Expand Up @@ -424,8 +424,7 @@ def test_prediction_with_dataframe(model, data_with_covariates):
SKIP_HYPEPARAM_TEST = (
sys.platform.startswith("win")
# Test skipped on Windows OS due to issues with ddp, see #1632"
or "optuna" not in _get_installed_packages()
or "statsmodels" not in _get_installed_packages()
or not _check_soft_dependencies(["optuna", "statsmodels"], severity="none")
# Test skipped if required package optuna or statsmodels not available
)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_models/test_tide.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import numpy as np
import pandas as pd
import pytest
from skbase.utils.dependencies import _check_soft_dependencies

from pytorch_forecasting.data.timeseries import TimeSeriesDataSet
from pytorch_forecasting.metrics import SMAPE
from pytorch_forecasting.models import TiDEModel
from pytorch_forecasting.tests._conftest import make_dataloaders
from pytorch_forecasting.utils._dependencies import _get_installed_packages


def _integration(
Expand Down Expand Up @@ -192,7 +192,7 @@ def test_pickle(model):


@pytest.mark.skipif(
"matplotlib" not in _get_installed_packages(),
not _check_soft_dependencies("matplotlib", severity="none"),
reason="skip test if required package matplotlib not installed",
)
def test_prediction_visualization(model, dataloaders_with_covariates):
Expand Down
Loading