The world’s leading publication for data science, AI, and ML professionals.

KAN: Why and How Does It Work? A Deep Dive

Can we discover new physics with KAN?

Can a Neural Net Discover New Physics? (Generated with DALLE-2 by Author)
Can a Neural Net Discover New Physics? (Generated with DALLE-2 by Author)

Last week while we were at an AI & Physics conference (EuCAIFCon), a lot of the discussions were on Foundational models and whether it is possible to discover potentially new laws in Physics using AI, and lo and behold: recently the Kan paper¹ came out in arXiv discussing possibilities of discovering/rediscovering physics and mathematical models using neural net.

I got some time over the weekend to go through parts of this fascinating paper and here we will take a deep dive into the math, formulations, and ideas behind constructing KANs. We will also install KAN to check one of the many amazing results presented in the paper by the MIT researchers.

We will go through the details to get accustomed to mathematical notations and concepts to appreciate why this network is causing such a big fuss!

Let’s begin.


KA Representation Theorem:

The KAN or Kolmogorov-Arnold Network is based on the famous mathematicians Kolmogorov & Arnold’s representation theorem. So first let’s take a few steps back to understand this theorem.

KA theorem was introduced by these mathematicians to work on Hilbert’s 13th problem: David Hilbert suggested that it is necessary to prove whether a solution exists for all 7th-degree equations using algebraic functions of two arguments. What does it mean?

Say we have a 7th-order polynomial as below:

Eq. 1: Equation with higher-order polynomials
Eq. 1: Equation with higher-order polynomials

Hilbert asked whether its solution, x, considered as a function of the three variables a,b,c, can be expressed as the composition (h=gf,→h(x)=g(f(x))) of a finite number of two-variable functions.

The KA representation theorem states that for any continuous function f:[0,1]ᵈ → R, there exists univariate continuous functions _gq, ψ{p,q}_ such that:

Eq. 2: KA Representation Theorem
Eq. 2: KA Representation Theorem

This means that the (2d+1)(d+1) univariate functions _gq, ψ{p,q}_ are enough for an exact representation of a d-variate function⁴. The underlying point here is that the only truly multivariate function is addition since every other function can be written using univariate functions and sum¹.

We bear in mind that _gq, ψ{p,q}_ both are univariate functions. So, any continuous function of several variables can be expressed as the composition of univariate functions. Wow! I didn’t know that. Cool! Let’s proceed.


Thinking About KA Theorem & MLPs:

The authors in the paper state that these 1D functions (defined before) can be non-smooth and even fractal, so they may not be learnable in practice. To build KANs the authors go beyond the definition of the KA representation theorem but first let’s think about MLPs.

The Multi Layer Perceptrons (MLPs) are based on the universal approximation theorem which states that any continuous function f:[0, 1]ᵈ → [0, 1] can be approximated arbitrarily well by a neural network (weights, biases and non-linear activation function) with at least 1 hidden layer with a finite number of weights. During backpropagation, the network learns to optimize the weights and biases to act as a function approximator while the activation functions remain fixed.

Now, can we build a neural-net architecture based on the KA representation theorem discussed above?

If we think about a supervised learning problem where given {_x_i, yi} pair we want to find a function such that _yif(_xi), then the KA representation theorem tells us that all we need to find are the univariate functions in eq. 2 (_gq, ψ{p,q}_).

Here, the authors argue that as we need to learn only univariate functions, we can parametrize each 1D function as a B-spline curve (check below), with learnable coefficients of local B-spline basis functions. This leads to the prototype KAN and is illustrated in Figure 1 (Model(shallow) (b)), with input dimensions n=2 appearing as a two-layer neural network with activation functions placed on edges instead of nodes (simple summation is performed on nodes), and with width 2n+1 in the middle layer. The construction of the network i.e. number of activations, nodes, etc will be clear soon. Given that definition in Eq. 2, the original KA representation theorem can be thought of as depth 2 with each layer having (2d+1) terms (sorry 😢 , on my notes, I initially started with d as my summation index, 2n+1 and 2d+1 are the same here).

Fig. 1: Thinking & Constructing KAN: Taken From 'KAN: Kolmogorov–Arnold Networks' by Z. Liu et.al, [arXiv: 2404.19756]
Fig. 1: Thinking & Constructing KAN: Taken From ‘KAN: Kolmogorov–Arnold Networks’ by Z. Liu et.al, [arXiv: 2404.19756]

B-splines: We can think that a B-spline function is a combination of flexible bands which is controlled by a number of points (control points), creating smooth curves. A bit more mathematical definition would be a B-spline of order p+1 is a collection of piecewise polynomial functions _B__{i, p} of degree p in a variable t. The values of t where the pieces of polynomials meet are known as knots.

