What if I told you database indexes could be learned?

Cody Marie Wild
Towards Data Science
5 min readDec 23, 2017

--

This paper is one that I unfortunately missed getting to see presented at NIPS, but which has been getting quite a lot of attention in ML circles in the last few days. The authors, who count among their number Jeff Dean, a very well-respected and early-days Google employee, have one core point, that they reiterate throughout the paper: at their heart, database indexes are models. They may not (typically) be statistically learned, but they are structures that provide a (hopefully quite fast) mapping between an input (the key upon which the index is built) and an output (a position in memory). A Binary Tree, which is a typical such structure used for ordered data, even takes the form of, well, a tree, which a core tool in the machine learning toolbox.

Building on this key intuition, the paper then asks: well, if these structures are just models, could statistical models that learn, and then leverage, the distribution of the data being indexed be better — smaller, more performant — than the indexes we currently use? Spoiler: the answer (at least for numeric data types) is yes.

For example, it could be the case that all elements are the same length, and the position of the numeric keys progress up by 5 with every key: in that case, you could easily learn a linear regression mapping between key and position that would be much faster than a B Tree that has to methodically go split by split. A B Tree is efficient in an ultimate worst case scenario, where the CDF density of the keys’ positions(if you ordered them up in a row) is truly random, and assumed unknown.

One intriguing fact about B-Trees is that they need to be rebalanced whenever new data is added, a process which the authors argue is analogous to a retraining of the model. So, for the purposes of comparison, they simply compare the performance on the *training* set, since both a B-Tree and a candidate statistical model are only expected to work up until they’ve been retrained.

Architecture & Results

The authors started out by training a baseline model: a two layer, 32 hidden units each, densely connected network. This model had two major problems.

  1. Firstly, it was quite slow to generate it’s predictions for min/max search positions as a function of key: it was initially trained in Tensorflow, which has high upfront costs that aren’t worth it for this small of a model
  2. Secondly, it just wasn’t that accurate on the level of individual keys. While it did a good job of learning the overall shape of the cumulative key distribution, it avoided overfitting to small local variations in that CDF function, and so became less accurate when you “zoomed in” on small regions of key space. Because of that, it didn’t significantly speed up the lookup process, over the baseline of just doing a full key scan. In the way they framed the problem, a simple model could quite easily reduce the expected search error from 100M to 10,000, but it would be difficult to reduce it all the way to being on the order of 100s, due to the smoothness assumptions inherent in the model

Proposed Solutions

These problems were addressed with two key solutions: one an implementation detail, and one a theoretical innovation. The first, is relatively straightforward (at least conceptually): build a framework where you can train models in Tensorflow, but models are evaluated in C++ at inference-stage. This leads to dramatically faster performance on the previously-tested baseline model: down to 30ns from 80,000ns.

The second, framed as a solution to the problem of “last mile” accuracy, is that of Recursive Models. In this framework, we start by training a “top level” model, that outputs a prediction for the key’s location. Then, we divide the space into, say, three parts, and then learn a model for each subregion separately. So, for example, the top model might predict that the key 4560 maps to location 2000, in a 10000 long memory region. So, they’d then group together observations, based on their predictions from the top model, and train a new model specifically on, say, keys with predicted locations between 0 and 3500.

When this approach — hierarchical models, using meaningfully more optimized code — was tried on data with numeric key values, the results were really impressive. Compared to B-Trees, the learned indexes achieved meaningful speedups, upwards of 60%, using much, much smaller models. Notably, the results below are not using a GPU. This suggests that if and when GPUs become more standard in database hardware, this improvement could even be increased.

They currently only show work for modeling of numeric keys, but suggest that more complicated methods currently used for text (RNNs, character-level CNNs) could be added onto this general framework.

Results for lat/long of Maps data

So, why is all of this interesting in the first place, beyond the pragmatic fact that it might lead to a new generation of Database Index design?

First off, I’ll just admit that I have a special place in my heart for papers that reframe one intellectual area in the context of another. In addition to introducing some compelling ML ideas, this paper also caused me to think more deeply and clearly about the mechanics behind how indexes work, whereas before, it’d always been something I’d generally understood but not delved deep into.

Secondly, there’s just the practical fact of it seeming meaningful as an example of machine learning models being used to optimize the kinds of lower-level systems on which they run. This is one of the first papers I recall seeing that uses machine learning to optimize the process of computing itself, but it seems quite unlikely that it will be the last.

--

--

machine learning engineer; lover of cats, languages, and elegant systems; professional curious person.