Skip to content

Commit 98c1b6f

Browse files
committed
Rebase merge
1 parent 117d5db commit 98c1b6f

File tree

4 files changed

+94
-133
lines changed

4 files changed

+94
-133
lines changed

autosklearn/automl.py

Lines changed: 17 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,8 @@ def _model_predict(
199199
class AutoML(BaseEstimator):
200200
"""Base class for handling the AutoML procedure"""
201201

202+
_task_mapping: dict[str, int]
203+
202204
def __init__(
203205
self,
204206
time_left_for_this_task: int,
@@ -243,12 +245,10 @@ def __init__(
243245
)
244246

245247
# Validate dataset_compression and set its values
246-
self._dataset_compression: DatasetCompressionSpec | None
248+
self._dataset_compression: DatasetCompressionSpec | None = None
247249
if isinstance(dataset_compression, bool):
248250
if dataset_compression is True:
249251
self._dataset_compression = default_dataset_compression_arg
250-
else:
251-
self._dataset_compression = None
252252
else:
253253
self._dataset_compression = validate_dataset_compression_arg(
254254
dataset_compression,
@@ -307,18 +307,18 @@ def __init__(
307307

308308
# Create the backend
309309
self._backend: Backend = create(
310-
temporary_directory=temporary_directory,
310+
# TODO update backend as this does accept optional str
311+
temporary_directory=temporary_directory, # type: ignore
311312
output_directory=None,
312313
prefix="auto-sklearn",
313314
delete_output_folder_after_terminate=delete_tmp_folder_after_terminate,
314315
)
315316

316-
self._data_memory_limit = None # TODO: dead variable? Always None
317317
self._datamanager = None
318318
self._dataset_name = None
319319
self._feat_type = None
320320
self._logger: PicklableClientLogger | None = None
321-
self._task = None
321+
self._task: int | None = None
322322
self._label_num = None
323323
self._parser = None
324324
self._can_predict = False
@@ -330,10 +330,11 @@ def __init__(
330330
self.InputValidator: InputValidator | None = None
331331
self.configuration_space = None
332332

333-
# The ensemble performance history through time
334333
self._stopwatch = StopWatch()
335334
self._logger_port = logging.handlers.DEFAULT_TCP_LOGGING_PORT
336-
self.ensemble_performance_history = []
335+
336+
# The ensemble performance history through time
337+
self.ensemble_performance_history: list[dict[str, Any]] = []
337338

338339
# Num_run tell us how many runs have been launched. It can be seen as an
339340
# identifier for each configuration saved to disk
@@ -480,14 +481,6 @@ def _do_dummy_prediction(self) -> None:
480481

481482
return
482483

483-
@classmethod
484-
def _task_type_id(cls, task_type: str) -> int:
485-
raise NotImplementedError
486-
487-
@classmethod
488-
def _supports_task_type(cls, task_type: str) -> bool:
489-
raise NotImplementedError
490-
491484
def fit(
492485
self,
493486
X: SUPPORTED_FEAT_TYPES,
@@ -594,16 +587,12 @@ def fit(
594587
y = convert_if_sparse(y)
595588
y_test = convert_if_sparse(y_test) if y_test is not None else None
596589

597-
# Get the task if it doesn't exist
598590
if task is None:
599-
y_task = type_of_target(y)
600-
if not self._supports_task_type(y_task):
601-
raise ValueError(
602-
f"{self.__class__.__name__} does not support" f" task {y_task}"
603-
)
604-
self._task = self._task_type_id(y_task)
605-
else:
606-
self._task = task
591+
task = self._task_mapping.get(type_of_target(y), None)
592+
if task is None:
593+
raise ValueError(f"{self.__class__.__name__} does not support {task}")
594+
595+
self._task = task
607596

608597
# Assign a metric if it doesnt exist
609598
if self._metrics is None:
@@ -613,9 +602,6 @@ def fit(
613602
if dataset_name is None:
614603
dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))
615604

616-
# By default try to use the TCP logging port or get a new port
617-
self._logger_port = logging.handlers.DEFAULT_TCP_LOGGING_PORT
618-
619605
# Once we start the logging server, it starts in a new process
620606
# If an error occurs then we want to make sure that we exit cleanly
621607
# and shut it down, else it might hang
@@ -1272,11 +1258,9 @@ def fit_pipeline(
12721258
# Get the task if it doesn't exist
12731259
if task is None:
12741260
y_task = type_of_target(y)
1275-
if not self._supports_task_type(y_task):
1276-
raise ValueError(
1277-
f"{self.__class__.__name__} does not support" f" task {y_task}"
1278-
)
1279-
self._task = self._task_type_id(y_task)
1261+
self._task = self._task_mapping.get(y_task, None)
1262+
if self._task is None:
1263+
raise ValueError(f"{self.__class__.__name__} does not support {y_task}")
12801264
else:
12811265
self._task = task
12821266

@@ -2271,14 +2255,6 @@ class AutoMLClassifier(AutoML):
22712255
"binary": BINARY_CLASSIFICATION,
22722256
}
22732257

2274-
@classmethod
2275-
def _task_type_id(cls, task_type: str) -> int:
2276-
return cls._task_mapping[task_type]
2277-
2278-
@classmethod
2279-
def _supports_task_type(cls, task_type: str) -> bool:
2280-
return task_type in cls._task_mapping.keys()
2281-
22822258
def fit(
22832259
self,
22842260
X: SUPPORTED_FEAT_TYPES,
@@ -2361,14 +2337,6 @@ class AutoMLRegressor(AutoML):
23612337
"multiclass": REGRESSION,
23622338
}
23632339

2364-
@classmethod
2365-
def _task_type_id(cls, task_type: str) -> int:
2366-
return cls._task_mapping[task_type]
2367-
2368-
@classmethod
2369-
def _supports_task_type(cls, task_type: str) -> bool:
2370-
return task_type in cls._task_mapping.keys()
2371-
23722340
def fit(
23732341
self,
23742342
X: SUPPORTED_FEAT_TYPES,

autosklearn/estimators.py

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from abc import ABC, abstractmethod
3+
from abc import ABC
44
from typing import Any, Generic, Iterable, Sequence, TypeVar
55

66
import warnings
@@ -21,7 +21,7 @@
2121
)
2222
from sklearn.utils.multiclass import type_of_target
2323
from smac.runhistory.runhistory import RunInfo, RunValue
24-
from typing_extensions import Literal, TypeAlias
24+
from typing_extensions import Literal
2525

2626
from autosklearn.automl import AutoML, AutoMLClassifier, AutoMLRegressor
2727
from autosklearn.data.validation import convert_if_sparse
@@ -31,34 +31,23 @@
3131
from autosklearn.pipeline.base import BasePipeline
3232
from autosklearn.util.smac_wrap import SMACCallback
3333

34-
# Used to indicate what type the underlying AutoML instance is
35-
TAutoML = TypeVar("TAutoML", bound=AutoML)
36-
TParetoModel = TypeVar("TParetoModel", VotingClassifier, VotingRegressor)
37-
3834
# Used to return self and give correct type information from subclasses,
3935
# see `fit(self: Self) -> Self`
4036
Self = TypeVar("Self", bound="AutoSklearnEstimator")
4137

42-
ResampleOptions: TypeAlias = Literal[
43-
"holdout",
44-
"cv",
45-
"holdout-iterative-fit",
46-
"cv-iterative-fit",
47-
"partial-cv",
48-
]
49-
DisableEvaluatorOptions: TypeAlias = Literal["y_optimization", "model"]
50-
38+
# Used to indicate what type the underlying AutoML instance is
39+
TParetoModel = TypeVar("TParetoModel", VotingClassifier, VotingRegressor)
40+
TAutoML = TypeVar("TAutoML", bound=AutoML)
5141

52-
class AutoSklearnEstimator(ABC, Generic[TAutoML, TParetoModel], BaseEstimator):
5342

54-
# List of target types supported by the estimator class
55-
supported_target_types: list[str]
43+
class AutoSklearnEstimator(ABC, BaseEstimator, Generic[TAutoML, TParetoModel]):
5644

57-
# The automl class used by the estimator class
58-
_automl_class: type[TAutoML]
45+
supported_target_types: list[str] # Support output types for the estimator
46+
_automl_class: type[TAutoML] # The automl class used by the estimator class
5947

6048
def __init__(
6149
self,
50+
*,
6251
time_left_for_this_task: int = 3600,
6352
per_run_time_limit: int | None = None, # TODO: allow percentage
6453
initial_configurations_via_metalearning: int = 25, # TODO validate
@@ -71,7 +60,9 @@ def __init__(
7160
memory_limit: int | None = 3072,
7261
include: dict[str, list[str]] | None = None,
7362
exclude: dict[str, list[str]] | None = None,
74-
resampling_strategy: ResampleOptions
63+
resampling_strategy: Literal[
64+
"holdout", "cv", "holdout-iterative-fit", "cv-iterative-fit", "partial-cv"
65+
]
7566
| BaseCrossValidator
7667
| _RepeatedSplits
7768
| BaseShuffleSplit = "holdout",
@@ -81,7 +72,7 @@ def __init__(
8172
n_jobs: int = 1,
8273
dask_client: dask.distributed.Client | None = None,
8374
disable_evaluator_output: bool
84-
| Sequence[DisableEvaluatorOptions] = False, # TODO fill in
75+
| Sequence[Literal["y_optimization", "model"]] = False, # TODO: fill in
8576
get_smac_object_callback: SMACCallback | None = None,
8677
smac_scenario_args: dict[str, Any] | None = None,
8778
logging_config: dict[str, Any] | None = None,
@@ -490,7 +481,7 @@ def __init__(
490481
self.allow_string_features = allow_string_features
491482

492483
# Cached
493-
self.automl_: AutoML | None = None
484+
self.automl_: TAutoML | None = None
494485

495486
# Handle the number of jobs and the time for them
496487
# Made private by `_n_jobs` to keep with sklearn compliance
@@ -504,21 +495,19 @@ def __init__(
504495
self.per_run_time_limit = self._n_jobs * self.time_left_for_this_task // 10
505496

506497
@property
507-
@abstractmethod
508498
def automl(self) -> TAutoML:
509499
"""Get the underlying Automl instance
510500
511501
Returns
512502
-------
513503
AutoML
514-
The underlying AutoML instanec
504+
The underlying AutoML instance
515505
"""
516506
if self.automl_ is not None:
517507
return self.automl_
518508

519509
initial_configs = self.initial_configurations_via_metalearning
520-
cls = self._get_automl_class()
521-
automl = cls(
510+
automl = self._automl_class(
522511
temporary_directory=self.tmp_folder,
523512
delete_tmp_folder_after_terminate=self.delete_tmp_folder_after_terminate,
524513
time_left_for_this_task=self.time_left_for_this_task,
@@ -568,16 +557,14 @@ def ensemble(self) -> AbstractEnsemble:
568557
NotFittedError
569558
If there this estimator has not been fitted
570559
"""
571-
572-
def __getstate__(self) -> dict[str, Any]:
573-
# Cannot serialize a client!
574-
self.dask_client = None
575-
return self.__dict__
560+
# TODO
561+
raise NotImplementedError()
576562

577563
def fit(
578564
self: Self,
579565
X: np.ndarray | pd.DataFrame | list | spmatrix,
580566
y: np.ndarray | pd.DataFrame | pd.Series | list,
567+
*,
581568
X_test: np.ndarray | pd.DataFrame | list | spmatrix | None = None,
582569
y_test: np.ndarray | pd.DataFrame | pd.Series | list | None = None,
583570
feat_type: list[str] | None = None,
@@ -697,6 +684,7 @@ def fit_pipeline(
697684
self,
698685
X: np.ndarray | pd.DataFrame | list | spmatrix,
699686
y: np.ndarray | pd.DataFrame | pd.Series | list,
687+
*,
700688
config: Configuration | dict[str, Any],
701689
dataset_name: str | None = None,
702690
X_test: np.ndarray | pd.DataFrame | list | spmatrix | None = None,
@@ -767,6 +755,7 @@ def fit_pipeline(
767755
def fit_ensemble(
768756
self: Self,
769757
y: np.ndarray | pd.DataFrame | pd.Series | list,
758+
*,
770759
task: int | None = None,
771760
precision: Literal[16, 32, 64] = 32,
772761
dataset_name: str | None = None,
@@ -913,6 +902,7 @@ def refit(
913902
def predict(
914903
self,
915904
X: np.ndarray | pd.DataFrame | list | spmatrix,
905+
*,
916906
batch_size: int | None = None,
917907
n_jobs: int = 1,
918908
) -> np.ndarray:
@@ -1088,6 +1078,7 @@ def sprint_statistics(self) -> str:
10881078

10891079
def leaderboard(
10901080
self,
1081+
*,
10911082
detailed: bool = False,
10921083
ensemble_only: bool = True,
10931084
top_k: int | Literal["all"] = "all",
@@ -1501,6 +1492,7 @@ def get_configuration_space(
15011492
self,
15021493
X: np.ndarray | pd.DataFrame | list | spmatrix,
15031494
y: np.ndarray | pd.DataFrame | pd.Series | list,
1495+
*,
15041496
X_test: np.ndarray | pd.DataFrame | list | spmatrix | None = None,
15051497
y_test: np.ndarray | pd.DataFrame | pd.Series | list | None = None,
15061498
dataset_name: str | None = None,
@@ -1549,6 +1541,11 @@ def get_pareto_set(self) -> Sequence[TParetoModel]:
15491541
"""
15501542
return self.automl._load_pareto_set()
15511543

1544+
def __getstate__(self) -> dict[str, Any]:
1545+
# Cannot serialize a client!
1546+
self.dask_client = None
1547+
return self.__dict__
1548+
15521549
def __sklearn_is_fitted__(self) -> bool:
15531550
return self.automl_ is not None and self.automl.fitted
15541551

0 commit comments

Comments
 (0)