1
1
from abc import ABCMeta
2
- from typing import Any , Dict , List , Optional , Sequence , Tuple , Union , cast
2
+ from typing import Any , Dict , List , Optional , Sequence , Tuple , Union
3
3
4
4
import numpy as np
5
5
13
13
14
14
from autoPyTorch .constants import CLASSIFICATION_OUTPUTS , STRING_TO_OUTPUT_TYPES
15
15
from autoPyTorch .datasets .resampling_strategy import (
16
- CROSS_VAL_FN ,
17
16
CrossValTypes ,
18
17
DEFAULT_RESAMPLING_PARAMETERS ,
19
- HOLDOUT_FN ,
20
- HoldoutValTypes ,
21
- get_cross_validators ,
22
- get_holdout_validators ,
23
- is_stratified ,
18
+ HoldoutValTypes
24
19
)
25
20
from autoPyTorch .utils .common import FitRequirement , hash_array_or_matrix
26
21
@@ -112,8 +107,6 @@ def __init__(
112
107
if not hasattr (train_tensors [0 ], 'shape' ):
113
108
type_check (train_tensors , val_tensors )
114
109
self .train_tensors , self .val_tensors , self .test_tensors = train_tensors , val_tensors , test_tensors
115
- self .cross_validators : Dict [str , CROSS_VAL_FN ] = {}
116
- self .holdout_validators : Dict [str , HOLDOUT_FN ] = {}
117
110
self .rng = np .random .RandomState (seed = seed )
118
111
self .shuffle = shuffle
119
112
self .resampling_strategy = resampling_strategy
@@ -133,9 +126,6 @@ def __init__(
133
126
# TODO: Look for a criteria to define small enough to preprocess
134
127
self .is_small_preprocess = True
135
128
136
- # Make sure cross validation splits are created once
137
- self .cross_validators = get_cross_validators (* CrossValTypes )
138
- self .holdout_validators = get_holdout_validators (* HoldoutValTypes )
139
129
self .splits = self .get_splits_from_resampling_strategy ()
140
130
141
131
# We also need to be able to transform the data, be it for pre-processing
@@ -203,106 +193,82 @@ def __len__(self) -> int:
203
193
def _get_indices (self ) -> np .ndarray :
204
194
return self .rng .permutation (len (self )) if self .shuffle else np .arange (len (self ))
205
195
196
+ def _process_resampling_strategy_args (self ) -> None :
197
+ """TODO: Refactor this function after introducing BaseDict"""
198
+
199
+ if not any (isinstance (self .resampling_strategy , val_type )
200
+ for val_type in [HoldoutValTypes , CrossValTypes ]):
201
+ raise ValueError (f"resampling_strategy { self .resampling_strategy } is not supported." )
202
+
203
+ if self .splitting_params is not None and \
204
+ not isinstance (self .resampling_strategy_args , dict ):
205
+
206
+ raise TypeError ("resampling_strategy_args must be dict or None,"
207
+ f" but got { type (self .resampling_strategy_args )} " )
208
+
209
+ if self .resampling_strategy_args is None :
210
+ self .resampling_strategy_args = {}
211
+
212
+ if isinstance (self .resampling_strategy , HoldoutValTypes ):
213
+ val_share = DEFAULT_RESAMPLING_PARAMETERS [self .resampling_strategy ].get (
214
+ 'val_share' , None )
215
+ self .resampling_strategy_args ['val_share' ] = val_share
216
+ elif isinstance (self .splitting_type , CrossValTypes ):
217
+ num_splits = DEFAULT_RESAMPLING_PARAMETERS [self .resampling_strategy ].get (
218
+ 'num_splits' , None )
219
+ self .resampling_strategy_args ['num_splits' ] = num_splits
220
+
221
+ """Comment: Do we need this raise Error?"""
222
+ if self .val_tensors is not None : # if we need it, we should share it with cross val as well
223
+ raise ValueError ('`val_share` specified, but the Dataset was'
224
+ ' a given a pre-defined split at initialization already.' )
225
+
226
+ val_share = self .resampling_strategy_args .get ('val_share' , None )
227
+ num_splits = self .resampling_strategy_args .get ('num_splits' , None )
228
+
229
+ if val_share is not None and (val_share < 0 or val_share > 1 ):
230
+ raise ValueError (f"`val_share` must be between 0 and 1, got { val_share } ." )
231
+
232
+ if num_splits is not None :
233
+ if num_splits <= 0 :
234
+ raise ValueError (f"`num_splits` must be a positive integer, got { num_splits } ." )
235
+ elif not isinstance (num_splits , int ):
236
+ raise ValueError (f"`num_splits` must be an integer, got { num_splits } ." )
237
+
206
238
def get_splits_from_resampling_strategy (self ) -> List [Tuple [List [int ], List [int ]]]:
207
239
"""
208
240
Creates a set of splits based on a resampling strategy provided
209
241
210
242
Returns
211
243
(List[Tuple[List[int], List[int]]]): splits in the [train_indices, val_indices] format
212
244
"""
213
- splits = []
214
- if isinstance (self .resampling_strategy , HoldoutValTypes ):
215
- val_share = DEFAULT_RESAMPLING_PARAMETERS [self .resampling_strategy ].get (
216
- 'val_share' , None )
217
- if self .resampling_strategy_args is not None :
218
- val_share = self .resampling_strategy_args .get ('val_share' , val_share )
219
- splits .append (
220
- self .create_holdout_val_split (
221
- holdout_val_type = self .resampling_strategy ,
222
- val_share = val_share ,
223
- )
224
- )
225
- elif isinstance (self .resampling_strategy , CrossValTypes ):
226
- num_splits = DEFAULT_RESAMPLING_PARAMETERS [self .resampling_strategy ].get (
227
- 'num_splits' , None )
228
- if self .resampling_strategy_args is not None :
229
- num_splits = self .resampling_strategy_args .get ('num_splits' , num_splits )
230
- # Create the split if it was not created before
231
- splits .extend (
232
- self .create_cross_val_splits (
233
- cross_val_type = self .resampling_strategy ,
234
- num_splits = cast (int , num_splits ),
235
- )
236
- )
237
- else :
238
- raise ValueError (f"Unsupported resampling strategy={ self .resampling_strategy } " )
239
- return splits
240
245
241
- def create_cross_val_splits (
242
- self ,
243
- cross_val_type : CrossValTypes ,
244
- num_splits : int
245
- ) -> List [Tuple [Union [List [int ], np .ndarray ], Union [List [int ], np .ndarray ]]]:
246
- """
247
- This function creates the cross validation split for the given task.
246
+ # check if the requirements are met and if we can get splits
247
+ self ._process_resampling_strategy_args ()
248
248
249
- It is done once per dataset to have comparable results among pipelines
250
- Args:
251
- cross_val_type (CrossValTypes):
252
- num_splits (int): number of splits to be created
253
-
254
- Returns:
255
- (List[Tuple[Union[List[int], np.ndarray], Union[List[int], np.ndarray]]]):
256
- list containing 'num_splits' splits.
257
- """
258
- # Create just the split once
259
- # This is gonna be called multiple times, because the current dataset
260
- # is being used for multiple pipelines. That is, to be efficient with memory
261
- # we dump the dataset to memory and read it on a need basis. So this function
262
- # should be robust against multiple calls, and it does so by remembering the splits
263
- if not isinstance (cross_val_type , CrossValTypes ):
264
- raise NotImplementedError (f'The selected `cross_val_type` "{ cross_val_type } " is not implemented.' )
265
249
kwargs = {}
266
- if is_stratified (cross_val_type ):
250
+ if self . resampling_strategy . is_stratified ():
267
251
# we need additional information about the data for stratification
268
252
kwargs ["stratify" ] = self .train_tensors [- 1 ]
269
- splits = self .cross_validators [cross_val_type .name ](
270
- num_splits , self ._get_indices (), ** kwargs )
271
- return splits
272
253
273
- def create_holdout_val_split (
274
- self ,
275
- holdout_val_type : HoldoutValTypes ,
276
- val_share : float ,
277
- ) -> Tuple [np .ndarray , np .ndarray ]:
278
- """
279
- This function creates the holdout split for the given task.
254
+ if isinstance (self .resampling_strategy , HoldoutValTypes ):
255
+ val_share = self .resampling_strategy_args ['val_share' ]
280
256
281
- It is done once per dataset to have comparable results among pipelines
282
- Args:
283
- holdout_val_type (HoldoutValTypes):
284
- val_share (float): share of the validation data
257
+ return self .resampling_strategy (
258
+ val_share = val_share ,
259
+ indices = self ._get_indices (),
260
+ ** kwargs
261
+ )
262
+ elif isinstance (self .resampling_strategy , CrossValTypes ):
263
+ num_splits = self .resampling_strategy_args ['num_splits' ]
285
264
286
- Returns:
287
- (Tuple[np.ndarray, np.ndarray]): Tuple containing (train_indices, val_indices)
288
- """
289
- if holdout_val_type is None :
290
- raise ValueError (
291
- '`val_share` specified, but `holdout_val_type` not specified.'
265
+ return self .create_cross_val_splits (
266
+ num_splits = int (num_splits ),
267
+ indices = self ._get_indices (),
268
+ ** kwargs
292
269
)
293
- if self .val_tensors is not None :
294
- raise ValueError (
295
- '`val_share` specified, but the Dataset was a given a pre-defined split at initialization already.' )
296
- if val_share < 0 or val_share > 1 :
297
- raise ValueError (f"`val_share` must be between 0 and 1, got { val_share } ." )
298
- if not isinstance (holdout_val_type , HoldoutValTypes ):
299
- raise NotImplementedError (f'The specified `holdout_val_type` "{ holdout_val_type } " is not supported.' )
300
- kwargs = {}
301
- if is_stratified (holdout_val_type ):
302
- # we need additional information about the data for stratification
303
- kwargs ["stratify" ] = self .train_tensors [- 1 ]
304
- train , val = self .holdout_validators [holdout_val_type .name ](val_share , self ._get_indices (), ** kwargs )
305
- return train , val
270
+ else :
271
+ raise ValueError (f"Unsupported resampling strategy={ self .resampling_strategy } " )
306
272
307
273
def get_dataset_for_training (self , split_id : int ) -> Tuple [Dataset , Dataset ]:
308
274
"""
0 commit comments