The world’s leading publication for data science, AI, and ML professionals.

The Math Behind Batch Normalization

Explore Batch Normalization, a cornerstone of neural networks, understand its mathematics and implement it from scratch.

Image generated by DALL-E
Image generated by DALL-E

Batch Normalization is a key technique in neural networks as it standardizes the inputs to each layer. It tackles the problem of internal covariate shift, where the input distribution of each layer shifts during training, complicating the learning process and reducing efficiency. By normalizing these inputs, Batch Normalization helps networks train faster and more consistently. This method is vital for ensuring reliable performance across different network architectures and tasks, including image recognition and natural language processing. In this article, we’ll delve into the principles and math behind Batch Normalization and show you how to implement it in Python from scratch.


Index 1: Introduction

2: Need for Normalization

3: Math and Mechanisms ∘ 3.1: Overcoming Covariate Shift3.2: Scale and Shift Step3.3: Flow of Batch Normalization3.4: Activation Distribution

4: Application From Scratch in Python4.1: Batch Normalization From Scratch4.2: LSTM From Scratch4.3: Data Preparation4.4: Training

5: Benefits of Batch Normalization5.1: Faster Convergence5.2: Increased Stability5.3: Reduction in the Need for Careful Initialization5.4: Influence on Learning Dynamics

6. Challenges and Considerations6.1: Dependency on Batch Size6.2: Impact on Training Dynamics6.3: Mitigation Strategies

7: Conclusion

1: Introduction

Batch Normalization, developed by Sergey Ioffe and Christian Szegedy in 2015, quickly became essential in Deep Learning. It tackles a common problem called internal covariate shift:

this is when the distribution of network activations changes during training, which can slow down how quickly the network learns.

By normalizing the inputs for each layer, Batch Normalization helps make training deep networks both faster and more consistent. It also makes the network less sensitive to how it’s initially set up.

This article will break down the math and mechanisms of Batch Normalization, examining how it affects training and performance across different network architectures. We’ll cover its foundational concepts, how it’s implemented, and the pros and cons of using it.

2: Overcoming Covariate Shift

Image generated by DALL-E
Image generated by DALL-E

Normalization is key in neural networks to overcome various training hurdles, especially internal covariate shifts, which describes how the distribution of inputs for each network layer changes during training as the parameters from previous layers are adjusted. Such shifts can slow down training because each layer must constantly adapt to new data distributions. It’s like trying to hit a moving target, complicating the training process and often necessitating lower learning rates and meticulous parameter initialization to achieve model convergence.

In the context of neural networks, managing data flow over time presents significant challenges, such as controlling vanishing or exploding gradients. The issues of shifting data distributions further complicate this, potentially destabilizing the network over long training periods. While some deep learning models like LSTMs do employ mechanisms like forget gates to lessen some effects of input shifts, these are specifically tailored to sequence modeling and don’t tackle the fundamental issue of shifts within the network’s hidden layers.

Batch Normalization provides a broader solution that applies not only to LSTMs but to all types of neural network architectures. By normalizing the data at each layer to have a mean of zero and a variance of one, it stabilizes the input distribution across the training process. This consistency enables using higher learning rates, speeding up training without the risk of divergence, and lessening the network’s reliance on precise initial weight settings. This creates a more resilient learning environment where the initial setup is less critical to overall model performance.

Integrating Batch Normalization enhances the stability provided by mechanisms within neural network architectures and can be extended across different layers of a wide range of neural network models. This step is crucial for more efficient and stable neural network training, going beyond just recurrent neural network architectures like LSTMs.

3: Math and Mechanisms

Batch Normalization fundamentally alters the training process to improve convergence speeds and stabilize the learning landscape across network layers.

3.1: Normalization Step

The first step is the actual normalization of the inputs for each mini-batch. For a given feature in a layer, the normalization adjusts the activations so that they mean zero and unit variance.

For a mini-batch 𝐵B _ of size 𝑚m,_ suppose you have activations for a particular feature. These activations are denoted by 𝑥1, 𝑥2, …, 𝑥_𝑚​ (where xi​ is the activation of the i-th instance in the batch).

Given a mini-batch B of size m, the mean _μ__B​ and variance _σ²__B​ of the feature are computed as:

Mini-Batch Normalization - Image by Author
Mini-Batch Normalization – Image by Author

Each activation 𝑥𝑖xi is then normalized using the computed mean and variance:

Activation Output Normalization - Image by Author
Activation Output Normalization – Image by Author

​​Here, ϵ is a small constant added to the variance to avoid division by zero, often referred to as a numerical stabilizer.

