Skip to content

[ENH] xLSTMTime implementation #1709

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Aug 6, 2025
Merged

[ENH] xLSTMTime implementation #1709

merged 38 commits into from
Aug 6, 2025

Conversation

phoeenniixx
Copy link
Member

@phoeenniixx phoeenniixx commented Nov 9, 2024

Description

This PR tries to implement xLSTMTime based on this paper

see also sktime issue #6793

Checklist

  • Linked issues (if existing)
  • Amended changelog for large changes (and added myself there as contributor)
  • Added/modified tests
  • Used pre-commit hooks when committing to ensure that code is compliant with hooks. Install hooks with pre-commit install.
    To run hooks independent of commit, execute pre-commit run --all-files

@phoeenniixx phoeenniixx changed the title initial commit [ENH] xLSTMTime implementation Nov 9, 2024
Copy link

codecov bot commented Nov 9, 2024

Codecov Report

❌ Patch coverage is 95.65217% with 15 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (main@a88a404). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...orch_forecasting/layers/_recurrent/_slstm/layer.py 91.11% 4 Missing ⚠️
...orch_forecasting/layers/_recurrent/_mlstm/layer.py 93.18% 3 Missing ⚠️
...torch_forecasting/layers/_recurrent/_slstm/cell.py 95.08% 3 Missing ⚠️
pytorch_forecasting/models/xlstm/_xlstm.py 96.72% 2 Missing ⚠️
...torch_forecasting/layers/_recurrent/_mlstm/cell.py 98.27% 1 Missing ⚠️
...ch_forecasting/layers/_recurrent/_mlstm/network.py 93.75% 1 Missing ⚠️
...ch_forecasting/layers/_recurrent/_slstm/network.py 95.23% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1709   +/-   ##
=======================================
  Coverage        ?   87.39%           
=======================================
  Files           ?      113           
  Lines           ?     8419           
  Branches        ?        0           
=======================================
  Hits            ?     7358           
  Misses          ?     1061           
  Partials        ?        0           
Flag Coverage Δ
cpu 87.39% <95.65%> (?)
pytest 87.39% <95.65%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@phoeenniixx
Copy link
Member Author

hi @fkiraly, I am new to pytorch-forecasting and its tests and all, can you please tell me exactly what am I "missing"?

@phoeenniixx
Copy link
Member Author

Will these tests suffice @fkiraly?

Copy link
Collaborator

@benHeid benHeid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @phoeenniixx,
welcome to pytorch-forecasting and thank you for your pull request and contributing xlstm.
I added first comments about the BaseClass you used. Please change it to one of the BaseClasses (see the comment). Since I suppose that this will change your code a bit. I will wait with a complete review until you changed it.

return trend, seasonal


