The world’s leading publication for data science, AI, and ML professionals.

Memoization in Python

Introduction to Memoization

Source
Source

Memoization is a term introduced by Donald Michie in 1968, which comes from the latin word memorandum (to be remembered). Memoization is a method used in computer science to speed up calculations by storing (remembering) past calculations. If repeated function calls are made with the same parameters, we can store the previous values instead of repeating unnecessary calculations. In this post, we will use memoization to find terms in the Fibonacci sequence.

Let’s get started!

First, let’s define a recursive function that we can use to display the first n terms in the Fibonacci sequence. If you are unfamiliar with recursion, check out this article: Recursion in Python.

As a reminder, the Fibonacci sequence is defined such that each number is the sum of the two previous numbers. For example, the first 6 terms in the Fibonacci sequence are 1, 1, 2, 3, 5, 8. We can define the recursive function as follows:

def fibonacci(input_value):
    if input_value == 1:
        return 1
    elif input_value == 2:
        return 1
    elif input_value > 2:
        return fibonacci(input_value -1) + fibonacci(input_value -2)

Here we specify the base cases, which say that if the input value is equal to 1 or 2, return 1. If the input value is greater than 2, return recursive function calls summing the previous two Fibonacci values.

Now, let’s print the first 10 terms:

for i in range(1, 11):
     print("fib({}) = ".format(i), fibonacci(i))

This seems to run fine. Now, let’s try displaying the first 200 terms:

for i in range(1, 201):
     print("fib({}) = ".format(i), fibonacci(i))

What we’ll find is that after fib(20) subsequent calculations take significantly longer than previous calculations. This is because with each subsequent calculation we are doing repeated work.

Consider how the recursive function is calculating each term:

fib(1) = 1

fib(2) = 1

fib(3) = fib(1) + fib(2) = 2

fib(4) = fib(3) + fib(2) = 3

fib(5) = fib(4) + fib(3) = 5

Notice, for fib(5) we are repeating the calculation of fib(4) and fib(3). If we had a way of remembering/storing those values upon calculating them, we’d avoid repeating calculations. This forms the motivation for the memoization method.

Let’s now walk through the steps of implementing the memoization method. To proceed, let’s initialize a dictionary:

fibonacci_cache = {}

Next, we will define our memoization function. First, we check if the input, which will be the dictionary key, exists in the dictionary. If the key is present we return the value corresponding to the input/key:

def fibonacci_memo(input_value):
        if input_value in fibonacci_cache:
            return fibonacci_cache[input_value]

Next, we define the base cases, which correspond to the two first values. If the input value is 1 or 2 then we set the value to 1:

def fibonacci_memo(input_value):
    ...
    if input_value == 1:
        value = 1
    elif input_value == 2:
        value = 1

Next, we consider the recursive cases. If the input is greater than 2, we set the value equal to the sum of the previous two terms:

def fibonacci_memo(input_value):
    ...
    elif input_value > 2:           
        value =  fibonacci_memo(input_value -1) + fibonacci_memo(input_value -2)

At the end we store the value in our dictionary and return the value:

def fibonacci_memo(input_value):
    ...
    fibonacci_cache[input_value] = value
    return value

The full function is:

def fibonacci_memo(input_value):
    if input_value in fibonacci_cache:
        return fibonacci_cache[input_value]
    if input_value == 1:
            value = 1
    elif input_value == 2:
            value = 1
    elif input_value > 2:           
            value =  fibonacci_memo(input_value -1) + fibonacci_memo(input_value -2)
    fibonacci_cache[input_value] = value
    return value

Now, let’s try displaying the first 200 terms with our new function:

for i in range(1, 201):
     print("fib({}) = ".format(i), fibonacci_memo(i))

Upon running our script, we see that we arrived at the 200th term in the sequence rather quickly.

There is a simpler way to implement memoization using less code. Let’s consider our original recursive function:

def fibonacci(input_value):
    if input_value == 1:
        return 1
    elif input_value == 2:
        return 1
    elif input_value > 2:
        return fibonacci(input_value -1) + fibonacci(input_value -2)

We can import a decorator from the ‘functools’ module, called ‘lru_cache’, that allows us to cache our values. The name stands for "least recently used cache". We can achieve the same performance as our ‘fibonacci_memo’ method using this decorator:

from functools import lru_cache
@lru_cache(maxsize = 1000)
def fibonacci(input_value):
    if input_value == 1:
        return 1
    elif input_value == 2:
        return 1
    elif input_value > 2:
        return fibonacci(input_value -1) + fibonacci(input_value -2)
for i in range(1, 201):
     print("fib({}) = ".format(i), fibonacci(i))

We see that we achieve similar performance. I’ll stop here but I encourage you to play around with the code yourself.

CONCLUSIONS

To summarize, in this post we discussed the memoization method in Python. First, we showed how the naive implementation of a recursive function becomes very slow after calculating many terms in the Fibonacci sequence. We then defined a new method where we stored past values that we’ve calculated in a dictionary. This leads to a significant speedup in calculations. We then discussed the ‘lru_cache’ decorator which allowed us to achieve a similar performance as our ‘fibonacci_memo’ method with less code. If you’re interested in learning more about memoization, I encourage you to check out Socratica’s YouTube tutorials. I hope you found this post useful/interesting. The code in this post is available on GitHub. Thank you for reading!


Related Articles