Skip to content
Open
38 changes: 11 additions & 27 deletions pytorch_forecasting/models/timexer/_timexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def __init__(
if enc_in is None:
self.enc_in = len(self.reals)

self.n_quantiles = None
self.n_quantiles = 1

if isinstance(loss, QuantileLoss):
self.n_quantiles = len(loss.quantiles)
Expand Down Expand Up @@ -353,10 +353,7 @@ def _forecast(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
enc_out = enc_out.permute(0, 1, 3, 2)

dec_out = self.head(enc_out)
if self.n_quantiles is not None:
dec_out = dec_out.permute(0, 2, 1, 3)
else:
dec_out = dec_out.permute(0, 2, 1)
dec_out = dec_out.permute(0, 2, 1, 3)

return dec_out

Expand Down Expand Up @@ -395,10 +392,7 @@ def _forecast_multi(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]
enc_out = enc_out.permute(0, 1, 3, 2)

dec_out = self.head(enc_out)
if self.n_quantiles is not None:
dec_out = dec_out.permute(0, 2, 1, 3)
else:
dec_out = dec_out.permute(0, 2, 1)
dec_out = dec_out.permute(0, 2, 1, 3)

return dec_out

Expand Down Expand Up @@ -470,25 +464,15 @@ def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
if prediction.size(2) != len(target_positions):
prediction = prediction[:, :, : len(target_positions)]

# In the case of a single target, the result will be a torch.Tensor
# with shape (batch_size, prediction_length)
# In the case of multiple targets, the result will be a list of "n_targets"
# tensors with shape (batch_size, prediction_length)
# If quantile predictions are used, the result will have an additional
# dimension for quantiles, resulting in a shape of
# (batch_size, prediction_length, n_quantiles)
if self.n_quantiles is not None:
# quantile predictions.
if len(target_indices) == 1:
prediction = prediction[..., 0, :]
else:
prediction = [prediction[..., i, :] for i in target_indices]
# output format is (batch_size, prediction_length, n_quantiles)
# in case of quantile loss, the output n_quantiles = self.n_quantiles
# which is the length of a list of float. In case of MAE, MSE, etc.
# n_quantiles = 1 and it mimics the behavior of a point prediction.
# for multi-target forecasting, the output is a list of tensors.
if len(target_positions) == 1:
prediction = prediction[..., 0, :]
else:
# point predictions.
if len(target_indices) == 1:
prediction = prediction[..., 0]
else:
prediction = [prediction[..., i] for i in target_indices]
prediction = [prediction[..., i, :] for i in target_indices]
prediction = self.transform_output(
prediction=prediction, target_scale=x["target_scale"]
)
Expand Down
15 changes: 5 additions & 10 deletions pytorch_forecasting/models/timexer/sub_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,29 +183,24 @@ class FlattenHead(nn.Module):
nf (int): Number of features in the last layer.
target_window (int): Target window size.
head_dropout (float): Dropout rate for the head. Defaults to 0.
n_quantiles (int, optional): Number of quantiles. Defaults to None."""
n_quantiles (int, optional): Number of quantiles. Defaults to 1."""

def __init__(self, n_vars, nf, target_window, head_dropout=0, n_quantiles=None):
def __init__(self, n_vars, nf, target_window, head_dropout=0, n_quantiles=1):
super().__init__()
self.n_vars = n_vars
self.flatten = nn.Flatten(start_dim=-2)
self.linear = nn.Linear(nf, target_window)
self.n_quantiles = n_quantiles

if self.n_quantiles is not None:
self.linear = nn.Linear(nf, target_window * n_quantiles)
else:
self.linear = nn.Linear(nf, target_window)
self.linear = nn.Linear(nf, target_window * n_quantiles)
self.dropout = nn.Dropout(head_dropout)

def forward(self, x):
x = self.flatten(x)
x = self.linear(x)
x = self.dropout(x)

if self.n_quantiles is not None:
batch_size, n_vars = x.shape[0], x.shape[1]
x = x.reshape(batch_size, n_vars, -1, self.n_quantiles)
batch_size, n_vars = x.shape[0], x.shape[1]
x = x.reshape(batch_size, n_vars, -1, self.n_quantiles)
return x


Expand Down
Loading