diff --git a/pytorch_forecasting/models/timexer/_timexer.py b/pytorch_forecasting/models/timexer/_timexer.py index a240ccc96..685b167d5 100644 --- a/pytorch_forecasting/models/timexer/_timexer.py +++ b/pytorch_forecasting/models/timexer/_timexer.py @@ -214,7 +214,7 @@ def __init__( if enc_in is None: self.enc_in = len(self.reals) - self.n_quantiles = None + self.n_quantiles = 1 if isinstance(loss, QuantileLoss): self.n_quantiles = len(loss.quantiles) @@ -353,10 +353,7 @@ def _forecast(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: enc_out = enc_out.permute(0, 1, 3, 2) dec_out = self.head(enc_out) - if self.n_quantiles is not None: - dec_out = dec_out.permute(0, 2, 1, 3) - else: - dec_out = dec_out.permute(0, 2, 1) + dec_out = dec_out.permute(0, 2, 1, 3) return dec_out @@ -395,10 +392,7 @@ def _forecast_multi(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor] enc_out = enc_out.permute(0, 1, 3, 2) dec_out = self.head(enc_out) - if self.n_quantiles is not None: - dec_out = dec_out.permute(0, 2, 1, 3) - else: - dec_out = dec_out.permute(0, 2, 1) + dec_out = dec_out.permute(0, 2, 1, 3) return dec_out @@ -470,25 +464,15 @@ def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: if prediction.size(2) != len(target_positions): prediction = prediction[:, :, : len(target_positions)] - # In the case of a single target, the result will be a torch.Tensor - # with shape (batch_size, prediction_length) - # In the case of multiple targets, the result will be a list of "n_targets" - # tensors with shape (batch_size, prediction_length) - # If quantile predictions are used, the result will have an additional - # dimension for quantiles, resulting in a shape of - # (batch_size, prediction_length, n_quantiles) - if self.n_quantiles is not None: - # quantile predictions. - if len(target_indices) == 1: - prediction = prediction[..., 0, :] - else: - prediction = [prediction[..., i, :] for i in target_indices] + # output format is (batch_size, prediction_length, n_quantiles) + # in case of quantile loss, the output n_quantiles = self.n_quantiles + # which is the length of a list of float. In case of MAE, MSE, etc. + # n_quantiles = 1 and it mimics the behavior of a point prediction. + # for multi-target forecasting, the output is a list of tensors. + if len(target_positions) == 1: + prediction = prediction[..., 0, :] else: - # point predictions. - if len(target_indices) == 1: - prediction = prediction[..., 0] - else: - prediction = [prediction[..., i] for i in target_indices] + prediction = [prediction[..., i, :] for i in target_indices] prediction = self.transform_output( prediction=prediction, target_scale=x["target_scale"] ) diff --git a/pytorch_forecasting/models/timexer/sub_modules.py b/pytorch_forecasting/models/timexer/sub_modules.py index b0ba3b089..87dacb034 100644 --- a/pytorch_forecasting/models/timexer/sub_modules.py +++ b/pytorch_forecasting/models/timexer/sub_modules.py @@ -183,19 +183,15 @@ class FlattenHead(nn.Module): nf (int): Number of features in the last layer. target_window (int): Target window size. head_dropout (float): Dropout rate for the head. Defaults to 0. - n_quantiles (int, optional): Number of quantiles. Defaults to None.""" + n_quantiles (int, optional): Number of quantiles. Defaults to 1.""" - def __init__(self, n_vars, nf, target_window, head_dropout=0, n_quantiles=None): + def __init__(self, n_vars, nf, target_window, head_dropout=0, n_quantiles=1): super().__init__() self.n_vars = n_vars self.flatten = nn.Flatten(start_dim=-2) - self.linear = nn.Linear(nf, target_window) self.n_quantiles = n_quantiles - if self.n_quantiles is not None: - self.linear = nn.Linear(nf, target_window * n_quantiles) - else: - self.linear = nn.Linear(nf, target_window) + self.linear = nn.Linear(nf, target_window * n_quantiles) self.dropout = nn.Dropout(head_dropout) def forward(self, x): @@ -203,9 +199,8 @@ def forward(self, x): x = self.linear(x) x = self.dropout(x) - if self.n_quantiles is not None: - batch_size, n_vars = x.shape[0], x.shape[1] - x = x.reshape(batch_size, n_vars, -1, self.n_quantiles) + batch_size, n_vars = x.shape[0], x.shape[1] + x = x.reshape(batch_size, n_vars, -1, self.n_quantiles) return x