Skip to content

Commit 5fc682a

Browse files
committed
pre-commit
1 parent dccdd22 commit 5fc682a

File tree

7 files changed

+54
-37
lines changed

7 files changed

+54
-37
lines changed

pytorch_forecasting/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@
3838
BaseModelWithCovariates,
3939
DecoderMLP,
4040
DeepAR,
41+
LSTMModel,
4142
MultiEmbedding,
4243
NBeats,
4344
NHiTS,
4445
RecurrentNetwork,
4546
TemporalFusionTransformer,
46-
LSTMModel,
4747
get_rnn,
4848
)
4949
from pytorch_forecasting.utils import (

pytorch_forecasting/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pytorch_forecasting.models.nn import GRU, LSTM, MultiEmbedding, get_rnn
1616
from pytorch_forecasting.models.rnn import RecurrentNetwork
1717
from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer
18+
1819
from .lstm import LSTMModel
1920

2021
__all__ = [

pytorch_forecasting/models/_base_autoregressive.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
__all__ = ["AutoRegressiveBaseModel"]
22

3-
from loguru import logger
4-
from typing import List, Union, Any, Sequence, Tuple, Dict, Callable
3+
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
54

5+
from loguru import logger
66
import torch
77
from torch import Tensor
88

9-
from pytorch_forecasting.metrics import MultiLoss, DistributionLoss
10-
from pytorch_forecasting.utils import to_list, apply_to_list
9+
from pytorch_forecasting.metrics import DistributionLoss, MultiLoss
1110
from pytorch_forecasting.models.base_model import AutoRegressiveBaseModel as AutoRegressiveBaseModel_
11+
from pytorch_forecasting.utils import apply_to_list, to_list
1212

1313

1414
class AutoRegressiveBaseModel(AutoRegressiveBaseModel_): # pylint: disable=abstract-method
@@ -39,9 +39,7 @@ def output_to_prediction(
3939
single_prediction = to_list(normalized_prediction_parameters)[0].ndim == 2
4040
logger.trace(f"single_prediction={single_prediction}")
4141
if single_prediction: # add time dimension as it is expected
42-
normalized_prediction_parameters = apply_to_list(
43-
normalized_prediction_parameters, lambda x: x.unsqueeze(1)
44-
)
42+
normalized_prediction_parameters = apply_to_list(normalized_prediction_parameters, lambda x: x.unsqueeze(1))
4543
# transform into real space
4644
prediction_parameters = self.transform_output(
4745
prediction=normalized_prediction_parameters, target_scale=target_scale, **kwargs
@@ -95,12 +93,17 @@ def decode_autoregressive(
9593
"""
9694
Make predictions in auto-regressive manner. Supports only continuous targets.
9795
Args:
98-
decode_one (Callable): function that takes at least the following arguments:
96+
decode_one (Callable):
97+
function that takes at least the following arguments:
9998
* ``idx`` (int): index of decoding step (from 0 to n_decoder_steps-1)
10099
* ``lagged_targets`` (List[torch.Tensor]): list of normalized targets.
101-
List is ``idx + 1`` elements long with the most recent entry at the end, i.e. ``previous_target = lagged_targets[-1]`` and in general ``lagged_targets[-lag]``.
102-
* ``hidden_state`` (Any): Current hidden state required for prediction. Keys are variable names. Only lags that are greater than ``idx`` are included.
103-
* additional arguments are not dynamic but can be passed via the ``**kwargs`` argument And returns tuple of (not rescaled) network prediction output and hidden state for next auto-regressive step.
100+
List is ``idx + 1`` elements long with the most recent entry at the end, i.e.
101+
``previous_target = lagged_targets[-1]`` and in general ``lagged_targets[-lag]``.
102+
* ``hidden_state`` (Any): Current hidden state required for prediction. Keys are variable
103+
names. Only lags that are greater than ``idx`` are included.
104+
* additional arguments are not dynamic but can be passed via the ``**kwargs`` argument And
105+
returns tuple of (not rescaled) network prediction output and hidden state for next
106+
auto-regressive step.
104107
first_target (Union[List[torch.Tensor], torch.Tensor]): first target value to use for decoding
105108
first_hidden_state (Any): first hidden state used for decoding
106109
target_scale (Union[List[torch.Tensor], torch.Tensor]): target scale as in ``x``
@@ -130,9 +133,7 @@ def decode_autoregressive(
130133
if isinstance(prediction, Tensor):
131134
logger.trace(f"prediction ({type(prediction)}): {prediction.size()}")
132135
else:
133-
logger.trace(
134-
f"prediction ({type(prediction)}|{len(prediction)}): {[p.size() for p in prediction]}"
135-
)
136+
logger.trace(f"prediction ({type(prediction)}|{len(prediction)}): {[p.size() for p in prediction]}")
136137
# save normalized output for lagged targets
137138
normalized_output.append(current_target)
138139
# set output to unnormalized samples, append each target as n_batch_samples x n_random_samples
@@ -159,7 +160,8 @@ def decode_autoregressive(
159160
logger.trace(f"final_output_multitarget: {final_output_multitarget.size()}")
160161
else:
161162
logger.trace(
162-
f"final_output_multitarget ({type(final_output_multitarget)}): {[o.size() for o in final_output_multitarget]}"
163+
f"final_output_multitarget ({type(final_output_multitarget)})"
164+
f"{[o.size() for o in final_output_multitarget]}"
163165
)
164166
r = [final_output_multitarget[..., i] for i in range(final_output_multitarget.size(-1))]
165167
return r

pytorch_forecasting/models/lstm.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
__all__ = ["LSTMModel"]
22

3-
from loguru import logger
4-
from typing import List, Union, Any, Sequence, Tuple, Dict
3+
from typing import Any, Dict, List, Sequence, Tuple, Union
54

5+
from loguru import logger
66
import torch
7-
from torch import nn, Tensor
7+
from torch import Tensor, nn
88

99
from pytorch_forecasting.metrics import MAE, Metric, MultiLoss
1010
from pytorch_forecasting.models.nn import LSTM
@@ -40,7 +40,8 @@ def __init__(
4040
input_size (int, optional):
4141
Input size. Defaults to: inferred from `target`.
4242
loss (Metric):
43-
Loss criterion. Can be different for each target in multi-target setting thanks to `MultiLoss`. Defaults to `MAE`.
43+
Loss criterion. Can be different for each target in multi-target setting thanks to
44+
`MultiLoss`. Defaults to `MAE`.
4445
**kwargs:
4546
See :class:`pytorch_forecasting.models.base_model.AutoRegressiveBaseModel`.
4647
"""

pytorch_forecasting/models/tuning.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@
44

55
__all__ = ["optimize_hyperparameters"]
66

7-
from loguru import logger
87
import copy
98
import logging
109
import os
11-
from typing import Any, Dict, Tuple, Union, Optional, Callable, Type, Sequence
10+
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
1211

1312
import lightning.pytorch as pl
1413
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
1514
from lightning.pytorch.loggers import TensorBoardLogger
1615
from lightning.pytorch.tuner import Tuner
1716
from lightning.pytorch.tuner.lr_finder import _LRFinder
17+
from loguru import logger
1818
import numpy as np
1919
import optuna
2020
from optuna import Trial
@@ -24,7 +24,7 @@
2424
from torch import Tensor
2525
from torch.utils.data import DataLoader
2626

27-
from pytorch_forecasting import TemporalFusionTransformer, BaseModel
27+
from pytorch_forecasting import BaseModel, TemporalFusionTransformer
2828
from pytorch_forecasting.data import TimeSeriesDataSet
2929

3030
optuna_logger = logging.getLogger("optuna")
@@ -87,7 +87,8 @@ def optimize_hyperparameters(
8787
**kwargs: Any,
8888
) -> optuna.Study:
8989
"""
90-
Optimize hyperparameters. Run hyperparameter optimization. Learning rate for is determined with the PyTorch Lightning learning rate finder.
90+
Optimize hyperparameters. Run hyperparameter optimization. Learning rate for is determined with the
91+
PyTorch Lightning learning rate finder.
9192
9293
Args:
9394
train_dataloaders (DataLoader):
@@ -97,30 +98,37 @@ def optimize_hyperparameters(
9798
model_path (str):
9899
Folder to which model checkpoints are saved.
99100
monitor (str):
100-
Metric to return. The hyper-parameter (HP) tuner trains a model for a certain HP config, and reads this metric to score configuration. By default, the lower the better.
101+
Metric to return. The hyper-parameter (HP) tuner trains a model for a certain HP config, and
102+
reads this metric to score configuration. By default, the lower the better.
101103
direction (str):
102-
By default, direction is "minimize", meaning that lower values of the specified `monitor` are better. You can change this, e.g. to "maximize".
104+
By default, direction is "minimize", meaning that lower values of the specified `monitor` are
105+
better. You can change this, e.g. to "maximize".
103106
max_epochs (int, optional):
104107
Maximum number of epochs to run training. Defaults to 20.
105108
n_trials (int, optional):
106109
Number of hyperparameter trials to run. Defaults to 100.
107110
timeout (float, optional):
108-
Time in seconds after which training is stopped regardless of number of epochs or validation metric. Defaults to 3600*8.0.
111+
Time in seconds after which training is stopped regardless of number of epochs or validation
112+
metric. Defaults to 3600*8.0.
109113
input_params (dict, optional):
110-
A dictionary, where each `key` contains another dictionary with two keys: `"method"` and `"ranges"`. Example:
114+
A dictionary, where each `key` contains another dictionary with two keys: `"method"` and
115+
`"ranges"`. Example:
111116
>>> {"hidden_size": {
112117
>>> "method": "suggest_int",
113118
>>> "ranges": (16, 265),
114119
>>> }}
115-
The method key has to be a method of the `optuna.Trial` object. The ranges key are the input ranges for the specified method.
120+
The method key has to be a method of the `optuna.Trial` object. The ranges key are the input
121+
ranges for the specified method.
116122
input_params_generator (Callable, optional):
117-
A function with the following signature: `fn(trial: optuna.Trial, **kwargs: Any) -> Dict[str, Any]`, returning the parameter values to set up your model for the current trial/run.
123+
A function with the following signature: `fn(trial: optuna.Trial, **kwargs: Any) -> Dict[str, Any]
124+
`, returning the parameter values to set up your model for the current trial/run.
118125
Example:
119126
>>> def fn(trial: optuna.Trial, param_ranges: Tuple[int, int] = (16, 265)) -> Dict[str, Any]:
120127
>>> param = trial.suggest_int("param", *param_ranges, log=True)
121128
>>> model_params = {"param": param}
122129
>>> return model_params
123-
Then, when your model is created (before training it and report the metrics for the current combination of hyperparameters), these dictionary is used as follows:
130+
Then, when your model is created (before training it and report the metrics for the current
131+
combination of hyperparameters), these dictionary is used as follows:
124132
>>> model = YourModelClass.from_dataset(
125133
>>> train_dataloaders.dataset,
126134
>>> log_interval=-1,
@@ -133,7 +141,9 @@ def optimize_hyperparameters(
133141
use_learning_rate_finder (bool):
134142
If to use learning rate finder or optimize as part of hyperparameters. Defaults to True.
135143
trainer_kwargs (Dict[str, Any], optional):
136-
Additional arguments to the `PyTorch Lightning trainer <https://pytorch-lightning.readthedocs.io/en/latest/trainer.html>` such as `limit_train_batches`. Defaults to {}.
144+
Additional arguments to the
145+
`PyTorch Lightning trainer <https://pytorch-lightning.readthedocs.io/en/latest/trainer.html>`
146+
such as `limit_train_batches`. Defaults to {}.
137147
log_dir (str, optional):
138148
Folder into which to log results for tensorboard. Defaults to "lightning_logs".
139149
study (optuna.Study, optional):
@@ -153,6 +163,9 @@ def optimize_hyperparameters(
153163
Returns:
154164
optuna.Study: optuna study results
155165
"""
166+
if generator_params is None:
167+
generator_params = {}
168+
156169
assert isinstance(train_dataloaders.dataset, TimeSeriesDataSet) and isinstance(
157170
val_dataloaders.dataset, TimeSeriesDataSet
158171
), "Dataloaders must be built from TimeSeriesDataSet."
@@ -209,8 +222,6 @@ def objective(trial: optuna.Trial) -> float:
209222
except ValueError as ex:
210223
raise ValueError(f"Error while calling {fn} for {key}.") from ex
211224
else:
212-
if generator_params is None:
213-
generator_params = {}
214225
params = input_params_generator(trial, **generator_params)
215226
kwargs.update(params)
216227
kwargs["loss"] = copy.deepcopy(loss)

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import os
22
import sys
33

4-
import pandas as pd
54
import numpy as np
5+
import pandas as pd
66
import pytest
77

88
sys.path.insert(0, os.path.abspath(os.path.join(__file__, "../.."))) # isort:skip

tests/test_models/test_tuning.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
import pytest
2-
import sys, os
1+
import os
2+
import sys
33
import typing as ty
4+
45
from loguru import logger
6+
import pytest
57

68
from pytorch_forecasting import TimeSeriesDataSet
79
from pytorch_forecasting.models import LSTMModel

0 commit comments

Comments
 (0)