Once again, B-splines are built from piecewise polynomials (basis functions) and the order of a B-spline is one more than the degree of its basis polynomials. For example, a quadratic B-spline has polynomials of degree 2 and is of order 3. This is actually what was demonstrated and used in the KAN paper.


Constructing KAN Layer:

It was already mentioned that the 2-layer network representing the original KA representation theorem is too simple to approximate any function arbitrarily well. How do we make the KAN wider and deeper?

Here the authors present an excellent analogy between KAN and MLPs to go deeper. First, we need to see what is a KAN layer and how to stack them on top of each other to build a deep neural net.

First of all, one can express the KA representation in matrix form as:

Eq. 3: Thinking about Eq. 2 in matrix form
Eq. 3: Thinking about Eq. 2 in matrix form

A KAN layer with _n__{in}-dimensional inputs and _n__{out} dimensional outputs can be defined as a matrix of 1D functions:

Eq. 4: Setting the dimension of the matrix
Eq. 4: Setting the dimension of the matrix

In the Kolmogov-Arnold theorem (Eq. 2), the inner functions form a KAN layer with _n{in}=n and _n__{out}=2n+1, and the outer functions form a KAN layer with _n{in}=2n+1 and _n__{out}=1. At this stage, we can see that the KA representation can be thought of as a composition of two KAN layers. Let’s try to get accustomed to the notations when we stack more KAN layers.

We can use the example figure that the authors presented to discuss the network dimensions and more:

Fig. 2: Thinking About KAN layer: Taken From 'KAN: Kolmogorov–Arnold Networks' by Z. Liu et.al, [arXiv: 2404.19756]
Fig. 2: Thinking About KAN layer: Taken From ‘KAN: Kolmogorov–Arnold Networks’ by Z. Liu et.al, [arXiv: 2404.19756]

The authors denote _ni as the number of nodes in the _i_th layer of the KAN and the _i_th neuron in _l_th layer would be denoted by (l, i) where the activation of this neuron is given by _x{l, i}. We think of the activation functions as learnable functions residing on the edges of the network graph and the nodes represent the summation operation. So between 1st (0th) and the 2nd layer (1st), we see there are 10 activation functions denoted by _ϕ__{0,1,1}, _ϕ{0,1,2},…. . The number of activation functions are governed by the number of nodes in 0th and 1st layer.

This is where we can clearly see the distinction between MLPs & KANs. KANs have activation functions on edges, but MLPs have activation functions on nodes.

In the 0th layer, we have two nodes _x__{0,1}, _x{0,2} and in the first layer, we have 5, so the number of activation functions would be _nl × _n{l+1}.

The _nl and _n{l+1} are determined from the input and output dimensions of the inner function defined in Eq. 4. So we started with two inputs _n__{in}=2, so our _n{out} has to be 2n+1=5. This in turn determines the number of activation functions in the hidden layer.

If we continue with the number of nodes _n1=5 and _n__2=1(_n{out}), it makes sense that the number of activations at that layer is 5. This would be the outer function. To repeat, the KA representation is composed of two KAN layers.


Matrix Form of KAN Layer:

We can now move to write the activations. Let’s see: The activation function that connects the two nodes at layers l, l+1 is denoted by ϕ_{l, j, i} where {j, i} represents the i & _j_th neurons in those layers respectively.

So learnable activation function between layer l, l+1:

Eq. 5: Learnable Activation Functions at the Edges of KAN
Eq. 5: Learnable Activation Functions at the Edges of KAN

We can check again that given Fig. 2,

Checking understanding for counting the number of nodes by comparing with fig. 2.
Checking understanding for counting the number of nodes by comparing with fig. 2.

As we denote the input pre-activation of _ϕ__{l, j, i} as _x__{l, i}; Then post-activation we will have:

Notations for Pre-post activation.
Notations for Pre-post activation.

The activation value of the (l+1, j) neuron is simply the sum of all incoming post-activations.

Using all these, we can define the learnable transformation matrix of activations as:

Eq. 6: Complete set of learnable activation functions at different layers of KAN
Eq. 6: Complete set of learnable activation functions at different layers of KAN

Using this we can also write the transformation rule:

Eq. 7: Complete transformation rule: Pre and post-activation given the activation matrix.
Eq. 7: Complete transformation rule: Pre and post-activation given the activation matrix.

We can always check our understanding once again so comparing with fig. 2:

Eq. 8: Check if we can understand the dimension of the transformation matrix or not given fig. 2!
Eq. 8: Check if we can understand the dimension of the transformation matrix or not given fig. 2!

Indeed we have 5 outputs _x{1,1}, _x__{1,2}, _x{1,3}, _x__{1,4}, _x__{1,5}.

