Skip to content

Commit 0ab670c

Browse files
committed
Changed CrossValTypes to have val functions directly
1 parent d05696b commit 0ab670c

File tree

3 files changed

+122
-114
lines changed

3 files changed

+122
-114
lines changed

autoPyTorch/datasets/base_dataset.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,11 @@
1313

1414
from autoPyTorch.constants import CLASSIFICATION_OUTPUTS, STRING_TO_OUTPUT_TYPES
1515
from autoPyTorch.datasets.resampling_strategy import (
16-
CrossValFuncs,
1716
CrossValTypes,
1817
CrossValFunc,
1918
DEFAULT_RESAMPLING_PARAMETERS,
2019
HoldoutValTypes,
21-
HoldOutFuncs,
22-
HoldOutFunc
20+
HoldOutValFunc
2321
)
2422
from autoPyTorch.utils.common import FitRequirement, hash_array_or_matrix
2523

@@ -104,7 +102,7 @@ def __init__(
104102
type_check(train_tensors, val_tensors)
105103
self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors
106104
self.cross_validators: Dict[str, CrossValFunc] = {}
107-
self.holdout_validators: Dict[str, HoldOutFunc] = {}
105+
self.holdout_validators: Dict[str, HoldOutValFunc] = {}
108106
self.rng = np.random.RandomState(seed=seed)
109107
self.shuffle = shuffle
110108
self.resampling_strategy = resampling_strategy
@@ -125,8 +123,8 @@ def __init__(
125123
self.is_small_preprocess = True
126124

127125
# Make sure cross validation splits are created once
128-
self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes)
129-
self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldoutValTypes)
126+
self.cross_validators = CrossValTypes.get_validators(*CrossValTypes)
127+
self.holdout_validators = HoldoutValTypes.get_validators(*HoldoutValTypes)
130128
self.splits = self.get_splits_from_resampling_strategy()
131129

132130
# We also need to be able to transform the data, be it for pre-processing
Lines changed: 110 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from enum import IntEnum
1+
from enum import Enum
2+
from functools import partial
23
from typing import Any, Dict, List, Optional, Tuple, Union
34

45
import numpy as np
@@ -17,135 +18,59 @@
1718

1819
# Use callback protocol as workaround, since callable with function fields count 'self' as argument
1920
class CrossValFunc(Protocol):
21+
"""TODO: This class is not required anymore, because CrossValTypes class does not require get_validators()"""
2022
def __call__(self,
2123
num_splits: int,
2224
indices: np.ndarray,
2325
stratify: Optional[Any]) -> List[Tuple[np.ndarray, np.ndarray]]:
2426
...
2527

2628

27-
class HoldOutFunc(Protocol):
29+
class HoldoutValFunc(Protocol):
2830
def __call__(self, val_share: float, indices: np.ndarray, stratify: Optional[Any]
2931
) -> Tuple[np.ndarray, np.ndarray]:
3032
...
3133

3234

