Understanding What We Lose

How We Tackle Catastrophic Forgetting in LLMs

Matt Tengtrakool
Towards Data Science

--

Figure 1: The shared experience of forgetting. Image generated by DALL·E, developed by OpenAI.

Forgetting is an intrinsic part of the human experience. We all misplace our keys, stumble on a familiar name, or draw a blank on what we had for dinner a couple of nights ago. But this apparent lapse in our memory isn’t necessarily a failing. Rather, it highlights a sophisticated cognitive mechanism that enables our brains to prioritize, sift through, and manage a deluge of information. Forgetting, paradoxically, is a testament to our ability to learn and remember.

Just as people forget, so do machine learning models — in particular, Large Language Models. These models learn by adjusting internal parameters in response to data exposure. However, if new data contrasts with what the model has previously learned, it might overwrite or dampen the old information. Even corroborating data can finagle and turn the wrong knobs on otherwise good learning weights. This phenomenon, known as “catastrophic forgetting,” is a significant challenge in training stable and versatile artificial intelligence systems.

The Mechanics of Forgetting in LLMs

At the core, an LLM’s memory lies in its weights. In a neural network, each weight essentially constitutes a dimension in the network’s high-dimensional weight space. As the learning process unfolds, the network navigates this space, guided by a select gradient descent, in a quest to minimize the loss function.

This loss function, usually a form of cross-entropy loss for classification tasks in LLMs, compares the model’s output distribution to the target distribution. Mathematically, for a target distribution y and model output ŷ, the cross-entropy loss can be expressed as:

During training, the network tweaks its weights to minimize this loss. Now, the central aspect governing how much a weight should change is the learning rate. In the stochastic gradient descent update rule:

η is the learning rate. However, the choice of this learning rate can be tricky and holds implications for catastrophic forgetting. If η is high, the model is highly plastic and can rapidly learn new tasks but risks losing prior knowledge. A small η preserves old knowledge but might compromise the learning of new tasks.

Moreover, the complexity rises when we realize that weight updates are not independent. Adjusting a weight associated with one feature may inadvertently affect the performance of other features, leading to a complex, tangled web of dependencies.

We must also consider the curricular order of tasks or data during training. Sequentially introducing tasks could lead to dominance of later tasks, making the model biased towards the latest learned task, a direct manifestation of catastrophic forgetting.

Strategies to Counter Catastrophic Forgetting

We want our LLMs to remember exponentially beyond what we can ourselves. Thus, we are striving to build systems that are efficient with their memory yet not confined necessarily to our biological standards. In the quest to combat catastrophic forgetting in LLMs, researchers have developed several innovative strategies. Three of the most prominent strategies include Elastic Weight Consolidation, Progressive Neural Networks , and Optimized Fixed Expansion Layers. Each technique incorporates a unique mathematical approach to mitigate the forgetting problem.

Elastic Weight Consolidation (EWC): Remembering the Importance of Each Weight

EWC is inspired by neuroscience and Bayesian inference, and it aims to quantify the importance of each weight to the tasks the model has previously learned. The fundamental idea is that the weights critical to prior tasks should be altered less when new data is encountered.

Figure 2 : EWC Schematic Parameter Space, https://www.pnas.org/doi/full/10.1073/pnas.1611835114

In Figure 2, we can clearly see the pivotal role that Elastic Weight Consolidation plays in preventing catastrophic forgetting when we train on task B, without losing the knowledge we’ve gained from task A. This diagram shows parameter space, with the grey areas signifying optimal performance for task A, and cream-colored regions indicating good performance for task B. After we’ve learned task A, our parameter values are labeled as θ*A.

If we concentrate only on task B and take steps in the direction of its gradient (as shown by the blue arrow), we’ll minimize the loss for task B, but potentially wipe out our knowledge of task A — this is the problem of catastrophic forgetting. On the other hand, if we constrain all weights with the same coefficient (as illustrated by the green arrow), we place a harsh restriction that lets us retain our memory of task A, but makes learning task B difficult.

This is where EWC steps in — it finds the sweet spot by identifying a solution for task B (indicated by the red arrow) that doesn’t drastically impact our knowledge of task A. It accomplishes this by specifically determining the importance of each weight in relation to task A.

EWC introduces a quadratic penalty to the loss function, constraining the modification of important weights. This penalty term is proportional to the square of the difference between the current and initial weight values, scaled by an importance factor. This importance factor, calculated from the Fisher Information Matrix, serves as a heuristic for a weight’s significance to the previously learned tasks.

