The world’s leading publication for data science, AI, and ML professionals.

Understanding DeepMind and Strassen algorithms

An introduction to the matrix multiplication problem, with applications in Python and JAX

Deepmind recently published an interesting paper that employed Deep Reinforcement Learning to find new Matrix Multiplication algorithms[1]. One of the objectives of this paper is to lower matrix multiplication computational complexity. The article has raised a lot of comments and questions about matrix multiplication – as you can see from Demis Hassabis’ tweet.

Matrix multiplication is an intense research area in mathematics [2–10]. Although matrix multiplication is a simple problem, the computational implementation has some hurdles to solve. If we are considering only square matrices, the first idea is to compute the product as a triple for-loop :

Such a simple calculation has a computational complexity of O(n³). This means that the time for running such a computation increases as the third power of the matrix size. This is a hurdle to tackle, as in AI and ML we deal with huge matrices for every model’s step – neural networks are tons of matrix multiplications! Thus, given a constant computational power, we need more and more time to run all our AI calculations!

DeepMind has brought the matrix multiplication problem to a more concrete step. However, before digging into this paper, let’s have a look at the matrix multiplication problem and what algorithms can help us in lowering the computational power. In particular, we’ll take a look at Strassen’s algorithm and we’ll then implement it in Python and Jax.

Strassen, Coppersmith & Winograd, Williams & Alman

_Remember that for the rest of the paper the size of matrices will be N>>1000. All the algorithms are to be applied to block matrices._

The product matrix C, is given by the sum over the rows and the columns of matrices A and B, respectively – fig.2.

As we saw in the introduction, the computational complexity for the standard matrix multiplication product is O(n³). In 1969 Volker Strassen, a German mathematician, smashed the O(n³) barrier, reducing the matrix multiplication to 7 multiplications and 18 additions, getting to a complexity of O(n²·⁸⁰⁸)[8]. If we consider a set of matrices A, B and C, as in fig.3, Strassen derived the following algorithm:

It’s worth noticing a few things about this algorithm:

  • the algorithm works recursively on block matrices
  • it’s easy to prove the complexity _O(n²·⁸⁰⁸). G_iven the matrix size n and the 7 multiplications, it follows that:

All the steps in fig.4 are polynomials, therefore the matrix multiplication can be treated as a polynomial problem.

  • Computationally Strassen’s algorithm is unstable with floating precision numbers [14]. The numerical instability is caused by rounding all the sub-matrices results. As the calculation progresses the total error sums up in a grave loss of accuracy.

From these points we can translate the matrix multiplication to a polynomial problem. Each operation in fig.4 can be written as a linear combination, for example steps I is:

Here α and β are the linear combinations from matrices A and B’s elements, while H denotes a one-hot encoded matrix for addition/subtraction operations. It’s then possible to define the product matrix C elements as a linear combination and write:

As you can see the entire algorithm has been reduced to a linear combination. In particular, the left-hand side of the equation Strassenin fig.7 can be denoted by the matrices sizes, m, n, and p – which means a multiplication between _m_xp and _n_xp matrices:

For Strassen <n, m, p> is <2,2,2>. Fig.8 describes the matrix multiplication as a linear combination, or a tensor – that’s why sometimes the Strassen algorithm is called "tensoring" . The a, b, and c elements in fig.8 form a triad. Following DeepMind’s paper convention the triad can be expressed as:

This triad establishes the objective of finding the best algorithm for minimizing the computational complexity of the matrix multiplication operation. Indeed, the minimal number of triads defines the minimal number of operations to compute the product matrix. This minimal number is the tensor’s rank R(t). Working on the tensor rank we’ll help us in finding new and more efficient matrix multiplication algorithms – and that’s what DeepMind people have done.

Starting from Strasse’s work, between 1969 and today there’s been a continual creation of new algorithms to solve the matrix multiplication complexity problem (tab.1).

How did Pan, Bini, Schonage, and all the researchers get to these brilliant results? One way to solve a computer science problem is to start with the definition of an algebraic problem P. For the matrix multiplication problem, for example, the algebraic problem P can be: "find a mathematical model to evaluate a set of polynomials". From here, scientists start to reduce the problem and "convert" it to a matrix multiplication problem – here is a good explanation of this approach. In a nutshell, scientists were able to prove theorems as well as steps that could decompose the polynomial evaluation to a matrix multiplication algorithm. Eventually, they all got a theoretical algorithm that can be more powerful than the Strassen algorithm (tab.1)

