Skip to content

[ENH] Add Tide to test framework of ptf-v2 #1889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 134 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
134 commits
Select commit Hold shift + click to select a range
b3644a6
test suite
fkiraly Feb 22, 2025
a1d64c6
Merge branch 'main' into test-suite
fkiraly Feb 22, 2025
4b2486e
skeleton
fkiraly Feb 22, 2025
02b0ce6
skeleton
fkiraly Feb 22, 2025
41cbf66
Update test_all_estimators.py
fkiraly Feb 23, 2025
cef62d3
Update _base_object.py
fkiraly Feb 23, 2025
bc2e93b
Update _lookup.py
fkiraly Feb 23, 2025
eee1c86
Update _lookup.py
fkiraly Feb 23, 2025
164fe0d
base metadatda
fkiraly Feb 23, 2025
20e88d0
registry
fkiraly Feb 23, 2025
318c1fb
fix private name
fkiraly Feb 23, 2025
012ab3d
Update _base_object.py
fkiraly Feb 23, 2025
86365a0
test failure
fkiraly Feb 23, 2025
f6dee46
Update test_all_estimators.py
fkiraly Feb 23, 2025
9b0e4ec
Update test_all_estimators.py
fkiraly Feb 23, 2025
7de5285
Update test_all_estimators.py
fkiraly Feb 23, 2025
57dfe3a
test folders
fkiraly Feb 23, 2025
c9f12db
Update test.yml
fkiraly Feb 23, 2025
fa8144e
test integration
fkiraly Feb 23, 2025
232a510
fixes
fkiraly Feb 23, 2025
1c8d4b5
Update _conftest.py
fkiraly Feb 23, 2025
f632e32
try scenarios
fkiraly Feb 23, 2025
252598d
D1, D2 layer commit
phoeenniixx Apr 6, 2025
d0d1c3e
remove one comment
phoeenniixx Apr 6, 2025
80e64d2
model layer commit
phoeenniixx Apr 6, 2025
6364780
update docstring
phoeenniixx Apr 6, 2025
82b3dc7
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 6, 2025
257183c
update data_module.py
phoeenniixx Apr 10, 2025
9cdcb19
update data_module.py
phoeenniixx Apr 10, 2025
a83bf32
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 10, 2025
ac56d4f
Add disclaimer
phoeenniixx Apr 10, 2025
0e7e36f
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 10, 2025
4bfff21
update docstring
phoeenniixx Apr 11, 2025
ef98273
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 11, 2025
8a53ed6
Add tests for D1,D2 layer
phoeenniixx Apr 19, 2025
9f9df31
Merge branch 'main' into refactor-d1-d2
phoeenniixx Apr 19, 2025
cdecb77
Code quality
phoeenniixx Apr 19, 2025
86360fd
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 19, 2025
20aafb7
refactor file
fkiraly Apr 30, 2025
043820d
warning
fkiraly Apr 30, 2025
1720a15
linting
fkiraly May 1, 2025
af44474
move coercion to utils
fkiraly May 1, 2025
a3cb8b7
linting
fkiraly May 1, 2025
75d7fb5
Update _timeseries_v2.py
fkiraly May 1, 2025
1b946e6
Update __init__.py
fkiraly May 1, 2025
3edb08b
Update __init__.py
fkiraly May 1, 2025
a4bc9d8
Merge branch 'main' into pr/1811
fkiraly May 1, 2025
4c0d570
Merge branch 'pr/1811' into pr/1812
fkiraly May 1, 2025
ef37f55
Merge branch 'main' into test-suite
fkiraly May 1, 2025
a669134
Update _lookup.py
fkiraly May 4, 2025
d78bf5d
Update _lookup.py
fkiraly May 4, 2025
e350291
update tests
phoeenniixx May 11, 2025
f90c94f
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx May 11, 2025
3099691
update tft_v2
phoeenniixx May 11, 2025
77cb979
warnings and init attr handling
fkiraly May 13, 2025
28df3c3
Merge branch 'refactor-d1-d2' of https://github.com/phoeenniixx/pytor…
fkiraly May 13, 2025
f8c94e6
simplify TimeSeries.__getitem__
fkiraly May 13, 2025
c289255
Update _timeseries_v2.py
fkiraly May 13, 2025
9467f38
Update data_module.py
fkiraly May 13, 2025
c3b40ad
backwards compat of private/public attrs
fkiraly May 13, 2025
c007310
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx May 13, 2025
2e25052
Merge branch 'main' into refactor-model
phoeenniixx May 13, 2025
38c28dc
add tests
phoeenniixx May 14, 2025
9d80eb8
add tests
phoeenniixx May 14, 2025
a8ccfe3
add tests
phoeenniixx May 14, 2025
f900ba5
add more docstrings
phoeenniixx May 14, 2025
ed1b799
add note about the commented out tests
phoeenniixx May 14, 2025
c947910
Merge branch 'main' into refactor-model
phoeenniixx May 16, 2025
c0ceb8a
add the commented out tests
phoeenniixx May 16, 2025
3828c26
remove note
phoeenniixx May 16, 2025
6d6d18e
Merge branch 'main' into refactor-model
phoeenniixx May 18, 2025
3144865
Merge branch 'test-suite' of https://github.com/sktime/pytorch-foreca…
phoeenniixx May 20, 2025
30b541b
make the modules private
phoeenniixx May 20, 2025
3f1e11f
Merge remote-tracking branch 'origin/refactor-model' into refactor-model
phoeenniixx May 20, 2025
5cc3ff1
initial commit
phoeenniixx May 20, 2025
1bcf181
Merge branch 'refactor-model' into test-framework
phoeenniixx May 20, 2025
f18e09d
add TFTMetadata class
phoeenniixx May 20, 2025
e1e360e
add TFTMetadata class
phoeenniixx May 20, 2025
168e16a
Merge branch 'main' into test-framework
phoeenniixx May 22, 2025
92c12bf
add TFT tests
phoeenniixx May 25, 2025
1d478d5
remove refactored TFT
phoeenniixx May 27, 2025
f9992f2
Merge branch 'main' into test-framework
phoeenniixx May 28, 2025
d049019
update test_all_estimators
phoeenniixx May 28, 2025
e72486b
linting
phoeenniixx May 28, 2025
7443b0b
Merge branch 'main' into test-framework
phoeenniixx May 29, 2025
a734f26
refactor
phoeenniixx May 29, 2025
7f466b2
Add more test_params
phoeenniixx May 29, 2025
0968452
Add metadata tests
phoeenniixx May 31, 2025
525bbb9
Merge branch 'main' into test-framework
phoeenniixx Jun 1, 2025
4267da6
Merge branch 'main' into test-framework
phoeenniixx Jun 1, 2025
4e8f863
add object-filter to ptf-v1
phoeenniixx Jun 1, 2025
c117092
Merge branch 'main' into test-framework
phoeenniixx Jun 5, 2025
f6d39fe
Merge branch 'main' into test-framework
phoeenniixx Jun 6, 2025
2c518ee
add new base classes
phoeenniixx Jun 6, 2025
7a5c58f
remove try block
phoeenniixx Jun 8, 2025
cb3e944
Merge branch 'main' into test-framework
phoeenniixx Jun 8, 2025
3b9de6d
add support for multiple datamodules
phoeenniixx Jun 9, 2025
032a7b0
typo
phoeenniixx Jun 9, 2025
4d9a19a
Merge branch 'main' into test-framework
phoeenniixx Jun 9, 2025
03c06e8
Merge branch 'main' into test-framework
phoeenniixx Jun 12, 2025
33ae311
add Tide
phoeenniixx Jun 12, 2025
8b0087e
linting
phoeenniixx Jun 12, 2025
0e1debd
Merge branch 'test-framework' into tide
phoeenniixx Jun 12, 2025
63f1eb7
softdep
phoeenniixx Jun 13, 2025
d328fae
Merge branch 'main' into test-framework
phoeenniixx Jun 13, 2025
62c3f83
Merge branch 'test-framework' into tide
phoeenniixx Jun 13, 2025
7dfba67
softdep
phoeenniixx Jun 13, 2025
f020229
add the error causing param
phoeenniixx Jun 13, 2025
43a837e
remove embs from params
phoeenniixx Jun 13, 2025
68df4b6
merge main
phoeenniixx Jun 13, 2025
57d635b
add pkg name to v2
phoeenniixx Jun 13, 2025
9798ff1
Merge branch 'test-framework' into tide
phoeenniixx Jun 13, 2025
1c88de0
add pkg name to v2
phoeenniixx Jun 13, 2025
8436793
Merge branch 'main' into tide
phoeenniixx Jun 14, 2025
6096b90
Merge branch 'main' into pr/1889
fkiraly Jun 15, 2025
0fbbf00
Update _tide_pkg.py
fkiraly Jun 15, 2025
ab94060
revert
fkiraly Jun 15, 2025
f4d4f37
Delete tft_v2_metadata.py
fkiraly Jun 15, 2025
d0b8677
revert
fkiraly Jun 15, 2025
3f89a45
revert
fkiraly Jun 15, 2025
ad00566
rename
fkiraly Jun 15, 2025
53131e4
Update tide_v2_pkg.py
fkiraly Jun 15, 2025
52fefc3
update tide.py
phoeenniixx Jun 17, 2025
2d20eb3
refactor
phoeenniixx Jun 17, 2025
5968af7
merge main
phoeenniixx Aug 7, 2025
8d94371
refactor code
phoeenniixx Aug 9, 2025
6cbb6aa
remove unused base class
phoeenniixx Aug 9, 2025
58f0b60
remove beautify string util
phoeenniixx Aug 9, 2025
89939e8
remove unused imports
phoeenniixx Aug 9, 2025
92e88ad
add docstrings
phoeenniixx Aug 12, 2025
c11fb4d
update docstrings
phoeenniixx Aug 12, 2025
d7eeeec
Merge branch 'main' into tide
phoeenniixx Aug 12, 2025
8f4a831
Merge branch 'main' into tide
phoeenniixx Aug 12, 2025
7619147
Merge branch 'main' into tide
phoeenniixx Aug 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pytorch_forecasting/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,8 +430,8 @@ def __getitem__(self, idx):
encoder_indices = slice(start_idx, start_idx + enc_length)
decoder_indices = slice(start_idx + enc_length, end_idx)

