Wanna know more about data science? Make sure to check out my events and my webinar What it's like to be a data scientist and What’s the best way to become a data scientist !

How much data you need for machine learning?

Collecting data is not easy. A general saying in data mining is “the more data the better”, but sometimes collecting more data is expensive, or just not feasible. Therefore, it becomes very valuable to try and understand how data affects a performance of a model. If a model can learn well with 1000 data points, then why collect 1000 more?

Learning curves are one such tool that helps us do exactly that.

Learning curves

Learning curves show you how the performance of a classifier changes. Here is an example of a learning curve. This is example from scikit-learn’s implementation.learning curve

So, on this curve you can see both the training and the cross-validation score. The training score doesn’t change much by adding more examples. But the cross-validation score definitely does! You can see that once we reach about 1000-1200 examples there is very small change in performance.

So, what this tells us is that adding more examples over the ones we currently have is probably not required.

How can you use learning curves?

Learning curves are super easy to use through scikit-learn. Here is an example piece of code below:

from sklearn.model_selection import learning_curve
from sklearn.svm import SVC
from sklearn.datasets import load_digits
from matplotlib import pyplot as plt
import numpy as np

X, y = load_digits(return_X_y=True)
estimator = SVC(gamma=0.001)

train_sizes, train_scores, test_scores, fit_times, _ = learning_curve(estimator, X, y, cv=30,return_times=True)

plt.plot(train_sizes,np.mean(train_scores,axis=1))

This will plot a curve like the one below:

learning curve scikit-learn

Here we have used the default setting of splitting up the data in 5 groups. In this case we have used the default train size split which is [0.1, 0.33, 0.55, 0.78, 1.0].

Note two things.

First, we have to call np.mean(train_scores,axis=1) before we plot. Why is this? If we print train_score then we see that it is a matrix, with a number of rows equal to the number of cross-validation iterations we used. So, we have to take the mean across all folds before we plot it.

Second, note that the function also returns fit_times. This returns the time it took to fit the model. As you can expect, the more data the longer it takes to run the model. If we get the mean running time for each one of the data samples.

array([0.8336064 , 0.96932203, 0.98043315, 0.9843597 , 0.98886064])

On scikit-learn’s website you can also find a very nice piece of code that will create a very nice plot for you.

Learning curves

So, next time you are not sure whether you need to find more data or not, you know which tool to use! Scikit-learn makes learning curves very easy to use, and can help you make an objective cost-benefit analysis, as to how to proceed with data collection. Make sure to add this technique as a staple in your machine learning arsenal!

 


Wanna know more about data science? Besides my events, you should check out my webinars:
  1. If you want to learn data science: What it's like to be a data scientist and What’s the best way to become a data scientist
  2. If you are a CEO: The importance of data strategy

Categories: Uncategorized

Dr. Stylianos Kampakis is the owner and author of The Data Scientist.