In Elastic Weight Consolidation, a neural network is first trained on Task A, after which the Fisher Information Matrix (FIM) is computed and saved along with the learned weights. When training the network on Task B, EWC modifies the loss function to include a penalty term, computed using the saved FIM and weights, which discourages drastic changes to the weights critical for Task A, thus balancing learning the new task with preserving knowledge from the previous task. The quadratic nature of the penalty ensures that larger deviations from the initial weights incur a higher penalty. By assigning greater penalties to weights that contribute more to prior tasks, EWC aims to retain their learned knowledge while accommodating new information.

Progressive Neural Networks (ProgNet): Building Neural Network Towers

ProgNets introduce a new architecture that allows the network to expand when encountering new tasks. Instead of altering the weights of a single network, it adds a new network (or column) for each task, stacking these columns akin to building a tower. Each new column is connected to all the previously added columns but not the other way around, preserving the knowledge in the older columns.

Behind ProgNet, each task is learned by a separate column, and the output is a function of the inputs from all previous and current columns. The weights of previous columns remain frozen, preventing any catastrophic forgetting, while the weights of the new column are trained normally.

Figure 3 : A Block-based ProgNet Model, https://arxiv.org/abs/1606.04671

​​Imagine Progressive Neural Networks as a constellation of separate processing units, each having the ability to discern and harness the most pertinent inputs for the tasks they are assigned. Let’s consider an example from Figure 3, where output₃ not only interacts with its directly connected hidden layer, h₂, but also interfaces with the h₂ layers of prior columns, modifying their outputs through its unique lateral parameters. This output₃ unit scans and evaluates the available data, strategically omitting inputs that are unnecessary. For instance, if h₂¹ encapsulates all the needed information, output₃ may choose to neglect the rest. On the other hand, if both h₂² and h₂³ carry valuable information, output₃ could preferentially focus on these while ignoring h₂¹. These side connections empower the network to effectively manage the flow of information across tasks while also enabling it to exclude irrelevant data.

Optimized Fixed Expansion Layers (OFELs): A New Room for Each Task

The concept behind OFELs is like building a new room in a house for each new family member. In the context of neural networks, OFELs add a new layer for each task the LLM encounters. This layer expansion allows the network to accommodate new information without disrupting what it has already learned.

Figure 4 : OFEL diagram, https://www.mdpi.com/2073-4425/10/7/553

OFELs involve modifying the architecture of the network itself. Here, for each new task, a new layer is added to the neural network instead of retraining the entire network. This modification in architecture helps to encapsulate the knowledge required for the new task within that specific layer, minimising the impact on the pre-existing weights of the old layers.

where g is the activation function. The architecture of OFELs is designed such that it allows for the inclusion of a new layer dedicated to the new task, which means that the network can process new inputs (x_new) independently of the old inputs (x_old). In essence, while the equation presents a comprehensive view of the underlying process in the architecture, during inference or prediction for a new task, we would typically use only x_new and not require x_old.

By selectively optimizing the new layers, OFELs strike a delicate balance between acquiring knowledge related to the new task and preserving the previously learned information. This meticulous optimization process allows the model to adapt to novel challenges while retaining its ability to leverage prior knowledge, ultimately facilitating more robust and versatile learning.

Forward Learnings

Forgetting — whether in humans or LLMs — is a fascinating paradox. On one hand, it can be an obstacle to continuous learning and adaptability. On the other, it’s an inherent part of how our brains and AI models manage and prioritize information. Strategies to counter catastrophic forgetting — Elastic Weight Consolidation, Progressive Neural Networks, and Optimized Fixed Expansion Layers — provide insightful yet diverse methodologies to preserve the retention capabilities of Large Language Models. Each offering distinct solutions, they reflect the resourcefulness and adaptability that the field of artificial intelligence must consistently embody. However, it is crucial to understand that the problem of catastrophic forgetting is not fully solved; there are still untapped avenues in this area demanding rigorous exploration, innovation, and creativity.

Addressing the challenge of catastrophic forgetting propels us not just towards more efficient AI systems, but towards a deeper understanding of learning and forgetting — a cognitive function shared by humans and machines alike. Therefore, it becomes an actionable imperative for researchers, scientists, practitioners, and anyone fascinated by the workings of intelligence, to contribute to this ongoing dialogue. The quest to tame the phenomenon of catastrophic forgetting is not merely an academic pursuit, but a journey that promises to redefine our relationship understanding.

--

--