Notes on Matrix Calculus for Deep Learning

Based on this paper by Parr and Howard.

Deep learning is an exciting field that is having a great real-world impact. This article is a collection of notes based on ‘The Matrix Calculus You Need For Deep Learning’ by Terence Parr and Jeremy Howard.

Deep Learning is all about Linear algebra. It is the use of neural networks with many many layers to solve complex problems. The model inputs, the neuron weights in multiple layers, the activation functions etc can all be defined as vectors. The operations/transforms needed to train or utilize a neural network are very parallel in nature, applied to all the inputs simultaneously. The vector/matrix representations and linear algebra operations that can be used on them, lend themselves very well to the pipelined data-flow model of a neural network. The math becomes greatly simplified when the inputs, weights, and functions are treated as vectors and the flow of values can be treated as operations on matrices.

Deep Learning is also all about differentiation! Calculating derivatives/having some way to measure the rate of change is critical in the training phase to optimize the loss functions. Starting with an arbitrary set of network model weights w, the goal is to arrive at an ‘optimal’ set of weights so as to reduce a given loss function. Nearly all neural networks use the backpropagation method to find such a set of weights. This process involves determining how a change in the weight value might affect the output. based on this we may decide to increase or decrease the value of the weight by a proportional amount. Measuring how the output changes with respect to a change in weight is the same as calculating the (partial) derivative of the output w.r.t weight w. This process is repeated many times, for all the weights in all the layers, for all the training examples.

Matrix calculus marries two fundamental branches of mathematics - linear algebra and calculus. A large majority of people have been introduced to linear algebra and calculus in isolation. These two topics are heavyweights in their own right. Not many undergraduate courses focus on matrix calculus. People usually rely on intuition to bridge the gaps in understanding while looking into concepts like backpropagation. The backpropagation step in most machine learning algorithms is all about calculating derivatives and updating values in vectors and matrices. Most machine learning frameworks do the heavy-lifting themselves and we never end up seeing the actual derivatives being calculated. However, it is always good to understand how this works internally, and it is essential if you are planning to be a serious practitioner or develop an ML library from scratch.

While the paper is geared to DL practitioners and coders, it is mathematical in nature. It is really important to pay attention to notation to cement your understanding. It was essential to pay particular attention to things like the shape of a vector (long or tall), is the variable scalar or vector, the dimensions of a matrix. Vectors are represented by bold letters. The untrained eye would probably not notice the difference between bold f and italicized f font but this makes a great difference while trying to understand the equations. The same thing goes for the shape and orientation of vectors. I went down multiple rabbit-hole paths trying to understand something, only to learn that my initial reading of the terms was incorrect.

One good thing is the manner in which the concept of functions (and ways to calculate their derivatives )are defined from the simple to the more complex. First, we start with functions of simple parameters represented by f(x). The function and the parameter x are scalars (represented in italics), and we can use the traditional derivative rules for finding derivative of f(x). Second, the kind of functions we’ll see often have many variables associated with it; of the form f(x,y,z). To calculate the derivatives of such functions, we use partial derivatives which are calculated with respect to specific parameters. The branch of calculus dealing with such functions would be multivariate calculus.

Grouping the input variables x, y, z as a vector in bold x, we can represent the scalar function of a vector of input parameters as f(x). The calculus for this field would be vector calculus, wherein the partial derivatives of f(x) are represented as vectors themselves and are amenable to various vector operations. Lastly, what will be most useful for deep learning is to represent multiple such functions at the same time. We use f(x) to represent a set of scalar functions of the form f(x). The field of calculus for this is the most general, namely matrix calculus.

To recap, f(x) is a scalar function of a scalar variable (use simple derivative rules), f(x) is a scalar function of vector variable x (use vector calculus rules) and f(x) is a vector of many scalar valued functions, with each function depending on a vector of inputs x (use matrix calculus rules). The paper demonstrates how to calculate derivatives of simple functions, and the relationships between partial derivatives in multivariate calculus (∂/∂x ), the gradient ∇ f function in vector calculus, and the Jacobian J in matrix calculus. Loosely put, the ∇ f(x) function is the collection of partial derivatives of f in a vector form. The Jacobian of f(x) is basically a stack of the individual f(x)’s in rows.

In the process of calculating the partial derivatives, the paper makes a few assumptions. It is important to keep in mind the end product to which these concepts will be applied eventually i.e. calculating the partial derivatives for the output function (y = w.x +b) and the loss function. The paper does provide a glimpse of these by foreshadowing where they will be used. The first assumption is that the cardinality of vector x is equal to the number of scalar functions in f. This results in having a nice square Jacobian. If you’re wondering why they need to be equal, consider the case where each input to a neuron xi is associated with a weight wi (the scalar function here is akin to xi*wi) and so we have as many x’s as there are w’s.

Another important assumption is in regards to the element-wise diagonal property. Basically, the property states that the ith scalar function in f(x) is a function of (only) the ith term in vector x. Again this makes more sense when you think of the common neuron mode use case. The contribution of input xi is in proportion to a single parameter wi. Assuming the element-wise diagonal property makes the Jacobian ( made square by the first assumption) into a diagonal matrix, with all the non-diagonal terms being zero.

The next few sections of the paper explain the process of calculating derivatives for more complicated functions. There are a few ways in which functions can go from the simple to the complex. First, consider functions that are derived by applying element-wise binary operators on two vectors (of the same size, of course). These are functions of the form f(x,y) = x + y, or max(x, y). Note that x, y are vectors in this case. Next, there are scalar expansion functions that are derived by multiplying/adding scalars to a vector (might remind some of us of broadcasting in Numpy). This operation involves ‘expanding’ the scalar to the same dimension as the vector and then performing the element-wise multiplication.addition operation. For example y = x + b ( y, x are vectors) means b is expanded to vector b and added element-wise to x.

Third, consider functions that collapse values in a vector to a single value. The most common example is calculating the loss of neural networks, usually of the form y = sum(f(x)). Here y is a scalar value got by adding up the elements of the vector f(x). The derivatives for these three cases are calculated in the paper. There are functions that can be more complex, for which the chain rule of derivatives are used. The paper describes the chain rule for simple scalar functions and gradually extends it all the way to the most general purpose of all, the vector chain rules.