|
| 1 | +__all__ = ["AutoRegressiveBaseModel"] |
| 2 | + |
| 3 | +from loguru import logger |
| 4 | +from typing import List, Union, Any, Sequence, Tuple, Dict, Callable |
| 5 | + |
| 6 | +import torch |
| 7 | +from torch import Tensor |
| 8 | + |
| 9 | +from pytorch_forecasting.metrics import MultiLoss, DistributionLoss |
| 10 | +from pytorch_forecasting.utils import to_list, apply_to_list |
| 11 | +from pytorch_forecasting.models.base_model import AutoRegressiveBaseModel as AutoRegressiveBaseModel_ |
| 12 | + |
| 13 | + |
| 14 | +class AutoRegressiveBaseModel(AutoRegressiveBaseModel_): # pylint: disable=abstract-method |
| 15 | + """Basically AutoRegressiveBaseModel from `pytorch_forecasting` but fixed for multi-target. Worked for `LSTM`.""" |
| 16 | + |
| 17 | + def output_to_prediction( |
| 18 | + self, |
| 19 | + normalized_prediction_parameters: torch.Tensor, |
| 20 | + target_scale: Union[List[torch.Tensor], torch.Tensor], |
| 21 | + n_samples: int = 1, |
| 22 | + **kwargs: Any, |
| 23 | + ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], torch.Tensor]: |
| 24 | + """ |
| 25 | + Convert network output to rescaled and normalized prediction. |
| 26 | + Function is typically not called directly but via :py:meth:`~decode_autoregressive`. |
| 27 | + Args: |
| 28 | + normalized_prediction_parameters (torch.Tensor): network prediction output |
| 29 | + target_scale (Union[List[torch.Tensor], torch.Tensor]): target scale to rescale network output |
| 30 | + n_samples (int, optional): Number of samples to draw independently. Defaults to 1. |
| 31 | + **kwargs: extra arguments for dictionary passed to :py:meth:`~transform_output` method. |
| 32 | + Returns: |
| 33 | + Tuple[Union[List[torch.Tensor], torch.Tensor], torch.Tensor]: tuple of rescaled prediction and |
| 34 | + normalized prediction (e.g. for input into next auto-regressive step) |
| 35 | + """ |
| 36 | + logger.trace(f"normalized_prediction_parameters={normalized_prediction_parameters.size()}") |
| 37 | + B = normalized_prediction_parameters.size(0) |
| 38 | + D = normalized_prediction_parameters.size(-1) |
| 39 | + single_prediction = to_list(normalized_prediction_parameters)[0].ndim == 2 |
| 40 | + logger.trace(f"single_prediction={single_prediction}") |
| 41 | + if single_prediction: # add time dimension as it is expected |
| 42 | + normalized_prediction_parameters = apply_to_list( |
| 43 | + normalized_prediction_parameters, lambda x: x.unsqueeze(1) |
| 44 | + ) |
| 45 | + # transform into real space |
| 46 | + prediction_parameters = self.transform_output( |
| 47 | + prediction=normalized_prediction_parameters, target_scale=target_scale, **kwargs |
| 48 | + ) |
| 49 | + logger.trace( |
| 50 | + f"prediction_parameters ({len(prediction_parameters)}): {[p.size() for p in prediction_parameters]}" |
| 51 | + ) |
| 52 | + # sample value(s) from distribution and select first sample |
| 53 | + if isinstance(self.loss, DistributionLoss) or ( |
| 54 | + isinstance(self.loss, MultiLoss) and isinstance(self.loss[0], DistributionLoss) |
| 55 | + ): |
| 56 | + if n_samples > 1: |
| 57 | + prediction_parameters = apply_to_list( |
| 58 | + prediction_parameters, lambda x: x.reshape(int(x.size(0) / n_samples), n_samples, -1) |
| 59 | + ) |
| 60 | + prediction = self.loss.sample(prediction_parameters, 1) |
| 61 | + prediction = apply_to_list(prediction, lambda x: x.reshape(x.size(0) * n_samples, 1, -1)) |
| 62 | + else: |
| 63 | + prediction = self.loss.sample(normalized_prediction_parameters, 1) |
| 64 | + else: |
| 65 | + prediction = prediction_parameters |
| 66 | + logger.trace(f"prediction ({len(prediction)}): {[p.size() for p in prediction]}") |
| 67 | + # normalize prediction prediction |
| 68 | + normalized_prediction = self.output_transformer.transform(prediction, target_scale=target_scale) |
| 69 | + if isinstance(normalized_prediction, list): |
| 70 | + logger.trace(f"normalized_prediction: {[p.size() for p in normalized_prediction]}") |
| 71 | + input_target = normalized_prediction[-1] # torch.cat(normalized_prediction, dim=-1) # dim=-1 |
| 72 | + else: |
| 73 | + logger.trace(f"normalized_prediction: {normalized_prediction.size()}") |
| 74 | + input_target = normalized_prediction # set next input target to normalized prediction |
| 75 | + logger.trace(f"input_target: {input_target.size()}") |
| 76 | + assert input_target.size(0) == B |
| 77 | + assert input_target.size(-1) == D, f"{input_target.size()} but D={D}" |
| 78 | + # remove time dimension |
| 79 | + if single_prediction: |
| 80 | + prediction = apply_to_list(prediction, lambda x: x.squeeze(1)) |
| 81 | + input_target = input_target.squeeze(1) |
| 82 | + logger.trace(f"input_target: {input_target.size()}") |
| 83 | + return prediction, input_target |
| 84 | + |
| 85 | + def decode_autoregressive( |
| 86 | + self, |
| 87 | + decode_one: Callable, |
| 88 | + first_target: Union[List[torch.Tensor], torch.Tensor], |
| 89 | + first_hidden_state: Any, |
| 90 | + target_scale: Union[List[torch.Tensor], torch.Tensor], |
| 91 | + n_decoder_steps: int, |
| 92 | + n_samples: int = 1, |
| 93 | + **kwargs: Any, |
| 94 | + ) -> Union[List[torch.Tensor], torch.Tensor]: |
| 95 | + """ |
| 96 | + Make predictions in auto-regressive manner. Supports only continuous targets. |
| 97 | + Args: |
| 98 | + decode_one (Callable): function that takes at least the following arguments: |
| 99 | + * ``idx`` (int): index of decoding step (from 0 to n_decoder_steps-1) |
| 100 | + * ``lagged_targets`` (List[torch.Tensor]): list of normalized targets. |
| 101 | + List is ``idx + 1`` elements long with the most recent entry at the end, i.e. ``previous_target = lagged_targets[-1]`` and in general ``lagged_targets[-lag]``. |
| 102 | + * ``hidden_state`` (Any): Current hidden state required for prediction. Keys are variable names. Only lags that are greater than ``idx`` are included. |
| 103 | + * additional arguments are not dynamic but can be passed via the ``**kwargs`` argument And returns tuple of (not rescaled) network prediction output and hidden state for next auto-regressive step. |
| 104 | + first_target (Union[List[torch.Tensor], torch.Tensor]): first target value to use for decoding |
| 105 | + first_hidden_state (Any): first hidden state used for decoding |
| 106 | + target_scale (Union[List[torch.Tensor], torch.Tensor]): target scale as in ``x`` |
| 107 | + n_decoder_steps (int): number of decoding/prediction steps |
| 108 | + n_samples (int): number of independent samples to draw from the distribution - |
| 109 | + only relevant for multivariate models. Defaults to 1. |
| 110 | + **kwargs: additional arguments that are passed to the decode_one function. |
| 111 | + Returns: |
| 112 | + Union[List[torch.Tensor], torch.Tensor]: re-scaled prediction |
| 113 | + """ |
| 114 | + # make predictions which are fed into next step |
| 115 | + output: List[Union[List[Tensor], Tensor]] = [] |
| 116 | + current_hidden_state = first_hidden_state |
| 117 | + normalized_output = [first_target] |
| 118 | + for idx in range(n_decoder_steps): |
| 119 | + # get lagged targets |
| 120 | + current_target, current_hidden_state = decode_one( |
| 121 | + idx, lagged_targets=normalized_output, hidden_state=current_hidden_state, **kwargs |
| 122 | + ) |
| 123 | + assert isinstance(current_target, Tensor) |
| 124 | + logger.trace(f"current_target: {current_target.size()}") |
| 125 | + # get prediction and its normalized version for the next step |
| 126 | + prediction, current_target = self.output_to_prediction( |
| 127 | + current_target, target_scale=target_scale, n_samples=n_samples |
| 128 | + ) |
| 129 | + logger.trace(f"current_target: {current_target.size()}") |
| 130 | + if isinstance(prediction, Tensor): |
| 131 | + logger.trace(f"prediction ({type(prediction)}): {prediction.size()}") |
| 132 | + else: |
| 133 | + logger.trace( |
| 134 | + f"prediction ({type(prediction)}|{len(prediction)}): {[p.size() for p in prediction]}" |
| 135 | + ) |
| 136 | + # save normalized output for lagged targets |
| 137 | + normalized_output.append(current_target) |
| 138 | + # set output to unnormalized samples, append each target as n_batch_samples x n_random_samples |
| 139 | + output.append(prediction) |
| 140 | + # Check things before finishing |
| 141 | + if isinstance(prediction, Tensor): |
| 142 | + logger.trace(f"output ({len(output)}): {[o.size() for o in output]}") # type: ignore |
| 143 | + else: |
| 144 | + logger.trace(f"output ({len(output)}): {[{len(o)} for o in output]}") |
| 145 | + if isinstance(self.hparams.target, str): |
| 146 | + # Here, output is List[Tensor] |
| 147 | + final_output = torch.stack(output, dim=1) # type: ignore |
| 148 | + logger.trace(f"final_output: {final_output.size()}") |
| 149 | + return final_output |
| 150 | + # For multi-targets: output is List[List[Tensor]] |
| 151 | + # final_output_multitarget = [ |
| 152 | + # torch.stack([out[idx] for out in output], dim=1) for idx in range(len(self.target_positions)) |
| 153 | + # ] |
| 154 | + # self.target_positions is always Tensor([0]), so len() of that is always 1... |
| 155 | + final_output_multitarget = torch.stack([out[0] for out in output], dim=1) |
| 156 | + if final_output_multitarget.dim() > 3: |
| 157 | + final_output_multitarget = final_output_multitarget.squeeze(2) |
| 158 | + if isinstance(final_output_multitarget, Tensor): |
| 159 | + logger.trace(f"final_output_multitarget: {final_output_multitarget.size()}") |
| 160 | + else: |
| 161 | + logger.trace( |
| 162 | + f"final_output_multitarget ({type(final_output_multitarget)}): {[o.size() for o in final_output_multitarget]}" |
| 163 | + ) |
| 164 | + r = [final_output_multitarget[..., i] for i in range(final_output_multitarget.size(-1))] |
| 165 | + return r |
0 commit comments