Skip to content

Commit 15b3df1

Browse files
committed
MNT test fastcan is sklearn estimator
1 parent 365f326 commit 15b3df1

File tree

4 files changed

+560
-560
lines changed

4 files changed

+560
-560
lines changed

fastcan/_fastcan.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@ class FastCan(SelectorMixin, BaseEstimator):
2626
2727
Parameters
2828
----------
29-
n_features_to_select : int
29+
n_features_to_select : int, default=1
3030
The parameter is the absolute number of features to select.
3131
3232
inclusive_indices : array-like of shape (n_inclusions,), default=None
3333
The indices of the prerequisite features.
3434
35-
eta : bool, default=None
35+
eta : bool, default=False
3636
Whether to use eta-cosine method.
3737
3838
tol : float, default=0.01
@@ -94,16 +94,16 @@ class FastCan(SelectorMixin, BaseEstimator):
9494
Interval(Integral, 1, None, closed="left"),
9595
],
9696
"inclusive_indices": [None, "array-like"],
97-
"eta": ["boolean", None],
97+
"eta": ["boolean"],
9898
"tol": [Interval(Real, 0, None, closed="neither")],
9999
"verbose": ["verbose"],
100100
}
101101

102102
def __init__(
103103
self,
104-
n_features_to_select,
104+
n_features_to_select=1,
105105
inclusive_indices=None,
106-
eta=None,
106+
eta=False,
107107
tol=0.01,
108108
verbose=1,
109109
):
@@ -132,8 +132,8 @@ def fit(self, X, y):
132132
"""
133133
self._validate_params()
134134
# X y
135-
check_X_params = {"order": "F"}
136-
check_y_params = {"ensure_2d": False, "order": "F"}
135+
check_X_params = {"order": "F", "dtype": float}
136+
check_y_params = {"ensure_2d": False, "order": "F", "dtype": float}
137137
X, y = self._validate_data(
138138
X=X,
139139
y=y,
@@ -147,9 +147,9 @@ def fit(self, X, y):
147147

148148
# inclusive_indices
149149
if self.inclusive_indices is None:
150-
self.inclusive_indices = np.zeros(0, dtype=int)
150+
inclusive_indices = np.zeros(0, dtype=int)
151151
else:
152-
self.inclusive_indices = check_array(
152+
inclusive_indices = check_array(
153153
self.inclusive_indices,
154154
ensure_2d=False,
155155
dtype=int,
@@ -165,18 +165,12 @@ def fit(self, X, y):
165165
f"must be <= n_features {n_features}."
166166
)
167167

168-
if self.inclusive_indices.shape[0] >= n_features:
168+
if inclusive_indices.shape[0] >= n_features:
169169
raise ValueError(
170-
f"n_inclusions {self.inclusive_indices.shape[0]} must "
170+
f"n_inclusions {inclusive_indices.shape[0]} must "
171171
f"be < n_features {n_features}."
172172
)
173173

174-
# Method determination
175-
if self.eta is None:
176-
if n_samples > n_features + n_outputs:
177-
self.eta = True
178-
else:
179-
self.eta = False
180174
if n_samples < n_features + n_outputs and self.eta:
181175
raise ValueError(
182176
"`eta` cannot be True, when n_samples < n_features+n_outputs."
@@ -197,7 +191,7 @@ def fit(self, X, y):
197191
y_transformed = y - y.mean(0)
198192

199193
mask, indices, scores = self._prepare_data(
200-
self.inclusive_indices,
194+
inclusive_indices,
201195
)
202196
n_threads = _openmp_effective_n_threads()
203197
_forward_search(

0 commit comments

Comments
 (0)