@@ -2730,7 +2730,7 @@ MatrixXd APLRRegressor::get_unique_term_affiliation_shape(const std::string &uni
2730
2730
{
2731
2731
split_points_in_each_predictor[i] = compute_split_points (base_predictors_in_each_unique_term_affiliation[unique_term_affiliation_index][i], relevant_term_indexes);
2732
2732
2733
- if (num_predictors_used_in_the_affiliation > 1 && additional_points > 0 )
2733
+ if (num_predictors_used_in_the_affiliation > 1 && additional_points > 0 && !split_points_in_each_predictor[i]. empty () )
2734
2734
{
2735
2735
double min_val = *std::min_element (split_points_in_each_predictor[i].begin (), split_points_in_each_predictor[i].end ());
2736
2736
double max_val = *std::max_element (split_points_in_each_predictor[i].begin (), split_points_in_each_predictor[i].end ());
@@ -2741,9 +2741,9 @@ MatrixXd APLRRegressor::get_unique_term_affiliation_shape(const std::string &uni
2741
2741
double val = min_val + (max_val - min_val) * j / (additional_points + 1 );
2742
2742
interpolated.push_back (val);
2743
2743
}
2744
+ split_points_in_each_predictor[i].reserve (split_points_in_each_predictor[i].size () + additional_points);
2744
2745
split_points_in_each_predictor[i].insert (split_points_in_each_predictor[i].end (), interpolated.begin (), interpolated.end ());
2745
- std::sort (split_points_in_each_predictor[i].begin (), split_points_in_each_predictor[i].end ());
2746
- split_points_in_each_predictor[i].erase (std::unique (split_points_in_each_predictor[i].begin (), split_points_in_each_predictor[i].end ()), split_points_in_each_predictor[i].end ());
2746
+ split_points_in_each_predictor[i] = remove_duplicate_elements_from_vector (split_points_in_each_predictor[i]);
2747
2747
}
2748
2748
}
2749
2749
@@ -2762,6 +2762,7 @@ MatrixXd APLRRegressor::get_unique_term_affiliation_shape(const std::string &uni
2762
2762
{
2763
2763
size_t current_num_observations = split_points.size ();
2764
2764
size_t num_observations_to_keep = std::round (factor * std::sqrt (current_num_observations));
2765
+ num_observations_to_keep = std::max<size_t >(1 , num_observations_to_keep);
2765
2766
if (current_num_observations > num_observations_to_keep)
2766
2767
{
2767
2768
std::shuffle (split_points.begin (), split_points.end (), seed);
0 commit comments