Skip to content

Commit 1c72ba5

Browse files
ceriottmjwa7
andauthored
Fixes bug in greedy selector when scores are degenerate (#265)
* Fix general bug in GreedySelector that would pick the same point if there are degeneracies * Write failing tests for when there are non-unique selections for zero score * Only select idxs that have been previously selected * CUR feature selection test --------- Co-authored-by: Joseph Abbott <joseph.william.abbott@gmail.com>
1 parent 575cb8f commit 1c72ba5

File tree

7 files changed

+84
-3
lines changed

7 files changed

+84
-3
lines changed

src/skmatter/_selection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,6 @@ def fit(self, X, y=None, warm_start=False):
228228
n_to_select_from = X.shape[self._axis]
229229
self.n_samples_in_, self.n_features_in_ = X.shape
230230

231-
self.n_samples_in_, self.n_features_in_ = X.shape
232-
233231
error_msg = (
234232
"n_to_select must be either None, an "
235233
f"integer in [1, n_{self.selection_type}s] "
@@ -428,6 +426,8 @@ def _continue_greedy_search(self, X, y, n_to_select):
428426
def _get_best_new_selection(self, scorer, X, y):
429427
scores = scorer(X, y)
430428

429+
# Get the score argmax, but only for idxs not already selected
430+
scores[self.selected_idx_[: self.n_selected_]] = -np.inf
431431
max_score_idx = np.argmax(scores)
432432
if self.score_threshold is not None:
433433
if self.first_score_ is None:

tests/test_feature_pcov_cur.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def test_non_it(self):
3333
self.idx = [2, 8, 3, 6, 7, 9, 1, 0, 5]
3434
selector = PCovCUR(n_to_select=9, recompute_every=0)
3535
selector.fit(self.X, self.y)
36-
3736
self.assertTrue(np.allclose(selector.selected_idx_, self.idx))
3837

3938

tests/test_feature_simple_cur.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,22 @@ def test_non_it(self):
4141

4242
self.assertTrue(np.allclose(selector.selected_idx_, ref_idx))
4343

44+
def test_unique_selected_idx_zero_score(self):
45+
"""
46+
Tests that the selected idxs are unique, which may not be the
47+
case when the score is numerically zero
48+
"""
49+
np.random.seed(0)
50+
n_samples = 10
51+
n_features = 15
52+
X = np.random.rand(n_samples, n_features)
53+
X[:, 1] = X[:, 0]
54+
X[:, 2] = X[:, 0]
55+
selector_problem = CUR(n_to_select=len(X.T)).fit(X)
56+
assert len(selector_problem.selected_idx_) == len(
57+
set(selector_problem.selected_idx_)
58+
)
59+
4460

4561
if __name__ == "__main__":
4662
unittest.main(verbosity=2)

tests/test_feature_simple_fps.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,22 @@ def test_get_distances(self):
8484
selector = FPS(n_to_select=7)
8585
_ = selector.get_select_distance()
8686

87+
def test_unique_selected_idx_zero_score(self):
88+
"""
89+
Tests that the selected idxs are unique, which may not be the
90+
case when the score is numerically zero
91+
"""
92+
np.random.seed(0)
93+
n_samples = 10
94+
n_features = 15
95+
X = np.random.rand(n_samples, n_features)
96+
X[:, 1] = X[:, 0]
97+
X[:, 2] = X[:, 0]
98+
selector_problem = FPS(n_to_select=len(X.T)).fit(X)
99+
assert len(selector_problem.selected_idx_) == len(
100+
set(selector_problem.selected_idx_)
101+
)
102+
87103

88104
if __name__ == "__main__":
89105
unittest.main(verbosity=2)

tests/test_sample_simple_cur.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,23 @@ def test_non_it(self):
5252

5353
self.assertTrue(np.allclose(selector.selected_idx_, ref_idx))
5454

55+
def test_unique_selected_idx_zero_score(self):
56+
"""
57+
Tests that the selected idxs are unique, which may not be the
58+
case when the score is numerically zero.
59+
"""
60+
np.random.seed(0)
61+
n_samples = 10
62+
n_features = 15
63+
X = np.random.rand(n_samples, n_features)
64+
X[1] = X[0]
65+
X[2] = X[0]
66+
X[3] = X[0]
67+
selector_problem = CUR(n_to_select=len(X)).fit(X)
68+
assert len(selector_problem.selected_idx_) == len(
69+
set(selector_problem.selected_idx_)
70+
)
71+
5572

5673
if __name__ == "__main__":
5774
unittest.main(verbosity=2)

tests/test_sample_simple_fps.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,23 @@ def test_threshold(self):
9999
self.assertEqual(len(selector.selected_idx_), 5)
100100
self.assertEqual(selector.selected_idx_.tolist(), self.idx[:5])
101101

102+
def test_unique_selected_idx_zero_score(self):
103+
"""
104+
Tests that the selected idxs are unique, which may not be the
105+
case when the score is numerically zero.
106+
"""
107+
np.random.seed(0)
108+
n_samples = 10
109+
n_features = 15
110+
X = np.random.rand(n_samples, n_features)
111+
X[1] = X[0]
112+
X[2] = X[0]
113+
X[3] = X[0]
114+
selector_problem = FPS(n_to_select=len(X)).fit(X)
115+
assert len(selector_problem.selected_idx_) == len(
116+
set(selector_problem.selected_idx_)
117+
)
118+
102119

103120
if __name__ == "__main__":
104121
unittest.main(verbosity=2)

tests/test_voronoi_fps.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,22 @@ def test_score(self):
165165
)
166166
)
167167

168+
def test_unique_selected_idx_zero_score(self):
169+
"""
170+
Tests that the selected idxs are unique, which may not be the
171+
case when the score is numerically zero
172+
"""
173+
np.random.seed(0)
174+
n_samples = 10
175+
n_features = 15
176+
X = np.random.rand(n_samples, n_features)
177+
X[1] = X[0]
178+
X[2] = X[0]
179+
selector_problem = VoronoiFPS(n_to_select=n_samples, initialize=3).fit(X)
180+
assert len(selector_problem.selected_idx_) == len(
181+
set(selector_problem.selected_idx_)
182+
)
183+
168184

169185
if __name__ == "__main__":
170186
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)