1
- from enum import IntEnum
1
+ from enum import Enum
2
+ from functools import partial
2
3
from typing import Any , Dict , List , Optional , Tuple , Union
3
4
4
5
import numpy as np
17
18
18
19
# Use callback protocol as workaround, since callable with function fields count 'self' as argument
19
20
class CrossValFunc (Protocol ):
21
+ """TODO: This class is not required anymore, because CrossValTypes class does not require get_validators()"""
20
22
def __call__ (self ,
21
23
num_splits : int ,
22
24
indices : np .ndarray ,
23
25
stratify : Optional [Any ]) -> List [Tuple [np .ndarray , np .ndarray ]]:
24
26
...
25
27
26
28
27
- class HoldOutFunc (Protocol ):
29
+ class HoldoutValFunc (Protocol ):
28
30
def __call__ (self , val_share : float , indices : np .ndarray , stratify : Optional [Any ]
29
31
) -> Tuple [np .ndarray , np .ndarray ]:
30
32
...
31
33
32
34
33
- class CrossValTypes (IntEnum ):
34
- """The type of cross validation
35
-
36
- This class is used to specify the cross validation function
37
- and is not supposed to be instantiated.
38
-
39
- Examples: This class is supposed to be used as follows
40
- >>> cv_type = CrossValTypes.k_fold_cross_validation
41
- >>> print(cv_type.name)
42
-
43
- k_fold_cross_validation
44
-
45
- >>> for cross_val_type in CrossValTypes:
46
- print(cross_val_type.name, cross_val_type.value)
47
-
48
- stratified_k_fold_cross_validation 1
49
- k_fold_cross_validation 2
50
- stratified_shuffle_split_cross_validation 3
51
- shuffle_split_cross_validation 4
52
- time_series_cross_validation 5
53
- """
54
- stratified_k_fold_cross_validation = 1
55
- k_fold_cross_validation = 2
56
- stratified_shuffle_split_cross_validation = 3
57
- shuffle_split_cross_validation = 4
58
- time_series_cross_validation = 5
59
-
60
- def is_stratified (self ) -> bool :
61
- stratified = [self .stratified_k_fold_cross_validation ,
62
- self .stratified_shuffle_split_cross_validation ]
63
- return getattr (self , self .name ) in stratified
64
-
65
-
66
- class HoldoutValTypes (IntEnum ):
67
- """TODO: change to enum using functools.partial"""
68
- """The type of hold out validation (refer to CrossValTypes' doc-string)"""
69
- holdout_validation = 6
70
- stratified_holdout_validation = 7
71
-
72
- def is_stratified (self ) -> bool :
73
- stratified = [self .stratified_holdout_validation ]
74
- return getattr (self , self .name ) in stratified
75
-
76
-
77
- """TODO: deprecate soon"""
78
- RESAMPLING_STRATEGIES = [CrossValTypes , HoldoutValTypes ]
79
-
80
- """TODO: deprecate soon"""
81
- DEFAULT_RESAMPLING_PARAMETERS = {
82
- HoldoutValTypes .holdout_validation : {
83
- 'val_share' : 0.33 ,
84
- },
85
- HoldoutValTypes .stratified_holdout_validation : {
86
- 'val_share' : 0.33 ,
87
- },
88
- CrossValTypes .k_fold_cross_validation : {
89
- 'num_splits' : 3 ,
90
- },
91
- CrossValTypes .stratified_k_fold_cross_validation : {
92
- 'num_splits' : 3 ,
93
- },
94
- CrossValTypes .shuffle_split_cross_validation : {
95
- 'num_splits' : 3 ,
96
- },
97
- CrossValTypes .time_series_cross_validation : {
98
- 'num_splits' : 3 ,
99
- },
100
- } # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]]
101
-
102
-
103
- class HoldOutFuncs ():
35
+ class HoldoutValFuncs ():
104
36
@staticmethod
105
- def holdout_validation (val_share : float , indices : np .ndarray , ** kwargs : Any ) -> Tuple [np .ndarray , np .ndarray ]:
37
+ def holdout_validation (val_share : float , indices : np .ndarray , stratify : Optional [Any ] = None ) \
38
+ -> Tuple [np .ndarray , np .ndarray ]:
106
39
train , val = train_test_split (indices , test_size = val_share , shuffle = False )
107
40
return train , val
108
41
109
42
@staticmethod
110
- def stratified_holdout_validation (val_share : float , indices : np .ndarray , ** kwargs : Any ) \
43
+ def stratified_holdout_validation (val_share : float , indices : np .ndarray , stratify : Optional [ Any ] = None ) \
111
44
-> Tuple [np .ndarray , np .ndarray ]:
112
- train , val = train_test_split (indices , test_size = val_share , shuffle = False , stratify = kwargs [ " stratify" ] )
45
+ train , val = train_test_split (indices , test_size = val_share , shuffle = False , stratify = stratify )
113
46
return train , val
114
47
115
- @classmethod
116
- def get_holdout_validators (cls , * holdout_val_types : Tuple [HoldoutValTypes ]) -> Dict [str , HoldOutFunc ]:
117
-
118
- holdout_validators = {
119
- holdout_val_type .name : getattr (cls , holdout_val_type .name )
120
- for holdout_val_type in holdout_val_types
121
- }
122
- return holdout_validators
123
-
124
48
125
49
class CrossValFuncs ():
126
50
@staticmethod
127
- def shuffle_split_cross_validation (num_splits : int , indices : np .ndarray , ** kwargs : Any ) \
51
+ def shuffle_split_cross_validation (num_splits : int , indices : np .ndarray , stratify : Optional [ Any ] = None ) \
128
52
-> List [Tuple [np .ndarray , np .ndarray ]]:
129
53
cv = ShuffleSplit (n_splits = num_splits )
130
54
splits = list (cv .split (indices ))
131
55
return splits
132
56
133
57
@staticmethod
134
- def stratified_shuffle_split_cross_validation (num_splits : int , indices : np .ndarray , ** kwargs : Any ) \
58
+ def stratified_shuffle_split_cross_validation (num_splits : int , indices : np .ndarray ,
59
+ stratify : Optional [Any ] = None ) \
135
60
-> List [Tuple [np .ndarray , np .ndarray ]]:
136
61
cv = StratifiedShuffleSplit (n_splits = num_splits )
137
- splits = list (cv .split (indices , kwargs [ " stratify" ] ))
62
+ splits = list (cv .split (indices , stratify ))
138
63
return splits
139
64
140
65
@staticmethod
141
- def stratified_k_fold_cross_validation (num_splits : int , indices : np .ndarray , ** kwargs : Any ) \
66
+ def stratified_k_fold_cross_validation (num_splits : int , indices : np .ndarray , stratify : Optional [ Any ] = None ) \
142
67
-> List [Tuple [np .ndarray , np .ndarray ]]:
143
68
cv = StratifiedKFold (n_splits = num_splits )
144
- splits = list (cv .split (indices , kwargs [ " stratify" ] ))
69
+ splits = list (cv .split (indices , stratify ))
145
70
return splits
146
71
147
72
@staticmethod
148
- def k_fold_cross_validation (num_splits : int , indices : np .ndarray , ** kwargs : Any ) \
73
+ def k_fold_cross_validation (num_splits : int , indices : np .ndarray , stratify : Optional [ Any ] = None ) \
149
74
-> List [Tuple [np .ndarray , np .ndarray ]]:
150
75
"""
151
76
Standard k fold cross validation.
@@ -159,7 +84,7 @@ def k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any)
159
84
return splits
160
85
161
86
@staticmethod
162
- def time_series_cross_validation (num_splits : int , indices : np .ndarray , ** kwargs : Any ) \
87
+ def time_series_cross_validation (num_splits : int , indices : np .ndarray , stratify : Optional [ Any ] = None ) \
163
88
-> List [Tuple [np .ndarray , np .ndarray ]]:
164
89
"""
165
90
Returns train and validation indices respecting the temporal ordering of the data.
@@ -176,10 +101,96 @@ def time_series_cross_validation(num_splits: int, indices: np.ndarray, **kwargs:
176
101
splits = list (cv .split (indices ))
177
102
return splits
178
103
179
- @classmethod
180
- def get_cross_validators (cls , * cross_val_types : CrossValTypes ) -> Dict [str , CrossValFunc ]:
181
- cross_validators = {
182
- cross_val_type .name : getattr (cls , cross_val_type .name )
183
- for cross_val_type in cross_val_types
184
- }
185
- return cross_validators
104
+
105
+ class CrossValTypes (Enum ):
106
+ """The type of cross validation
107
+
108
+ This class is used to specify the cross validation function
109
+ and is not supposed to be instantiated.
110
+
111
+ Examples: This class is supposed to be used as follows
112
+ >>> cv_type = CrossValTypes.k_fold_cross_validation
113
+ >>> print(cv_type.name)
114
+
115
+ k_fold_cross_validation
116
+
117
+ >>> print(cv_type.value)
118
+
119
+ functools.partial(<function CrossValTypes.k_fold_cross_validation at ...>)
120
+
121
+ >>> for cross_val_type in CrossValTypes:
122
+ print(cross_val_type.name)
123
+
124
+ stratified_k_fold_cross_validation
125
+ k_fold_cross_validation
126
+ stratified_shuffle_split_cross_validation
127
+ shuffle_split_cross_validation
128
+ time_series_cross_validation
129
+
130
+ Additionally, CrossValTypes.<function> can be called directly.
131
+ """
132
+ stratified_k_fold_cross_validation = partial (CrossValFuncs .stratified_k_fold_cross_validation )
133
+ k_fold_cross_validation = partial (CrossValFuncs .k_fold_cross_validation )
134
+ stratified_shuffle_split_cross_validation = partial (CrossValFuncs .stratified_shuffle_split_cross_validation )
135
+ shuffle_split_cross_validation = partial (CrossValFuncs .shuffle_split_cross_validation )
136
+ time_series_cross_validation = partial (CrossValFuncs .time_series_cross_validation )
137
+
138
+ def is_stratified (self ) -> bool :
139
+ stratified = [self .stratified_k_fold_cross_validation ,
140
+ self .stratified_shuffle_split_cross_validation ]
141
+ return getattr (self , self .name ) in stratified
142
+
143
+ def __call__ (self , num_splits : int , indices : np .ndarray , stratify : Optional [Any ]
144
+ ) -> Tuple [np .ndarray , np .ndarray ]:
145
+ """TODO: doc-string and test files"""
146
+ self .value (num_splits = num_splits , indices = indices , stratify = stratify )
147
+
148
+ @staticmethod
149
+ def get_validators (* choices : CrossValFunc ):
150
+ """TODO: to be compatible, it is here now, but will be deprecated soon."""
151
+ return {choice .name : choice .value for choice in choices }
152
+
153
+
154
+ class HoldoutValTypes (Enum ):
155
+ """The type of hold out validation (refer to CrossValTypes' doc-string)"""
156
+ holdout_validation = partial (HoldoutValFuncs .holdout_validation )
157
+ stratified_holdout_validation = partial (HoldoutValFuncs .stratified_holdout_validation )
158
+
159
+ def is_stratified (self ) -> bool :
160
+ stratified = [self .stratified_holdout_validation ]
161
+ return getattr (self , self .name ) in stratified
162
+
163
+ def __call__ (self , val_share : float , indices : np .ndarray , stratify : Optional [Any ]
164
+ ) -> Tuple [np .ndarray , np .ndarray ]:
165
+ self .value (val_share = val_share , indices = indices , stratify = stratify )
166
+
167
+ @staticmethod
168
+ def get_validators (* choices : HoldoutValFunc ):
169
+ """TODO: to be compatible, it is here now, but will be deprecated soon."""
170
+ return {choice .name : choice .value for choice in choices }
171
+
172
+
173
+ """TODO: deprecate soon (Will rename CrossValTypes -> CrossValFunc)"""
174
+ RESAMPLING_STRATEGIES = [CrossValTypes , HoldoutValTypes ]
175
+
176
+ """TODO: deprecate soon"""
177
+ DEFAULT_RESAMPLING_PARAMETERS = {
178
+ HoldoutValTypes .holdout_validation : {
179
+ 'val_share' : 0.33 ,
180
+ },
181
+ HoldoutValTypes .stratified_holdout_validation : {
182
+ 'val_share' : 0.33 ,
183
+ },
184
+ CrossValTypes .k_fold_cross_validation : {
185
+ 'num_splits' : 3 ,
186
+ },
187
+ CrossValTypes .stratified_k_fold_cross_validation : {
188
+ 'num_splits' : 3 ,
189
+ },
190
+ CrossValTypes .shuffle_split_cross_validation : {
191
+ 'num_splits' : 3 ,
192
+ },
193
+ CrossValTypes .time_series_cross_validation : {
194
+ 'num_splits' : 3 ,
195
+ },
196
+ } # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]]
0 commit comments