3.2: Scale and Shift Step

After normalization, while the data is standardized, the network might still benefit from adjusting these standardized values to better capture the underlying patterns in the data. This is where scaling and shifting come into play.

  • The normalized activations 𝑥^𝑖x^i​ __ are transformed further using two new parameters: 𝛾 (gamma, scale factor) and 𝛽 (beta, shift factor). These parameters are learned during the training process, similar to weights in the network. The transformation is defined by:
Scaling and Shifting Formula - Image by Author
Scaling and Shifting Formula – Image by Author

In this equation, _yi​ represents the output of the Batch Normalization layer, while γ and β are learned during training alongside the original model parameters. These are crucial as they allow the model to undo the normalization effect if it best minimizes the loss.

3.3: Flow of Batch Normalization

Let’s introduce briefly the code for batch normalization from scratch for this visualization, but don’t worry too much about understanding it, as we will do it later. Moreover, this will be a more simplified version, than what we will create later:

class BatchNorm:
    def __init__(self, hidden_size):
        self.hidden_size = hidden_size
        self.x = None
        self.gamma = np.ones((hidden_size, 1))
        self.beta = np.zeros((hidden_size, 1))

    def forward(self, x):
        self.x = x
        self.mu = np.mean(x, axis=0)
        self.var = np.var(x, axis=0)
        self.x_norm = (x - self.mu) / np.sqrt(self.var + 1e-6)
        out = self.gamma * self.x_norm + self.beta
        return out

  def plot_batch_norm(self):
        # Compute the histograms of the pre-normalized and post-normalized data
        pre_norm_hist, pre_norm_bins = np.histogram(self.x, bins=30)
        post_norm_hist, post_norm_bins = np.histogram(self.x_norm, bins=30)

        # Plot the pre-normalized data
        plt.hist(pre_norm_bins[:-1], pre_norm_bins, weights=pre_norm_hist, alpha=0.5, label='Pre-Normalization')

        # Plot the post-normalized data
        plt.hist(post_norm_bins[:-1], post_norm_bins, weights=post_norm_hist, alpha=0.5, label='Post-Normalization')

        # Add labels, a title, and a legend
        plt.xlabel('Value')
        plt.ylabel('Frequency')
        plt.title('Pre-Normalization vs. Post-Normalization')
        plt.legend()

        # Display the plot
        plt.show()

The class is initialized with a parameter hidden_size which specifies the size of the data it expects to normalize. It sets up two parameters, gamma and beta. gamma is initialized to an array of ones, and beta to an array of zeros, both having a length equal to hidden_size. These parameters are used to scale and shift the normalized data, respectively. In

The forward method takes an input array x, which it then normalizes. The method calculates the mean (mu) and variance (var) of the input data x across axis 0 (typically the feature axis in a batch of data). The input data x is normalized using the formula (x - mu) / sqrt(var + 1e-6), where 1e-6 is a small constant added for numerical stability to avoid division by zero. The normalized data (x_norm) is then scaled and shifted using the gamma and beta parameters, resulting in the output out.

Finally, let’s define a method to visualize the effect of batch normalization by comparing histograms of the data before and after normalization. The plot includes labels for the axes (Value and Frequency), a title (Pre-Normalization vs. Post-Normalization), and a legend to identify the histograms.

# Create an instance of BatchNorm with a hypothetical hidden size, for example, 50.
bn_layer = BatchNorm(50)

# Pass some data from a mixture of Gaussians through the BatchNorm layer
# Generate data from two different normal distributions and combine them
data1 = np.random.normal(-2, 1, size=(500, 50))
data2 = np.random.normal(2, 1, size=(500, 50))
data = np.concatenate([data1, data2])

bn_layer.forward(data.T)

# Now, call the visualization methods
bn_layer.plot_batch_norm()

Let’s then use the code above to generate the plot, and we should get something similar to the plot below:

Pre-Normalization vs Post-Normalization chart - Image by Author
Pre-Normalization vs Post-Normalization chart – Image by Author

Here, the pre-normalization histogram shows the distribution of the data before batch normalization, and the post-normalization histogram shows the distribution of the data after batch normalization.

The data are the results of the concatenation of two normal distributions, resulting in a bimodal distribution. As you can see, Batch Normalization can help make the distribution more Gaussian-like, especially if the original data follows a Gaussian distribution, but it does not guarantee a perfect Gaussian distribution. The skewness or kurtosis of the original data can still be present after Batch Normalization.

