Skip to content

Commit 76b4826

Browse files
committed
[refactor] Gether kwargs for get splits for CV and Holdout
1 parent e37cfbc commit 76b4826

File tree

1 file changed

+9
-14
lines changed

1 file changed

+9
-14
lines changed

autoPyTorch/datasets/base_dataset.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -234,27 +234,22 @@ def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]
234234
self._check_resampling_strategy_args()
235235

236236
labels_to_stratify = self.train_tensors[-1] if self.is_stratify else None
237+
kwargs = {}
238+
kwargs.update(
239+
random_state=self.random_state,
240+
shuffle=self.shuffle_split,
241+
indices=self._get_indices(),
242+
labels_to_stratify=labels_to_stratify
243+
)
237244

238245
if isinstance(self.resampling_strategy, HoldoutValTypes):
239246
val_share = self.resampling_strategy_args.get('val_share', None)
247+
return self.resampling_strategy(val_share=val_share, **kwargs)
240248

241-
return self.resampling_strategy(
242-
random_state=self.random_state,
243-
val_share=val_share,
244-
shuffle=self.shuffle_split,
245-
indices=self._get_indices(),
246-
labels_to_stratify=labels_to_stratify
247-
)
248249
elif isinstance(self.resampling_strategy, CrossValTypes):
249250
num_splits = self.resampling_strategy_args.get('num_splits', None)
251+
return self.resampling_strategy(num_splits=num_splits, **kwargs)
250252

251-
return self.resampling_strategy(
252-
random_state=self.random_state,
253-
num_splits=num_splits,
254-
shuffle=self.shuffle_split,
255-
indices=self._get_indices(),
256-
labels_to_stratify=labels_to_stratify
257-
)
258253
else:
259254
raise ValueError(f"Unsupported resampling strategy={self.resampling_strategy}")
260255

0 commit comments

Comments
 (0)