diff --git a/src/skmatter/_selection.py b/src/skmatter/_selection.py index 8c6406203..28219a1ef 100644 --- a/src/skmatter/_selection.py +++ b/src/skmatter/_selection.py @@ -228,8 +228,6 @@ def fit(self, X, y=None, warm_start=False): n_to_select_from = X.shape[self._axis] self.n_samples_in_, self.n_features_in_ = X.shape - self.n_samples_in_, self.n_features_in_ = X.shape - error_msg = ( "n_to_select must be either None, an " f"integer in [1, n_{self.selection_type}s] " @@ -428,7 +426,11 @@ def _continue_greedy_search(self, X, y, n_to_select): def _get_best_new_selection(self, scorer, X, y): scores = scorer(X, y) - max_score_idx = np.argmax(scores) + # Get the score argmax, but only for idxs not already selected + _tmp_scores = { + i: score for i, score in enumerate(scores) if i not in self.selected_idx_ + } + max_score_idx = max(_tmp_scores, key=_tmp_scores.get) if self.score_threshold is not None: if self.first_score_ is None: self.first_score_ = scores[max_score_idx] diff --git a/tests/test_feature_simple_cur.py b/tests/test_feature_simple_cur.py index 147a16fed..a6360aaeb 100644 --- a/tests/test_feature_simple_cur.py +++ b/tests/test_feature_simple_cur.py @@ -41,6 +41,22 @@ def test_non_it(self): self.assertTrue(np.allclose(selector.selected_idx_, ref_idx)) + def test_unique_selected_idx_zero_score(self): + """ + Tests that the selected idxs are unique, which may not be the + case when the score is numerically zero + """ + np.random.seed(0) + n_samples = 10 + n_features = 15 + X = np.random.rand(n_samples, n_features) + X[:, 3] = np.random.rand(10) * 1e-13 + X[:, 4] = np.random.rand(10) * 1e-13 + selector_problem = CUR(n_to_select=len(X.T)).fit(X) + assert len(selector_problem.selected_idx_) == len( + set(selector_problem.selected_idx_) + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_feature_simple_fps.py b/tests/test_feature_simple_fps.py index 8135bd819..dcd618945 100644 --- a/tests/test_feature_simple_fps.py +++ b/tests/test_feature_simple_fps.py @@ -84,6 +84,22 @@ def test_get_distances(self): selector = FPS(n_to_select=7) _ = selector.get_select_distance() + def test_unique_selected_idx_zero_score(self): + """ + Tests that the selected idxs are unique, which may not be the + case when the score is numerically zero + """ + np.random.seed(0) + n_samples = 10 + n_features = 15 + X = np.random.rand(n_samples, n_features) + X[:, 3] = np.random.rand(10) * 1e-13 + X[:, 4] = np.random.rand(10) * 1e-13 + selector_problem = FPS(n_to_select=len(X.T)).fit(X) + assert len(selector_problem.selected_idx_) == len( + set(selector_problem.selected_idx_) + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_sample_simple_cur.py b/tests/test_sample_simple_cur.py index 0969074a3..72fab4c3a 100644 --- a/tests/test_sample_simple_cur.py +++ b/tests/test_sample_simple_cur.py @@ -52,6 +52,23 @@ def test_non_it(self): self.assertTrue(np.allclose(selector.selected_idx_, ref_idx)) + def test_unique_selected_idx_zero_score(self): + """ + Tests that the selected idxs are unique, which may not be the + case when the score is numerically zero. + """ + np.random.seed(0) + n_samples = 10 + n_features = 15 + X = np.random.rand(n_samples, n_features) + X[4, :] = np.random.rand(15) * 1e-13 + X[5, :] = np.random.rand(15) * 1e-13 + X[6, :] = np.random.rand(15) * 1e-13 + selector_problem = CUR(n_to_select=len(X)).fit(X) + assert len(selector_problem.selected_idx_) == len( + set(selector_problem.selected_idx_) + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_sample_simple_fps.py b/tests/test_sample_simple_fps.py index 5f6f216bc..afac270e3 100644 --- a/tests/test_sample_simple_fps.py +++ b/tests/test_sample_simple_fps.py @@ -99,6 +99,23 @@ def test_threshold(self): self.assertEqual(len(selector.selected_idx_), 5) self.assertEqual(selector.selected_idx_.tolist(), self.idx[:5]) + def test_unique_selected_idx_zero_score(self): + """ + Tests that the selected idxs are unique, which may not be the + case when the score is numerically zero. + """ + np.random.seed(0) + n_samples = 10 + n_features = 15 + X = np.random.rand(n_samples, n_features) + X[4, :] = np.random.rand(15) * 1e-13 + X[5, :] = np.random.rand(15) * 1e-13 + X[6, :] = np.random.rand(15) * 1e-13 + selector_problem = FPS(n_to_select=len(X)).fit(X) + assert len(selector_problem.selected_idx_) == len( + set(selector_problem.selected_idx_) + ) + if __name__ == "__main__": unittest.main(verbosity=2)