Skip to content

Commit a7e8a7f

Browse files
committed
[refactor] Remove get_cross_validators and get_holdout_validators
Since we can call each split function directly from CrossValTypes and HoldoutValTypes. I removed these two functions.
1 parent ef6acf2 commit a7e8a7f

File tree

3 files changed

+298
-226
lines changed

3 files changed

+298
-226
lines changed

autoPyTorch/datasets/base_dataset.py

Lines changed: 62 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABCMeta
2-
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
2+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
33

44
import numpy as np
55

@@ -13,14 +13,9 @@
1313

1414
from autoPyTorch.constants import CLASSIFICATION_OUTPUTS, STRING_TO_OUTPUT_TYPES
1515
from autoPyTorch.datasets.resampling_strategy import (
16-
CROSS_VAL_FN,
1716
CrossValTypes,
1817
DEFAULT_RESAMPLING_PARAMETERS,
19-
HOLDOUT_FN,
20-
HoldoutValTypes,
21-
get_cross_validators,
22-
get_holdout_validators,
23-
is_stratified,
18+
HoldoutValTypes
2419
)
2520
from autoPyTorch.utils.common import FitRequirement, hash_array_or_matrix
2621

@@ -112,8 +107,6 @@ def __init__(
112107
if not hasattr(train_tensors[0], 'shape'):
113108
type_check(train_tensors, val_tensors)
114109
self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors
115-
self.cross_validators: Dict[str, CROSS_VAL_FN] = {}
116-
self.holdout_validators: Dict[str, HOLDOUT_FN] = {}
117110
self.rng = np.random.RandomState(seed=seed)
118111
self.shuffle = shuffle
119112
self.resampling_strategy = resampling_strategy
@@ -133,9 +126,6 @@ def __init__(
133126
# TODO: Look for a criteria to define small enough to preprocess
134127
self.is_small_preprocess = True
135128

136-
# Make sure cross validation splits are created once
137-
self.cross_validators = get_cross_validators(*CrossValTypes)
138-
self.holdout_validators = get_holdout_validators(*HoldoutValTypes)
139129
self.splits = self.get_splits_from_resampling_strategy()
140130

141131
# We also need to be able to transform the data, be it for pre-processing
@@ -203,106 +193,82 @@ def __len__(self) -> int:
203193
def _get_indices(self) -> np.ndarray:
204194
return self.rng.permutation(len(self)) if self.shuffle else np.arange(len(self))
205195

196+
def _process_resampling_strategy_args(self) -> None:
197+
"""TODO: Refactor this function after introducing BaseDict"""
198+
199+
if not any(isinstance(self.resampling_strategy, val_type)
200+
for val_type in [HoldoutValTypes, CrossValTypes]):
201+
raise ValueError(f"resampling_strategy {self.resampling_strategy} is not supported.")
202+
203+
if self.splitting_params is not None and \
204+
not isinstance(self.resampling_strategy_args, dict):
205+
206+
raise TypeError("resampling_strategy_args must be dict or None,"
207+
f" but got {type(self.resampling_strategy_args)}")
208+
209+
if self.resampling_strategy_args is None:
210+
self.resampling_strategy_args = {}
211+
212+
if isinstance(self.resampling_strategy, HoldoutValTypes):
213+
val_share = DEFAULT_RESAMPLING_PARAMETERS[self.resampling_strategy].get(
214+
'val_share', None)
215+
self.resampling_strategy_args['val_share'] = val_share
216+
elif isinstance(self.splitting_type, CrossValTypes):
217+
num_splits = DEFAULT_RESAMPLING_PARAMETERS[self.resampling_strategy].get(
218+
'num_splits', None)
219+
self.resampling_strategy_args['num_splits'] = num_splits
220+
221+
"""Comment: Do we need this raise Error?"""
222+
if self.val_tensors is not None: # if we need it, we should share it with cross val as well
223+
raise ValueError('`val_share` specified, but the Dataset was'
224+
' a given a pre-defined split at initialization already.')
225+
226+
val_share = self.resampling_strategy_args.get('val_share', None)
227+
num_splits = self.resampling_strategy_args.get('num_splits', None)
228+
229+
if val_share is not None and (val_share < 0 or val_share > 1):
230+
raise ValueError(f"`val_share` must be between 0 and 1, got {val_share}.")
231+
232+
if num_splits is not None:
233+
if num_splits <= 0:
234+
raise ValueError(f"`num_splits` must be a positive integer, got {num_splits}.")
235+
elif not isinstance(num_splits, int):
236+
raise ValueError(f"`num_splits` must be an integer, got {num_splits}.")
237+
206238
def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]]]:
207239
"""
208240
Creates a set of splits based on a resampling strategy provided
209241
210242
Returns
211243
(List[Tuple[List[int], List[int]]]): splits in the [train_indices, val_indices] format
212244
"""
213-
splits = []
214-
if isinstance(self.resampling_strategy, HoldoutValTypes):
215-
val_share = DEFAULT_RESAMPLING_PARAMETERS[self.resampling_strategy].get(
216-
'val_share', None)
217-
if self.resampling_strategy_args is not None:
218-
val_share = self.resampling_strategy_args.get('val_share', val_share)
219-
splits.append(
220-
self.create_holdout_val_split(
221-
holdout_val_type=self.resampling_strategy,
222-
val_share=val_share,
223-
)
224-
)
225-
elif isinstance(self.resampling_strategy, CrossValTypes):
226-
num_splits = DEFAULT_RESAMPLING_PARAMETERS[self.resampling_strategy].get(
227-
'num_splits', None)
228-
if self.resampling_strategy_args is not None:
229-
num_splits = self.resampling_strategy_args.get('num_splits', num_splits)
230-
# Create the split if it was not created before
231-
splits.extend(
232-
self.create_cross_val_splits(
233-
cross_val_type=self.resampling_strategy,
234-
num_splits=cast(int, num_splits),
235-
)
236-
)
237-
else:
238-
raise ValueError(f"Unsupported resampling strategy={self.resampling_strategy}")
239-
return splits
240245