Also, remember that the purpose of Batch Normalization is not to make the data Gaussian, but to stabilize the learning process by ensuring that the inputs to each layer of the network have a consistent mean and variance. Even if the data does not become perfectly Gaussian, Batch Normalization can still be very beneficial for the training of deep neural networks.

3.4: Activation Distribution

Now let’s add a new method to our code, which will plot the activation distribution

  def plot_activation_distribution(self, activation_function):
        # Compute the activation before batch normalization
        pre_norm_activation = activation_function(self.x)

        # Compute the activation after batch normalization
        post_norm_activation = activation_function(self.x_norm)

        # Compute the histograms of the pre-normalized and post-normalized activations
        pre_norm_hist, pre_norm_bins = np.histogram(pre_norm_activation, bins=30)
        post_norm_hist, post_norm_bins = np.histogram(post_norm_activation, bins=30)

        # Plot the pre-normalized activation
        plt.hist(pre_norm_bins[:-1], pre_norm_bins, weights=pre_norm_hist, alpha=0.5, label='Pre-Normalization')

        # Plot the post-normalized activation
        plt.hist(post_norm_bins[:-1], post_norm_bins, weights=post_norm_hist, alpha=0.5, label='Post-Normalization')

        # Add labels, a title, and a legend
        plt.xlabel('Activation Value')
        plt.ylabel('Frequency')
        plt.title('Activation Distribution: Pre-Normalization vs. Post-Normalization')
        plt.legend()

        # Display the plot
        plt.show()

pre_norm_activation = activation_function(self.x)computes the activation values using the provided activation_function – in this case tanh – on the layer’s input data self.x before any normalization has been applied.

post_norm_activation = activation_function(self.x_norm)computes the activations after applying batch normalization to the input data self.x_norm. This represents the data after it has been normalized to have specific statistical properties (like zero mean and unit variance).

pre_norm_hist, pre_norm_bins = np.histogram(pre_norm_activation, bins=30) computes a histogram of the pre-normalization activations. It organizes the activation values into 30 bins and counts how many activation values fall into each bin. post_norm_hist, post_norm_bins = np.histogram(post_norm_activation, bins=30) does the same for the post-normalization activations.

# Create an instance of BatchNorm with a hypothetical hidden size, for example, 50.
bn_layer = BatchNorm(50)

bn_layer.forward(data.T)

# Now, call the visualization methods
bn_layer.plot_activation_distribution(np.tanh)  # Use the tanh activation function

This will generate the following plot:

Activation Distribution Plot: Pre-Normalization vs. Post-Normalization - Image by Author
Activation Distribution Plot: Pre-Normalization vs. Post-Normalization – Image by Author

The pre-normalization activation values are spread across a range of values from approximately -1.00 to 1.00. The distribution is skewed towards the extreme sides, showing higher frequency of activation values closer to 1 and -1. This might suggest that the tanh the function is activated more aggressively due to the input data’s characteristics (high variance).

The post-normalization activation values reduce the variance, resembling a more uniform distribution.

4: Application From Scratch in Python

In this section, we will implement Batch Normalization from scratch in Python, leveraging an LSTM architecture -also from scratch – which we built in the previous article. Today we will focus on the code about Batch Normalization, and the explanation related to LSTM will be minimal. However, if you are interested in learning about LSTM, I highly suggest you take a look at this article:

The Math Behind LSTM

Moreover, to help you navigate through the code above, keep a look at this notebook, which contains all the code we will use for this implementation:

models-from-scratch-python/Batch Normalization/demo.ipynb at main ·…

4.1: Batch Normalization From Scratch

The BatchNorm class is designed to manage the normalization of activations within a network, and it includes methods for both forward and backward propagation:

4.1.1: Initialization (__init__)

class BatchNorm:
    def __init__(self, hidden_size):
        self.hidden_size = hidden_size
        self.x = None
        self.gamma = np.ones((hidden_size, 1))
        self.beta = np.zeros((hidden_size, 1))

The constructor of the BatchNorm class initializes two parameters, gamma and beta, which are used to scale and shift the normalized data. gamma is initialized to ones, and beta to zeros, ensuring that at the beginning of training, these parameters do not alter the normalized activations.

4.1.2: Forward PassDuring the forward pass, the Batch Normalization layer normalizes its input before the activation function is applied. Here’s how the BatchNorm the class handles the forward propagation:

class BatchNorm:
    def forward(self, x):
        self.mu = np.mean(x, axis=0)
        self.var = np.var(x, axis=0)
        self.x_norm = (x - self.mu) / np.sqrt(self.var + 1e-8)
        out = self.gamma * self.x_norm + self.beta
        return out