However, these theoretical algorithms can’t be coded up, unless there are some heavy and strong mathematical assumptions and restrictions that could affect the algorithm’s efficiency.

Let’s now see how powerful Strassen’s algorithm is and how we can implement it in Python and JAX.

0.20 improvement: don’t joke! Is Strassen’s algorithm effectively improving matrix multiplications?

Here is the repo with all the following codes. I run these tests on a MacBook Pro, 2019, 2.6GHz 6-Core Intel Core i7, 16 GB 2667MHz DDR4 Memory.

Python: numpy.matmul vs Strassen

In the main code we can follow these steps:

  • we are going to create 2 square matrices A and B, initialised with random integers
  • we are going to test the algorithms for different matrices’ sizes: 128, 256, 512, 768, 1024, 1280, 2048
  • For each size will run numpy.matmul and Strassen’s algorithms three times. At each run we are recording the running time in a list. Form this list we are extracting the average time and the standard deviation to compare both methods (fig.10)

The core part of the script is the recursivestrassen function:

  • Firstly, we check input matrix dimension. If the dimension is below a given threshold (or not divisible by 2) we can compute the remaining product with standard numpy, as this won’t influence the final computational cost
  • For each input matrix the top left, top right, bottom left and bottom right sub-matrices are extracted. In my code, I am proposing a naive and simple solution, so everybody can understand what’s going on. To further test and understand the block matrix creation, try to manually compute the indices for a small matrix (e.g. 12×12).
  • In the final step, the product matrix is reconstructed from all the computed sub-elements ( C11, C12, C21 and C22 in fig.4)

    Fig.10 compares the standard numpy.matmul and strassen algorithm. As you can see for a dimension < 2000 ( Matrix Size < 2000 ) Strassen can be outperformed by the standard matrix multiplication. The real improvement can be seen on bigger matrices. Strassen completes the matrix multiplication for a 2048×2048 matrix in 8.16 +/- 1.56 s, while the standard methods required 63.89 +/- 2.48 s. Doubling the matrix size, 4096 columns and rows, Strassen runs in 31.57 +/- 1.01 s, while the standard matrix multiplication takes 454.37 +/- 6.27 s.

Jax implementation: on DeepMind’s giant shoulders

According to the equation in fig.9 we can further decompose the Strassen algorithm in a tensor form. The tensors u, v and w can then be applied to the matrices’ blocks to obtain the final product matrix. C.H. Huang, J. R. Johnson, and R. W. Johnson published a little paper to show how to derive the tensor version of Strassen [18], followed by another formulation in 1994 [19] where they explicitly wrote Strassen’s u, v and w tensors. For the detailed calculations you can check [18], while fig.12 reports the tensors values.

This is a good starting point for working with JAX and comparing Strassen to the standard jax.numpy.matmul. For the JAX script I have followed closely DeepMind’s implementation.

The script deals with 4×4 block matrices. The core function, f, runs Strassen method. In this case, all the A and B block matrices are multiplied by the _u_and v tensors. The result is multiplied by tensor w, obtaining the final product (fig.13). Given JAX powerful performance, the algorithm was tested on the following matrices’ dimensions: 8192, 10240, 12288, 14336, 16384, 18432, 20480

Finally, in the very last step the product matrix is reconstructed by concatenating and reshaping the product matrix from f function (fig.14).

Fig. 15 compares JAX numpy matrix multiplication with the Strassen implementation. As you can see JAX is very powerful, as 8192×8192 matrices multiplication can be run in 12 s (on average). For dimensinos under 12000×12000 there is no real improvement and JAX standard method takes an average computational time of 60s on my laptop – while I am running some other things. Above that dimensions we can see an impressive 20% improvement. For example, for 18432×18432 and 20480×20480, the Strassen algorithm runs in 142.31+/-14.41 s and 186.07+/-12.80 s, respectively – and this was done by running on a CPU. A good homework could be trying this code adding the device_put option and running on Colab’s GPU. I am sure you’ll be flabbergasted!

Conclusions

Today we made a little step forward to get a complete understanding of DeepMind’s publication "Discovering faster matrix multiplication algorithms with Reinforcement Learning" [1]. This paper proposes new ways to tackle the matrix multiplication problem, using Deep Reinforcement Learning.

