Skip to content

Commit ef6acf2

Browse files
authored
Adds more examples to customise AutoPyTorch. (#124)
* 3 examples plus doc update * Forgot the examples * Added example for resampling strategy * Update example worflow * Fixed bugs in example and resampling strategies * Addressed comments * Addressed comments * Addressed comments from shuhei, better documentation
1 parent 5c6ce0b commit ef6acf2

File tree

11 files changed

+446
-104
lines changed

11 files changed

+446
-104
lines changed

.github/workflows/examples.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ jobs:
3030
echo "::set-output name=BEFORE::$(git status --porcelain -b)"
3131
- name: Run tests
3232
run: |
33-
python examples/example_tabular_classification.py
34-
python examples/example_tabular_regression.py
33+
python examples/tabular/20_basics/example_tabular_classification.py
34+
python examples/tabular/20_basics/example_tabular_regression.py
35+
python examples/tabular/40_advanced/example_custom_configuration_space.py
36+
python examples/tabular/40_advanced/example_resampling_strategy.py
3537
python examples/example_image_classification.py

autoPyTorch/api/tabular_classification.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,27 +27,36 @@ class TabularClassificationTask(BaseTask):
2727
"""
2828
Tabular Classification API to the pipelines.
2929
Args:
30-
seed (int): seed to be used for reproducibility.
31-
n_jobs (int), (default=1): number of consecutive processes to spawn.
32-
logging_config (Optional[Dict]): specifies configuration
33-
for logging, if None, it is loaded from the logging.yaml
34-
ensemble_size (int), (default=50): Number of models added to the ensemble built by
30+
seed (int):
31+
seed to be used for reproducibility.
32+
n_jobs (int), (default=1):
33+
number of consecutive processes to spawn.
34+
logging_config (Optional[Dict]):
35+
specifies configuration for logging, if None, it is loaded from the logging.yaml
36+
ensemble_size (int), (default=50):
37+
Number of models added to the ensemble built by
3538
Ensemble selection from libraries of models.
3639
Models are drawn with replacement.
37-
ensemble_nbest (int), (default=50): only consider the ensemble_nbest
40+
ensemble_nbest (int), (default=50):
41+
only consider the ensemble_nbest
3842
models to build the ensemble
39-
max_models_on_disc (int), (default=50): maximum number of models saved to disc.
43+
max_models_on_disc (int), (default=50):
44+
maximum number of models saved to disc.
4045
Also, controls the size of the ensemble as any additional models will be deleted.
4146
Must be greater than or equal to 1.
42-
temporary_directory (str): folder to store configuration output and log file
43-
output_directory (str): folder to store predictions for optional test set
44-
delete_tmp_folder_after_terminate (bool): determines whether to delete the temporary directory,
45-
when finished
46-
include_components (Optional[Dict]): If None, all possible components are used.
47-
Otherwise specifies set of components to use.
48-
exclude_components (Optional[Dict]): If None, all possible components are used.
49-
Otherwise specifies set of components not to use. Incompatible with include
50-
components
47+
temporary_directory (str):
48+
folder to store configuration output and log file
49+
output_directory (str):
50+
folder to store predictions for optional test set
51+
delete_tmp_folder_after_terminate (bool):
52+
determines whether to delete the temporary directory, when finished
53+
include_components (Optional[Dict]):
54+
If None, all possible components are used. Otherwise
55+
specifies set of components to use.
56+
exclude_components (Optional[Dict]):
57+
If None, all possible components are used. Otherwise
58+
specifies set of components not to use. Incompatible
59+
with include components
5160
"""
5261
def __init__(
5362
self,

autoPyTorch/datasets/resampling_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) ->
9797

9898
def stratified_holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) \
9999
-> Tuple[np.ndarray, np.ndarray]:
100-
train, val = train_test_split(indices, test_size=val_share, shuffle=False, stratify=kwargs["stratify"])
100+
train, val = train_test_split(indices, test_size=val_share, shuffle=True, stratify=kwargs["stratify"])
101101
return train, val
102102

103103

autoPyTorch/utils/hyperparameter_search_space_update.py

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,25 @@
66
from autoPyTorch.pipeline.components.base_component import autoPyTorchComponent
77