target_scale = data["target"][encoder_indices]
target_scale = target_scale[~torch.isnan(target_scale)].abs().mean()
target_past = data["target"][encoder_indices]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we doing this? This seems like a significant change to the internal contract of TimeSeriesDataSet

Copy link
Member Author

@phoeenniixx phoeenniixx Aug 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed it some months back, that there was no target_past in EncoderDecoderDataModule and some models do require it. TSLibDataModule already implements it as history_target (see here). Tide also needs it, and we were already calculating target_past for the target_scale, so I just renamed the variable and added it to the return

(if you see line 509, we are still returning target_scale as well)

target_scale = target_past[~torch.isnan(target_past)].abs().mean()
if torch.isnan(target_scale) or target_scale == 0:
target_scale = torch.tensor(1.0)

Expand Down Expand Up @@ -503,6 +503,7 @@ def __getitem__(self, idx):
"decoder_lengths": torch.tensor(pred_length),
"decoder_target_lengths": torch.tensor(pred_length),
"groups": data["group"],
"target_past": target_past,
"encoder_time_idx": torch.arange(enc_length),
"decoder_time_idx": torch.arange(enc_length, enc_length + pred_length),
"target_scale": target_scale,
Expand Down Expand Up @@ -713,6 +714,7 @@ def collate_fn(batch):
[x["decoder_target_lengths"] for x, _ in batch]
),
"groups": torch.stack([x["groups"] for x, _ in batch]),
"target_past": torch.stack([x["target_past"] for x, _ in batch]),
"encoder_time_idx": torch.stack([x["encoder_time_idx"] for x, _ in batch]),
"decoder_time_idx": torch.stack([x["decoder_time_idx"] for x, _ in batch]),
"target_scale": torch.stack([x["target_scale"] for x, _ in batch]),
Expand Down
3 changes: 3 additions & 0 deletions pytorch_forecasting/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
TriangularCausalMask,
)
from pytorch_forecasting.layers._decomposition import SeriesDecomposition
from pytorch_forecasting.layers._dsipts import ResidualBlock, embedding_cat_variables
from pytorch_forecasting.layers._embeddings import (
DataEmbedding_inverted,
EnEmbedding,
Expand Down Expand Up @@ -48,4 +49,6 @@
"sLSTMLayer",
"sLSTMNetwork",
"SeriesDecomposition",
"ResidualBlock",
"embedding_cat_variables",
]
4 changes: 4 additions & 0 deletions pytorch_forecasting/layers/_dsipts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from pytorch_forecasting.layers._dsipts._residual_block_dsipts import ResidualBlock
from pytorch_forecasting.layers._dsipts._sub_nn import embedding_cat_variables