33-
class CrossValTypes(IntEnum):
34-
"""The type of cross validation
35-
36-
This class is used to specify the cross validation function
37-
and is not supposed to be instantiated.
38-
39-
Examples: This class is supposed to be used as follows
40-
>>> cv_type = CrossValTypes.k_fold_cross_validation
41-
>>> print(cv_type.name)
42-
43-
k_fold_cross_validation
44-
45-
>>> for cross_val_type in CrossValTypes:
46-
print(cross_val_type.name, cross_val_type.value)
47-
48-
stratified_k_fold_cross_validation 1
49-
k_fold_cross_validation 2
50-
stratified_shuffle_split_cross_validation 3
51-
shuffle_split_cross_validation 4
52-
time_series_cross_validation 5
53-
"""
54-
stratified_k_fold_cross_validation = 1
55-
k_fold_cross_validation = 2
56-
stratified_shuffle_split_cross_validation = 3
57-
shuffle_split_cross_validation = 4
58-
time_series_cross_validation = 5
59-
60-
def is_stratified(self) -> bool:
61-
stratified = [self.stratified_k_fold_cross_validation,
62-
self.stratified_shuffle_split_cross_validation]
63-
return getattr(self, self.name) in stratified
64-
65-
66-
class HoldoutValTypes(IntEnum):
67-
"""TODO: change to enum using functools.partial"""
68-
"""The type of hold out validation (refer to CrossValTypes' doc-string)"""
69-
holdout_validation = 6
70-
stratified_holdout_validation = 7
71-
72-
def is_stratified(self) -> bool:
73-
stratified = [self.stratified_holdout_validation]
74-
return getattr(self, self.name) in stratified
75-
76-
77-
"""TODO: deprecate soon"""
78-
RESAMPLING_STRATEGIES = [CrossValTypes, HoldoutValTypes]
79-
80-
"""TODO: deprecate soon"""
81-
DEFAULT_RESAMPLING_PARAMETERS = {
82-
HoldoutValTypes.holdout_validation: {
83-
'val_share': 0.33,
84-
},
85-
HoldoutValTypes.stratified_holdout_validation: {
86-
'val_share': 0.33,
87-
},
88-
CrossValTypes.k_fold_cross_validation: {
89-
'num_splits': 3,
90-
},
91-
CrossValTypes.stratified_k_fold_cross_validation: {
92-
'num_splits': 3,
93-
},
94-
CrossValTypes.shuffle_split_cross_validation: {
95-
'num_splits': 3,
96-
},
97-
CrossValTypes.time_series_cross_validation: {
98-
'num_splits': 3,
99-
},
100-
} # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]]
101-
102-
103-
class HoldOutFuncs():
35+
class HoldoutValFuncs():
10436
@staticmethod
105-
def holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) -> Tuple[np.ndarray, np.ndarray]:
37+
def holdout_validation(val_share: float, indices: np.ndarray, stratify: Optional[Any] = None) \
38+
-> Tuple[np.ndarray, np.ndarray]:
10639
train, val = train_test_split(indices, test_size=val_share, shuffle=False)
10740
return train, val
10841

10942
@staticmethod
110-
def stratified_holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) \
43+
def stratified_holdout_validation(val_share: float, indices: np.ndarray, stratify: Optional[Any] = None) \
11144
-> Tuple[np.ndarray, np.ndarray]:
112-
train, val = train_test_split(indices, test_size=val_share, shuffle=False, stratify=kwargs["stratify"])
45+
train, val = train_test_split(indices, test_size=val_share, shuffle=False, stratify=stratify)
11346
return train, val
11447

115-
@classmethod
116-
def get_holdout_validators(cls, *holdout_val_types: Tuple[HoldoutValTypes]) -> Dict[str, HoldOutFunc]:
117-
118-
holdout_validators = {
119-
holdout_val_type.name: getattr(cls, holdout_val_type.name)
120-
for holdout_val_type in holdout_val_types
121-
}
122-
return holdout_validators
123-
12448

12549
class CrossValFuncs():
12650
@staticmethod
127-
def shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \
51+
def shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, stratify: Optional[Any] = None) \
12852
-> List[Tuple[np.ndarray, np.ndarray]]:
12953
cv = ShuffleSplit(n_splits=num_splits)
13054
splits = list(cv.split(indices))
13155
return splits
13256

13357
@staticmethod
134-
def stratified_shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \
58+
def stratified_shuffle_split_cross_validation(num_splits: int, indices: np.ndarray,
59+
stratify: Optional[Any] = None) \
13560
-> List[Tuple[np.ndarray, np.ndarray]]:
13661
cv = StratifiedShuffleSplit(n_splits=num_splits)
137-
splits = list(cv.split(indices, kwargs["stratify"]))
62+
splits = list(cv.split(indices, stratify))
13863
return splits
13964

14065
@staticmethod
141-
def stratified_k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \
66+
def stratified_k_fold_cross_validation(num_splits: int, indices: np.ndarray, stratify: Optional[Any] = None) \
14267
-> List[Tuple[np.ndarray, np.ndarray]]:
14368
cv = StratifiedKFold(n_splits=num_splits)
144-
splits = list(cv.split(indices, kwargs["stratify"]))
69+
splits = list(cv.split(indices, stratify))
14570
return splits
14671