88

9-
class HyperparameterSearchSpaceUpdate():
9+
class HyperparameterSearchSpaceUpdate:
10+
"""
11+
Allows specifying update to the search space of a
12+
particular hyperparameter.
13+
14+
Args:
15+
node_name (str):
16+
The name of the node in the pipeline
17+
hyperparameter (str):
18+
The name of the hyperparameter
19+
value_range (Union[List, Tuple]):
20+
In case of categorical hyperparameter, defines the new categorical choices.
21+
In case of numerical hyperparameter, defines the new range
22+
in the form of (LOWER, UPPER)
23+
default_value (Union[int, float, str]):
24+
New default value for the hyperparameter
25+
log (bool) (default=False):
26+
In case of numerical hyperparameters, whether to sample on a log scale
27+
"""
1028
def __init__(self, node_name: str, hyperparameter: str, value_range: Union[List, Tuple],
1129
default_value: Union[int, float, str], log: bool = False) -> None:
1230
self.node_name = node_name
@@ -16,6 +34,15 @@ def __init__(self, node_name: str, hyperparameter: str, value_range: Union[List,
1634
self.default_value = default_value
1735

1836
def apply(self, pipeline: List[Tuple[str, Union[autoPyTorchComponent, autoPyTorchChoice]]]) -> None:
37+
"""
38+
Applies the update to the appropriate hyperparameter of the pipeline
39+
Args:
40+
pipeline (List[Tuple[str, Union[autoPyTorchComponent, autoPyTorchChoice]]]):
41+
The named steps of the current autopytorch pipeline
42+
43+
Returns:
44+
None
45+
"""
1946
[node[1]._apply_search_space_update(name=self.hyperparameter,
2047
new_value_range=self.value_range,
2148
log=self.log,
@@ -29,30 +56,69 @@ def __str__(self) -> str:
2956
(" log" if self.log else ""))
3057

3158

32-
class HyperparameterSearchSpaceUpdates():
59+
class HyperparameterSearchSpaceUpdates:
60+
""" Contains a collection of HyperparameterSearchSpaceUpdate """
3361
def __init__(self, updates: Optional[List[HyperparameterSearchSpaceUpdate]] = None) -> None:
3462
self.updates = updates if updates is not None else []
3563

3664
def apply(self, pipeline: List[Tuple[str, Union[autoPyTorchComponent, autoPyTorchChoice]]]) -> None:
65+
"""
66+
Iteratively applies updates to the pipeline
67+
68+
Args:
69+
pipeline: (List[Tuple[str, Union[autoPyTorchComponent, autoPyTorchChoice]]]):
70+
The named steps of the current autoPyTorch pipeline
71+
72+
Returns:
73+
None
74+
"""
3775
for update in self.updates:
3876
update.apply(pipeline)
3977

4078
def append(self, node_name: str, hyperparameter: str, value_range: Union[List, Tuple],
4179
default_value: Union[int, float, str], log: bool = False) -> None:
80+
"""
81+
Add a new update
82+
83+
Args:
84+
node_name (str):
85+
The name of the node in the pipeline
86+
hyperparameter (str):
87+
The name of the hyperparameter
88+
value_range (Union[List, Tuple]):
89+
In case of categorical hyperparameter, defines the new categorical choices.
90+
In case of numerical hyperparameter, defines the new range
91+
in the form of (LOWER, UPPER)
92+
default_value (Union[int, float, str]):
93+
New default value for the hyperparameter
94+
log (bool) (default=False):
95+
In case of numerical hyperparameters, whether to sample on a log scale
96+
97+
Returns:
98+
None
99+
"""
42100
self.updates.append(HyperparameterSearchSpaceUpdate(node_name=node_name,
43101
hyperparameter=hyperparameter,
44102
value_range=value_range,
45103
default_value=default_value,
46104
log=log))
47105

48106
def save_as_file(self, path: str) -> None:
107+
"""
108+
Save the updates as a file to reuse later
109+
110+
Args:
111+
path (str): path of the file
112+
113+
Returns:
114+
None
115+
"""
49116
with open(path, "w") as f:
50-
with open(path, "w") as f:
51-
for update in self.updates:
52-
print(update.node_name, update.hyperparameter, # noqa: T001
53-
str(update.value_range), "'{}'".format(update.default_value)
54-
if isinstance(update.default_value, str) else update.default_value,
55-
(" log" if update.log else ""), file=f)
117+
for update in self.updates:
118+
print(update.node_name, update.hyperparameter, # noqa: T001
119+
str(update.value_range), "'{}'".format(update.default_value)
120+
if isinstance(update.default_value, str) else update.default_value,
121+
(" log" if update.log else ""), file=f)
56122

57123

58124
def parse_hyperparameter_search_space_updates(updates_file: Optional[str]

docs/conf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@
6868

6969
sphinx_gallery_conf = {
7070
# path to the examples
71-
'examples_dirs': '../examples',
71+
'examples_dirs': ['../examples/tabular/20_basics', '../examples/tabular/40_advanced'],
7272
# path where to save gallery generated examples
73-
'gallery_dirs': 'examples',
73+
'gallery_dirs': ['basics_tabular', 'advanced_tabular'],
7474
#TODO: fix back/forward references for the examples.
7575
#'doc_module': ('autoPyTorch'),
7676
#'reference_url': {

examples/tabular/20_basics/README.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
.. _examples_tabular_basics:
2+
3+
4+
==============================
5+
Basic Tabular Dataset Examples
6+
==============================
7+
8+
Basic examples for using *Auto-PyTorch* on tabular datasets

examples/example_tabular_classification.py renamed to examples/tabular/20_basics/example_tabular_classification.py

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,32 +22,10 @@
2222
import sklearn.model_selection
2323

2424
from autoPyTorch.api.tabular_classification import TabularClassificationTask
25-
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
26-
27-
28-
def get_search_space_updates():
29-
"""
30-
Search space updates to the task can be added using HyperparameterSearchSpaceUpdates
31-
Returns:
32-
HyperparameterSearchSpaceUpdates
33-
"""
34-
updates = HyperparameterSearchSpaceUpdates()
35-
updates.append(node_name="data_loader",
36-
hyperparameter="batch_size",
37-
value_range=[16, 512],
38-
default_value=32)
39-
updates.append(node_name="lr_scheduler",
40-
hyperparameter="CosineAnnealingLR:T_max",
41-
value_range=[50, 60],
42-
default_value=55)
43-
updates.append(node_name='network_backbone',
44-
hyperparameter='ResNetBackbone:dropout',
45-
value_range=[0, 0.5],
46-
default_value=0.2)
47-
return updates
4825

4926

5027
if __name__ == '__main__':
28+
5129
############################################################################
5230
# Data Loading
5331
# ============
@@ -62,16 +40,23 @@ def get_search_space_updates():
6240
# Build and fit a classifier
6341
# ==========================
6442
api = TabularClassificationTask(
65-
delete_tmp_folder_after_terminate=False,
66-
search_space_updates=get_search_space_updates()
43+
temporary_directory='./tmp/autoPyTorch_example_tmp_01',
44+
output_directory='./tmp/autoPyTorch_example_out_01',
45+
# To maintain logs of the run, set the next two as False
46+
delete_tmp_folder_after_terminate=True,
47+
delete_output_folder_after_terminate=True
6748
)
49+
50+
############################################################################
51+
# Search for an ensemble of machine learning algorithms
52+
# =====================================================
6853
api.search(
6954
X_train=X_train,
7055
y_train=y_train,
7156
X_test=X_test.copy(),
7257
y_test=y_test.copy(),
7358
optimize_metric='accuracy',
74-
total_walltime_limit=500,
59+
total_walltime_limit=300,
7560
func_eval_time_limit=50
7661
)
7762

@@ -82,4 +67,5 @@ def get_search_space_updates():
8267
y_pred = api.predict(X_test)
8368
score = api.score(y_pred, y_test)
8469
print(score)
70+
# Print the final ensemble built by AutoPyTorch
8571
print(api.show_models())

examples/example_tabular_regression.py renamed to examples/tabular/20_basics/example_tabular_regression.py

Lines changed: 18 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,15 @@
33
Tabular Regression
44
======================
55
6-
The following example shows how to fit a sample classification model
6+
The following example shows how to fit a sample regression model
77
with AutoPyTorch
88
"""
99
import os
1010
import tempfile as tmp
11-
import typing
1211
import warnings
1312

14-
from sklearn.datasets import make_regression
15-
16-
from autoPyTorch.data.tabular_feature_validator import TabularFeatureValidator
13+
import sklearn.datasets
14+
import sklearn.model_selection
1715

1816
os.environ['JOBLIB_TEMP_FOLDER'] = tmp.gettempdir()
1917
os.environ['OMP_NUM_THREADS'] = '1'
@@ -23,54 +21,16 @@
2321
warnings.simplefilter(action='ignore', category=UserWarning)
2422
warnings.simplefilter(action='ignore', category=FutureWarning)
2523

26-
from sklearn import model_selection, preprocessing
27-
2824
from autoPyTorch.api.tabular_regression import TabularRegressionTask
29-
from autoPyTorch.datasets.tabular_dataset import TabularDataset
30-
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
31-
32-
33-
def get_search_space_updates():
34-
"""
35-
Search space updates to the task can be added using HyperparameterSearchSpaceUpdates
36-
Returns:
37-
HyperparameterSearchSpaceUpdates
38-
"""
39-
updates = HyperparameterSearchSpaceUpdates()
40-
updates.append(node_name="data_loader",
41-
hyperparameter="batch_size",
42-
value_range=[16, 512],
43-
default_value=32)
44-
updates.append(node_name="lr_scheduler",
45-
hyperparameter="CosineAnnealingLR:T_max",
46-
value_range=[50, 60],
47-
default_value=55)
48-
updates.append(node_name='network_backbone',
49-
hyperparameter='ResNetBackbone:dropout',
50-
value_range=[0, 0.5],
51-
default_value=0.2)
52-
return updates
5325

5426

5527
if __name__ == '__main__':
28+
5629
############################################################################
5730
# Data Loading
5831
# ============
59-
60-
# Get the training data for tabular regression
61-
# X, y = datasets.fetch_openml(name="cholesterol", return_X_y=True)
62-
63-
# Use dummy data for now since there are problems with categorical columns
64-
X, y = make_regression(
65-
n_samples=5000,
66-
n_features=4,
67-
n_informative=3,
68-
n_targets=1,
69-
shuffle=True,
70-
random_state=0
71-
)
72-
73-
X_train, X_test, y_train, y_test = model_selection.train_test_split(
32+
X, y = sklearn.datasets.fetch_openml(name='boston', return_X_y=True, as_frame=True)
33+
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
7434
X,
7535
y,
7636
random_state=1,
@@ -89,16 +49,23 @@ def get_search_space_updates():
8949
# Build and fit a regressor
9050
# ==========================
9151
api = TabularRegressionTask(
92-
delete_tmp_folder_after_terminate=False,
93-
search_space_updates=get_search_space_updates()
52+
temporary_directory='./tmp/autoPyTorch_example_tmp_02',
53+
output_directory='./tmp/autoPyTorch_example_out_02',
54+
# To maintain logs of the run, set the next two as False
55+
delete_tmp_folder_after_terminate=True,
56+
delete_output_folder_after_terminate=True
9457
)
58+
59+
############################################################################
60+
# Search for an ensemble of machine learning algorithms
61+
# =====================================================
9562
api.search(
9663
X_train=X_train,
9764
y_train=y_train_scaled,
9865
X_test=X_test.copy(),
9966
y_test=y_test_scaled.copy(),
10067
optimize_metric='r2',
101-
total_walltime_limit=500,
68+
total_walltime_limit=300,
10269
func_eval_time_limit=50,
10370
traditional_per_total_budget=0
10471
)
@@ -114,3 +81,5 @@ def get_search_space_updates():
11481
score = api.score(y_pred, y_test)
11582

11683
print(score)
84+
# Print the final ensemble built by AutoPyTorch
85+
print(api.show_models())

0 commit comments

Comments
 (0)