__all__ = ["ResidualBlock", "embedding_cat_variables"]
50 changes: 50 additions & 0 deletions pytorch_forecasting/layers/_dsipts/_residual_block_dsipts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch.nn as nn


class ResidualBlock(nn.Module):
def __init__(
self, in_size: int, out_size: int, dropout_rate: float, activation_fun: str = ""
):
"""Residual Block as basic layer of the archetecture.

MLP with one hidden layer, activation and skip connection
Basically dimension d_model, but better if input_dim and output_dim are explicit

in_size and out_size to handle dimensions at different stages of the NN

Parameters
----------
in_size: int
input size
out_size: int
output size
dropout_rate: float
dropout
activation_fun: str, Optional
activation function to use in the Residual Block. Defaults to nn.ReLU.
""" # noqa: E501
import ast

super().__init__()

self.direct_linear = nn.Linear(in_size, out_size, bias=False)

if activation_fun == "":
self.act = nn.ReLU()
else:
activation = ast.literal_eval(activation_fun)
self.act = activation()
self.lin = nn.Linear(in_size, out_size)
self.dropout = nn.Dropout(dropout_rate)

self.final_norm = nn.LayerNorm(out_size)

def forward(self, x, apply_final_norm=True):
direct_x = self.direct_linear(x)