For each batch of data x, the mean (self.mu) and variance (self.var) are computed. These statistics are used to normalize the data, ensuring each feature has zero mean and unit variance, which is crucial for maintaining consistent training behavior across different layers.

The input data x is normalized using the formula (x - self.mu) / np.sqrt(self.var + 1e-8). Here, 1e-8 is a small number added to prevent division by zero.

The normalized data self.x_norm is then scaled and shifted by gamma and beta. This step allows the network to undo the normalization if it finds that doing so minimizes the loss, providing flexibility in learning data distributions.

4.1.3: Backward PassThe backward pass involves computing gradients for the input and the parameters (gamma and beta) to allow proper updates during training:

  def backward(self, dout):
        N = dout.shape[0]
        dgamma = np.sum(dout * self.x_norm, axis=0)
        dbeta = np.sum(dout, axis=0)

        dx_norm = dout * self.gamma
        dvar = np.sum(dx_norm * (self.x - self.mu) * -0.5 * (self.var + 1e-8)**-1.5, axis=0)
        dmu = np.sum(dx_norm * -1 / np.sqrt(self.var + 1e-8), axis=0) + dvar * np.mean(-2 * (self.x - self.mu), axis=0)
        dx = dx_norm / np.sqrt(self.var + 1e-8) + dvar * 2 * (self.x - self.mu) / N + dmu / N
        return dx, dgamma, dbeta

The method receives dout, the gradient of the loss for the output of the BatchNorm layer. It first calculates gradients to gamma and beta.

It recalculates the gradients concerning the normalized data, which involves several steps including the calculation of gradients for variance (dvar) and mean (dmu).

The gradients for the original input x are calculated using the chain rule, considering how each transformation during the forward pass affects the backpropagation.

4.2: LSTM From Scratch

Let’s delve deeper into the implementation details and operational mechanics of integrating Batch Normalization into an LSTM architecture. This will involve a thorough breakdown of how Batch Normalization layers are defined, initialized, and used during the forward and backward propagation stages of LSTM processing.

An LSTM unit traditionally consists of three gates: the input gate (i), the forget gate (f), and the output gate (o), along with an additional component to calculate the cell state (c). Each gate in an LSTM has its own set of weights and biases, and operates on the input data (x_t) and the previous hidden state (h_{t-1}).

The LSTM class definition includes the instantiation of Batch Normalization layers for each gate within the LSTM. This modification ensures that the data passing through each gate is normalized, which can help manage the internal covariate shift and improve the training stability and speed. Here’s how these components are typically initialized:

class LSTM:
    def __init__(self, input_size, hidden_size, output_size, init_method='xavier'):
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.weight_initializer = WeightInitializer(method=init_method)

        # Initialize weights for the gates
        self.wf = self.weight_initializer.initialize((hidden_size, hidden_size + input_size))
        self.wi = self.weight_initializer.initialize((hidden_size, hidden_size + input_size))
        self.wo = self.weight_initializer.initialize((hidden_size, hidden_size + input_size))
        self.wc = self.weight_initializer.initialize((hidden_size, hidden_size + input_size))
        # Initialize biases for the gates
        self.bf = np.zeros((hidden_size, 1))
        self.bi = np.zeros((hidden_size, 1))
        self.bo = np.zeros((hidden_size, 1))
        self.bc = np.zeros((hidden_size, 1))
        # Initialize Batch Normalization layers
        self.bn_f = BatchNorm(hidden_size)
        self.bn_i = BatchNorm(hidden_size)
        self.bn_o = BatchNorm(hidden_size)
        self.bn_c = BatchNorm(hidden_size)

Each gate in the LSTM (forget, input, output, and cell candidate) has its weights and biases initialized, typically using techniques like Xavier initialization to help in the proper gradient flow at the beginning of training.

A Batch Normalization layer is initialized for each gate, which will process the data immediately after the linear transformations and before the activation functions.