Once we have the transformation matrix ready we can simply compose (stack layers) them to go deeper as below:

Eq. 9: Compose a KAN by stacking several KAN layers
Eq. 9: Compose a KAN by stacking several KAN layers

At this point we can also appreciate that all the operations are differentiable (assuming the 1D functions also are) gradients can flow through the network i.e. we can do the backpropagation!

One can also make a comparison with the MLP layer where we have weight matrices (linear transformation) and activation function (non-linearity) separated:

Eq. 10: Comparing the KAN with MLP; Weights (linear) and activations (non-linearity).
Eq. 10: Comparing the KAN with MLP; Weights (linear) and activations (non-linearity).

The values in the weight matrices get updated but the activation function once defined is fixed in the MLP. This is the pivotal difference between a KAN and an MLP layer where our activation functions are learnable.

Since for KAN now everything boils down to the activation functions, the authors define how to construct these functions.


Learnable Activations

For constructing the activation function ϕ(x), the authors propose to have a basis function (b(⋅)) and spline function and combine them as below:

Eq. 11: Learnable activations as a linear combination of basis and spline function.
Eq. 11: Learnable activations as a linear combination of basis and spline function.

The authors take the basis function to be SiLU:

Eq. 12: Choice of basis function b(⋅)
Eq. 12: Choice of basis function b(⋅)

For the spline function it is a linear superposition of B-splines:

Eq. 13: Spline function as linear combination of B-splines.
Eq. 13: Spline function as linear combination of B-splines.

If we look back at the second figure we see that it is a linear combination of B-splines with k=3, i.e. the order is 3, so the degree of the polynomials in the B-spline is 2. One advantage of defining a spline like this is the possibility of making it arbitrarily smooth by having more curves. This is shown also in Fig. 2 where we the authors have increased the number of intervals where we join different polynomials from 7 to 12.

The weighting of the B-splines i.e. _c_i’_s are trainable and the authors argue that the only usage of the factor w in Eq. 11 is to have better control of the overall magnitude of the activation function.


Number of Parameters in MLP & KAN:

Here also the authors discuss that in general, KAN is slower than MLP. So to understand this we can simply calculate the number of parameters by assuming a network of depth L with every layer having an equal number of nodes _n__i=N with each spline of order k (usually k = 3) on G intervals; This would be:

Number of parameters in KAN compared to MLP
Number of parameters in KAN compared to MLP

However, KANs require much less width i.e. N than in MLPs and KANs are interpretable, we will see an example that the authors presented. Another reason the authors highlighted why training KAN is slower than MLPs is because as the activation functions are learnable, it is not possible to leverage ‘batch computation’ i.e. large data passing through the same function. This is not an issue in MLPs because the activations are fixed throughout the training and testing time.


Outro:

There are just so many intricate details still left in this paper but one thing that stands out personally for me is the interpretability of KANs. The authors show that KANs can "discover" simple division laws to non-trivial relations in knot theory. This could lead to further applications of KANs in the foundational models for AI & Science.

KANs could become more ‘attractive’ than symbolic regression as the authors suggested; The authors gave an example of learning the very wiggly Bessel Function of order 20 (J_{20}(x)) via KAN which is impossible through symbolic regression without any prior knowledge of that special function (in this case Bessel function) itself.


Example of ‘Discovery’ Through KAN

Out of the many examples the authors presented, I liked a relatively easy but fascinating ‘auto discoverable’ property of KAN. As a physicist we always love these types of physics examples; Say, we start with relativistic addition of the velocities formula:

Adding two relativistic velocities
Adding two relativistic velocities

The way one can think about the depth of KAN is to consider every layer of KAN discovering one mathematical operation; So by looking at the formula above, first we think about multiplication; we need two layers for multiplication as the authors show that learned activation functions would be linear and quadratic, so:

Learning the multiplication with KAN as a combination of linear and quadratic function
Learning the multiplication with KAN as a combination of linear and quadratic function

Inverting of (1+v_1 v_2) would use one layer, and multiplication of (v_1 + v_2) with (1/(1+v_1 v_2)) would require another 2 layers; In total 5.

But the researchers found that ‘auto-discovered’ KANs are only 2 layers deep and this in hindsight could be explained via the rapidity trick.

In relativity we can simplify the transformation rules via the rapidity trick; One can define rapidity as:

Rapidity trick in relativity
Rapidity trick in relativity

We can use the tanh addition formula:

tan hyperbolic addition rule
tan hyperbolic addition rule

Using this, we can see that:

relativistic addition: Simplified using the rapidity trick
relativistic addition: Simplified using the rapidity trick

