You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Make predictions in auto-regressive manner. Supports only continuous targets.
97
95
Args:
98
-
decode_one (Callable): function that takes at least the following arguments:
96
+
decode_one (Callable):
97
+
function that takes at least the following arguments:
99
98
* ``idx`` (int): index of decoding step (from 0 to n_decoder_steps-1)
100
99
* ``lagged_targets`` (List[torch.Tensor]): list of normalized targets.
101
-
List is ``idx + 1`` elements long with the most recent entry at the end, i.e. ``previous_target = lagged_targets[-1]`` and in general ``lagged_targets[-lag]``.
102
-
* ``hidden_state`` (Any): Current hidden state required for prediction. Keys are variable names. Only lags that are greater than ``idx`` are included.
103
-
* additional arguments are not dynamic but can be passed via the ``**kwargs`` argument And returns tuple of (not rescaled) network prediction output and hidden state for next auto-regressive step.
100
+
List is ``idx + 1`` elements long with the most recent entry at the end, i.e.
101
+
``previous_target = lagged_targets[-1]`` and in general ``lagged_targets[-lag]``.
102
+
* ``hidden_state`` (Any): Current hidden state required for prediction. Keys are variable
103
+
names. Only lags that are greater than ``idx`` are included.
104
+
* additional arguments are not dynamic but can be passed via the ``**kwargs`` argument And
105
+
returns tuple of (not rescaled) network prediction output and hidden state for next
106
+
auto-regressive step.
104
107
first_target (Union[List[torch.Tensor], torch.Tensor]): first target value to use for decoding
105
108
first_hidden_state (Any): first hidden state used for decoding
106
109
target_scale (Union[List[torch.Tensor], torch.Tensor]): target scale as in ``x``
Optimize hyperparameters. Run hyperparameter optimization. Learning rate for is determined with the PyTorch Lightning learning rate finder.
90
+
Optimize hyperparameters. Run hyperparameter optimization. Learning rate for is determined with the
91
+
PyTorch Lightning learning rate finder.
91
92
92
93
Args:
93
94
train_dataloaders (DataLoader):
@@ -97,30 +98,37 @@ def optimize_hyperparameters(
97
98
model_path (str):
98
99
Folder to which model checkpoints are saved.
99
100
monitor (str):
100
-
Metric to return. The hyper-parameter (HP) tuner trains a model for a certain HP config, and reads this metric to score configuration. By default, the lower the better.
101
+
Metric to return. The hyper-parameter (HP) tuner trains a model for a certain HP config, and
102
+
reads this metric to score configuration. By default, the lower the better.
101
103
direction (str):
102
-
By default, direction is "minimize", meaning that lower values of the specified `monitor` are better. You can change this, e.g. to "maximize".
104
+
By default, direction is "minimize", meaning that lower values of the specified `monitor` are
105
+
better. You can change this, e.g. to "maximize".
103
106
max_epochs (int, optional):
104
107
Maximum number of epochs to run training. Defaults to 20.
105
108
n_trials (int, optional):
106
109
Number of hyperparameter trials to run. Defaults to 100.
107
110
timeout (float, optional):
108
-
Time in seconds after which training is stopped regardless of number of epochs or validation metric. Defaults to 3600*8.0.
111
+
Time in seconds after which training is stopped regardless of number of epochs or validation
112
+
metric. Defaults to 3600*8.0.
109
113
input_params (dict, optional):
110
-
A dictionary, where each `key` contains another dictionary with two keys: `"method"` and `"ranges"`. Example:
114
+
A dictionary, where each `key` contains another dictionary with two keys: `"method"` and
115
+
`"ranges"`. Example:
111
116
>>> {"hidden_size": {
112
117
>>> "method": "suggest_int",
113
118
>>> "ranges": (16, 265),
114
119
>>> }}
115
-
The method key has to be a method of the `optuna.Trial` object. The ranges key are the input ranges for the specified method.
120
+
The method key has to be a method of the `optuna.Trial` object. The ranges key are the input
121
+
ranges for the specified method.
116
122
input_params_generator (Callable, optional):
117
-
A function with the following signature: `fn(trial: optuna.Trial, **kwargs: Any) -> Dict[str, Any]`, returning the parameter values to set up your model for the current trial/run.
123
+
A function with the following signature: `fn(trial: optuna.Trial, **kwargs: Any) -> Dict[str, Any]
124
+
`, returning the parameter values to set up your model for the current trial/run.
Then, when your model is created (before training it and report the metrics for the current combination of hyperparameters), these dictionary is used as follows:
130
+
Then, when your model is created (before training it and report the metrics for the current
131
+
combination of hyperparameters), these dictionary is used as follows:
124
132
>>> model = YourModelClass.from_dataset(
125
133
>>> train_dataloaders.dataset,
126
134
>>> log_interval=-1,
@@ -133,7 +141,9 @@ def optimize_hyperparameters(
133
141
use_learning_rate_finder (bool):
134
142
If to use learning rate finder or optimize as part of hyperparameters. Defaults to True.
135
143
trainer_kwargs (Dict[str, Any], optional):
136
-
Additional arguments to the `PyTorch Lightning trainer <https://pytorch-lightning.readthedocs.io/en/latest/trainer.html>` such as `limit_train_batches`. Defaults to {}.
0 commit comments