diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 488a56f33..3ad6b4dcc 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -430,8 +430,8 @@ def __getitem__(self, idx): encoder_indices = slice(start_idx, start_idx + enc_length) decoder_indices = slice(start_idx + enc_length, end_idx) - target_scale = data["target"][encoder_indices] - target_scale = target_scale[~torch.isnan(target_scale)].abs().mean() + target_past = data["target"][encoder_indices] + target_scale = target_past[~torch.isnan(target_past)].abs().mean() if torch.isnan(target_scale) or target_scale == 0: target_scale = torch.tensor(1.0) @@ -503,6 +503,7 @@ def __getitem__(self, idx): "decoder_lengths": torch.tensor(pred_length), "decoder_target_lengths": torch.tensor(pred_length), "groups": data["group"], + "target_past": target_past, "encoder_time_idx": torch.arange(enc_length), "decoder_time_idx": torch.arange(enc_length, enc_length + pred_length), "target_scale": target_scale, @@ -713,6 +714,7 @@ def collate_fn(batch): [x["decoder_target_lengths"] for x, _ in batch] ), "groups": torch.stack([x["groups"] for x, _ in batch]), + "target_past": torch.stack([x["target_past"] for x, _ in batch]), "encoder_time_idx": torch.stack([x["encoder_time_idx"] for x, _ in batch]), "decoder_time_idx": torch.stack([x["decoder_time_idx"] for x, _ in batch]), "target_scale": torch.stack([x["target_scale"] for x, _ in batch]), diff --git a/pytorch_forecasting/layers/__init__.py b/pytorch_forecasting/layers/__init__.py index ab9f2dad6..f91553423 100644 --- a/pytorch_forecasting/layers/__init__.py +++ b/pytorch_forecasting/layers/__init__.py @@ -8,6 +8,7 @@ TriangularCausalMask, ) from pytorch_forecasting.layers._decomposition import SeriesDecomposition +from pytorch_forecasting.layers._dsipts import ResidualBlock, embedding_cat_variables from pytorch_forecasting.layers._embeddings import ( DataEmbedding_inverted, EnEmbedding, @@ -48,4 +49,6 @@ "sLSTMLayer", "sLSTMNetwork", "SeriesDecomposition", + "ResidualBlock", + "embedding_cat_variables", ] diff --git a/pytorch_forecasting/layers/_dsipts/__init__.py b/pytorch_forecasting/layers/_dsipts/__init__.py new file mode 100644 index 000000000..7c01b2803 --- /dev/null +++ b/pytorch_forecasting/layers/_dsipts/__init__.py @@ -0,0 +1,4 @@ +from pytorch_forecasting.layers._dsipts._residual_block_dsipts import ResidualBlock +from pytorch_forecasting.layers._dsipts._sub_nn import embedding_cat_variables + +__all__ = ["ResidualBlock", "embedding_cat_variables"] diff --git a/pytorch_forecasting/layers/_dsipts/_residual_block_dsipts.py b/pytorch_forecasting/layers/_dsipts/_residual_block_dsipts.py new file mode 100644 index 000000000..0c050017b --- /dev/null +++ b/pytorch_forecasting/layers/_dsipts/_residual_block_dsipts.py @@ -0,0 +1,50 @@ +import torch.nn as nn + + +class ResidualBlock(nn.Module): + def __init__( + self, in_size: int, out_size: int, dropout_rate: float, activation_fun: str = "" + ): + """Residual Block as basic layer of the archetecture. + + MLP with one hidden layer, activation and skip connection + Basically dimension d_model, but better if input_dim and output_dim are explicit + + in_size and out_size to handle dimensions at different stages of the NN + + Parameters + ---------- + in_size: int + input size + out_size: int + output size + dropout_rate: float + dropout + activation_fun: str, Optional + activation function to use in the Residual Block. Defaults to nn.ReLU. + """ # noqa: E501 + import ast + + super().__init__() + + self.direct_linear = nn.Linear(in_size, out_size, bias=False) + + if activation_fun == "": + self.act = nn.ReLU() + else: + activation = ast.literal_eval(activation_fun) + self.act = activation() + self.lin = nn.Linear(in_size, out_size) + self.dropout = nn.Dropout(dropout_rate) + + self.final_norm = nn.LayerNorm(out_size) + + def forward(self, x, apply_final_norm=True): + direct_x = self.direct_linear(x) + + x = self.dropout(self.lin(self.act(x))) + + out = x + direct_x + if apply_final_norm: + return self.final_norm(out) + return out diff --git a/pytorch_forecasting/layers/_dsipts/_sub_nn.py b/pytorch_forecasting/layers/_dsipts/_sub_nn.py new file mode 100644 index 000000000..e8549318b --- /dev/null +++ b/pytorch_forecasting/layers/_dsipts/_sub_nn.py @@ -0,0 +1,101 @@ +from typing import Union + +import torch +import torch.nn as nn + + +class embedding_cat_variables(nn.Module): + # at the moment cat_past and cat_fut together + def __init__(self, seq_len: int, lag: int, d_model: int, emb_dims: list, device): + """Class for embedding categorical variables, adding 3 positional variables during forward + + Parameters + ---------- + seq_len: int + length of the sequence (sum of past and future steps) + lag: (int): + number of future step to be predicted + hiden_size: int + dimension of all variables after they are embedded + emb_dims: list + size of the dictionary for embedding. One dimension for each categorical variable + device : torch.device + """ # noqa: E501 + super().__init__() + self.seq_len = seq_len + self.lag = lag + self.device = device + self.cat_embeds = emb_dims + [seq_len, lag + 1, 2] # + self.cat_n_embd = nn.ModuleList( + [nn.Embedding(emb_dim, d_model) for emb_dim in self.cat_embeds] + ) + + def forward( + self, x: Union[torch.Tensor, int], device: torch.device + ) -> torch.Tensor: + """All components of x are concatenated with 3 new variables for data augmentation, in the order: + + - pos_seq: assign at each step its time-position + - pos_fut: assign at each step its future position. 0 if it is a past step + - is_fut: explicit for each step if it is a future(1) or past one(0) + + Parameters + ---------- + x: torch.Tensor + `[bs, seq_len, num_vars]` + + Returns + ------ + torch.Tensor: + `[bs, seq_len, num_vars+3, n_embd]` + """ # noqa: E501 + if isinstance(x, int): + no_emb = True + B = x + else: + no_emb = False + B, _, _ = x.shape + + pos_seq = self.get_pos_seq(bs=B).to(device) + pos_fut = self.get_pos_fut(bs=B).to(device) + is_fut = self.get_is_fut(bs=B).to(device) + + if no_emb: + cat_vars = torch.cat((pos_seq, pos_fut, is_fut), dim=2) + else: + cat_vars = torch.cat((x, pos_seq, pos_fut, is_fut), dim=2) + cat_vars = cat_vars.long() + cat_n_embd = self.get_cat_n_embd(cat_vars) + return cat_n_embd + + def get_pos_seq(self, bs): + pos_seq = torch.arange(0, self.seq_len) + pos_seq = pos_seq.repeat(bs, 1).unsqueeze(2).to(self.device) + return pos_seq + + def get_pos_fut(self, bs): + pos_fut = torch.cat( + ( + torch.zeros((self.seq_len - self.lag), dtype=torch.long), + torch.arange(1, self.lag + 1), + ) + ) + pos_fut = pos_fut.repeat(bs, 1).unsqueeze(2).to(self.device) + return pos_fut + + def get_is_fut(self, bs): + is_fut = torch.cat( + ( + torch.zeros((self.seq_len - self.lag), dtype=torch.long), + torch.ones((self.lag), dtype=torch.long), + ) + ) + is_fut = is_fut.repeat(bs, 1).unsqueeze(2).to(self.device) + return is_fut + + def get_cat_n_embd(self, cat_vars): + cat_n_embd = torch.Tensor().to(cat_vars.device) + for index, layer in enumerate(self.cat_n_embd): + emb = layer(cat_vars[:, :, index]) + cat_n_embd = torch.cat((cat_n_embd, emb.unsqueeze(2)), dim=2) + return cat_n_embd diff --git a/pytorch_forecasting/models/tide/tide_dsipts/__init__.py b/pytorch_forecasting/models/tide/tide_dsipts/__init__.py new file mode 100644 index 000000000..0dfd84c18 --- /dev/null +++ b/pytorch_forecasting/models/tide/tide_dsipts/__init__.py @@ -0,0 +1,6 @@ +"""DSIPTS Tide Implementation for V2""" + +from pytorch_forecasting.models.tide.tide_dsipts._tide_v2 import TIDE +from pytorch_forecasting.models.tide.tide_dsipts._tide_v2_pkg import TIDE_pkg_v2 + +__all__ = ["TIDE", "TIDE_pkg_v2"] diff --git a/pytorch_forecasting/models/tide/tide_dsipts/_tide_v2.py b/pytorch_forecasting/models/tide/tide_dsipts/_tide_v2.py new file mode 100644 index 000000000..c77d52e2d --- /dev/null +++ b/pytorch_forecasting/models/tide/tide_dsipts/_tide_v2.py @@ -0,0 +1,366 @@ +from typing import Union + +import torch +import torch.nn as nn + +from pytorch_forecasting.layers._dsipts import _sub_nn as sub_nn +from pytorch_forecasting.layers._dsipts._residual_block_dsipts import ResidualBlock +from pytorch_forecasting.models.base._base_model_v2 import BaseModel + + +class TIDE(BaseModel): + """Long-term Forecasting with TiDE: Time-series Dense Encoder + https://arxiv.org/abs/2304.08424 + + This NN uses as subnet the ResidualBlocks, which is composed by skip connection and activation+dropout. + Every encoder and decoder head is composed by one Residual Block, like the temporal decoder and the feature projection for covariates. + """ # noqa: E501 + + @classmethod + def _pkg(cls): + """Package containing the model.""" + from pytorch_forecasting.models.tide.tide_dsipts import TIDE_pkg_v2 + + return TIDE_pkg_v2 + + def __init__( + self, + metadata: dict, + loss: nn.Module, + hidden_size: int, + d_model: int, + n_add_enc: int, + n_add_dec: int, + dropout_rate: float, + activation: str = "", + embs: list[int] = [], + persistence_weight: float = 0.0, + optim: Union[str, None] = None, + optim_config: Union[dict, None] = None, + scheduler_config: Union[dict, None] = None, + **kwargs, + ) -> None: + """Initialise the model. + + Parameters + ---------- + metadata : dict + Metadata for the model from ``EncoderDecoderDataModule``. This can include + information about the dataset, such as the number of time steps, number of + features, etc. It is used to initialize the model + and ensure it is compatible with the data being used. + loss : nn.Module + Loss function module (e.g., ``MSELoss``, ``QuantileLoss``). + hidden_size : int + Dimensionality of hidden layers in projections (R). + d_model : int + Dimensionality of model projections after feature projection (R̃). + n_add_enc : int + Number of additional encoder residual blocks (after the first). + n_add_dec : int + Number of additional decoder residual blocks (after the first). + dropout_rate : float + Dropout probability applied in residual blocks. + activation : str, optional + Name of activation function to use (e.g., ``"relu"``). + embs : list of int, optional + List specifying embedding sizes for categorical variables. + persistence_weight : float, optional + Weight for the persistence (autoregressive) component. + optim : str or None, optional + Name of optimizer (e.g., ``"adam"``), or None to use default. + optim_config : dict or None, optional + Optimizer configuration dictionary. + scheduler_config : dict or None, optional + Scheduler configuration dictionary. + **kwargs + Additional keyword arguments passed to `BaseModel`. + + """ + + super().__init__(loss=loss) + self.save_hyperparameters(logger=False) + + self.dropout = dropout_rate + self.persistence_weight = persistence_weight + self.optim = optim + self.optim_config = optim_config + self.scheduler_config = scheduler_config + self.loss = loss + + self.hidden_size = hidden_size # r + self.d_model = d_model # r^tilda + self.past_steps = metadata["max_encoder_length"] # lookback size + self.future_steps = metadata["max_prediction_length"] # horizon size + self.past_channels = metadata["encoder_cont"] # psat_vars + self.future_channels = metadata["decoder_cont"] # fut_vars + self.output_channels = metadata["target"] # target_vars + self.mul = 1 + self.use_quantiles = False + self.outLinear = nn.Linear(d_model, self.output_channels) + + # for other numerical variables in the past + self.aux_past_channels = self.past_channels + self.linear_aux_past = nn.ModuleList( + [nn.Linear(1, self.hidden_size) for _ in range(self.aux_past_channels)] + ) + + # for numerical variables in the future + self.aux_fut_channels = self.future_channels + self.linear_aux_fut = nn.ModuleList( + [nn.Linear(1, self.hidden_size) for _ in range(self.aux_fut_channels)] + ) + + # embedding categorical for both past and future + self.seq_len = self.past_steps + self.future_steps + self.emb_cat_var = sub_nn.embedding_cat_variables( + self.seq_len, self.future_steps, hidden_size, embs, self.device + ) + + ## FEATURE PROJECTION + # past + if self.aux_past_channels > 0: + self.feat_proj_past = ResidualBlock( + 2 * hidden_size, d_model, dropout_rate, activation + ) + else: + self.feat_proj_past = ResidualBlock( + hidden_size, d_model, dropout_rate, activation + ) + # future + if self.aux_fut_channels > 0: + self.feat_proj_fut = ResidualBlock( + 2 * hidden_size, d_model, dropout_rate, activation + ) + else: + self.feat_proj_fut = ResidualBlock( + hidden_size, d_model, dropout_rate, activation + ) + + # # ENCODER + self.enc_dim_input = ( + self.past_steps * self.output_channels + + (self.past_steps + self.future_steps) * d_model + ) + self.enc_dim_output = self.future_steps * d_model + self.first_encoder = ResidualBlock( + self.enc_dim_input, self.enc_dim_output, dropout_rate, activation + ) + self.aux_encoder = nn.ModuleList( + [ + ResidualBlock( + self.enc_dim_output, self.enc_dim_output, dropout_rate, activation + ) + for _ in range(1, n_add_enc) + ] + ) + + # # DECODER + self.first_decoder = ResidualBlock( + self.enc_dim_output, self.enc_dim_output, dropout_rate, activation + ) + self.aux_decoder = nn.ModuleList( + [ + ResidualBlock( + self.enc_dim_output, self.enc_dim_output, dropout_rate, activation + ) + for _ in range(1, n_add_dec) + ] + ) + + ## TEMPORAL DECOER + self.temporal_decoder = ResidualBlock( + 2 * d_model, self.output_channels * self.mul, dropout_rate, activation + ) + + # linear for Y lookback + self.linear_target = nn.Linear( + self.past_steps * self.output_channels, + self.future_steps * self.output_channels * self.mul, + ) + + def forward(self, X: dict) -> dict: + """training process of the diffusion network + + Parameters + ---------- + X : dict + variables loaded + + Returns + ------- + float: + total loss about the prediction of the noises over all subnets extracted + """ # noqa: E501 + if isinstance(X, tuple): + x_batch, y_batch = X + batch = x_batch + else: + batch = X + + if "x_num_past" not in batch: + batch["x_num_past"] = batch["encoder_cont"] + if "x_num_future" not in batch: + batch["x_num_future"] = batch["decoder_cont"] + if "x_cat_past" not in batch: + batch["x_cat_past"] = batch["encoder_cat"] + if "x_cat_future" not in batch: + batch["x_cat_future"] = batch["decoder_cat"] + + y_past = batch["target_past"] + B = y_past.shape[0] + + # LOADING EMBEDDING CATEGORICAL VARIABLES + emb_cat_past, emb_cat_fut = self.cat_categorical_vars(batch) + + emb_cat_past = torch.mean(emb_cat_past, dim=2) + emb_cat_fut = torch.mean(emb_cat_fut, dim=2) + + ### LOADING PAST AND FUTURE NUMERICAL VARIABLES + # load in the model auxiliar numerical variables + + if self.aux_past_channels > 0: # if we have more numerical variables about past + aux_num_past = batch["encoder_cont"] + assert self.aux_past_channels == aux_num_past.size(2), ( + f"{self.aux_past_channels} LAYERS FOR PAST VARS AND " + f"{aux_num_past.size(2)} VARS" + ) # to check if we are using the expected number of variables about past + # concat all embedded vars and mean of them + aux_emb_num_past = torch.Tensor().to(self.device) + for i, layer in enumerate(self.linear_aux_past): + aux_emb_past = layer(aux_num_past[:, :, [i]]).unsqueeze(2) + aux_emb_num_past = torch.cat((aux_emb_num_past, aux_emb_past), dim=2) + aux_emb_num_past = torch.mean(aux_emb_num_past, dim=2) + else: + aux_emb_num_past = None # non available vars + + if ( + self.aux_fut_channels > 0 + ): # if we have more numerical variables about future + # AUX means AUXILIARY variables + aux_num_fut = batch["x_num_future"].to(self.device) + assert self.aux_fut_channels == aux_num_fut.size(2), ( + f"{self.aux_fut_channels} LAYERS FOR PAST VARS AND " + f"{aux_num_fut.size(2)} VARS" + ) # to check if we are using the expected number of variables about fut + # concat all embedded vars and mean of them + aux_emb_num_fut = torch.Tensor().to(self.device) + for j, layer in enumerate(self.linear_aux_fut): + aux_emb_fut = layer(aux_num_fut[:, :, [j]]).unsqueeze(2) + aux_emb_num_fut = torch.cat((aux_emb_num_fut, aux_emb_fut), dim=2) + aux_emb_num_fut = torch.mean(aux_emb_num_fut, dim=2) + else: + aux_emb_num_fut = None # non available vars + + # past^tilda + if self.aux_past_channels > 0: + emb_past = torch.cat( + (emb_cat_past, aux_emb_num_past), dim=2 + ) # [B, L, 2R] # + proj_past = self.feat_proj_past(emb_past, True) # [B, L, R^tilda] # + else: + proj_past = self.feat_proj_past(emb_cat_past, True) # [B, L, R^tilda] # + + # fut^tilda + if self.aux_fut_channels > 0: + emb_fut = torch.cat((emb_cat_fut, aux_emb_num_fut), dim=2) + # [B, H, 2R] # + proj_fut = self.feat_proj_fut(emb_fut, True) # [B, H, R^tilda] # + else: + proj_fut = self.feat_proj_fut(emb_cat_fut, True) # [B, H, R^tilda] # + + concat = torch.cat( + (y_past.view(B, -1), proj_past.view(B, -1), proj_fut.view(B, -1)), dim=1 + ) # [B, L*self.mul + (L+H)*R^tilda] # + dense_enc = self.first_encoder(concat) + for lay_enc in self.aux_encoder: + dense_enc = lay_enc(dense_enc) + + dense_dec = self.first_decoder(dense_enc) + for lay_dec in self.aux_decoder: + dense_dec = lay_dec(dense_dec) + + temp_dec_input = torch.cat( + (dense_dec.view(B, self.future_steps, self.d_model), proj_fut), dim=2 + ) + temp_dec_output = self.temporal_decoder(temp_dec_input, False) + temp_dec_output = temp_dec_output.view( + B, self.future_steps, self.output_channels + ) + + linear_regr = self.linear_target(y_past.view(B, -1)) + linear_output = linear_regr.view(B, self.future_steps, self.output_channels) + + output = temp_dec_output + linear_output + return {"prediction": output} + + # function to concat embedded categorical variables + def cat_categorical_vars(self, batch: dict): + """Extracting categorical context about past and future + + Parameters + -------- + batch: dict + dataloader batch + + Returns + ------- + List[torch.Tensor, torch.Tensor]: + cat_emb_past, cat_emb_fut + """ + cat_past = batch.get( + "encoder_cat", + torch.empty(batch["encoder_cont"].shape[0], self.past_steps, 0), + ).to(self.device) + cat_fut = batch.get( + "decoder_cat", + torch.empty(batch["encoder_cont"].shape[0], self.future_steps, 0), + ).to(self.device) + # GET AVAILABLE CATEGORICAL CONTEXT + if "x_cat_past" in batch.keys(): + cat_past = batch["x_cat_past"].to(self.device) + if "x_cat_future" in batch.keys(): + cat_fut = batch["x_cat_future"].to(self.device) + # CONCAT THEM, according to self.emb_cat_var usage + if cat_past is None: + emb_cat_full = self.emb_cat_var(batch["x_num_past"].shape[0], self.device) + + else: + cat_full = torch.cat((cat_past, cat_fut), dim=1) + emb_cat_full = self.emb_cat_var(cat_full, self.device) + cat_emb_past = emb_cat_full[:, : self.past_steps, :, :] + cat_emb_fut = emb_cat_full[:, -self.future_steps :, :, :] + + return cat_emb_past, cat_emb_fut + + # function to extract from batch['x_num_past'] all variables except the + # one autoregressive + def remove_var( + self, tensor: torch.Tensor, indexes_to_exclude: list, dimension: int + ) -> torch.Tensor: + """Function to remove variables from tensors in chosen dimension and position + + Parameters + ---------- + tensor: torch.Tensor + starting tensor + indexes_to_exclude: list + index of the chosen dimension we want t oexclude + dimension: int + dimension of the tensor on which we want to work (not list od dims!!) + + Returns + ------- + torch.Tensor: + new tensor without the chosen variables + """ # noqa: E501 + + remaining_idx = torch.tensor( + [i for i in range(tensor.size(dimension)) if i not in indexes_to_exclude] + ).to(tensor.device) + # Select the desired sub-tensor + extracted_subtensors = torch.index_select( + tensor, dim=dimension, index=remaining_idx + ) + + return extracted_subtensors diff --git a/pytorch_forecasting/models/tide/tide_dsipts/_tide_v2_pkg.py b/pytorch_forecasting/models/tide/tide_dsipts/_tide_v2_pkg.py new file mode 100644 index 000000000..7952e2a65 --- /dev/null +++ b/pytorch_forecasting/models/tide/tide_dsipts/_tide_v2_pkg.py @@ -0,0 +1,140 @@ +"""TIDE package container.""" + +from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2 + + +class TIDE_pkg_v2(_BasePtForecasterV2): + """TIDE package container.""" + + _tags = { + "info:name": "TIDE", + "authors": ["fbk_dsipts"], + } + + @classmethod + def get_cls(cls): + """Get model class.""" + from pytorch_forecasting.models.tide.tide_dsipts._tide_v2 import TIDE + + return TIDE + + @classmethod + def _get_test_datamodule_from(cls, trainer_kwargs): + """Create test dataloaders from trainer_kwargs - following v1 pattern.""" + from pytorch_forecasting.data.data_module import ( + EncoderDecoderTimeSeriesDataModule, + ) + from pytorch_forecasting.tests._data_scenarios import ( + data_with_covariates_v2, + make_datasets_v2, + ) + + data_with_covariates = data_with_covariates_v2() + + data_loader_default_kwargs = dict( + target="target", + group_ids=["agency_encoded", "sku_encoded"], + add_relative_time_idx=True, + ) + + data_loader_kwargs = trainer_kwargs.get("data_loader_kwargs", {}) + data_loader_default_kwargs.update(data_loader_kwargs) + + datasets_info = make_datasets_v2( + data_with_covariates, **data_loader_default_kwargs + ) + + training_dataset = datasets_info["training_dataset"] + validation_dataset = datasets_info["validation_dataset"] + training_max_time_idx = datasets_info["training_max_time_idx"] + + max_encoder_length = data_loader_kwargs.get("max_encoder_length", 4) + max_prediction_length = data_loader_kwargs.get("max_prediction_length", 3) + add_relative_time_idx = data_loader_kwargs.get("add_relative_time_idx", True) + batch_size = data_loader_kwargs.get("batch_size", 2) + + train_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=training_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + add_relative_time_idx=add_relative_time_idx, + batch_size=batch_size, + train_val_test_split=(0.8, 0.2, 0.0), + ) + + val_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=validation_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + min_prediction_idx=training_max_time_idx, + add_relative_time_idx=add_relative_time_idx, + batch_size=batch_size, + train_val_test_split=(0.0, 1.0, 0.0), + ) + + test_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=validation_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + min_prediction_idx=training_max_time_idx, + add_relative_time_idx=add_relative_time_idx, + batch_size=1, + train_val_test_split=(0.0, 0.0, 1.0), + ) + + train_datamodule.setup("fit") + val_datamodule.setup("fit") + test_datamodule.setup("test") + + train_dataloader = train_datamodule.train_dataloader() + val_dataloader = val_datamodule.val_dataloader() + test_dataloader = test_datamodule.test_dataloader() + + return { + "train": train_dataloader, + "val": val_dataloader, + "test": test_dataloader, + "data_module": train_datamodule, + } + + @classmethod + def get_test_train_params(cls): + """Return testing parameter settings for the trainer. + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + `create_test_instance` uses the first (or only) dictionary in `params` + """ + import torch.nn as nn + + return [ + dict( + hidden_size=16, + d_model=8, + n_add_enc=1, + n_add_dec=1, + dropout_rate=0.1, + ), + dict( + hidden_size=32, + d_model=16, + n_add_enc=2, + n_add_dec=2, + dropout_rate=0.2, + data_loader_kwargs=dict(max_encoder_length=5, max_prediction_length=3), + loss=nn.MSELoss(), + ), + dict( + hidden_size=64, + d_model=32, + n_add_enc=3, + n_add_dec=2, + dropout_rate=0.1, + data_loader_kwargs=dict(max_encoder_length=4, max_prediction_length=2), + loss=nn.PoissonNLLLoss(), + ), + ] diff --git a/pytorch_forecasting/tests/test_all_estimators_v2.py b/pytorch_forecasting/tests/test_all_estimators_v2.py index 8c61c5e7d..ba391944f 100644 --- a/pytorch_forecasting/tests/test_all_estimators_v2.py +++ b/pytorch_forecasting/tests/test_all_estimators_v2.py @@ -57,9 +57,15 @@ def _integration( metadata, dict ), f"Expected metadata to be dict, got {type(metadata)}" + if "loss" in kwargs: + loss = kwargs["loss"] + kwargs.pop("loss") + else: + loss = nn.MSELoss() + net = estimator_cls( metadata=metadata, - loss=nn.MSELoss(), + loss=loss, **kwargs, )