k nearest neighbors computational complexity

Understanding the computational cost of kNN algorithm, with case study examples

Jakub Adamczyk
Towards Data Science

--

Visualization of the kNN algorithm (source)

Algorithm introduction

kNN (k nearest neighbors) is one of the simplest ML algorithms, often taught as one of the first algorithms during introductory courses. It’s relatively simple but quite powerful, although rarely time is spent on understanding its computational complexity and practical issues. It can be used both for classification and regression with the same complexity, so for simplicity, we’ll consider the kNN classifier.

kNN is an associative algorithm — during prediction it searches for the nearest neighbors and takes their majority vote as the class predicted for the sample. Training phase may or may not exist at all, as in general, we have 2 possibilities:

  1. Brute force method — calculate distance from new point to every point in training data matrix X, sort distances and take k nearest, then do a majority vote. There is no need for separate training, so we only consider prediction complexity.
  2. Using data structure — organize the training points from X into the auxiliary data structure for faster nearest neighbors lookup. This approach uses additional space and time (for creating data structure during training phase) for faster predictions.

We focus on the methods implemented in Scikit-learn, the most popular ML library for Python. It supports brute force, k-d tree and ball tree data structures. These are relatively simple, efficient and perfectly suited for the kNN algorithm. Construction of these trees stems from computational geometry, not from machine learning, and does not concern us that much, so I’ll cover it in less detail, more on the conceptual level. For more details on that, see links at the end of the article.

In all complexities below, times of calculating the distance were omitted, since they are in most cases negligible compared to the rest of the algorithm. Additionally, we mark:

  • n: number of points in the training dataset
  • d: data dimensionality
  • k: number of neighbors that we consider for voting

Brute force method

Training time complexity: O(1)

Training space complexity: O(1)

Prediction time complexity: O(k * n * d)

Prediction space complexity: O(1)

Training phase technically does not exist, since all computation is done during prediction, so we have O(1) for both time and space.

Prediction phase is, as method name suggest, a simple exhaustive search, which in pseudocode is:

Loop through all points k times:
1. Compute the distance between currently classifier sample and
training points, remember the index of the element with the
smallest distance (ignore previously selected points)
2. Add the class at found index to the counter
Return the class with the most votes as a prediction