Now it perfectly makes sense to have only two layers. If we didn’t know the rapidity trick, trying to make sense of this 2-layer auto-discovered KAN could guide us to the trick. Can we use KAN like this example to discover/rediscover some fundamental physics laws? ❣ ❣ ❣ ❣


Loading & Running KAN:

Let’s just go through the example above and the authors also presented this example in the paper and available in GitHub. We will run this on Colab.

We will start with local installation and loading of necessary libraries.

!pip install pykan

from kan import KAN, create_dataset

import torch

We create the dataset which contains two sets of velocities as separated in train and test sets and corresponding regression value i.e. label (f(v_1, v_2)). For sanity, we can check the distribution of the input velocities and corresponding added relative velocity using matplotlib.

f = lambda x: (x[:,[0]]+x[:,[1]])/(1+x[:,[0]]*x[:,[1]]) 
# dataset creation where x[:, [0]] represents v1, x[:, [1]]: v2
dataset = create_dataset(f, n_var=2, ranges=[-0.9,0.9])
#plot the distribution 
import matplotlib.pyplot as plt

### check train and test input distribution

fig = plt.figure(figsize=(10, 5))
fig.add_subplot(131)
plt.hist(dataset['train_input'][:, 0], bins=20, alpha=0.7, 
         label=r'$v_1$-train', color='orange')
plt.hist(dataset['train_input'][:, 1], bins=20, alpha=0.7, 
         label=r'$v_2$-train', histtype='step')
plt.legend(fontsize=12)
fig.add_subplot(132)
plt.hist(dataset['test_input'][:, 0], bins=20, alpha=0.7, 
         label=r'$v_1$-test', color='orange')
plt.hist(dataset['test_input'][:, 1], bins=20, alpha=0.7, 
         label=r'$v_2$-test', histtype='step')
plt.legend(fontsize=12)
fig.add_subplot(133)
plt.hist(dataset['train_label'].numpy(), bins=20, alpha=0.7, 
         label=r'$frac{v_1+v_2}{1+v_1, v_2}$-train', color='orange')
plt.hist(dataset['test_label'].numpy(), bins=20, alpha=0.7, 
         label=r'$frac{v_1+v_2}{1+v_1, v_2}$-test', histtype='step')
plt.legend(fontsize=12)
plt.tight_layout()
plt.show()

We get these histograms of our data and label.

Fig. 3: Randomly distributed velocities between [-1, 1] and corresponding (relativistically) added value. Source: Author's Notebook
Fig. 3: Randomly distributed velocities between [-1, 1] and corresponding (relativistically) added value. Source: Author’s Notebook

To avoid CPU, and GPU confusion we can explicitly set the device to ‘cuda’ within Colab for both data and model:

### let's try this explicitly anyway
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = KAN(width=[2, 1, 1], grid=10, k=3, seed=0, device=device)
dataset = create_dataset(f, n_var=2,device=device)

We train the model with ‘LBFGS’ optimizer and already, at this point model.plot() reveals something very interesting;

model.train(dataset, opt="LBFGS", steps=25, device=device)
model.plot(beta=10)
Fig. 4: 2 layered KAN with learned activations
Fig. 4: 2 layered KAN with learned activations

The activations at the first layer already look like arctanh and the activation at the second layer looks like tanh . That’s pretty cool!

Trying to get a suggestion from the Model for the symbolic function representation of the activations in the first layer reveals what we actually see:

model.suggest_symbolic(0, 1, 0)
>>> function , r2
arctanh , 0.9986623525619507
tan , 0.9961022138595581
arcsin , 0.968244731426239
.
.

Similarly for the second layer:

model.suggest_symbolic(1, 0, 0)

>>> function , r2
tanh , 0.9995558857917786
arctan , 0.995667040348053
gaussian , 0.9793974757194519
.
.
.

We indeed get tanh as the best suggestion for the symbolic function. That’s so cool!!


As a researcher in Physics, I am so excited to see how the researchers in fundamental AI, Physics and, Math will merge KAN and MLP or modify KAN to make it faster, better and possibly more interpretable (if that’s actually possible); Also, the possibility of discovering/rediscovering Physics laws may be in astrophysics, cosmology should be another aspect that needs exploring using KAN. We can get mildly excited now!


References:

[1] ‘KAN: Kolmogorov–Arnold Networks’: Ziming Liu et. al, arXiv: 2404.19756.

[2] Notebook for Notes Used Here: My GitHub.

[3] ‘Deep networks and the Kolmogorov–Arnold theorem’: H. Montanelli.

[4] ‘The Kolmogorov–Arnold representation theorem revisited’; J. Schmidt-Hieber.


If you’re interested in further fundamental machine learning concepts and more, you can consider joining Medium using My Link. You won’t pay anything extra but I’ll get a tiny commission. Appreciate you all!!


Related Articles