Before we begin, this article is a follow-up to the article on building models with scikit-learn. I'll be reusing many concepts that I explained in that article, as well as quite a bit of code. While I will re-explain my code, I'll do so more briefly. If you'd like more detailed explanations, I strongly recommend reading my previous article on scikit-learn and the one on the introduction to Machine Learning. Now that the groundwork is laid, let's get started!
Cross-validation is a training method that involves training a model multiple times on the same dataset, but changing the training and test sets each time (which are drawn from different parts of the dataset). We then measure the test score of each trained model and calculate an average to assess the algorithm's performance. Between each training run, the model is naturally reset, to prevent it from learning the entire dataset.
The goal of these multiple training runs is to ensure that the method we used is sound, and that the data split between training and test sets hasn't unfairly favored one algorithm over another. In other words, the aim is to make sure the model has properly generalized the problem it's trying to solve, and that the performance measurements accurately reflect this.
How the dataset is split during cross-validation depends on the cross-validation strategy you choose. There are, in fact, several ways to split a dataset, and the choice of strategy depends on the dataset and what you want to accomplish. This article will cover the most common ones.
One clarification, though: as you may have already gathered, the purpose of cross-validation is not to improve our model's performance, but rather to measure its performance more reliably so we can compare it more precisely with other models.
As with the previous article, my dataset will be the Adult Census Income. This will allow us to build on the familiar ground from before, particularly regarding the model's objective. Indeed, the model will have the same goal: determining, based on the information available, whether a person's salary is above or below $50K per year.
For the model, we'll switch up the algorithm a bit. I'm going to use a Random Forest. Without necessarily going into detail, a Random Forest is a collection of decision trees, each trained on a small portion of the data, and the model makes its decision by averaging the responses from the decision trees. Here's the model that will serve as the foundation for cross-validation:
Base model for this article
This model is a Random Forest, and it takes care of preprocessing both numerical and categorical data to normalize the former and encode the latter.
In the first article, to train the model we used the train_test_split function, which split the dataset into 2 sets.
This function randomly separates the dataset into training and validation sets. Although this split is random, it can still be biased, creating data imbalances that may either make our model less performant or, conversely, favor it. The purpose of the models we build is to generalize the rules for deducing the target from our inputs, and to adapt to new data. It's therefore important that performance measurements are as precise and neutral as possible, eliminating as many potential biases as we can.
The first cross-validation strategy we'll look at is the K-Fold strategy. The principle is simple: we divide the dataset into K parts, where K is an integer chosen in advance (typically between 5 and 10), and we train K models, each time using 1 part of the dataset as the test set and the remaining parts as the training set. The part used as the test set is never the same, obviously.
Dataset split using KFold (K = 4)
How do we implement this with scikit-learn? The first thing to know is how to perform cross-validation. For this, there's the cross_validate function, which simplifies the process. Then, for KFold, there's the KFold class, which creates these different dataset segmentations.
By default, the cross_validate method returns 3 pieces of information per split: training time, scoring time, and the test score.
Cross-validation with scikit-learn
In the code above, you can see how cross-validation is performed. First, we create the strategy using, in this case, the KFold class. The n_splits parameter indicates the number of parts the dataset is divided into (what I've been calling K up to this point).
Next, we call the cross_validate function, which takes the model, data, and target as parameters. We then specify the cv parameter, which is the cross-validation strategy we want, and we can also specify the scoring metric using the scoring parameter. This last parameter can be a string predefined by scikit-learn or a custom scoring function.
We then retrieve the results, which I've converted here into a DataFrame for easier manipulation. This allows us to see the score of each trained model instance.
Pretty easy, right? Now that we've covered the basics, let's look at other strategies.
The Stratified K-Fold split principle is slightly different from regular K-Fold, and it applies only to classification problems. The goal of this split is still to create K parts of the dataset, but this time, we want each part, each fold, to contain the same distribution of data. Let me explain: let's go back to our dataset, the Adult Census Income. This dataset contains data from 32,000 American households. Among these households, 76% have an income below $50K, and 24% above $50K. The Stratified K-Fold split ensures that in each fold, there are 76% of households with income below $50K and 24% above $50K.
Dataset split using Stratified KFold (K = 4)
This guarantees that all data is properly represented in each fold. But it also assumes that the dataset accurately represents the data we'll encounter in production (which should be the case, but isn't always).
Let's move on to the implementation, and as always, scikit-learn has a class to help us: StratifiedKFold. Here's how it's used:
As with the K-Fold split, StratifiedKFold takes an n_splits parameter that defines the number of folds we want to create. The rest of the code is similar.
Just a quick note on the print output: one way to present cross-validation results is to display the mean of the test scores with an error margin corresponding to the standard deviation of the results. This makes it easy to present and compare data.
Another thing I'd like to show you about the cross_validate method is that the cv parameter also accepts integers as values. The split strategy used will then be Stratified K-Fold for classification problems and a simple K-Fold for regression problems.
The random split, as its name suggests, randomly separates the dataset for each model instance between training and test sets. You need to define 2 things: the number of splits you want to perform and the size of the test set.
Random split
In scikit-learn, the class that implements this is ShuffleSplit, and here's how it's used:
Here, we'll train 10 models with an 80/20 split between training and test data. The rest of the code remains the same.
Cross-validation is a process that allows for more precise and reliable comparison of different model architectures, by reducing biases that may be associated with splitting training and test sets. The main mechanism at work in this process is training multiple models with different training and test sets, even though they all come from the same dataset.
There are many possible strategies for splitting data during cross-validation. This article covers 3 of them, but scikit-learn offers many more. You can find them in their documentation.
Pilier de Lamalo, Yohann allie expertise technique et pédagogie. Archi dans l'âme, développeur de talent, il apporte son énergie et ses compétences à la scale-up Lamalo. Pédagogue, il n'hésite pas à partager son savoir.
LinkedInGet our best articles every month.
Formateurs opérationnels. IA, data science, développement web. Certifié Qualiopi.
ProjectDébloquer la valeur cachée dans des milliers de documents. Un projet bancaire qui transforme la recherche documentaire en quelques secondes.
ProjectLe premier produit propre de Reboot Conseil. Une solution innovante née de la collaboration.
ProjectDébloquer l'extraction de données hétérogènes. Un projet utilisant l'IA multimodale pour 9 marques.
ProjectOrchestrer plusieurs LLMs et services IA. Un projet créant un système d'agents IA scalable.
ProjectCréer une plateforme IA accessible sur web et mobile. Un projet combinant orchestration IA et mobilité.