Neural Networks: Why do they work so well?

Isaiah Nields
Towards Data Science
7 min readApr 4, 2019

--

Prerequisites

This article requires knowledge of a few key mathematical objects and techniques. These requirements are very useful to know in general, so I recommend you look into them if you aren’t familiar. Click on any of the links for my favorite resource on the topic.

Linear Algebra

  1. Vectors
  2. Matrices

Calculus

  1. Partial Derivatives
  2. Chain Rule

The Basics: input & output

A neural network is a function. It takes some input and maps it to an output. More specifically, a neural network is a function meant to make predictions.

To get a better feel for this definition, let’s look at an example. Say we want to predict if you have heart disease based some data we have about you: weight, height, age, resting heart rate, and cholesterol level.

To start, we can stack your features into a vector x like so:

To get a neural network’s prediction for the probability that you have heart disease, we feed x in into the model. This outputs y-hat — the model’s prediction for p(heart disease). We can write this like so (where N is the neural network):

Here, y-hat is a continuous value between 0 and 1 as it is a probability. However, the true value y is binary (i.e. either 1 or 0) because you either have heart disease (1) or you don’t (0).

Ok! You now have a general idea of what a neural network does and why it might be useful. But, you still have no idea how a neural network makes a prediction for y. Let’s dig a little deeper.

Layers: 🍰

Note: in this section, I use an underscore followed by enclosing brackets to denote in-text subscript characters that aren’t available in Unicode. Here’s an example:

As previously stated, a neural network is a function. But, more specifically, it is a composition of functions. In deep learning speak, these functions are called layers. Later we’ll talk more about how these layers are composed but for now, let’s just focus on just one layer.

Because a layer is a function, it simply takes in an input and produces an output. We will call this function f_[ℓ] for layer ℓ. By convention, f_[ℓ] takes in an input a_[ℓ-1] and produces an output a_[ℓ]. Here, a_[ℓ] is called the activation for layer ℓ.

We’ve labeled f_[ℓ]’s inputs and outputs, but what specifically does the layer f_[ℓ] do? Well, f_[ℓ] maps a_[ℓ-1] → a_[ℓ] in three stages. Here is what happens in each:

  1. First, a weight matrix W applies a linear transformation → W a_[ℓ-1]
  2. Then, a bias vector b is added → W a_[ℓ-1] + b
  3. Finally, a nonlinear activation function σ is applied → σ(W a_[ℓ-1] + b)

This sequence is summarized by the equation below:

Ok, we made some mathematical notation for what a layer does. But I’m not sure that we’ve truly understood what a layer is doing. Let’s try to get a better feel for this with a visualization of how f_[ℓ] transforms space.

In order to visualize f_[ℓ], we’re going to need to work with low (1, 2, or 3) dimensional inputs and outputs. For this visualization, we’ll have a 2D input and a 2D output. This means that the input a_[ℓ-1] ∈ ℝ² and the output a_[ℓ] ∈ ℝ². We can also say that f_[ℓ] maps ℝ² → ℝ². Visually, this means that every point in a 2D space is mapped to a new point in 2D space by f_[ℓ].

To plot how f_[ℓ] maneuvers each point in 2D space, we need to choose a function for σ, a matrix for W, and a vector for b. For this visualization, we’ll choose σ, W, and b like so:

Now that we have this concrete function, we can show how it affects every point in a 2D space. In order to see what’s happening, we’ll only show points that lie on a grid. With that in mind, let’s produce a visualization:

A visualization of a layer f_[ℓ] mapping ℝ² → ℝ².

Notice how the layer f_[ℓ] takes every point in ℝ² through the three different transformations. First, W stretches 2D space while keeping all lines parallel. Then, b shifts space away from the origin. And finally, σ smooshes space with no regard for keeping grid lines parallel. When f_[ℓ] completes its transformations, each input activation a_[ℓ-1] will have moved from its original position to its corresponding output activation a_[ℓ] according to f_[ℓ].

Great, but again, why is this useful? Well, on its own, a single layer isn’t necessarily useful at all. Layers really only become useful when we compose them. Let’s talk more about how this is done.

Composition: y = f(f(f(x)))

As previously mentioned, we can construct a neural network by composing layers. Composing layers means that we feed the output activation of one layer into the input of the next layer.

Once we’ve wired up each layer, we can feed in an input x (also called a_[0]) into the first layer f_[1]. The first layer f_[1] feeds its output activation a_[1] into f_[2] which feeds its activation a_[2] into f_[3] and so on. Until finally the last layer of the network f_[L] is reached and an output y-hat (also called a_[L]) is produced as a prediction for y.

As an example of what this process looks like mathematically, here is a three-layer neural network N :

Great! So that sums up what a neural network is. But, it’s still not clear why a series of composed layers are useful for making predictions. To get a better intuition, let’s make another visualization.

Like the last visualization, we are going to work with inputs in ℝ² and outputs in ℝ². The key difference here is that instead of showing how one layer f_[ℓ] transforms ℝ², we will show how an entire neural network N transforms ℝ². To ensure that the whole process can be visualized, every layer of N (f_[1], f_[2], f_[3], …, f_[L]) will also map ℝ² → ℝ². This allows us to see what’s happening at every step of the neural network’s transformation.

For this visualization, I’m also going to plot two spirals belonging to either an orange or blue class. These spirals will help demonstrate the network’s usefulness. With all that in mind, here is what we come up with:

A visualization of a neural network separating 2 spirals by mapping ℝ² → ℝ².

Cool, right? Let’s unpack this a bit more.

First, notice how each layer is applied one after the other. The first layer f_[1] maps a_[0] → a_[1]. The second layer f_[2] maps a_[1] → a_[2]. The third layer f_[3] maps a_[2] → a_[3] and so on. Until finally f_[L] maps a_[L-1] → a_[L] (i.e y-hat). This just follows the basic definition of a neural network.

Next, notice how each layer of the network separates the orange and blue spirals little by little. At each stage, the current transformation builds on the progress of the previous transformation. Through this process, the network is forming its prediction for each of the orange and blue points. When the mapping finally ends, the input points x land on the model’s prediction for x’s class — either orange (-1, -1) or blue (1, 1). Here, the model seems to be doing a pretty decent job.

But the reason why a neural network is useful has nothing to do with this specific example. The power of the model lies in its flexibility. Because each layer ℓ can have its very own σ_[ℓ], W_[ℓ], and b_[ℓ], there are theoretically infinite configurations that the model can take on.

This means that a neural network can do a lot more than separate these orange and blue spirals. In fact, it can separate tons of different points in all different configurations. And it can do more than separate points too. In fact, it can approximate almost any transformation (i.e. function). This makes the model applicable to tons and tons of problems.

Here’s another instance of a neural network. This one separates three spirals.

A visualization of a neural network separating 3 spirals by mapping ℝ² → ℝ².

But hold on just a second. How in the world does a model know how to approximate a given function? Of course, it uses a sequence of layers. But how do the layers know how to make themselves useful for a given problem? More specifically, how does each layer in the neural net figure out how to set its parameters σ, W, and b to accomplish something useful?

I’ll save that question for the next article where we’ll begin to explore how a neural network learns its σs, Ws, and bs.

If you’d like to find out how I created the visualizations in this article, check out my machine learning models repository on GitHub.

--

--