class xLSTMTime(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the Base classes of pytorch-forecasting (BaseModelWithCovariates, etc.) depending on the properties of the forecaster.
The advantage of doing this is that it automatically comes with PyTorch lightning and thus less boilerplate is needed.

You might compare it with the NHITS implementation and check how it is implemented.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please ensure that the naming conventions of files are met. I.e., only lower case is allowed and use _ as a separator. between words. .../x_lstm_time/x_lstm_time.py

device: Optional[torch.device] = None,
):
"""
Initialize xLSTMTime model.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check where to put the reference to the paper that originally proposes xlstm.

@phoeenniixx
Copy link
Member Author

Thanks for the review @benHeid!
I will have to restructure a little ig, I will see and use appropriate base class, use it in main xLSTMTime class, rest will be left untouched? (wrt to baseclass atleast)
I will make the changes and get back to you in few days!
Thanks!

@phoeenniixx
Copy link
Member Author

Hi @benHeid, I need some help:

  • here I implemented xLSTMTime class using BaseModel as for now I think this is the best fitted class... what do you think?

  • Also, I made some changes in the forward function of the code where before it was accepting Tensor object, I changed it to Dict as I found out that the user mainly uses TimeSeriesDataSet and it returns a dict, please correct me if I am wrong here.

  • I am using the encoder_cont key of the dict as input x.

Please tell me if I am in a right direction

class xLSTMTime(BaseModel):

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        xlstm_type: Literal['slstm', 'mlstm'],
        num_layers: int = 1,
        decomposition_kernel: int = 25,
        input_projection_size: Optional[int] = None,
        dropout: float = 0.1,
        loss: Metric = SMAPE(),
        device: Optional[torch.device] = None,
        **kwargs
    ):
        super().__init__(loss=loss, **kwargs)

        if xlstm_type not in ['slstm', 'mlstm']:
            raise ValueError("xlstm_type must be either 'slstm' or 'mlstm'")

        self.xlstm_type = xlstm_type
        self._device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.to(self._device)

        self.decomposition = SeriesDecomposition(decomposition_kernel)
        self.batch_norm = nn.BatchNorm1d(hidden_size)

        self.input_projection_size = input_projection_size or hidden_size

        self.input_linear = None  

        if xlstm_type == 'mlstm':
            self.lstm = mLSTMNetwork(
                input_size=hidden_size,
                hidden_size=hidden_size,
                num_layers=num_layers,
                output_size=hidden_size,
                dropout=dropout,
                device=self.device
            )
        else:  # slstm
            self.lstm = sLSTMNetwork(
                input_size=hidden_size,
                hidden_size=hidden_size,
                num_layers=num_layers,
                output_size=hidden_size,
                dropout=dropout,
                device=self.device
            )

        self.output_linear = nn.Linear(hidden_size, output_size)
        self.instance_norm = nn.InstanceNorm1d(output_size)

    def forward(
        self,
        x: Dict[str, torch.Tensor],  
        hidden_states: Optional[
            Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
        ] = None
    ) -> Tuple[torch.Tensor, Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
   
        encoder_cont = x["encoder_cont"]
        batch_size, seq_len, n_features = encoder_cont.shape

        trend, seasonal = self.decomposition(encoder_cont)

        x = torch.cat([trend, seasonal], dim=-1)
        concatenated_features = x.shape[-1]

        if self.input_linear is None:
            self.input_linear = nn.Linear(concatenated_features, self.input_projection_size).to(self._device)

        x = self.input_linear(x)

        x = x.transpose(1, 2)  
        x = self.batch_norm(x)
        x = x.transpose(1, 2)  

        if hidden_states is None:
            hidden_states = self.lstm.init_hidden(batch_size)

        x = x.transpose(0, 1)
        output, hidden_states = self.lstm(x, *hidden_states)

        if isinstance(output, tuple):
            output = output[0]

        if output.dim() == 2:
            output = output.unsqueeze(0)
        output = self.output_linear(output)

        output = output.transpose(1, 2)
        output = self.instance_norm(output)
        output = output.transpose(1, 2)

        return output, hidden_states


    def predict(
            self,
            x: torch.Tensor,
            hidden_states: Optional[
                Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
            ] = None
    ) -> torch.Tensor:

        output, _ = self.forward(x, hidden_states)
        return output

    def training_step(self, batch, batch_idx):
        x, y = batch
        y = y[0] if isinstance(y, tuple) else y 

        y_pred, _ = self(x)

        if y_pred.ndim == 3 and y_pred.size(0) == 1:
            y_pred = y_pred.squeeze(0)  
        loss = self.loss(y_pred, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y = y[0] if isinstance(y, tuple) else y 

        y_pred, _ = self(x)

        if y_pred.ndim == 3 and y_pred.size(0) == 1:
            y_pred = y_pred.squeeze(0)  
        loss = self.loss(y_pred, y)
        self.log("val_loss", loss)

        return loss




    def test_step(self, batch, batch_idx):
        x, y = batch
        y = y[0] if isinstance(y, tuple) else y 

        y_pred, _ = self(x)

        if y_pred.ndim == 3 and y_pred.size(0) == 1:
            y_pred = y_pred.squeeze(0)  
        loss = self.loss(y_pred, y)
        self.log("test_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=10)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}

@phoeenniixx
Copy link
Member Author

Also, Do we need to change the baseclass of just xLSTMTime only or mLSTMNetwork and sLSTMNetwork should also be changed?
(Although I think they are just a part of this main class so they could inherit from nn.Module without any problem?)

@benHeid
Copy link
Collaborator

benHeid commented Dec 10, 2024

  • here I implemented xLSTMTime class using BaseModel as for now I think this is the best fitted class... what do you think?

Mhm. if the implementation does not support any exogenous features than either BaseModel or AutoRegressiveBaseModel. I would assume that the ladder is probably the better fit.

  • Also, I made some changes in the forward function of the code where before it was accepting Tensor object, I changed it to Dict as I found out that the user mainly uses TimeSeriesDataSet and it returns a dict, please correct me if I am wrong here.

I agree that a dict should be used here.

  • I am using the encoder_cont key of the dict as input x.

Yes that is the target time series.

Please tell me if I am in a right direction

You might check the RNN implementation. Since this is also inheriting from an Autoregressive model and probably the most similar of the implemented models.
I would suggest that you check carefully, if you really need to implement the step / training_step method etc. or if is sufficient to use the inherited methods from the base class.

But I think you are in the right direction.

@phoeenniixx
Copy link
Member Author

phoeenniixx commented Dec 12, 2024

Hi @benHeid, I have updated the implementation using AutoRegressiveBaseModel, please review it. Also, I have not changed or added the tests (they are failing due to some changes in input and output format) as I saw that for other modules, there is a specific "trend" of writing the tests and I might need some help with that. Can you please provide me a brief about them, like what specific tests should I add etc.

I can add the docstrings in subsequent commits once I am sure that this is what we want.

@benHeid
Copy link
Collaborator

benHeid commented Dec 24, 2024

Sorry for my late response. Please ensure that the linting tests are green. Probably running the pre commit hooks locally should make it.

Regarding the failing tests, you might check how the output currently looks like by manually executing the xLSTM. You might then see what the issue is.

@fkiraly do we have any guides for pytorch-forecasting on how to write tests?

@phoeenniixx
Copy link
Member Author

Thanks for the reply @benHeid, actually the reason the tests are failing is: earlier I was using tensors, tuple etc and now TimeSeriesDataset is being used that uses a dict, that is the reason the tests are failing, I can correct those but I didn't do that because I noticed that for other models, they just use functions like test_integration etc. To write those functions, I first need to understand the input like dataloaders, dataset that is entered in these functions, like which data we are using here, the labels etc. is that data any arbitrary data or some pre-defined dataset?
Like look into this function from test_models.test_rnn_model,py:

def _integration(
    data_with_covariates, tmp_path, cell_type="LSTM", data_loader_kwargs={}, clip_target: bool = False, **kwargs
):
    data_with_covariates = data_with_covariates.copy()
    if clip_target:
        data_with_covariates["target"] = data_with_covariates["volume"].clip(1e-3, 1.0)
    else:
        data_with_covariates["target"] = data_with_covariates["volume"]
    data_loader_default_kwargs = dict(
        target="target",
        time_varying_known_reals=["price_actual"],
        time_varying_unknown_reals=["target"],
        static_categoricals=["agency"],
        add_relative_time_idx=True,
    )
    data_loader_default_kwargs.update(data_loader_kwargs)
    dataloaders_with_covariates = make_dataloaders(data_with_covariates, **data_loader_default_kwargs)
    train_dataloader = dataloaders_with_covariates["train"]
    val_dataloader = dataloaders_with_covariates["val"]
    test_dataloader = dataloaders_with_covariates["test"]

    early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min")

    logger = TensorBoardLogger(tmp_path)
    trainer = pl.Trainer(
        max_epochs=3,
        gradient_clip_val=0.1,
        callbacks=[early_stop_callback],
        enable_checkpointing=True,
        default_root_dir=tmp_path,
        limit_train_batches=2,
        limit_val_batches=2,
        limit_test_batches=2,
        logger=logger,
    )

    net = RecurrentNetwork.from_dataset(
        train_dataloader.dataset,
        cell_type=cell_type,
        learning_rate=0.15,
        log_gradient_flow=True,
        log_interval=1000,
        hidden_size=5,
        **kwargs,
    )
    net.size()
    try:
        trainer.fit(
            net,
            train_dataloaders=train_dataloader,
            val_dataloaders=val_dataloader,
        )
        test_outputs = trainer.test(net, dataloaders=test_dataloader)
        assert len(test_outputs) > 0
        # check loading
        net = RecurrentNetwork.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

        # check prediction
        net.predict(val_dataloader, fast_dev_run=True, return_index=True, return_decoder_lengths=True)
    finally:
        shutil.rmtree(tmp_path, ignore_errors=True)

    net.predict(val_dataloader, fast_dev_run=True, return_index=True, return_decoder_lengths=True)

Here they are using keys like "volume", and this is for data_with_covariates but I am not using the covariate base class that i can use directly this code and modify it to my requirements. I want to understand how this whole thing works and then I can write the test...

@phoeenniixx
Copy link
Member Author

for now I am just removing the test file and updating the code as required

Copy link
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor things before a more thorough review:

  • can you kindly add tests for some basic use cases?
  • can you make sure nothing except imports are in the __init__ files? Similar to the recent change sin the repo.


__all__ = [
"FullAttention",
"TriangularCausalMask",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this line get removed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see any imports for TriangularCausalMask

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually found it in layers._attention._full_attention, it is not imported even in the __init__ of layers._attention. I will add it to both the locations. At first, I thought it didnt exist 😅

Copy link
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

Minor requests related to docs:

  • class docstring should be in the class, not in __init__
  • the model should also be added to the model overview

@phoeenniixx
Copy link
Member Author

phoeenniixx commented Jul 31, 2025

the model should also be added to the model overview

I have seen some of the models having this docstring in __init__ rather than the class, so should we move those docstrings as well to the class? (obv in some other PR)

@phoeenniixx phoeenniixx requested a review from fkiraly July 31, 2025 19:15
@phoeenniixx
Copy link
Member Author

phoeenniixx commented Aug 1, 2025

Hi @fkiraly, before we close this, I have one doubt:
Why do the models have so much of if-else conditions in the __init__ of classes (see DeepAR, DecoderMLP etc)? And many of these variables in the init are not even used afterwards. Is this a design choice? I think we should avoid such things in v2 - keeping unnecessary params in __init__?

I avoided a similar design here in xlstm, should i add these conditions here as well?

@fkiraly
Copy link
Collaborator

fkiraly commented Aug 2, 2025

Hi @fkiraly, before we close this, I have one doubt: Why do the models have so much of if-else conditions in the __init__ of classes (see DeepAR, DecoderMLP etc)? And many of these variables in the init are not even used afterwards. Is this a design choice? I think we should avoid such things in v2 - keeping unnecessary params in __init__?

I think this design has to do with the fact that the models are not properly getting the metadata from the TimeSeriesDataSet - the translation is done in imperative fashion in from_dataset.

What I am a bit surprised about - why do you not need these in __init__?

@fkiraly
Copy link
Collaborator

fkiraly commented Aug 2, 2025

I see, the tests only construct via from_dataset. The __init__ is not actually tested - missing that there might be a problem.

@fkiraly
Copy link
Collaborator

fkiraly commented Aug 2, 2025

This PR is related: apparently not all models were properly tested or initializable via __init__: #1837

Do you know what the implicit contract is for __init__?

@fkiraly
Copy link
Collaborator

fkiraly commented Aug 2, 2025

minor comment regarding module structure - I would make it similar to other modules in naming:

  • call the module xlstm
  • the internal python module should be _xlstm.py etc

@fkiraly fkiraly moved this from PR under review to PR in progress in May - Sep 2025 mentee projects Aug 4, 2025
@fkiraly fkiraly moved this from PR in progress to PR under review in May - Sep 2025 mentee projects Aug 5, 2025
@phoeenniixx
Copy link
Member Author

phoeenniixx commented Aug 5, 2025

Hi @fkiraly, are there any other changes I need to make?

@@ -31,6 +31,7 @@ and you should take into account. Here is an overview over the pros and cons of
:py:class:`~pytorch_forecasting.models.deepar.DeepAR`, "x", "x", "x", "", "x", "x", "x [#deepvar]_ ", "x", "", 3
:py:class:`~pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`, "x", "x", "x", "x", "", "x", "", "x", "x", 4
:py:class:`~pytorch_forecasting.model.tide.TiDEModel`, "x", "x", "x", "", "", "", "", "x", "", 3
:py:class:`~pytorch_forecasting.models.x_lstm_time.xLSTMTime`, "x", "x", "x", "", "", "", "", "x", "", 3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is incorrect now

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry I forgot to change here

Copy link
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only minimal change requests.

  • doclink is now broken due to rename
  • could you move the lstm layers into a _recurrent folder, i.e., layers._recurrent._mlstm etc?

@fkiraly fkiraly moved this from PR under review to PR in progress in May - Sep 2025 mentee projects Aug 6, 2025
@phoeenniixx phoeenniixx requested a review from fkiraly August 6, 2025 19:40
Copy link
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

@fkiraly fkiraly merged commit 3093b9f into sktime:main Aug 6, 2025
35 checks passed
@github-project-automation github-project-automation bot moved this from PR in progress to Done in May - Sep 2025 mentee projects Aug 6, 2025
@github-project-automation github-project-automation bot moved this from PR in progress to Done in Dec 2024 - Mar 2025 mentee projects Aug 6, 2025
@phoeenniixx phoeenniixx deleted the xLSTMTime branch August 7, 2025 17:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request new network
Development

Successfully merging this pull request may close these issues.

3 participants