|
198 | 198 | "source": [
|
199 | 199 | "## Tuning using a grid-search\n",
|
200 | 200 | "\n",
|
201 |
| - "In the previous exercise we used one `for` loop for each hyperparameter to\n", |
202 |
| - "find the best combination over a fixed grid of values. `GridSearchCV` is a\n", |
203 |
| - "scikit-learn class that implements a very similar logic with less repetitive\n", |
204 |
| - "code.\n", |
| 201 | + "In the previous exercise (M3.01) we used two nested `for` loops (one for each\n", |
| 202 | + "hyperparameter) to test different combinations over a fixed grid of\n", |
| 203 | + "hyperparameter values. In each iteration of the loop, we used\n", |
| 204 | + "`cross_val_score` to compute the mean score (as averaged across\n", |
| 205 | + "cross-validation splits), and compared those mean scores to select the best\n", |
| 206 | + "combination. `GridSearchCV` is a scikit-learn class that implements a very\n", |
| 207 | + "similar logic with less repetitive code. The suffix `CV` refers to the\n", |
| 208 | + "cross-validation it runs internally (instead of the `cross_val_score` we\n", |
| 209 | + "\"hard\" coded).\n", |
| 210 | + "\n", |
| 211 | + "The `GridSearchCV` estimator takes a `param_grid` parameter which defines all\n", |
| 212 | + "hyperparameters and their associated values. The grid-search is in charge of\n", |
| 213 | + "creating all possible combinations and testing them.\n", |
| 214 | + "\n", |
| 215 | + "The number of combinations is equal to the product of the number of values to\n", |
| 216 | + "explore for each parameter. Thus, adding new parameters with their associated\n", |
| 217 | + "values to be explored rapidly becomes computationally expensive. Because of\n", |
| 218 | + "that, here we only explore the combination learning-rate and the maximum\n", |
| 219 | + "number of nodes for a total of 4 x 3 = 12 combinations.\n", |
205 | 220 | "\n",
|
206 |
| - "Let's see how to use the `GridSearchCV` estimator for doing such search. Since\n", |
207 |
| - "the grid-search is costly, we only explore the combination learning-rate and\n", |
208 |
| - "the maximum number of nodes." |
209 |
| - ] |
210 |
| - }, |
211 |
| - { |
212 |
| - "cell_type": "code", |
213 |
| - "execution_count": null, |
214 |
| - "metadata": {}, |
215 |
| - "outputs": [], |
216 |
| - "source": [ |
217 | 221 | "%%time\n",
|
218 | 222 | "from sklearn.model_selection import GridSearchCV\n",
|
219 | 223 | "\n",
|
220 | 224 | "param_grid = {\n",
|
221 |
| - " \"classifier__learning_rate\": (0.01, 0.1, 1, 10),\n", |
222 |
| - " \"classifier__max_leaf_nodes\": (3, 10, 30),\n", |
223 |
| - "}\n", |
| 225 | + " \"classifier__learning_rate\": (0.01, 0.1, 1, 10), # 4 possible values\n", |
| 226 | + " \"classifier__max_leaf_nodes\": (3, 10, 30), # 3 possible values\n", |
| 227 | + "} # 12 unique combinations\n", |
224 | 228 | "model_grid_search = GridSearchCV(model, param_grid=param_grid, n_jobs=2, cv=2)\n",
|
225 | 229 | "model_grid_search.fit(data_train, target_train)"
|
226 | 230 | ]
|
|
229 | 233 | "cell_type": "markdown",
|
230 | 234 | "metadata": {},
|
231 | 235 | "source": [
|
232 |
| - "Finally, we check the accuracy of our model using the test set." |
| 236 | + "You can access the best combination of hyperparameters found by the grid\n", |
| 237 | + "search using the `best_params_` attribute." |
233 | 238 | ]
|
234 | 239 | },
|
235 | 240 | {
|
|
238 | 243 | "metadata": {},
|
239 | 244 | "outputs": [],
|
240 | 245 | "source": [
|
241 |
| - "accuracy = model_grid_search.score(data_test, target_test)\n", |
242 |
| - "print(\n", |
243 |
| - " f\"The test accuracy score of the grid-searched pipeline is: {accuracy:.2f}\"\n", |
244 |
| - ")" |
245 |
| - ] |
246 |
| - }, |
247 |
| - { |
248 |
| - "cell_type": "markdown", |
249 |
| - "metadata": {}, |
250 |
| - "source": [ |
251 |
| - "<div class=\"admonition warning alert alert-danger\">\n", |
252 |
| - "<p class=\"first admonition-title\" style=\"font-weight: bold;\">Warning</p>\n", |
253 |
| - "<p>Be aware that the evaluation should normally be performed through\n", |
254 |
| - "cross-validation by providing <tt class=\"docutils literal\">model_grid_search</tt> as a model to the\n", |
255 |
| - "<tt class=\"docutils literal\">cross_validate</tt> function.</p>\n", |
256 |
| - "<p class=\"last\">Here, we used a single train-test split to evaluate <tt class=\"docutils literal\">model_grid_search</tt>. In\n", |
257 |
| - "a future notebook will go into more detail about nested cross-validation, when\n", |
258 |
| - "you use cross-validation both for hyperparameter tuning and model evaluation.</p>\n", |
259 |
| - "</div>" |
| 246 | + "print(f\"The best set of parameters is: {model_grid_search.best_params_}\")" |
260 | 247 | ]
|
261 | 248 | },
|
262 | 249 | {
|
263 | 250 | "cell_type": "markdown",
|
264 | 251 | "metadata": {},
|
265 | 252 | "source": [
|
266 |
| - "The `GridSearchCV` estimator takes a `param_grid` parameter which defines all\n", |
267 |
| - "hyperparameters and their associated values. The grid-search is in charge\n", |
268 |
| - "of creating all possible combinations and test them.\n", |
269 |
| - "\n", |
270 |
| - "The number of combinations are equal to the product of the number of values to\n", |
271 |
| - "explore for each parameter (e.g. in our example 4 x 3 combinations). Thus,\n", |
272 |
| - "adding new parameters with their associated values to be explored become\n", |
273 |
| - "rapidly computationally expensive.\n", |
274 |
| - "\n", |
275 |
| - "Once the grid-search is fitted, it can be used as any other predictor by\n", |
276 |
| - "calling `predict` and `predict_proba`. Internally, it uses the model with the\n", |
| 253 | + "Once the grid-search is fitted, it can be used as any other estimator, i.e. it\n", |
| 254 | + "has `predict` and `score` methods. Internally, it uses the model with the\n", |
277 | 255 | "best parameters found during `fit`.\n",
|
278 | 256 | "\n",
|
279 |
| - "Get predictions for the 5 first samples using the estimator with the best\n", |
280 |
| - "parameters." |
| 257 | + "Let's get the predictions for the 5 first samples using the estimator with the\n", |
| 258 | + "best parameters:" |
281 | 259 | ]
|
282 | 260 | },
|
283 | 261 | {
|
|
293 | 271 | "cell_type": "markdown",
|
294 | 272 | "metadata": {},
|
295 | 273 | "source": [
|
296 |
| - "You can know about these parameters by looking at the `best_params_`\n", |
297 |
| - "attribute." |
| 274 | + "Finally, we check the accuracy of our model using the test set." |
298 | 275 | ]
|
299 | 276 | },
|
300 | 277 | {
|
|
303 | 280 | "metadata": {},
|
304 | 281 | "outputs": [],
|
305 | 282 | "source": [
|
306 |
| - "print(f\"The best set of parameters is: {model_grid_search.best_params_}\")" |
| 283 | + "accuracy = model_grid_search.score(data_test, target_test)\n", |
| 284 | + "print(\n", |
| 285 | + " f\"The test accuracy score of the grid-search pipeline is: {accuracy:.2f}\"\n", |
| 286 | + ")" |
307 | 287 | ]
|
308 | 288 | },
|
309 | 289 | {
|
310 | 290 | "cell_type": "markdown",
|
311 | 291 | "metadata": {},
|
312 | 292 | "source": [
|
313 |
| - "The accuracy and the best parameters of the grid-searched pipeline are similar\n", |
| 293 | + "The accuracy and the best parameters of the grid-search pipeline are similar\n", |
314 | 294 | "to the ones we found in the previous exercise, where we searched the best\n",
|
315 |
| - "parameters \"by hand\" through a double for loop.\n", |
| 295 | + "parameters \"by hand\" through a double `for` loop.\n", |
| 296 | + "\n", |
| 297 | + "## The need for a validation set\n", |
| 298 | + "\n", |
| 299 | + "In the previous section, the selection of the best hyperparameters was done\n", |
| 300 | + "using the train set, coming from the initial train-test split. Then, we\n", |
| 301 | + "evaluated the generalization performance of our tuned model on the left out\n", |
| 302 | + "test set. This can be shown schematically as follows:\n", |
| 303 | + "\n", |
| 304 | + "\n", |
| 306 | + "\n", |
| 307 | + "<div class=\"admonition note alert alert-info\">\n", |
| 308 | + "<p class=\"first admonition-title\" style=\"font-weight: bold;\">Note</p>\n", |
| 309 | + "<p>This figure shows the particular case of <strong>K-fold</strong> cross-validation strategy\n", |
| 310 | + "using <tt class=\"docutils literal\">n_splits=5</tt> to further split the train set coming from a train-test\n", |
| 311 | + "split. For each cross-validation split, the procedure trains a model on all\n", |
| 312 | + "the red samples, evaluates the score of a given set of hyperparameters on the\n", |
| 313 | + "green samples. The best combination of hyperparameters <tt class=\"docutils literal\">best_params</tt> is selected\n", |
| 314 | + "based on those intermediate scores.</p>\n", |
| 315 | + "<p>Then a final model is refitted using <tt class=\"docutils literal\">best_params</tt> on the concatenation of the\n", |
| 316 | + "red and green samples and evaluated on the blue samples.</p>\n", |
| 317 | + "<p class=\"last\">The green samples are sometimes referred as the <strong>validation set</strong> to\n", |
| 318 | + "differentiate them from the final test set in blue.</p>\n", |
| 319 | + "</div>\n", |
316 | 320 | "\n",
|
317 | 321 | "In addition, we can inspect all results which are stored in the attribute\n",
|
318 | 322 | "`cv_results_` of the grid-search. We filter some specific columns from these\n",
|
|
0 commit comments