x = self.dropout(self.lin(self.act(x)))

out = x + direct_x
if apply_final_norm:
return self.final_norm(out)
return out
101 changes: 101 additions & 0 deletions pytorch_forecasting/layers/_dsipts/_sub_nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from typing import Union

import torch
import torch.nn as nn


class embedding_cat_variables(nn.Module):
# at the moment cat_past and cat_fut together
def __init__(self, seq_len: int, lag: int, d_model: int, emb_dims: list, device):
"""Class for embedding categorical variables, adding 3 positional variables during forward

Parameters
----------
seq_len: int
length of the sequence (sum of past and future steps)
lag: (int):
number of future step to be predicted
hiden_size: int
dimension of all variables after they are embedded
emb_dims: list
size of the dictionary for embedding. One dimension for each categorical variable
device : torch.device
""" # noqa: E501
super().__init__()
self.seq_len = seq_len
self.lag = lag
self.device = device
self.cat_embeds = emb_dims + [seq_len, lag + 1, 2] #
self.cat_n_embd = nn.ModuleList(
[nn.Embedding(emb_dim, d_model) for emb_dim in self.cat_embeds]
)

def forward(
self, x: Union[torch.Tensor, int], device: torch.device
) -> torch.Tensor:
"""All components of x are concatenated with 3 new variables for data augmentation, in the order:

- pos_seq: assign at each step its time-position
- pos_fut: assign at each step its future position. 0 if it is a past step
- is_fut: explicit for each step if it is a future(1) or past one(0)

Parameters
----------
x: torch.Tensor
`[bs, seq_len, num_vars]`

Returns
------
torch.Tensor:
`[bs, seq_len, num_vars+3, n_embd]`
""" # noqa: E501
if isinstance(x, int):
no_emb = True
B = x
else:
no_emb = False
B, _, _ = x.shape

pos_seq = self.get_pos_seq(bs=B).to(device)
pos_fut = self.get_pos_fut(bs=B).to(device)
is_fut = self.get_is_fut(bs=B).to(device)

if no_emb:
cat_vars = torch.cat((pos_seq, pos_fut, is_fut), dim=2)
else:
cat_vars = torch.cat((x, pos_seq, pos_fut, is_fut), dim=2)
cat_vars = cat_vars.long()
cat_n_embd = self.get_cat_n_embd(cat_vars)
return cat_n_embd

def get_pos_seq(self, bs):
pos_seq = torch.arange(0, self.seq_len)
pos_seq = pos_seq.repeat(bs, 1).unsqueeze(2).to(self.device)
return pos_seq

def get_pos_fut(self, bs):
pos_fut = torch.cat(
(
torch.zeros((self.seq_len - self.lag), dtype=torch.long),
torch.arange(1, self.lag + 1),
)
)
pos_fut = pos_fut.repeat(bs, 1).unsqueeze(2).to(self.device)
return pos_fut

def get_is_fut(self, bs):
is_fut = torch.cat(
(
torch.zeros((self.seq_len - self.lag), dtype=torch.long),
torch.ones((self.lag), dtype=torch.long),
)
)
is_fut = is_fut.repeat(bs, 1).unsqueeze(2).to(self.device)
return is_fut

def get_cat_n_embd(self, cat_vars):
cat_n_embd = torch.Tensor().to(cat_vars.device)
for index, layer in enumerate(self.cat_n_embd):
emb = layer(cat_vars[:, :, index])
cat_n_embd = torch.cat((cat_n_embd, emb.unsqueeze(2)), dim=2)
return cat_n_embd
6 changes: 6 additions & 0 deletions pytorch_forecasting/models/tide/tide_dsipts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""DSIPTS Tide Implementation for V2"""

from pytorch_forecasting.models.tide.tide_dsipts._tide_v2 import TIDE
from pytorch_forecasting.models.tide.tide_dsipts._tide_v2_pkg import TIDE_pkg_v2

__all__ = ["TIDE", "TIDE_pkg_v2"]
Loading
Loading