From fc93504781e4718400219206eb647d081bd714f7 Mon Sep 17 00:00:00 2001 From: PranavBhatP Date: Fri, 1 Aug 2025 01:33:30 +0530 Subject: [PATCH 1/6] alter output format to 3d --- pytorch_forecasting/models/timexer/_timexer.py | 6 ++++-- tests/test_models/test_timexer.py | 1 - 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_forecasting/models/timexer/_timexer.py b/pytorch_forecasting/models/timexer/_timexer.py index a240ccc96..8b8f42739 100644 --- a/pytorch_forecasting/models/timexer/_timexer.py +++ b/pytorch_forecasting/models/timexer/_timexer.py @@ -486,9 +486,11 @@ def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: else: # point predictions. if len(target_indices) == 1: - prediction = prediction[..., 0] + prediction = prediction[..., 0].unsqueeze(-1) else: - prediction = [prediction[..., i] for i in target_indices] + prediction = [ + prediction[..., i].unsqueeze(-1) for i in target_indices + ] # noqa: E501 prediction = self.transform_output( prediction=prediction, target_scale=x["target_scale"] ) diff --git a/tests/test_models/test_timexer.py b/tests/test_models/test_timexer.py index 2f5518026..1db80dd66 100644 --- a/tests/test_models/test_timexer.py +++ b/tests/test_models/test_timexer.py @@ -366,7 +366,6 @@ def test_no_exogenous_variables(): ) assert isinstance(predictions.output, torch.Tensor) - assert predictions.output.ndim == 2 def test_with_exogenous_variables(tmp_path): From 7c21c698a8fc4b1199756fd45077b13cd33f743c Mon Sep 17 00:00:00 2001 From: PranavBhatP Date: Sat, 2 Aug 2025 22:04:23 +0530 Subject: [PATCH 2/6] revert unintended removal of assert in test_timexer.py --- tests/test_models/test_timexer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_models/test_timexer.py b/tests/test_models/test_timexer.py index 1db80dd66..2f5518026 100644 --- a/tests/test_models/test_timexer.py +++ b/tests/test_models/test_timexer.py @@ -366,6 +366,7 @@ def test_no_exogenous_variables(): ) assert isinstance(predictions.output, torch.Tensor) + assert predictions.output.ndim == 2 def test_with_exogenous_variables(tmp_path): From 6329b99595a3dfa0e862a2b7d2ff6ca64fd01f71 Mon Sep 17 00:00:00 2001 From: PranavBhatP Date: Sat, 9 Aug 2025 22:46:41 +0530 Subject: [PATCH 3/6] remove stepout for handling single and multi-target output and unify to a single handling block --- .../models/timexer/_timexer.py | 40 +++++-------------- .../models/timexer/sub_modules.py | 15 +++---- 2 files changed, 16 insertions(+), 39 deletions(-) diff --git a/pytorch_forecasting/models/timexer/_timexer.py b/pytorch_forecasting/models/timexer/_timexer.py index 8b8f42739..685b167d5 100644 --- a/pytorch_forecasting/models/timexer/_timexer.py +++ b/pytorch_forecasting/models/timexer/_timexer.py @@ -214,7 +214,7 @@ def __init__( if enc_in is None: self.enc_in = len(self.reals) - self.n_quantiles = None + self.n_quantiles = 1 if isinstance(loss, QuantileLoss): self.n_quantiles = len(loss.quantiles) @@ -353,10 +353,7 @@ def _forecast(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: enc_out = enc_out.permute(0, 1, 3, 2) dec_out = self.head(enc_out) - if self.n_quantiles is not None: - dec_out = dec_out.permute(0, 2, 1, 3) - else: - dec_out = dec_out.permute(0, 2, 1) + dec_out = dec_out.permute(0, 2, 1, 3) return dec_out @@ -395,10 +392,7 @@ def _forecast_multi(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor] enc_out = enc_out.permute(0, 1, 3, 2) dec_out = self.head(enc_out) - if self.n_quantiles is not None: - dec_out = dec_out.permute(0, 2, 1, 3) - else: - dec_out = dec_out.permute(0, 2, 1) + dec_out = dec_out.permute(0, 2, 1, 3) return dec_out @@ -470,27 +464,15 @@ def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: if prediction.size(2) != len(target_positions): prediction = prediction[:, :, : len(target_positions)] - # In the case of a single target, the result will be a torch.Tensor - # with shape (batch_size, prediction_length) - # In the case of multiple targets, the result will be a list of "n_targets" - # tensors with shape (batch_size, prediction_length) - # If quantile predictions are used, the result will have an additional - # dimension for quantiles, resulting in a shape of - # (batch_size, prediction_length, n_quantiles) - if self.n_quantiles is not None: - # quantile predictions. - if len(target_indices) == 1: - prediction = prediction[..., 0, :] - else: - prediction = [prediction[..., i, :] for i in target_indices] + # output format is (batch_size, prediction_length, n_quantiles) + # in case of quantile loss, the output n_quantiles = self.n_quantiles + # which is the length of a list of float. In case of MAE, MSE, etc. + # n_quantiles = 1 and it mimics the behavior of a point prediction. + # for multi-target forecasting, the output is a list of tensors. + if len(target_positions) == 1: + prediction = prediction[..., 0, :] else: - # point predictions. - if len(target_indices) == 1: - prediction = prediction[..., 0].unsqueeze(-1) - else: - prediction = [ - prediction[..., i].unsqueeze(-1) for i in target_indices - ] # noqa: E501 + prediction = [prediction[..., i, :] for i in target_indices] prediction = self.transform_output( prediction=prediction, target_scale=x["target_scale"] ) diff --git a/pytorch_forecasting/models/timexer/sub_modules.py b/pytorch_forecasting/models/timexer/sub_modules.py index b0ba3b089..87dacb034 100644 --- a/pytorch_forecasting/models/timexer/sub_modules.py +++ b/pytorch_forecasting/models/timexer/sub_modules.py @@ -183,19 +183,15 @@ class FlattenHead(nn.Module): nf (int): Number of features in the last layer. target_window (int): Target window size. head_dropout (float): Dropout rate for the head. Defaults to 0. - n_quantiles (int, optional): Number of quantiles. Defaults to None.""" + n_quantiles (int, optional): Number of quantiles. Defaults to 1.""" - def __init__(self, n_vars, nf, target_window, head_dropout=0, n_quantiles=None): + def __init__(self, n_vars, nf, target_window, head_dropout=0, n_quantiles=1): super().__init__() self.n_vars = n_vars self.flatten = nn.Flatten(start_dim=-2) - self.linear = nn.Linear(nf, target_window) self.n_quantiles = n_quantiles - if self.n_quantiles is not None: - self.linear = nn.Linear(nf, target_window * n_quantiles) - else: - self.linear = nn.Linear(nf, target_window) + self.linear = nn.Linear(nf, target_window * n_quantiles) self.dropout = nn.Dropout(head_dropout) def forward(self, x): @@ -203,9 +199,8 @@ def forward(self, x): x = self.linear(x) x = self.dropout(x) - if self.n_quantiles is not None: - batch_size, n_vars = x.shape[0], x.shape[1] - x = x.reshape(batch_size, n_vars, -1, self.n_quantiles) + batch_size, n_vars = x.shape[0], x.shape[1] + x = x.reshape(batch_size, n_vars, -1, self.n_quantiles) return x From 51e912f204af181d00ce147dbd604726abc99bfb Mon Sep 17 00:00:00 2001 From: PranavBhatP Date: Sun, 17 Aug 2025 15:24:06 +0530 Subject: [PATCH 4/6] add test for expected shape from forward of model --- tests/test_models/test_timexer.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/test_models/test_timexer.py b/tests/test_models/test_timexer.py index 2f5518026..ca0b557dd 100644 --- a/tests/test_models/test_timexer.py +++ b/tests/test_models/test_timexer.py @@ -82,6 +82,31 @@ def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs) **kwargs, ) + x, y = next(iter(train_dataloader)) + output = net(x) + + # add test for the raw output shape from the model. + if isinstance(loss, QuantileLoss): + if len(net.target_positions) == 1: + # Single target case + assert output["prediction"].shape[2] == len(loss.quantiles) + else: + # Multiple target case + assert all(o.shape[2] == len(loss.quantiles) for o in output["prediction"]) + else: + if len(net.target_positions) == 1: + # Single target case + assert output["prediction"].shape[2] == 1, ( + "The output tensor should have a third dimension of size 1 for single", + "target.", + ) + else: + # Multiple target case + assert all(o.shape[2] == 1 for o in output["prediction"]), ( + "Each tensor in the output list should have a", + "third dimension of size 1.", + ) + try: trainer.fit( net, From 8acbcfae4237faa9e1912b400b3cf3f4391f9990 Mon Sep 17 00:00:00 2001 From: PranavBhatP Date: Sun, 17 Aug 2025 15:33:33 +0530 Subject: [PATCH 5/6] use helper function for checking model forward output shape --- tests/test_models/test_timexer.py | 53 +++++++++++++++++++------------ 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/tests/test_models/test_timexer.py b/tests/test_models/test_timexer.py index ca0b557dd..af81a71ad 100644 --- a/tests/test_models/test_timexer.py +++ b/tests/test_models/test_timexer.py @@ -17,6 +17,37 @@ from pytorch_forecasting.models import TimeXer +def check_model_output_shape(output, net, loss): + """ + Check the output shape of the model. + Args: + output: The output from the model. + net: The model instance. + loss: The loss function used in the model. + """ + # add test for the raw output shape from the model. + if isinstance(loss, QuantileLoss): + if len(net.target_positions) == 1: + # Single target case + assert output["prediction"].shape[2] == len(loss.quantiles) + else: + # Multiple target case + assert all(o.shape[2] == len(loss.quantiles) for o in output["prediction"]) + else: + if len(net.target_positions) == 1: + # Single target case + assert output["prediction"].shape[2] == 1, ( + "The output tensor should have a third dimension of size 1 for single", + "target.", + ) + else: + # Multiple target case + assert all(o.shape[2] == 1 for o in output["prediction"]), ( + "Each tensor in the output list should have a", + "third dimension of size 1.", + ) + + def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs): """ Integration test for the TimeXer model. @@ -85,27 +116,7 @@ def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs) x, y = next(iter(train_dataloader)) output = net(x) - # add test for the raw output shape from the model. - if isinstance(loss, QuantileLoss): - if len(net.target_positions) == 1: - # Single target case - assert output["prediction"].shape[2] == len(loss.quantiles) - else: - # Multiple target case - assert all(o.shape[2] == len(loss.quantiles) for o in output["prediction"]) - else: - if len(net.target_positions) == 1: - # Single target case - assert output["prediction"].shape[2] == 1, ( - "The output tensor should have a third dimension of size 1 for single", - "target.", - ) - else: - # Multiple target case - assert all(o.shape[2] == 1 for o in output["prediction"]), ( - "Each tensor in the output list should have a", - "third dimension of size 1.", - ) + check_model_output_shape(output, net, loss) try: trainer.fit( From 7a594433dd49f0db6e22b3ca6c6c072b29c3c27e Mon Sep 17 00:00:00 2001 From: PranavBhatP Date: Wed, 20 Aug 2025 15:13:41 +0530 Subject: [PATCH 6/6] revert tests for timexer to original version --- tests/test_models/test_timexer.py | 36 ------------------------------- 1 file changed, 36 deletions(-) diff --git a/tests/test_models/test_timexer.py b/tests/test_models/test_timexer.py index af81a71ad..2f5518026 100644 --- a/tests/test_models/test_timexer.py +++ b/tests/test_models/test_timexer.py @@ -17,37 +17,6 @@ from pytorch_forecasting.models import TimeXer -def check_model_output_shape(output, net, loss): - """ - Check the output shape of the model. - Args: - output: The output from the model. - net: The model instance. - loss: The loss function used in the model. - """ - # add test for the raw output shape from the model. - if isinstance(loss, QuantileLoss): - if len(net.target_positions) == 1: - # Single target case - assert output["prediction"].shape[2] == len(loss.quantiles) - else: - # Multiple target case - assert all(o.shape[2] == len(loss.quantiles) for o in output["prediction"]) - else: - if len(net.target_positions) == 1: - # Single target case - assert output["prediction"].shape[2] == 1, ( - "The output tensor should have a third dimension of size 1 for single", - "target.", - ) - else: - # Multiple target case - assert all(o.shape[2] == 1 for o in output["prediction"]), ( - "Each tensor in the output list should have a", - "third dimension of size 1.", - ) - - def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs): """ Integration test for the TimeXer model. @@ -113,11 +82,6 @@ def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs) **kwargs, ) - x, y = next(iter(train_dataloader)) - output = net(x) - - check_model_output_shape(output, net, loss) - try: trainer.fit( net,