In this first article, we started to scratch the surface of matrix multiplication. We learned what’s the computational cost for this operation and we saw the Strassen algorithm.

From here we defined how the Strassen algorithm is made and what are its mathematical implications. Since its publication, researchers have found better and better solutions to the matrix multiplication problem. However, not all of these methods can be implemented in code.

Finally, we played a bit with Python and JAX to find out how powerful the algorithm is. We learned that Strassen is a great tool to use when we have to deal with very big matrices. We saw the power of JAX in handling big matrix multiplications and how easy is to implement such a solution, without using GPUs or further memory options.

In the next paper, we’ll see more details from DeepMind’s paper. In particular, we’ll tackle the deep reinforcement algorithm, as well as the paper findings. Then, we’ll implement the new DeepMind algorithms and run them in JAX on a GPU instance.

I hope you enjoyed this article 🙂 and thanks for reading it.

Support my writing:

Join Medium with my referral link – Stefano Bosisio

Please, feel free to send me an email for questions or comments at: [email protected] or directly here in Medium.

Bibliography

  1. Fawzi, Alhussein, et al. "Discovering faster matrix multiplication algorithms with reinforcement learning." _Nature_610.7930 (2022): 47–53.
  2. Bläser, Markus. "Fast matrix multiplication." Theory of Computing (2013): 1–60.
  3. Bini, Dario. "O (n2. 7799) complexity for nxn approximate matrix multiplication." (1979).
  4. Coppersmith, Don, and Shmuel Winograd. "On the asymptotic complexity of matrix multiplication." SIAM Journal on Computing 11.3 (1982): 472–492.
  5. Coppersmith, Don, and Shmuel Winograd. "Matrix multiplication via arithmetic progressions." Proceedings of the nineteenth annual ACM symposium on Theory of computing. 1987.
  6. de Groote, Hans F. "On varieties of optimal algorithms for the computation of bilinear mappings II. Optimal algorithms for 2× 2-matrix multiplication." Theoretical Computer Science 7.2 (1978): 127–148.
  7. Schönhage, Arnold. "A lower bound for the length of addition chains." Theoretical Computer Science 1.1 (1975): 1–12.
  8. Strassen, Volker. "Gaussian elimination is not optimal." Numerische mathematik 13.4 (1969): 354–356.
  9. Winograd, Shmuel. "On multiplication of 2× 2 matrices." Linear algebra and its applications 4.4 (1971): 381–388.
  10. Gentleman, W. Morven. "Matrix multiplication and fast Fourier transforms." The Bell System Technical Journal 47.6 (1968): 1099–1103.
  11. Alman, Josh, and Virginia Vassilevska Williams. "A refined laser method and faster matrix multiplication." Proceedings of the 2021 ACM-SIAM Symposium on Discrete Algorithms (SODA). Society for Industrial and Applied Mathematics, 2021.
  12. Le Gall, François. "Powers of tensors and fast matrix multiplication." Proceedings of the 39th international symposium on symbolic and algebraic computation. 2014.
  13. Williams, Virginia Vassilevska. "Multiplying matrices faster than Coppersmith-Winograd." Proceedings of the forty-fourth annual ACM symposium on Theory of computing. 2012.
  14. Bailey, David H., King Lee, and Horst D. Simon. "Using Strassen’s algorithm to accelerate the solution of linear systems." The Journal of Supercomputing 4.4 (1991): 357–371.
  15. Pan, V. Ya. "Strassen’s algorithm is not optimal trilinear technique of aggregating, uniting and canceling for constructing fast algorithms for matrix operations." 19th Annual Symposium on Foundations of Computer Science (sfcs 1978). IEEE, 1978.
  16. Schönhage, Arnold. "Partial and total matrix multiplication." SIAM Journal on Computing 10.3 (1981): 434–455.
  17. Davie, Alexander Munro, and Andrew James Stothers. "Improved bound for complexity of matrix multiplication." Proceedings of the Royal Society of Edinburgh Section A: Mathematics 143.2 (2013): 351–369.
  18. Huang, C-H., Jeremy R. Johnson, and Rodney W. Johnson. "A tensor product formulation of Strassen’s matrix multiplication algorithm." Applied Mathematics Letters 3.3 (1990): 67–71.
  19. Kumar, Bharat, et al. "A tensor product formulation of Strassen’s matrix multiplication algorithm with memory reduction." _Scientific Programming_4.4 (1995): 275–289.

Related Articles