The world’s leading publication for data science, AI, and ML professionals.

The K-Nearest Neighbors Algorithm, Explained in Simple Terms

"Tell me who your neighbors are and I will tell you who you are."

The K-Nearest Neighbors Algorithm for Beginners

If you’ve hung out with a data scientist or you’re learning data science, chances are you’ve heard about the k-Nearest Neighbors algorithm (or k-NN, for short).

The first time I heard this term at a networking event, the following image popped into my head:

Photo by Yarne Fiten at Unsplash.
Photo by Yarne Fiten at Unsplash.

Sadly, the k-NN algorithm offers little information about your neighbors, unless, of course, you think of them as data points.

What is the k-NN algorithm?

Put simply, k-NN is a machine learning Classification algorithm that helps you predict the category something belongs to, based on its similarity to other data points around it.

To make this more concrete, imagine that you’re the chief data scientist for a start-up mobile network called Mobile, Inc. You have a dataset that contains information about your users. Among other variables, you know your users’ age, income (in 1000’s of USD per year), and their subscription plan (Basic, Plus, or Premium). The first 10 rows of your data set are displayed below:

First 10 rows of your customer data set. Image by author.
First 10 rows of your customer data set. Image by author.

One way to visualize how two numerical variables (age and income) are related to a categorical variable (subscription plan) is to use a color-coded scatter plot, as shown below.

Image by author.
Image by author.

Even though the number of data points is still very small, some patterns seem to be emerging in the data: there are regions or clusters of the same class.

Suppose that we have a new potential customer that is 36 years old and earns 124,000 USD per year. Using the k-NN algorithm, we want to predict her subscription plan based on her nearest neighbors.

New data point displayed in gray. Image by author.
New data point displayed in gray. Image by author.

If we set k = 1, the 1-NN algorithm will find the single closest neighbor. Because this neighbor is a "Plus" customer (blue), the algorithm will predict that the new customer will also become a "Plus" customer.

1-NN algorithm. Image by author.
1-NN algorithm. Image by author.

However, this would be a poor choice for k. We can clearly see that this "Plus" neighbor is an outlier in that region. Most of its neighbors are, in fact, "Premium" customers. What if instead of choosing k=1, we set k=6?

6-NN algorithm. Image by author.
6-NN algorithm. Image by author.

Using the 6-NN algorithm, we find that 5 out of the 6 nearest neighbors belong to the class "Premium." As a result, our machine learning algorithm predicts that the new customer will choose the "Premium" plan.

Within seconds, sales representatives at Mobile, Inc set in motion a marketing plan to highlight the benefits of the Premium plan to the new customer. Within minutes, the customer becomes another red dot in the scatter plot.

I could have predicted that. Do we really need a machine learning algorithm?

In the previous example, you may have been thinking: I could have guessed the customer’s subscription plan by looking at the scatter plot! Do we really need a machine learning algorithm for this?

However, imagine that instead of having two variables for every customer, you had more than a dozen relevant features (number of kids, credit score, occupation, ZIP code, gender, etc). And instead of having a few dozen data points, you had several thousand data points!

A human would no longer find this classifying task easy.

In fact, it’s extremely hard to wrap your head around high-dimensional space. Our visualization powers start to fail us past 3-dimensional space.

A scatter plot in 3-dimensional space. Photo source.
A scatter plot in 3-dimensional space. Photo source.

For a decent computer, however, running a k-NN classification algorithm on a medium-sized data set feels like a walk in the cyber-park.

How does the k-NN algorithm actually work?

By now you probably have an intuition of what _k-_NN does. But how does it work? The k-NN algorithm boils down to two simple concepts:

  • the Euclidian distance between two points in n-dimensional space
  • the statistical mode in a set of values

To understand this, let’s make a quick stop at high school algebra. Suppose that we have two points in a coordinate plane. We can calculate the distance between these points using the Euclidian distance formula:

Euclidian distance between two points. Image by author.
Euclidian distance between two points. Image by author.

This formula can be intuitively derived by imaging a right triangle and applying the Pythagorean theorem to find the length of the hypothenuse. For example, given points A and B, we can calculate the distance between them, as illustrated below:

Euclidian distance between two points. Image by author.
Euclidian distance between two points. Image by author.

More generally, for any two points p and q in n-dimensional space, we can calculate the distance between them using the following formula (where i represents the _i_th feature for points p and q).

Image by author.
Image by author.

Once the k-NN algorithm finds the k-nearest neighbors, it chooses the mode (the value that appears most frequently) in that set of values. The table below shows how a 6-NN algorithm would record the distances to the 6 nearest points.

6-NN with recorded distances. Image by author.
6-NN with recorded distances. Image by author.

The beauty of k-NN is that it runs on two concepts that remain relevant in n-dimensional space: the Euclidian distance and the statistical mode. Therefore, in theory, we can apply the k-NN algorithm to a dataset with x number of features.

However, given computational constraints, k-NN becomes less effective as the number of features in a data set increases (this falls outside of the scope of this article, but if you’re interested in learning more about this, let me know in the comments below).

How do we choose the best value for k?

As you learned earlier, the choice of k can dramatically change the algorithm’s predictive power. If k is too small, the classification algorithm will probably be poor because it’ll be sensitive to noise (outliers in the data).

On the other hand, if the k value is too large, the region of interest will be too large to find granular patterns in the data.

Clearly there’s a sweet spot. But how do we find it?

The short answer is that we can run k-NN with different values for k and test how well it predicts a subset of the data. Since we have the labels for this subset, we can compare the k-NN predictions to the actual values in our data set. By trying different values for k, we can select the k-NN algorithm with the highest accuracy rate.

Model accuracy as a function of k. Image by author.
Model accuracy as a function of k. Image by author.

It appears that k=7 is our best choice for predicting customer’s subscription plans based on our current user data set.

Summary

Whether you’re a beginner data scientist or you’re simply curious about machine learning algorithms, I hope this article gave you a better understanding of what k-NN is and how it works.

The two key takeaways from this article are:

  • k-NN is a very powerful classification algorithm that can help make predictions about new data based on its similarity to existing data points
  • the choice of k matters significantly: your _k-_NN algorithm should find granular patterns in the data but not be overly sensitive to noise

No matter how busy we are, perhaps we should remember to be friendly to our nearest neighbors. They might give us useful information about who we are.


Thank you for reading. In these articles, I try to explain Data Science concepts in a way that feels accessible and intuitive.

If this article was helpful to you, or you have any suggestions for new topics to cover, please comment below!


Related Articles