This is a nested loop structure, where the outer loop takes k steps and the inner loop takes n steps. 3rd point is O(1) and 4th is O(# of classes), so they are smaller. Additionally, we have to take into consideration the numer of dimensions d, more directions mean longer vectors to compute distances. Therefore, we have O(n * k * d) time complexity.

As for space complexity, we need a small vector to count the votes for each class. It’s almost always very small and is fixed, so we can treat it as a O(1) space complexity.

k-d tree method

Training time complexity: O(d * n * log(n))

Training space complexity: O(d * n)

Prediction time complexity: O(k * log(n))

Prediction space complexity: O(1)

During the training phase, we have to construct the k-d tree. This data structure splits the k-dimensional space (here k means k dimensions of space, don’t confuse this with k as a number of nearest neighbors!) and allows faster search for nearest points, since we “know where to look” in that space. You may think of it like a generalization of BST for many dimensions. It “cuts” space with axis-aligned cuts, dividing points into groups in children nodes.

Constructing the k-d tree is not a machine learning task itself, since it stems from computational geometry domain, so we won’t cover this in detail, only on conceptual level. The time complexity is usually O(d * n * log(n)), because insertion is O(log(n)) (similar to regular BST) and we have n points from the training dataset, each with d dimensions. I assume the efficient implementation of the data structure, i. e. it finds the optimal split point (median in the dimension) in O(n), which is possible with the median of medians algorithm. Space complexity is O(d * n) — note that it depends on dimensionality d, which makes sense, since more dimensions correspond to more space divisions and larger trees (in addition to larger time complexity for the same reason).

As for the prediction phase, the k-d tree structure naturally supports “k nearest point neighbors query” operation, which is exactly what we need for kNN. The simple approach is to just query k times, removing the point found each time — since query takes O(log(n)), it is O(k * log(n)) in total. But since the k-d tree already cuts space during construction, after a single query we approximately know where to look — we can just search the “surroundings” around that point. Therefore, practical implementations of k-d tree support querying for whole k neighbors at one time and with complexity O(sqrt(n) + k), which is much better for larger dimensionalities, which are very common in machine learning.

The above complexities are the average ones, assuming the balanced k-d tree. The O(log(n)) times assumed above may degrade up to O(n) for unbalanced trees, but if the median is used during the tree construction, we should always get a tree with approximately O(log(n)) insertion/deletion/search complexity.

ball tree method

Training time complexity: O(d * n * log(n))

Training space complexity: O(d * n)

Prediction time complexity: O(k * log(n))

Prediction space complexity: O(1)

Ball tree algorithm takes another approach to dividing space where training points lie. In contrast to k-d trees, which divides space with median value “cuts”, ball tree groups points into “balls” organized into a tree structure. They go from the largest (root, with all points) to the smallest (leaves, with only a few or even 1 point). It allows fast nearest neighbor lookup because nearby neighbors are in the same or at least close “balls”.

During the training phase, we only need to construct the ball tree. There are a few algorithms for constructing the ball tree, but the one most similar to k-d tree (called “k-d construction algorithm” for that reason) is O(d * n * log(n)), the same as k-d tree.

Because of the tree building similarity, the complexities of the prediction phase are also the same as for k-d tree.

Choosing the method in practice

To summarize the complexities: brute force is the slowest in the big O notation, while both k-d tree and ball tree have the same lower complexity. How do we know which one to use then?

To get the answer, we have to look at both training and prediction times, that’s why I have provided both. The brute force algorithm has only one complexity, for prediction, O(k * n). Other algorithms need to create the data structure first, so for training and prediction they get O(d * n * log(n) + k * log(n)), not taking into account the space complexity, which may also be important. Therefore, where the construction of the trees is frequent, the training phase may outweigh their advantage of faster nearest neighbor lookup.

Should we use k-d tree or ball tree? It depends on the data structure — relatively uniform or “well behaved” data will make better use of k-d tree, since the cuts of space will work well (near points will be close in the leaves after all cuts). For more clustered data the “balls” from the ball tree will reflect the structure better and therefore allow for faster nearest neighbor search. Fortunately, Scikit-learn supports “auto” option, which will automatically infer the best data structure from the data.

Let’s see this in practice on two case studies, which I’ve encountered in practice during my studies and job.

Case study 1: classification

The more “traditional” application of the kNN is the classification of data. It often has quite a lot of points, e. g. MNIST has 60k training images and 10k test images. Classification is done offline, which means we first do the training phase, then just use the results during prediction. Therefore, if we want to construct the data structure, we only need to do so once. For 10k test images, let’s compare the brute force (which calculates all distances every time) and k-d tree for 3 neighbors:

Brute force (O(k * n)): 3 * 10,000 = 30,000

k-d tree (O(k * log(n))): 3 * log(10,000) ~ 3 * 13 = 39

Comparison: 39 / 30,000 = 0.0013

As you can see, the performance gain is huge! The data structure method uses only a tiny fraction of the brute force time. For most datasets this method is a clear winner.

Case study 2: real-time smart monitoring

Machine Learning is commonly used for image recognition, often using neural networks. It’s very useful for real-time applications, where it’s often integrated with cameras, alarms etc. The problem with neural networks is that they often detect the same object 2 or more times — even the best architectures like YOLO have this problem. We can actually solve it with nearest neighbor search with a simple approach:

  1. Calculate the center of each bounding box (rectangle)
  2. For each rectangle, search for its nearest neighbor (1NN)
  3. If points are closer than the selected threshold, merge them (they detect the same object)

The crucial part is searching for the closest center of another bounding box (point 2). Which algorithm should be used here? Typically we have only a few moving objects on camera, maybe up to 30–40. For such a small number, speedup from using data structures for faster lookup is negligible. Each frame is a separate image, so if we wanted to construct a k-d tree for example, we would have to do so for every frame, which may mean 30 times per second — a huge cost overall. Therefore, for such situations a simple brute force method works fastest and also has the smallest space requirement (which, with heavy neural networks or for embedded CPUs in cameras, may be important).

Summary

kNN algorithm is a popular, easy and useful technique in Machine Learning, and I hope after reading this article you understand it’s complexities and real world scenarios where and how you can use this method.

References:

--

--

Data Science student, ML engineer, Data Science and ML algorithms enthusiast.