241-
def create_cross_val_splits(
242-
self,
243-
cross_val_type: CrossValTypes,
244-
num_splits: int
245-
) -> List[Tuple[Union[List[int], np.ndarray], Union[List[int], np.ndarray]]]:
246-
"""
247-
This function creates the cross validation split for the given task.
246+
# check if the requirements are met and if we can get splits
247+
self._process_resampling_strategy_args()
248248

249-
It is done once per dataset to have comparable results among pipelines
250-
Args:
251-
cross_val_type (CrossValTypes):
252-
num_splits (int): number of splits to be created
253-
254-
Returns:
255-
(List[Tuple[Union[List[int], np.ndarray], Union[List[int], np.ndarray]]]):
256-
list containing 'num_splits' splits.
257-
"""
258-
# Create just the split once
259-
# This is gonna be called multiple times, because the current dataset
260-
# is being used for multiple pipelines. That is, to be efficient with memory
261-
# we dump the dataset to memory and read it on a need basis. So this function
262-
# should be robust against multiple calls, and it does so by remembering the splits
263-
if not isinstance(cross_val_type, CrossValTypes):
264-
raise NotImplementedError(f'The selected `cross_val_type` "{cross_val_type}" is not implemented.')
265249
kwargs = {}
266-
if is_stratified(cross_val_type):
250+
if self.resampling_strategy.is_stratified():
267251
# we need additional information about the data for stratification
268252
kwargs["stratify"] = self.train_tensors[-1]
269-
splits = self.cross_validators[cross_val_type.name](
270-
num_splits, self._get_indices(), **kwargs)
271-
return splits
272253

273-
def create_holdout_val_split(
274-
self,
275-
holdout_val_type: HoldoutValTypes,
276-
val_share: float,
277-
) -> Tuple[np.ndarray, np.ndarray]:
278-
"""
279-
This function creates the holdout split for the given task.
254+
if isinstance(self.resampling_strategy, HoldoutValTypes):
255+
val_share = self.resampling_strategy_args['val_share']
280256

281-
It is done once per dataset to have comparable results among pipelines
282-
Args:
283-
holdout_val_type (HoldoutValTypes):
284-
val_share (float): share of the validation data
257+
return self.resampling_strategy(
258+
val_share=val_share,
259+
indices=self._get_indices(),
260+
**kwargs
261+
)
262+
elif isinstance(self.resampling_strategy, CrossValTypes):
263+
num_splits = self.resampling_strategy_args['num_splits']
285264

286-
Returns:
287-
(Tuple[np.ndarray, np.ndarray]): Tuple containing (train_indices, val_indices)
288-
"""
289-
if holdout_val_type is None:
290-
raise ValueError(
291-
'`val_share` specified, but `holdout_val_type` not specified.'
265+
return self.create_cross_val_splits(
266+
num_splits=int(num_splits),
267+
indices=self._get_indices(),
268+
**kwargs
292269
)
293-
if self.val_tensors is not None:
294-
raise ValueError(
295-
'`val_share` specified, but the Dataset was a given a pre-defined split at initialization already.')
296-
if val_share < 0 or val_share > 1:
297-
raise ValueError(f"`val_share` must be between 0 and 1, got {val_share}.")
298-
if not isinstance(holdout_val_type, HoldoutValTypes):
299-
raise NotImplementedError(f'The specified `holdout_val_type` "{holdout_val_type}" is not supported.')
300-
kwargs = {}
301-
if is_stratified(holdout_val_type):
302-
# we need additional information about the data for stratification
303-
kwargs["stratify"] = self.train_tensors[-1]
304-
train, val = self.holdout_validators[holdout_val_type.name](val_share, self._get_indices(), **kwargs)
305-
return train, val
270+
else:
271+
raise ValueError(f"Unsupported resampling strategy={self.resampling_strategy}")
306272

307273
def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]:
308274
"""

0 commit comments

Comments
 (0)