Skip to content

Commit dccdd22

Browse files
committed
init
1 parent cf0f2dd commit dccdd22

File tree

9 files changed

+811
-37
lines changed

9 files changed

+811
-37
lines changed

poetry.lock

Lines changed: 33 additions & 37 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ networkx = { version = "^3.0.0", optional = true }
6969
cpflows = { version = "^0.1.2", optional = true }
7070
fastapi = ">=0.80"
7171
pytorch-optimizer = "^2.5.1"
72+
loguru = "^0.7.2"
7273

7374
[tool.poetry.group.dev.dependencies]
7475
pydocstyle = "^6.1.1"

pytorch_forecasting/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
NHiTS,
4444
RecurrentNetwork,
4545
TemporalFusionTransformer,
46+
LSTMModel,
4647
get_rnn,
4748
)
4849
from pytorch_forecasting.utils import (
@@ -68,6 +69,7 @@
6869
"TemporalFusionTransformer",
6970
"NBeats",
7071
"NHiTS",
72+
"LSTMModel",
7173
"Baseline",
7274
"DeepAR",
7375
"BaseModel",

pytorch_forecasting/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pytorch_forecasting.models.nn import GRU, LSTM, MultiEmbedding, get_rnn
1616
from pytorch_forecasting.models.rnn import RecurrentNetwork
1717
from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer
18+
from .lstm import LSTMModel
1819

1920
__all__ = [
2021
"NBeats",
@@ -32,4 +33,5 @@
3233
"GRU",
3334
"MultiEmbedding",
3435
"DecoderMLP",
36+
"LSTMModel",
3537
]
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
__all__ = ["AutoRegressiveBaseModel"]
2+
3+
from loguru import logger
4+
from typing import List, Union, Any, Sequence, Tuple, Dict, Callable
5+
6+
import torch
7+
from torch import Tensor
8+
9+
from pytorch_forecasting.metrics import MultiLoss, DistributionLoss
10+
from pytorch_forecasting.utils import to_list, apply_to_list
11+
from pytorch_forecasting.models.base_model import AutoRegressiveBaseModel as AutoRegressiveBaseModel_
12+
13+
14+
class AutoRegressiveBaseModel(AutoRegressiveBaseModel_): # pylint: disable=abstract-method
15+
"""Basically AutoRegressiveBaseModel from `pytorch_forecasting` but fixed for multi-target. Worked for `LSTM`."""
16+
17+
def output_to_prediction(
18+
self,
19+
normalized_prediction_parameters: torch.Tensor,
20+
target_scale: Union[List[torch.Tensor], torch.Tensor],
21+
n_samples: int = 1,
22+
**kwargs: Any,
23+
) -> Tuple[Union[List[torch.Tensor], torch.Tensor], torch.Tensor]:
24+
"""
25+
Convert network output to rescaled and normalized prediction.
26+
Function is typically not called directly but via :py:meth:`~decode_autoregressive`.
27+
Args:
28+
normalized_prediction_parameters (torch.Tensor): network prediction output
29+
target_scale (Union[List[torch.Tensor], torch.Tensor]): target scale to rescale network output
30+
n_samples (int, optional): Number of samples to draw independently. Defaults to 1.
31+
**kwargs: extra arguments for dictionary passed to :py:meth:`~transform_output` method.
32+
Returns:
33+
Tuple[Union[List[torch.Tensor], torch.Tensor], torch.Tensor]: tuple of rescaled prediction and
34+
normalized prediction (e.g. for input into next auto-regressive step)
35+
"""
36+
logger.trace(f"normalized_prediction_parameters={normalized_prediction_parameters.size()}")
37+
B = normalized_prediction_parameters.size(0)
38+
D = normalized_prediction_parameters.size(-1)
39+
single_prediction = to_list(normalized_prediction_parameters)[0].ndim == 2
40+
logger.trace(f"single_prediction={single_prediction}")
41+
if single_prediction: # add time dimension as it is expected
42+
normalized_prediction_parameters = apply_to_list(
43+
normalized_prediction_parameters, lambda x: x.unsqueeze(1)
44+
)
45+
# transform into real space
46+
prediction_parameters = self.transform_output(
47+
prediction=normalized_prediction_parameters, target_scale=target_scale, **kwargs
48+
)
49+
logger.trace(
50+
f"prediction_parameters ({len(prediction_parameters)}): {[p.size() for p in prediction_parameters]}"
51+
)
52+
# sample value(s) from distribution and select first sample
53+
if isinstance(self.loss, DistributionLoss) or (
54+
isinstance(self.loss, MultiLoss) and isinstance(self.loss[0], DistributionLoss)
55+
):
56+
if n_samples > 1:
57+
prediction_parameters = apply_to_list(
58+
prediction_parameters, lambda x: x.reshape(int(x.size(0) / n_samples), n_samples, -1)
59+
)
60+
prediction = self.loss.sample(prediction_parameters, 1)
61+
prediction = apply_to_list(prediction, lambda x: x.reshape(x.size(0) * n_samples, 1, -1))
62+
else:
63+
prediction = self.loss.sample(normalized_prediction_parameters, 1)
64+
else:
65+
prediction = prediction_parameters
66+
logger.trace(f"prediction ({len(prediction)}): {[p.size() for p in prediction]}")
67+
# normalize prediction prediction
68+
normalized_prediction = self.output_transformer.transform(prediction, target_scale=target_scale)
69+
if isinstance(normalized_prediction, list):
70+
logger.trace(f"normalized_prediction: {[p.size() for p in normalized_prediction]}")
71+
input_target = normalized_prediction[-1] # torch.cat(normalized_prediction, dim=-1) # dim=-1
72+
else:
73+
logger.trace(f"normalized_prediction: {normalized_prediction.size()}")
74+
input_target = normalized_prediction # set next input target to normalized prediction
75+
logger.trace(f"input_target: {input_target.size()}")
76+
assert input_target.size(0) == B
77+
assert input_target.size(-1) == D, f"{input_target.size()} but D={D}"
78+
# remove time dimension
79+
if single_prediction:
80+
prediction = apply_to_list(prediction, lambda x: x.squeeze(1))
81+
input_target = input_target.squeeze(1)
82+
logger.trace(f"input_target: {input_target.size()}")
83+
return prediction, input_target
84+
85+
def decode_autoregressive(
86+
self,
87+
decode_one: Callable,
88+
first_target: Union[List[torch.Tensor], torch.Tensor],
89+
first_hidden_state: Any,
90+
target_scale: Union[List[torch.Tensor], torch.Tensor],
91+
n_decoder_steps: int,
92+
n_samples: int = 1,
93+
**kwargs: Any,
94+
) -> Union[List[torch.Tensor], torch.Tensor]:
95+
"""
96+
Make predictions in auto-regressive manner. Supports only continuous targets.
97+
Args:
98+
decode_one (Callable): function that takes at least the following arguments:
99+
* ``idx`` (int): index of decoding step (from 0 to n_decoder_steps-1)
100+
* ``lagged_targets`` (List[torch.Tensor]): list of normalized targets.
101+
List is ``idx + 1`` elements long with the most recent entry at the end, i.e. ``previous_target = lagged_targets[-1]`` and in general ``lagged_targets[-lag]``.
102+
* ``hidden_state`` (Any): Current hidden state required for prediction. Keys are variable names. Only lags that are greater than ``idx`` are included.
103+
* additional arguments are not dynamic but can be passed via the ``**kwargs`` argument And returns tuple of (not rescaled) network prediction output and hidden state for next auto-regressive step.
104+
first_target (Union[List[torch.Tensor], torch.Tensor]): first target value to use for decoding
105+
first_hidden_state (Any): first hidden state used for decoding
106+
target_scale (Union[List[torch.Tensor], torch.Tensor]): target scale as in ``x``
107+
n_decoder_steps (int): number of decoding/prediction steps
108+
n_samples (int): number of independent samples to draw from the distribution -
109+
only relevant for multivariate models. Defaults to 1.
110+
**kwargs: additional arguments that are passed to the decode_one function.
111+
Returns:
112+
Union[List[torch.Tensor], torch.Tensor]: re-scaled prediction
113+
"""
114+
# make predictions which are fed into next step
115+
output: List[Union[List[Tensor], Tensor]] = []
116+
current_hidden_state = first_hidden_state
117+
normalized_output = [first_target]
118+
for idx in range(n_decoder_steps):
119+
# get lagged targets
120+
current_target, current_hidden_state = decode_one(
121+
idx, lagged_targets=normalized_output, hidden_state=current_hidden_state, **kwargs
122+
)
123+
assert isinstance(current_target, Tensor)
124+
logger.trace(f"current_target: {current_target.size()}")
125+
# get prediction and its normalized version for the next step
126+
prediction, current_target = self.output_to_prediction(
127+
current_target, target_scale=target_scale, n_samples=n_samples
128+
)
129+
logger.trace(f"current_target: {current_target.size()}")
130+
if isinstance(prediction, Tensor):
131+
logger.trace(f"prediction ({type(prediction)}): {prediction.size()}")
132+
else:
133+
logger.trace(
134+
f"prediction ({type(prediction)}|{len(prediction)}): {[p.size() for p in prediction]}"
135+
)
136+
# save normalized output for lagged targets
137+
normalized_output.append(current_target)
138+
# set output to unnormalized samples, append each target as n_batch_samples x n_random_samples
139+
output.append(prediction)
140+
# Check things before finishing
141+
if isinstance(prediction, Tensor):
142+
logger.trace(f"output ({len(output)}): {[o.size() for o in output]}") # type: ignore
143+
else:
144+
logger.trace(f"output ({len(output)}): {[{len(o)} for o in output]}")
145+
if isinstance(self.hparams.target, str):
146+
# Here, output is List[Tensor]
147+
final_output = torch.stack(output, dim=1) # type: ignore
148+
logger.trace(f"final_output: {final_output.size()}")
149+
return final_output
150+
# For multi-targets: output is List[List[Tensor]]
151+
# final_output_multitarget = [
152+
# torch.stack([out[idx] for out in output], dim=1) for idx in range(len(self.target_positions))
153+
# ]
154+
# self.target_positions is always Tensor([0]), so len() of that is always 1...
155+
final_output_multitarget = torch.stack([out[0] for out in output], dim=1)
156+
if final_output_multitarget.dim() > 3:
157+
final_output_multitarget = final_output_multitarget.squeeze(2)
158+
if isinstance(final_output_multitarget, Tensor):
159+
logger.trace(f"final_output_multitarget: {final_output_multitarget.size()}")
160+
else:
161+
logger.trace(
162+
f"final_output_multitarget ({type(final_output_multitarget)}): {[o.size() for o in final_output_multitarget]}"
163+
)
164+
r = [final_output_multitarget[..., i] for i in range(final_output_multitarget.size(-1))]
165+
return r

0 commit comments

Comments
 (0)