1
1
from __future__ import annotations
2
2
3
- from abc import ABC , abstractmethod
3
+ from abc import ABC
4
4
from typing import Any , Generic , Iterable , Sequence , TypeVar
5
5
6
6
import warnings
21
21
)
22
22
from sklearn .utils .multiclass import type_of_target
23
23
from smac .runhistory .runhistory import RunInfo , RunValue
24
- from typing_extensions import Literal , TypeAlias
24
+ from typing_extensions import Literal
25
25
26
26
from autosklearn .automl import AutoML , AutoMLClassifier , AutoMLRegressor
27
27
from autosklearn .data .validation import convert_if_sparse
31
31
from autosklearn .pipeline .base import BasePipeline
32
32
from autosklearn .util .smac_wrap import SMACCallback
33
33
34
- # Used to indicate what type the underlying AutoML instance is
35
- TAutoML = TypeVar ("TAutoML" , bound = AutoML )
36
- TParetoModel = TypeVar ("TParetoModel" , VotingClassifier , VotingRegressor )
37
-
38
34
# Used to return self and give correct type information from subclasses,
39
35
# see `fit(self: Self) -> Self`
40
36
Self = TypeVar ("Self" , bound = "AutoSklearnEstimator" )
41
37
42
- ResampleOptions : TypeAlias = Literal [
43
- "holdout" ,
44
- "cv" ,
45
- "holdout-iterative-fit" ,
46
- "cv-iterative-fit" ,
47
- "partial-cv" ,
48
- ]
49
- DisableEvaluatorOptions : TypeAlias = Literal ["y_optimization" , "model" ]
50
-
38
+ # Used to indicate what type the underlying AutoML instance is
39
+ TParetoModel = TypeVar ("TParetoModel" , VotingClassifier , VotingRegressor )
40
+ TAutoML = TypeVar ("TAutoML" , bound = AutoML )
51
41
52
- class AutoSklearnEstimator (ABC , Generic [TAutoML , TParetoModel ], BaseEstimator ):
53
42
54
- # List of target types supported by the estimator class
55
- supported_target_types : list [str ]
43
+ class AutoSklearnEstimator (ABC , BaseEstimator , Generic [TAutoML , TParetoModel ]):
56
44
57
- # The automl class used by the estimator class
58
- _automl_class : type [TAutoML ]
45
+ supported_target_types : list [ str ] # Support output types for the estimator
46
+ _automl_class : type [TAutoML ] # The automl class used by the estimator class
59
47
60
48
def __init__ (
61
49
self ,
50
+ * ,
62
51
time_left_for_this_task : int = 3600 ,
63
52
per_run_time_limit : int | None = None , # TODO: allow percentage
64
53
initial_configurations_via_metalearning : int = 25 , # TODO validate
@@ -71,7 +60,9 @@ def __init__(
71
60
memory_limit : int | None = 3072 ,
72
61
include : dict [str , list [str ]] | None = None ,
73
62
exclude : dict [str , list [str ]] | None = None ,
74
- resampling_strategy : ResampleOptions
63
+ resampling_strategy : Literal [
64
+ "holdout" , "cv" , "holdout-iterative-fit" , "cv-iterative-fit" , "partial-cv"
65
+ ]
75
66
| BaseCrossValidator
76
67
| _RepeatedSplits
77
68
| BaseShuffleSplit = "holdout" ,
@@ -81,7 +72,7 @@ def __init__(
81
72
n_jobs : int = 1 ,
82
73
dask_client : dask .distributed .Client | None = None ,
83
74
disable_evaluator_output : bool
84
- | Sequence [DisableEvaluatorOptions ] = False , # TODO fill in
75
+ | Sequence [Literal [ "y_optimization" , "model" ]] = False , # TODO: fill in
85
76
get_smac_object_callback : SMACCallback | None = None ,
86
77
smac_scenario_args : dict [str , Any ] | None = None ,
87
78
logging_config : dict [str , Any ] | None = None ,
@@ -490,7 +481,7 @@ def __init__(
490
481
self .allow_string_features = allow_string_features
491
482
492
483
# Cached
493
- self .automl_ : AutoML | None = None
484
+ self .automl_ : TAutoML | None = None
494
485
495
486
# Handle the number of jobs and the time for them
496
487
# Made private by `_n_jobs` to keep with sklearn compliance
@@ -504,21 +495,19 @@ def __init__(
504
495
self .per_run_time_limit = self ._n_jobs * self .time_left_for_this_task // 10
505
496
506
497
@property
507
- @abstractmethod
508
498
def automl (self ) -> TAutoML :
509
499
"""Get the underlying Automl instance
510
500
511
501
Returns
512
502
-------
513
503
AutoML
514
- The underlying AutoML instanec
504
+ The underlying AutoML instance
515
505
"""
516
506
if self .automl_ is not None :
517
507
return self .automl_
518
508
519
509
initial_configs = self .initial_configurations_via_metalearning
520
- cls = self ._get_automl_class ()
521
- automl = cls (
510
+ automl = self ._automl_class (
522
511
temporary_directory = self .tmp_folder ,
523
512
delete_tmp_folder_after_terminate = self .delete_tmp_folder_after_terminate ,
524
513
time_left_for_this_task = self .time_left_for_this_task ,
@@ -568,16 +557,14 @@ def ensemble(self) -> AbstractEnsemble:
568
557
NotFittedError
569
558
If there this estimator has not been fitted
570
559
"""
571
-
572
- def __getstate__ (self ) -> dict [str , Any ]:
573
- # Cannot serialize a client!
574
- self .dask_client = None
575
- return self .__dict__
560
+ # TODO
561
+ raise NotImplementedError ()
576
562
577
563
def fit (
578
564
self : Self ,
579
565
X : np .ndarray | pd .DataFrame | list | spmatrix ,
580
566
y : np .ndarray | pd .DataFrame | pd .Series | list ,
567
+ * ,
581
568
X_test : np .ndarray | pd .DataFrame | list | spmatrix | None = None ,
582
569
y_test : np .ndarray | pd .DataFrame | pd .Series | list | None = None ,
583
570
feat_type : list [str ] | None = None ,
@@ -697,6 +684,7 @@ def fit_pipeline(
697
684
self ,
698
685
X : np .ndarray | pd .DataFrame | list | spmatrix ,
699
686
y : np .ndarray | pd .DataFrame | pd .Series | list ,
687
+ * ,
700
688
config : Configuration | dict [str , Any ],
701
689
dataset_name : str | None = None ,
702
690
X_test : np .ndarray | pd .DataFrame | list | spmatrix | None = None ,
@@ -767,6 +755,7 @@ def fit_pipeline(
767
755
def fit_ensemble (
768
756
self : Self ,
769
757
y : np .ndarray | pd .DataFrame | pd .Series | list ,
758
+ * ,
770
759
task : int | None = None ,
771
760
precision : Literal [16 , 32 , 64 ] = 32 ,
772
761
dataset_name : str | None = None ,
@@ -913,6 +902,7 @@ def refit(
913
902
def predict (
914
903
self ,
915
904
X : np .ndarray | pd .DataFrame | list | spmatrix ,
905
+ * ,
916
906
batch_size : int | None = None ,
917
907
n_jobs : int = 1 ,
918
908
) -> np .ndarray :
@@ -1088,6 +1078,7 @@ def sprint_statistics(self) -> str:
1088
1078
1089
1079
def leaderboard (
1090
1080
self ,
1081
+ * ,
1091
1082
detailed : bool = False ,
1092
1083
ensemble_only : bool = True ,
1093
1084
top_k : int | Literal ["all" ] = "all" ,
@@ -1501,6 +1492,7 @@ def get_configuration_space(
1501
1492
self ,
1502
1493
X : np .ndarray | pd .DataFrame | list | spmatrix ,
1503
1494
y : np .ndarray | pd .DataFrame | pd .Series | list ,
1495
+ * ,
1504
1496
X_test : np .ndarray | pd .DataFrame | list | spmatrix | None = None ,
1505
1497
y_test : np .ndarray | pd .DataFrame | pd .Series | list | None = None ,
1506
1498
dataset_name : str | None = None ,
@@ -1549,6 +1541,11 @@ def get_pareto_set(self) -> Sequence[TParetoModel]:
1549
1541
"""
1550
1542
return self .automl ._load_pareto_set ()
1551
1543
1544
+ def __getstate__ (self ) -> dict [str , Any ]:
1545
+ # Cannot serialize a client!
1546
+ self .dask_client = None
1547
+ return self .__dict__
1548
+
1552
1549
def __sklearn_is_fitted__ (self ) -> bool :
1553
1550
return self .automl_ is not None and self .automl .fitted
1554
1551
0 commit comments