Relationships and machine learning: What do they have in common?

[Spoiler alert] The answer is overfitting

Fede Urena
Towards Data Science

--

Photo by Geran de Klerk on Unsplash

I have always thought of overfitting as the machine learning equivalent of a couple that has been together for a LONG time.

Maybe it’s a bit strong to kick-start an article with that statement, so if it sounds weird, check the couple of zebras in the image there. Let’s call them Dick and Jane and assume they’ve been a couple for a while now. So long, in fact, that they know the other zebra down to the smallest details concerning how it talks, thinks and even what happens just before a fight is going to break out (by the way, this is crucial information for self-preservation). So, Dick and Jane get comfortable and start thinking that this relationship/couples thing is easy, they have become experts in the art of “reading” facial zebra expressions, interpreting words in zebra language, and even body language, such as tail-wagging; they know when the other zebra is happy, sad or mad, and even what to say or do to get their way. Dick and Jane are masters of the relationship game, right?

Wrong. The proof? After many years together, they break up and it’s time to get back into the dating game again (for some of you zebras out there, dating is such a serious thing, it can hardly be called a “game”). But our little zebra friends find out that it is actually pretty hard. Maybe they’re not as good at dating/relationships as they thought they were. Well, guess what? They probably aren’t (and neither are you). So why did Dick and Jane ever think they were the masters of relationships?

Because they were doing some good old overfitting (relationship style)! Yeap, they built a relationship-learning model based on a single zebra and thought it would work with all zebras. Dick, for one, thought that because his nice little tricks worked with Jane, they would also work on his future couple, but then he found out that different zebras out there are actually different from each other. Surprised? Welcome to the dangerous world of overfitting.

Overfitting in machine learning

So, enough with the animal kingdom for now. Overfitting usually occurs in supervised learning, when you build a model based solely on a specific sample of instances. Basically, your model “memorizes” the attributes and values of the target variable of the individuals in the sample and then, when you ask for a prediction for one of the individuals in the same sample, the model simply spits out the previously-memorized value of the target variable for that individual and calls it a “prediction”. Sounds familiar? It should, because this is exactly what Excel’s VLOOKUP formula does: you feed it a value, and it looks for that value’s row on a table, and then returns the value of another column for that same row. And I think we can all agree that the VLOOKUP formula is not a machine learning model, not even in its wildest of dreams.

This is like getting to know your romantic partner so well, you know that when he/she raises the left eyebrow, a fight is about to begin. Well, that works fine for this specific person, but perhaps when you have a new partner, you’ll find out that for this new person, eyebrow-raising is actually a sign of “let’s-get-romantic” (if you know what I mean).

As silly as this sounds, this is a common mistake in data science. You build a model and fall in love with it because it predicts every single instance of your training set perfectly. You pitch it to your boss, organize a team meeting to present your big idea, and 30 seconds into the Q&A section, your experienced colleague from IT shoots it down because it doesn’t seem to generalize very well to other cases. So now you’re just there looking like a fool in front of everybody…

In all seriousness, companies sometimes spend millions of dollars in building models that simply memorize datasets, and call that a “predictive model”. Basically, you just spent a bunch of money for something you could have done in about 30 seconds (remember that VLOOKUP formula?)

So, how do we avoid the dangers of overfitting?

Glad you asked. Here are three general ideas to keep in mind while waging war on overfitting.

Training/testing set splitting

Photo by Antoine Dautry on Unsplash

This is a classic. I’m not going to go too deep into it, because it has been explained countless times in textbooks and YouTube videos by people far funnier and more empathetic than I am. Let me just leave you with this idea: if you were a third-grade math teacher, you wouldn’t use the same exercises you gave your students as practice, for the actual test. Why not? Well, because it would be far too easy to just memorize the answers in the back of the practice sheet and reproduce them on the test, so there would be close to zero math learning going on at all. If you understand this, you understand the importance of training/test splits in machine learning model building. The lesson here is simple: never test a model with the same data you used to build it, otherwise it will cheat on you, just like your students.

Cross validation

Photo by Mockup Graphics on Unsplash

Let’s say you like cooking and you come up with a brand-new recipe for an amazing banana bread the world has never tasted before. You want to find a way to objectively determine whether your recipe is any good, or if it’s just your magic cooking skills that make it taste great (what I would call cooking overfitting). So you gather ten cooking experts who want to go along for the ride. You give the recipe to eight of them, who prepare the bread and then have the remaining two try it and score it on a scale from 1 to 10. You repeat the experiment with the same ten experts, but this time you have a different couple try it and grade it. You do this three more times and obtain five different grades for your banana bread, each from a different pair of cooking experts. To get the final grade of how good your recipe is, you average the five grades, and you’re done.

This, in a nutshell, is cross validation. In order to know how good a machine learning model really is, you split your data into training and testing sets several times, always keeping the same proportion (80/20 in this case) but rotating the instances on which you test for accuracy. In a ten-instance data set, this is how it looks:

Image by the author

On the left, the complete dataset (each banana is an instance); on the right, the same dataset, only now we have performed five different train/test splits, called folds: for each split, the yellow bananas are the instances used to build the model and the green ones represent the instances we will use to test it. This is where the name “x-fold cross validation” comes from. In this case it’s a 5-fold cross validation, since we had 5 different splits, but 10-fold cross validation is also pretty common.

Now, all that’s left to do is test the accuracy of the model on each different fold, calculate some aggregate measure of overall accuracy, and we’re done. This overall score should provide us a reliable quantification of our model’s ability to predict the target variable, and it should also give us a good hint at the possibility of overfitting. Keep in mind: always find objective ways to measure your model’s performance and overfitting tendencies.

Complexity control

Image by the author

Yeah, that’s me, trying to become a songwriting hero, which never happened by the way. Man, that thing is hard. You know when you’re writing a song, but it is so boring and simple that no one is interested in listening to it? So you try to make it a little bit more complex and interesting, and people start liking it, so you make it even more complex, and everybody still likes it. And then you start thinking “Man, I’m a good song writer. The more complex the song is, the better it will turn out”.

Only it doesn’t really work like that. A point comes when your song is too complex so it is only appealing to, say, music experts and you end up losing most of your listeners (a good example of the inverted U curve). This my friend, is called complexity control and it applies also to machine learning models.

Basically, a machine learning model that is too simple, will not capture patterns in the data very well. Hence, it will make very vague predictions of your target variable, with possibly a very low accuracy. Like you did with your song, you would probably increase its complexity (perhaps by adding more explanatory variables, in case of a linear regression for example) to make the predictions more accurate. However, again like your song, this will work only up to a certain point, after which… Well unsurprisingly, your model starts overfitting the training data and its performance level decreases. So, in short, don’t get too caught up in complexity or it will drive all your fans away from you and back to that Bruno Mars song (which is definitely better than yours anyway).

These are only a few very general ideas out of many out there. The point here is to build in you an overall awareness of the dangers of overfitting and some clues as to where to look for solutions. For those of you interested in going deeper, there are MANY great YouTube channels with amazing and relevant content. You’ll find a couple here and here.

Zebras, revisited

Remember our friends, Dick and Jane? I’d like to tell you that they got back together, but the truth is they didn’t. Actually, they did something better: they each developed smarter relationship-learning models that broadly apply to more zebras and have allowed them not only to find new partners, but also to increase significantly the number of zebra specimens in the African savannah. I guess we’re all winners, aren’t we?

--

--

Born in Costa Rica, living in France. I write about business, strategy, and value exchange in general. And every now and then I tell a good joke.