Skip to content

Commit 9d4d61c

Browse files
MTN Fix the parallel plots
1 parent 0f2ac12 commit 9d4d61c

8 files changed

+67
-63
lines changed

notebooks/datasets_adult_census.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@
105105
" dimensions=plot_list,\n",
106106
" )\n",
107107
")\n",
108-
"fig.show()"
108+
"fig.show(renderer=\"notebook\")"
109109
]
110110
},
111111
{

notebooks/linear_models_feature_engineering_classification.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@
641641
"- Transformers such as `KBinsDiscretizer` and `SplineTransformer` can be used\n",
642642
" to engineer non-linear features independently for each original feature.\n",
643643
"- As a result, these transformers cannot capture interactions between the\n",
644-
" orignal features (and then would fail on the XOR classification task).\n",
644+
" original features (and then would fail on the XOR classification task).\n",
645645
"- Despite this limitation they already augment the expressivity of the\n",
646646
" pipeline, which can be sufficient for some datasets.\n",
647647
"- They also favor axis-aligned decision boundaries, in particular in the low\n",

notebooks/parameter_tuning_grid_search.ipynb

Lines changed: 60 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -198,29 +198,33 @@
198198
"source": [
199199
"## Tuning using a grid-search\n",
200200
"\n",
201-
"In the previous exercise we used one `for` loop for each hyperparameter to\n",
202-
"find the best combination over a fixed grid of values. `GridSearchCV` is a\n",
203-
"scikit-learn class that implements a very similar logic with less repetitive\n",
204-
"code.\n",
201+
"In the previous exercise (M3.01) we used two nested `for` loops (one for each\n",
202+
"hyperparameter) to test different combinations over a fixed grid of\n",
203+
"hyperparameter values. In each iteration of the loop, we used\n",
204+
"`cross_val_score` to compute the mean score (as averaged across\n",
205+
"cross-validation splits), and compared those mean scores to select the best\n",
206+
"combination. `GridSearchCV` is a scikit-learn class that implements a very\n",
207+
"similar logic with less repetitive code. The suffix `CV` refers to the\n",
208+
"cross-validation it runs internally (instead of the `cross_val_score` we\n",
209+
"\"hard\" coded).\n",
210+
"\n",
211+
"The `GridSearchCV` estimator takes a `param_grid` parameter which defines all\n",
212+
"hyperparameters and their associated values. The grid-search is in charge of\n",
213+
"creating all possible combinations and testing them.\n",
214+
"\n",
215+
"The number of combinations is equal to the product of the number of values to\n",
216+
"explore for each parameter. Thus, adding new parameters with their associated\n",
217+
"values to be explored rapidly becomes computationally expensive. Because of\n",
218+
"that, here we only explore the combination learning-rate and the maximum\n",
219+
"number of nodes for a total of 4 x 3 = 12 combinations.\n",
205220
"\n",
206-
"Let's see how to use the `GridSearchCV` estimator for doing such search. Since\n",
207-
"the grid-search is costly, we only explore the combination learning-rate and\n",
208-
"the maximum number of nodes."
209-
]
210-
},
211-
{
212-
"cell_type": "code",
213-
"execution_count": null,
214-
"metadata": {},
215-
"outputs": [],
216-
"source": [
217221
"%%time\n",
218222
"from sklearn.model_selection import GridSearchCV\n",
219223
"\n",
220224
"param_grid = {\n",
221-
" \"classifier__learning_rate\": (0.01, 0.1, 1, 10),\n",
222-
" \"classifier__max_leaf_nodes\": (3, 10, 30),\n",
223-
"}\n",
225+
" \"classifier__learning_rate\": (0.01, 0.1, 1, 10), # 4 possible values\n",
226+
" \"classifier__max_leaf_nodes\": (3, 10, 30), # 3 possible values\n",
227+
"} # 12 unique combinations\n",
224228
"model_grid_search = GridSearchCV(model, param_grid=param_grid, n_jobs=2, cv=2)\n",
225229
"model_grid_search.fit(data_train, target_train)"
226230
]
@@ -229,7 +233,8 @@
229233
"cell_type": "markdown",
230234
"metadata": {},
231235
"source": [
232-
"Finally, we check the accuracy of our model using the test set."
236+
"You can access the best combination of hyperparameters found by the grid\n",
237+
"search using the `best_params_` attribute."
233238
]
234239
},
235240
{
@@ -238,46 +243,19 @@
238243
"metadata": {},
239244
"outputs": [],
240245
"source": [
241-
"accuracy = model_grid_search.score(data_test, target_test)\n",
242-
"print(\n",
243-
" f\"The test accuracy score of the grid-searched pipeline is: {accuracy:.2f}\"\n",
244-
")"
245-
]
246-
},
247-
{
248-
"cell_type": "markdown",
249-
"metadata": {},
250-
"source": [
251-
"<div class=\"admonition warning alert alert-danger\">\n",
252-
"<p class=\"first admonition-title\" style=\"font-weight: bold;\">Warning</p>\n",
253-
"<p>Be aware that the evaluation should normally be performed through\n",
254-
"cross-validation by providing <tt class=\"docutils literal\">model_grid_search</tt> as a model to the\n",
255-
"<tt class=\"docutils literal\">cross_validate</tt> function.</p>\n",
256-
"<p class=\"last\">Here, we used a single train-test split to evaluate <tt class=\"docutils literal\">model_grid_search</tt>. In\n",
257-
"a future notebook will go into more detail about nested cross-validation, when\n",
258-
"you use cross-validation both for hyperparameter tuning and model evaluation.</p>\n",
259-
"</div>"
246+
"print(f\"The best set of parameters is: {model_grid_search.best_params_}\")"
260247
]
261248
},
262249
{
263250
"cell_type": "markdown",
264251
"metadata": {},
265252
"source": [
266-
"The `GridSearchCV` estimator takes a `param_grid` parameter which defines all\n",
267-
"hyperparameters and their associated values. The grid-search is in charge\n",
268-
"of creating all possible combinations and test them.\n",
269-
"\n",
270-
"The number of combinations are equal to the product of the number of values to\n",
271-
"explore for each parameter (e.g. in our example 4 x 3 combinations). Thus,\n",
272-
"adding new parameters with their associated values to be explored become\n",
273-
"rapidly computationally expensive.\n",
274-
"\n",
275-
"Once the grid-search is fitted, it can be used as any other predictor by\n",
276-
"calling `predict` and `predict_proba`. Internally, it uses the model with the\n",
253+
"Once the grid-search is fitted, it can be used as any other estimator, i.e. it\n",
254+
"has `predict` and `score` methods. Internally, it uses the model with the\n",
277255
"best parameters found during `fit`.\n",
278256
"\n",
279-
"Get predictions for the 5 first samples using the estimator with the best\n",
280-
"parameters."
257+
"Let's get the predictions for the 5 first samples using the estimator with the\n",
258+
"best parameters:"
281259
]
282260
},
283261
{
@@ -293,8 +271,7 @@
293271
"cell_type": "markdown",
294272
"metadata": {},
295273
"source": [
296-
"You can know about these parameters by looking at the `best_params_`\n",
297-
"attribute."
274+
"Finally, we check the accuracy of our model using the test set."
298275
]
299276
},
300277
{
@@ -303,16 +280,43 @@
303280
"metadata": {},
304281
"outputs": [],
305282
"source": [
306-
"print(f\"The best set of parameters is: {model_grid_search.best_params_}\")"
283+
"accuracy = model_grid_search.score(data_test, target_test)\n",
284+
"print(\n",
285+
" f\"The test accuracy score of the grid-search pipeline is: {accuracy:.2f}\"\n",
286+
")"
307287
]
308288
},
309289
{
310290
"cell_type": "markdown",
311291
"metadata": {},
312292
"source": [
313-
"The accuracy and the best parameters of the grid-searched pipeline are similar\n",
293+
"The accuracy and the best parameters of the grid-search pipeline are similar\n",
314294
"to the ones we found in the previous exercise, where we searched the best\n",
315-
"parameters \"by hand\" through a double for loop.\n",
295+
"parameters \"by hand\" through a double `for` loop.\n",
296+
"\n",
297+
"## The need for a validation set\n",
298+
"\n",
299+
"In the previous section, the selection of the best hyperparameters was done\n",
300+
"using the train set, coming from the initial train-test split. Then, we\n",
301+
"evaluated the generalization performance of our tuned model on the left out\n",
302+
"test set. This can be shown schematically as follows:\n",
303+
"\n",
304+
"![Cross-validation tuning\n",
305+
"diagram](../figures/cross_validation_train_test_diagram.png)\n",
306+
"\n",
307+
"<div class=\"admonition note alert alert-info\">\n",
308+
"<p class=\"first admonition-title\" style=\"font-weight: bold;\">Note</p>\n",
309+
"<p>This figure shows the particular case of <strong>K-fold</strong> cross-validation strategy\n",
310+
"using <tt class=\"docutils literal\">n_splits=5</tt> to further split the train set coming from a train-test\n",
311+
"split. For each cross-validation split, the procedure trains a model on all\n",
312+
"the red samples, evaluates the score of a given set of hyperparameters on the\n",
313+
"green samples. The best combination of hyperparameters <tt class=\"docutils literal\">best_params</tt> is selected\n",
314+
"based on those intermediate scores.</p>\n",
315+
"<p>Then a final model is refitted using <tt class=\"docutils literal\">best_params</tt> on the concatenation of the\n",
316+
"red and green samples and evaluated on the blue samples.</p>\n",
317+
"<p class=\"last\">The green samples are sometimes referred as the <strong>validation set</strong> to\n",
318+
"differentiate them from the final test set in blue.</p>\n",
319+
"</div>\n",
316320
"\n",
317321
"In addition, we can inspect all results which are stored in the attribute\n",
318322
"`cv_results_` of the grid-search. We filter some specific columns from these\n",

notebooks/parameter_tuning_parallel_plot.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@
145145
" color=\"mean_test_score\",\n",
146146
" color_continuous_scale=px.colors.sequential.Viridis,\n",
147147
")\n",
148-
"fig.show()"
148+
"fig.show(renderer=\"notebook\")"
149149
]
150150
},
151151
{

notebooks/parameter_tuning_sol_03.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@
266266
" dimensions=[\"n_neighbors\", \"centering\", \"scaling\", \"mean test score\"],\n",
267267
" color_continuous_scale=px.colors.diverging.Tealrose,\n",
268268
")\n",
269-
"fig.show()"
269+
"fig.show(renderer=\"notebook\")"
270270
]
271271
},
272272
{

python_scripts/datasets_adult_census.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def generate_dict(col):
9191
dimensions=plot_list,
9292
)
9393
)
94-
fig.show()
94+
fig.show(renderer="notebook")
9595

9696
# %% [markdown]
9797
# The `Parcoords` plot is quite similar to the parallel coordinates plot that we

python_scripts/parameter_tuning_parallel_plot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def shorten_param(param_name):
102102
color="mean_test_score",
103103
color_continuous_scale=px.colors.sequential.Viridis,
104104
)
105-
fig.show()
105+
fig.show(renderer="notebook")
106106

107107
# %% [markdown]
108108
# ```{note}

python_scripts/parameter_tuning_sol_03.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@
160160
dimensions=["n_neighbors", "centering", "scaling", "mean test score"],
161161
color_continuous_scale=px.colors.diverging.Tealrose,
162162
)
163-
fig.show()
163+
fig.show(renderer="notebook")
164164

165165
# %% [markdown] tags=["solution"]
166166
# We recall that it is possible to select a range of results by clicking and

0 commit comments

Comments
 (0)