4.2.2: Forward PassThe model processes inputs through time steps using a loop. For each time step:

  • Inputs at the current time step are combined with the previous hidden state.
  • Before applying the non-linear activation function (sigmoid or tanh), the linear combination of inputs and weights plus biases is normalized using the respective Batch Normalization layer. This ensures that the data fed into the activation functions is more stable.
  • The cell state and hidden state are updated based on the outputs from these gates.
  def forward(self, x):
        caches = []
        h_prev = np.zeros((self.hidden_size, 1))
        c_prev = np.zeros((self.hidden_size, 1))
        h = h_prev
        c = c_prev

        for t in range(x.shape[0]):
            x_t = x[t].reshape(-1, 1)
            combined = np.vstack((h_prev, x_t))

            f = self.sigmoid(self.bn_f.forward(np.dot(self.wf, combined) + self.bf))
            i = self.sigmoid(self.bn_i.forward(np.dot(self.wi, combined) + self.bi))
            o = self.sigmoid(self.bn_o.forward(np.dot(self.wo, combined) + self.bo))
            c_ = np.tanh(self.bn_c.forward(np.dot(self.wc, combined) + self.bc))

            c = f * c_prev + i * c_
            h = o * np.tanh(c)

            cache = (h_prev, c_prev, f, i, o, c_, x_t, combined, c, h)
            caches.append(cache)

            h_prev, c_prev = h, c

        y = np.dot(self.why, h) + self.by
        return y, caches

In the forward pass, the Batch Normalization process is integrated right after the linear combination of inputs and before applying the activation function:

      for t in range(x.shape[0]):
          ...
          f = self.sigmoid(self.bn_f.forward(np.dot(self.wf, combined) + self.bf))
          i = self.sigmoid(self.bn_i.forward(np.dot(self.wi, combined) + self.bi))
          o = self.sigmoid(self.bn_o.forward(np.dot(self.wo, combined) + self.bo))
          c_ = self.tanh(self.bn_c.forward(np.dot(self.wc, combined) + self.bc))
          ...

Each gate’s inputs are normalized, which helps in stabilizing the activations across the network, making the network less likely to get stuck during training or explode gradients.

4.2.3: Backward PassThe backward pass involves calculating gradients for all trainable parameters and propagating errors back through the network, a process known as Backpropagation Through Time (BPTT). The presence of Batch Normalization layers adds complexity to this process, as gradients must also be computed for the gamma and beta parameters of each normalization layer.

  def backward(self, dy, caches, clip_value=1.0):
        dWf, dWi, dWo, dWc = [np.zeros_like(w) for w in (self.wf, self.wi, self.wo, self.wc)]
        dbf, dbi, dbo, dbc = [np.zeros_like(b) for b in (self.bf, self.bi, self.bo, self.bc)]
        dWhy = np.zeros_like(self.why)
        dby = np.zeros_like(self.by)

        dgamma_f, dbeta_f = np.zeros_like(self.bn_f.gamma), np.zeros_like(self.bn_f.beta)
        dgamma_i, dbeta_i = np.zeros_like(self.bn_i.gamma), np.zeros_like(self.bn_i.beta)
        dgamma_o, dbeta_o = np.zeros_like(self.bn_o.gamma), np.zeros_like(self.bn_o.beta)
        dgamma_c, dbeta_c = np.zeros_like(self.bn_c.gamma), np.zeros_like(self.bn_c.beta)

        dy = dy.reshape(self.output_size, -1)
        dh_next = np.zeros((self.hidden_size, 1))
        dc_next = np.zeros_like(dh_next)

        for cache in reversed(caches):
            h_prev, c_prev, f, i, o, c_, x_t, combined, c, h = cache

            dh = np.dot(self.why.T, dy) + dh_next
            dc = dc_next + (dh * o * self.dtanh(np.tanh(c)))

            df = dc * c_prev * self.dsigmoid(f)
            di = dc * c_ * self.dsigmoid(i)
            do = dh * self.dtanh(np.tanh(c))
            dc_ = dc * i * self.dtanh(c_)

            df, dgamma_f_, dbeta_f_ = self.bn_f.backward(df)
            di, dgamma_i_, dbeta_i_ = self.bn_i.backward(di)
            do, dgamma_o_, dbeta_o_ = self.bn_o.backward(do)
            dc_, dgamma_c_, dbeta_c_ = self.bn_c.backward(dc_)

            dgamma_f += dgamma_f_
            dbeta_f += dbeta_f_
            dgamma_i += dgamma_i_
            dbeta_i += dbeta_i_
            dgamma_o += dgamma_o_
            dbeta_o += dbeta_o_
            dgamma_c += dgamma_c_
            dbeta_c += dbeta_c_

            dcombined_f = np.dot(self.wf.T, df)
            dcombined_i = np.dot(self.wi.T, di)
            dcombined_o = np.dot(self.wo.T, do)
            dcombined_c = np.dot(self.wc.T, dc_)

            dcombined = dcombined_f + dcombined_i + dcombined_o + dcombined_c
            dh_next = dcombined[:self.hidden_size]
            dc_next = f * dc

            dWf += np.dot(df, combined.T)
            dWi += np.dot(di, combined.T)
            dWo += np.dot(do, combined.T)
            dWc += np.dot(dc_, combined.T)

            dbf += df.sum(axis=1, keepdims=True)
            dbi += di.sum(axis=1, keepdims=True)
            dbo += do.sum(axis=1, keepdims=True)
            dbc += dc_.sum(axis=1, keepdims=True)

        dWhy += np.dot(dy, h.T)
        dby += dy

        gradients = (dWf, dWi, dWo, dWc, dbf, dbi, dbo, dbc, dWhy, dby, dgamma_f, dbeta_f, dgamma_i, dbeta_i, dgamma_o, dbeta_o, dgamma_c, dbeta_c)

        for i in range(len(gradients)):
            np.clip(gradients[i], -clip_value, clip_value, out=gradients[i])

        return gradients

