Introducing K-FAC
A Second-Order Optimization Method for Large-Scale Deep Learning
In this article, I summarize Kronecker-factored Approximate Curvature (K-FAC) (James Martens et al., 2015), one of the most efficient second-order optimization method for deep learning.
Overview
The heavy computation of the curvature has limited the number of application of second-order optimization methods in deep learning. Kronecker-factored Approximate Curvature (K-FAC) is a second-order optimization method for deep learning proposed by James Martens and Roger Grosse of the University of Toronto in ICML2015 which approximates the curvature by Kronecker factorization and reduces the computation complexity of the parameter updates. Thanks to the efficient second-order methods including K-FAC, ML researchers are now beginning to revisit the benefit of the fast convergence of second-order methods for reducing the training time of deep learning.
Natural Gradient Descent
Natural Gradient Descent (NGD) is an optimization method proposed by Shun-Ichi Amari in 1998 based on Information Geometry. NGD acquire the loss landscape correctly by using Fisher information matrix (FIM) as a curvature of loss function and converges faster in term of ‘iterations’ than a simple first-order method, e.g., stochastic gradient descent. Hence one can see NGD as an efficient realization of second-order optimization.
The FIM of the probabilistic model which outputs a conditional probability of y given x is defined as follows:
This definition which takes the expected value for the data is called Empirical Fisher (When you use mini-batch, you compute the mean value among the data in it to get the FIM). In an image classification task, since people often use the mean value of the negative log-likelihood as the loss function, one can see the FIM as an approximation of the curvature of the loss function. The equation below shows the relationship between the FIM and the Hessian of a negative log-likelihood loss E(θ):
The update rule of NGD is:
Here the inverse of the FIM is applied to the gradient of loss, and the gradient preconditioned by the FIM is called natural gradient. For the parameters of N, the size of FIM is N x N, and neural networks used in deep learning tend to have a massive number of parameters (e.g., 60 million parameters in AlexNet for ImageNet classification) so the inverse of the FIM is intractable, and it limits the number of the applications of NGD to deep learning.
Natural gradient approximation methods
In recent years, some works have proposed the methods that approximate (or avoid) inversing the FIM and deep learning researchers revisited the ‘fast convergence’ of NGD.
Roughly speaking, there are three kinds of approximation methods (I referred to this article for this categorization.)
- Approximate Fisher information matrix (so that the inverse matrix is easy to calculate)
- Reparameterize for making the FIM closer to a unit matrix.
- Approximate the natural gradient directly.
Approximation using Kronecker-factorization (K-FAC)
One can regard K-FAC as one of the natural gradient approximation methods, which corresponds to “1. approximate Fisher information matrix (so that the inverse matrix is easy to calculate)”. In particular, it is the most efficient approximation method based on the mathematical principle compared with other natural gradient approximation methods.
First, K-FAC block-diagonalizes the FIM where each diagonal block is corresponding to parameters of each layer of the neural network. For example, K-FAC approximates the FIM for a three-layer network as a block-diagonal matrix with three blocks.
Next, K-FAC approximates each block with the Kronecker product of two matrices (called Kronecker factorization).
Finally, K-FAC uses the critical property of the Kronecker product of the matrices:
In short, K-FAC approximates the inverse of the FIM as a block-diagonal matrix where each diagonal block is an inverse of tiny Kronecker factors (compared to FIM).
To clarify the size of the Kronecker factors (how much you can reduce the complexity for the inverse), I explain the mechanism of Kroneck factorization. Taking a fully connected layer as an example, you can see how to factorize a diagonal block of the FIM (for convenience, called a Fisher block) corresponding to this layer. The Fisher block in the i-th layer is expressed as
(The notation of expected value is simplified.) Where ∇i is the gradient for the parameters of i-th layer. By using the back propagation method, which is an efficient method of computing gradients in a deep neural network, the gradient of the log likelihood (for each sample) is represented as the Kronecker product of two vectors:
By using this relationship, the Fisher block can be transformed into the form of “expectation value of Kronecker product”:
K-FAC approximates the “expectation value of Kronecker product” as “Kronecker product of expected value” (Kronecker factorization).
As explained above, Kronecker factorization significantly reduces the computation complexity of the inverse of the Fisher block. You can see this effect more clearly by taking AlexNet, an architecture frequently used in the field of image classification, as an example. The following figure shows the result of comparing the matrix sizes for all layers and the final layer (fully connected layer) of AlexNet for ImageNet (1,000 class classification).
To summarize so far, one can say that K-FAC is a natural gradient approximation method that performs the following three procedures.
- By the block diagonalization of the Fisher information matrix (each diagonal block corresponds to each layer), ignore the correlation of “parameters across layers.”
* There is also a method using block tri-diagonal (considering the correlation of parameters of adjacent layers). - By the Kronecker factorization of each diagonal block (Fisher block), ignore the correlation between “input” and “output gradient” in each layer.
- By approximation of 1, 2, compute the inverse matrix of the Fisher information matrix efficiently to make the natural gradient.
Here I introduced K-FAC for a fully connected layer, but in K-FAC for a convolution layer more approximation is applied in addition to 2. (Refer the paper.)
Finally, I show the effectiveness of K-FAC by taking classification task (10 class classification) of image data set CIFAR-10 as an example. The figure below shows the comparison of training curves of the stochastic gradient descent method (SGD), the natural gradient descent without any approximation (NGD) and K-FAC.
You can see that NGD converges faster than SGD in “number of iterations,” but in NGD, computation time per iteration is heavy, so you can also see that “elapsed time” is later than SGD. On the other hand, K-FAC reproduces the training curve of NGD well in “number of iterations” and also is faster than “SGD” in “elapsed time.”
This fast convergence motivates introducing a natural gradient approximation method such as K-FAC, but the application of K-FAC is limited in large-scale deep learning such as ImageNet, and no one has verified the effectiveness against SGD before.
Applications of K-FAC
- Recurrent Neural Networks (RNN)
Kronecker-factored Curvature Approximations for Recurrent Neural Networks,
James Martens, Jimmy Ba, Matt Johnson,
ICLR2018. - Reinforcement Learning
An Empirical Analysis of Proximal Policy Optimization with Kronecker-factored Natural Gradients,
Jiaming Song, Yuhuai Wu,
arXiv:1801.05566 [cs.AI], Jan 2018. - Bayesian Deep Learning
Noisy Natural Gradient as Variational Inference,
Guodong Zhang, Shengyang Sun, David Duvenaud, Roger Grosse,
arXiv:1712.02390 [cs.LG], Dec 2018.
Implementations of K-FAC
- TensorFlow
https://github.com/tensorflow/kfac - PyTorch
https://github.com/yaroslavvb/kfac_pytorch
(introduced in this post) - Chainer
https://github.com/tyohei/chainerkfac
Conclusion
In this article, I explained the outline of K-FAC, which is one of the approximation methods of natural gradient method. I sacrificed mathematical rigor and aimed at intuitive understanding.