14772
@staticmethod
148-
def k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \
73+
def k_fold_cross_validation(num_splits: int, indices: np.ndarray, stratify: Optional[Any] = None) \
14974
-> List[Tuple[np.ndarray, np.ndarray]]:
15075
"""
15176
Standard k fold cross validation.
@@ -159,7 +84,7 @@ def k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any)
15984
return splits
16085

16186
@staticmethod
162-
def time_series_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \
87+
def time_series_cross_validation(num_splits: int, indices: np.ndarray, stratify: Optional[Any] = None) \
16388
-> List[Tuple[np.ndarray, np.ndarray]]:
16489
"""
16590
Returns train and validation indices respecting the temporal ordering of the data.
@@ -176,10 +101,96 @@ def time_series_cross_validation(num_splits: int, indices: np.ndarray, **kwargs:
176101
splits = list(cv.split(indices))
177102
return splits
178103

179-
@classmethod
180-
def get_cross_validators(cls, *cross_val_types: CrossValTypes) -> Dict[str, CrossValFunc]:
181-
cross_validators = {
182-
cross_val_type.name: getattr(cls, cross_val_type.name)
183-
for cross_val_type in cross_val_types
184-
}
185-
return cross_validators
104+
105+
class CrossValTypes(Enum):
106+
"""The type of cross validation
107+
108+
This class is used to specify the cross validation function
109+
and is not supposed to be instantiated.
110+
111+
Examples: This class is supposed to be used as follows
112+
>>> cv_type = CrossValTypes.k_fold_cross_validation
113+
>>> print(cv_type.name)
114+
115+
k_fold_cross_validation
116+
117+
>>> print(cv_type.value)
118+
119+
functools.partial(<function CrossValTypes.k_fold_cross_validation at ...>)
120+
121+
>>> for cross_val_type in CrossValTypes:
122+
print(cross_val_type.name)
123+
124+
stratified_k_fold_cross_validation
125+
k_fold_cross_validation
126+
stratified_shuffle_split_cross_validation
127+
shuffle_split_cross_validation
128+
time_series_cross_validation
129+
130+
Additionally, CrossValTypes.<function> can be called directly.
131+
"""
132+
stratified_k_fold_cross_validation = partial(CrossValFuncs.stratified_k_fold_cross_validation)
133+
k_fold_cross_validation = partial(CrossValFuncs.k_fold_cross_validation)
134+
stratified_shuffle_split_cross_validation = partial(CrossValFuncs.stratified_shuffle_split_cross_validation)
135+
shuffle_split_cross_validation = partial(CrossValFuncs.shuffle_split_cross_validation)
136+
time_series_cross_validation = partial(CrossValFuncs.time_series_cross_validation)
137+
138+
def is_stratified(self) -> bool:
139+
stratified = [self.stratified_k_fold_cross_validation,
140+
self.stratified_shuffle_split_cross_validation]
141+
return getattr(self, self.name) in stratified
142+
143+
def __call__(self, num_splits: int, indices: np.ndarray, stratify: Optional[Any]
144+
) -> Tuple[np.ndarray, np.ndarray]:
145+
"""TODO: doc-string and test files"""
146+
self.value(num_splits=num_splits, indices=indices, stratify=stratify)
147+
148+
@staticmethod
149+
def get_validators(*choices: CrossValFunc):
150+
"""TODO: to be compatible, it is here now, but will be deprecated soon."""
151+
return {choice.name: choice.value for choice in choices}
152+
153+
154+
class HoldoutValTypes(Enum):
155+
"""The type of hold out validation (refer to CrossValTypes' doc-string)"""
156+
holdout_validation = partial(HoldoutValFuncs.holdout_validation)
157+
stratified_holdout_validation = partial(HoldoutValFuncs.stratified_holdout_validation)
158+
159+
def is_stratified(self) -> bool:
160+
stratified = [self.stratified_holdout_validation]
161+
return getattr(self, self.name) in stratified
162+
163+
def __call__(self, val_share: float, indices: np.ndarray, stratify: Optional[Any]
164+
) -> Tuple[np.ndarray, np.ndarray]:
165+
self.value(val_share=val_share, indices=indices, stratify=stratify)
166+
167+
@staticmethod
168+
def get_validators(*choices: HoldoutValFunc):
169+
"""TODO: to be compatible, it is here now, but will be deprecated soon."""
170+
return {choice.name: choice.value for choice in choices}
171+
172+
173+
"""TODO: deprecate soon (Will rename CrossValTypes -> CrossValFunc)"""
174+
RESAMPLING_STRATEGIES = [CrossValTypes, HoldoutValTypes]
175+
176+
"""TODO: deprecate soon"""
177+
DEFAULT_RESAMPLING_PARAMETERS = {
178+
HoldoutValTypes.holdout_validation: {
179+
'val_share': 0.33,
180+
},
181+
HoldoutValTypes.stratified_holdout_validation: {
182+
'val_share': 0.33,
183+
},
184+
CrossValTypes.k_fold_cross_validation: {
185+
'num_splits': 3,
186+
},
187+
CrossValTypes.stratified_k_fold_cross_validation: {
188+
'num_splits': 3,
189+
},
190+
CrossValTypes.shuffle_split_cross_validation: {
191+
'num_splits': 3,
192+
},
193+
CrossValTypes.time_series_cross_validation: {
194+
'num_splits': 3,
195+
},
196+
} # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]]

autoPyTorch/datasets/time_series_dataset.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
from autoPyTorch.datasets.base_dataset import BaseDataset
88
from autoPyTorch.datasets.resampling_strategy import (
99
CrossValTypes,
10-
HoldoutValTypes,
11-
CrossValFuncs,
12-
HoldOutFuncs
10+
HoldoutValTypes
1311
)
1412

1513
TIME_SERIES_FORECASTING_INPUT = Tuple[np.ndarray, np.ndarray] # currently only numpy arrays are supported
@@ -60,8 +58,9 @@ def __init__(self,
6058
train_transforms=train_transforms,
6159
val_transforms=val_transforms,
6260
)
63-
self.cross_validators = CrossValFuncs.get_cross_validators(CrossValTypes.time_series_cross_validation)
64-
self.holdout_validators = HoldOutFuncs.get_holdout_validators(HoldoutValTypes.holdout_validation)
61+
"""Comment: Do we really need those two? They are already defined in BaseDataset"""
62+
self.cross_validators = CrossValTypes.get_cross_validators(CrossValTypes.time_series_cross_validation)
63+
self.holdout_validators = HoldoutValTypes.get_holdout_validators(HoldoutValTypes.holdout_validation)
6564

6665

6766
def _check_time_series_forecasting_inputs(target_variables: Tuple[int],
@@ -117,13 +116,13 @@ def __init__(self,
117116
val=val,
118117
task_type="time_series_classification")
119118
super().__init__(train_tensors=train, val_tensors=val, shuffle=True)
120-
self.cross_validators = CrossValFuncs.get_cross_validators(
119+
self.cross_validators = CrossValTypes.get_cross_validators(
121120
CrossValTypes.stratified_k_fold_cross_validation,
122121
CrossValTypes.k_fold_cross_validation,
123122
CrossValTypes.shuffle_split_cross_validation,
124123
CrossValTypes.stratified_shuffle_split_cross_validation
125124
)
126-
self.holdout_validators = HoldOutFuncs.get_holdout_validators(
125+
self.holdout_validators = HoldoutValTypes.get_holdout_validators(
127126
HoldoutValTypes.holdout_validation,
128127
HoldoutValTypes.stratified_holdout_validation
129128
)
@@ -135,11 +134,11 @@ def __init__(self, train: Tuple[np.ndarray, np.ndarray], val: Optional[Tuple[np.
135134
val=val,
136135
task_type="time_series_regression")
137136
super().__init__(train_tensors=train, val_tensors=val, shuffle=True)
138-
self.cross_validators = CrossValFuncs.get_cross_validators(
137+
self.cross_validators = CrossValTypes.get_cross_validators(
139138
CrossValTypes.k_fold_cross_validation,
140139
CrossValTypes.shuffle_split_cross_validation
141140
)
142-
self.holdout_validators = HoldOutFuncs.get_holdout_validators(
141+
self.holdout_validators = HoldoutValTypes.get_holdout_validators(
143142
HoldoutValTypes.holdout_validation
144143
)
145144

0 commit comments

Comments
 (0)