The backward pass of an LSTM with Batch Normalization involves processing through the sequence in reverse, ensuring that the gradient information is accurately propagated back through all transformations.

  def backward(self, dy, caches, clip_value=1.0):
      ...

The backward method starts by initializing gradients for all weights, biases, and Batch Normalization parameters. This initialization sets up containers to accumulate gradients from each step in the sequence.

    dWf, dWi, dWo, dWc = [np.zeros_like(w) for w in (self.wf, self.wi, self.wo, self.wc)]
    dbf, dbi, dbo, dbc = [np.zeros_like(b) for b in (self.bf, self.bi, self.bo, self.bc)]
    dWhy = np.zeros_like(self.why)
    dby = np.zeros_like(self.by)

    dgamma_f, dbeta_f = np.zeros_like(self.bn_f.gamma), np.zeros_like(self.bn_f.beta)
    dgamma_i, dbeta_i = np.zeros_like(self.bn_i.gamma), np.zeros_like(self.bn_i.beta)
    dgamma_o, dbeta_o = np.zeros_like(self.bn_o.gamma), np.zeros_like(self.bn_o.beta)
    dgamma_c, dbeta_c = np.zeros_like(self.bn_c.gamma), np.zeros_like(self.bn_c.beta)

The gradients of the output dy are reshaped to match the output dimensions of the model. This reshaping is essential to ensure that matrix operations during gradient calculations are correctly aligned.

As the method loops through each timestep in reverse order (from the last to the first), it retrieves the cached data for that timestep. This data includes previous states and outputs from the forward pass, which are crucial for calculating gradients accurately.

    for cache in reversed(caches):
        h_prev, c_prev, f, i, o, c_, x_t, combined, c, h = cache

Within each loop iteration, gradients for the output gate and the cell state are updated. These calculations start by adjusting gradients through the gates, involving the chain rule, and handling the complexities introduced by the Batch Normalization layers.

    dh = np.dot(self.why.T, dy) + dh_next
    dc = dc_next + (dh * o * self.dtanh(np.tanh(c)))

For each gate, the backward method of the respective BatchNorm class is invoked to calculate gradients for the inputs of the Batch Normalization layer, along with gamma and beta. These backward calls are structured as follows:

    df, dgamma_f_, dbeta_f_ = self.bn_f.backward(df)
    di, dgamma_i_, dbeta_i_ = self.bn_i.backward(di)
    do, dgamma_o_, dbeta_o_ = self.bn_o.backward(do)
    dc_, dgamma_c_, dbeta_c_ = self.bn_c.backward(dc_)

Each of these calls not only computes the gradient for the normalized data but also accumulates gradients for gamma and beta, which are parameters specific to the Batch Normalization process. These gradients (dgamma and dbeta) are essential for learning the optimal scale and shift to apply during the normalization step.

The gradients for the model’s weights and biases are updated by accumulating contributions from each timestep. This accumulation is crucial because the LSTM parameters influence the entire sequence, and therefore, the gradients from all timesteps must be considered.

After computing all necessary gradients, the method applies gradient clipping to avoid exploding gradients, a common issue in training RNNs. Clipping is performed before using these gradients for parameter updates:

    for i in range(len(gradients)):
        np.clip(gradients[i], -clip_value, clip_value, out=gradients[i])

