@@ -234,27 +234,22 @@ def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]
234
234
self ._check_resampling_strategy_args ()
235
235
236
236
labels_to_stratify = self .train_tensors [- 1 ] if self .is_stratify else None
237
+ kwargs = {}
238
+ kwargs .update (
239
+ random_state = self .random_state ,
240
+ shuffle = self .shuffle_split ,
241
+ indices = self ._get_indices (),
242
+ labels_to_stratify = labels_to_stratify
243
+ )
237
244
238
245
if isinstance (self .resampling_strategy , HoldoutValTypes ):
239
246
val_share = self .resampling_strategy_args .get ('val_share' , None )
247
+ return self .resampling_strategy (val_share = val_share , ** kwargs )
240
248
241
- return self .resampling_strategy (
242
- random_state = self .random_state ,
243
- val_share = val_share ,
244
- shuffle = self .shuffle_split ,
245
- indices = self ._get_indices (),
246
- labels_to_stratify = labels_to_stratify
247
- )
248
249
elif isinstance (self .resampling_strategy , CrossValTypes ):
249
250
num_splits = self .resampling_strategy_args .get ('num_splits' , None )
251
+ return self .resampling_strategy (num_splits = num_splits , ** kwargs )
250
252
251
- return self .resampling_strategy (
252
- random_state = self .random_state ,
253
- num_splits = num_splits ,
254
- shuffle = self .shuffle_split ,
255
- indices = self ._get_indices (),
256
- labels_to_stratify = labels_to_stratify
257
- )
258
253
else :
259
254
raise ValueError (f"Unsupported resampling strategy={ self .resampling_strategy } " )
260
255
0 commit comments