4.2.4: Parameter Update MechanismOnce gradients are computed and clipped, the update_params method applies these gradients to update the model’s parameters. This is where the actual learning step occurs, adjusting the weights and Batch Normalization parameters by a fraction of the gradients scaled by the learning rate.

  def update_params(self, grads, learning_rate):
        dWf, dWi, dWo, dWc, dbf, dbi, dbo, dbc, dWhy, dby, dgamma_f, dbeta_f, dgamma_i, dbeta_i, dgamma_o, dbeta_o, dgamma_c, dbeta_c = grads

        self.wf -= learning_rate * dWf
        self.wi -= learning_rate * dWi
        self.wo -= learning_rate * dWo
        self.wc -= learning_rate * dWc
        self.bf -= learning_rate * dbf

        self.bi -= learning_rate * dbi
        self.bo -= learning_rate * dbo
        self.bc -= learning_rate * dbc

        self.why -= learning_rate * dWhy
        self.by -= learning_rate * dby

        self.bn_f.gamma -= learning_rate * dgamma_f
        self.bn_f.beta -= learning_rate * dbeta_f
        self.bn_i.gamma -= learning_rate * dgamma_i
        self.bn_i.beta -= learning_rate * dbeta_i
        self.bn_o.gamma -= learning_rate * dgamma_o
        self.bn_o.beta -= learning_rate * dbeta_o
        self.bn_c.gamma -= learning_rate * dgamma_c
        self.bn_c.beta -= learning_rate * dbeta_c

The unpacking of gradients ensures each parameter is correctly identified and updated. For each parameter of the model and its Batch Normalization layers, the update rule subtracts the product of the learning rate and the respective gradient from the current parameter value:

self.wf -= learning_rate * dWf
self.bi -= learning_rate * dbi
self.bn_f.gamma -= learning_rate * dgamma_f
self.bn_f.beta -= learning_rate * dbeta_f

This method ensures that all parameters, including those for controlling the gates and those for scaling and shifting in the Batch Normalization process, are updated based on their respective gradients. This comprehensive update allows the LSTM to adapt not just to the error signal but also to the internal variability in data distribution, enhancing the model’s stability and performance over time.

4.3: Data Preparation

As we did in the previous article, we will train the LSTM on a sample of the Google stock dataset, retrieved from Kaggle, spanning from January 1, 2010, to December 31, 2020 (free to use for commercial use).

# Instantiate the dataset
dataset = TimeSeriesDataset('2005-01-01', '2020-12-31', train_size=0.7)
trainX, trainY, testX, testY = dataset.get_train_test()

# Plot the data
# Combine train and test data
combined = np.concatenate((trainY, testY))

# Reshape input to be [samples, time steps, features]
trainX = np.reshape(trainX, (trainX.shape[0], trainX.shape[1], 1))
testX = np.reshape(testX, (testX.shape[0], testX.shape[1], 1))

The TimeSeriesDataset class is used to load and prepare time series data for the LSTM. It normalizes the data, splits it into training and test sets, and reshapes it into the required format for LSTM processing.

# Plot the data
plt.figure(figsize=(14, 5))
plt.plot(combined, label='Google Stock Price', linewidth=2, color='dodgerblue')
plt.title('Google Stock Price', fontsize=20)
plt.xlabel('Time', fontsize=16)
plt.ylabel('Normalized Stock Price', fontsize=16)
plt.grid(True)
plt.legend(fontsize=14)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.show()

Let’s now take a look at the data we will use to train our model using the code above:

Google Stock Price - Image by Author
Google Stock Price – Image by Author

Let’s finally train and evaluate the model:

look_back = 1  # Number of previous time steps to include in each sample
hidden_size = 256  # Number of LSTM units
output_size = 1  # Dimensionality of the output space

lstm = LSTM(input_size=1, hidden_size=hidden_size, output_size=output_size)

# Create and train the LSTM using LSTMTrainer
trainer = LSTMTrainer(lstm, learning_rate=1e-3, patience=50, verbose=True, delta=0.001)
trainer.train(trainX, trainY, testX, testY, epochs=1000, batch_size=32)

4.4: Training

look_back = 1  # Number of previous time steps to include in each sample
hidden_size = 256  # Number of LSTM units
output_size = 1  # Dimensionality of the output space

lstm = LSTM(input_size=1, hidden_size=hidden_size, output_size=output_size, init_method='xavier')

# Create and train the LSTM using LSTMTrainer
trainer = LSTMTrainer(lstm, learning_rate=1e-3, patience=50, verbose=True, delta=0.001)
trainer.train(trainX, trainY, testX, testY, epochs=10000, batch_size=32)

In the code above, we initiate an instance of the LSTMTrainer class and train the LSTM model. The trainer handles the training loop, applying forward and backward passes, and updating model parameters. This updated LSTM class contains our custom BatchNorm class, which will handle the batch normalization.

Now, it’s your time, go ahead and test the code. I suggest you play around with the batch size and further fine-tune the model to improve its performance!

5: Benefits of Batch Normalization

At this point, you may have a good understanding of what the benefits of Batch Normalization are, but let’s outline them formally:

5.1: Faster Convergence

Batch Normalization accelerates the training process by allowing the use of higher learning rates. It achieves this by stabilizing the distribution of the inputs to each layer throughout training. With inputs that have a fixed mean and variance, the network’s weights are less likely to diverge even with higher learning rate settings. This normalization reduces the number of epochs needed to reach convergence, effectively speeding up the training process considerably.

5.2: Increased Stability

By normalizing the inputs across mini-batches, Batch Normalization helps in reducing the internal covariate shift which is the phenomenon where the distribution of layer inputs changes due to changes in network parameters during training. This shift often complicates training by requiring lower learning rates and careful parameter initialization. By ensuring consistent activation distributions, Batch Normalization adds a level of stability to the training process, making the network less sensitive to parameter scale and initialization.

5.3: Reduction in the Need for Careful Initialization

Before the advent of Batch Normalization, careful initialization of network weights was crucial to prevent activations from becoming too large or too small, which could cause vanishing or exploding gradients. This technique alleviates this issue by re-scaling activations, which means that the initial weights no longer need to be precisely tuned. This relaxation in the requirement for initialization precision allows for more flexibility in choosing weight initialization methods, simplifying the setup of neural network training significantly.

5.4: Influence on Learning Dynamics

Batch Normalization not only improves training metrics but also alters the learning dynamics of a neural network. It decouples the parameters of preceding layers from those of the succeeding layers, making each layer more independent of others in terms of learning updates. This independence can lead to better learning patterns as the layers do not need to compensate for the changing scales of inputs during the training process. Additionally, the scale and shift parameters learned alongside the normalized inputs can help the network better model complex relationships in the data by dynamically adjusting the range of activations throughout the network.

6. Challenges and Considerations

Despite its many benefits, Batch Normalization has its challenges and limitations, which need careful consideration.

6.1: Dependency on Batch Size

Batch Normalization’s effectiveness is closely tied to the size of the mini-batches used during training. With too small batch sizes, the estimates of the mean and variance used for normalization can become inaccurate, leading to unstable training behavior. Conversely, very large batches require significant memory overhead and can slow down the training process. Finding the right balance in batch size is crucial to leveraging the full potential of Batch Normalization.

6.2: Impact on Training Dynamics

While Batch Normalization stabilizes training, it can also introduce a certain level of unpredictability in batch statistics, especially when applied to recurrent architectures like LSTMs or in situations where the batch size is small. This unpredictability can sometimes lead to degradation in model performance if not managed correctly.

6.3: Mitigation Strategies

To mitigate these challenges, one approach is to use sufficiently large mini-batches to ensure stable estimates of mean and variance. Alternatively, techniques such as Ghost Batch Normalization, which divides larger batches into smaller "ghost" batches, can be employed to maintain normalization effectiveness without the computational overhead of large batch sizes. Additionally, careful monitoring of training and validation performance should be conducted to adjust hyperparameters like learning rate and batch size dynamically based on the observed training behavior.

7: Conclusion

Batch Normalization has made a big impact on deep learning. Tackling internal covariate shifts has made training models more efficient and stable. This breakthrough has also opened up new possibilities for building and using models that handle complex and varied data or have many layers, which used to be really tough to manage.

Essentially, Batch Normalization is a key part of how neural networks learn today. It’s a powerful technique that improves training and performance across many different applications. Its flexibility encourages ongoing experimentation, paving the way for future advancements in artificial intelligence technology.

References

  1. Ioffe, S., & Szegedy, C. (2015). Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. https://arxiv.org/abs/1502.03167
  2. Goodfellow, I., Bengio, Y., & Courville, A. (2016). Deep Learning. https://www.deeplearningbook.org/
  3. Karpathy, A. CS231n: Convolutional Neural Networks for Visual Recognition. http://cs231n.stanford.edu/
  4. Brownlee, J. (2019). How to Accelerate Learning of Deep Neural Networks With Batch Normalization. https://machinelearningmastery.com/batch-normalization-for-training-of-deep-neural-networks/

You’ve reached the end – well done! I hope you found this article insightful. If you liked it, please consider giving it a thumbs up and following for more content like this. I aim to demystify Machine Learning by recreating popular algorithms from the ground up and making them accessible to everyone. Stay tuned for more updates!


Related Articles