The Current State of Continual Learning in AI

Why is ChatGPT only trained up until 2021?

Jon Flynn
Towards Data Science

--

Image generated by author using DALL-E 3

Knowledge prerequisites:

A couple of years ago, I learned the basics of deep learning through StatQuest videos, Lena Voita’s NLP blogs, and books like “Deep Learning for Coders” and “Talking Nets.” I’m now wanting to understand the current state of continual learning in deep learning. I found that there is not much information available that summarises this topic in simpler terms, and it requires sifting through expert research papers. Therefore, this article is intended for readers who have a basic understanding of the topic but find the research difficult to read and may not be experts. It holds a focus on chatbots, so knowing the training stages of chatGPT is also helpful.

Intro

ChatGPT telling the user it is only trained up until September 2021 (screenshot by author)

If large language models like ChatGPT could be continuously updated with new data, they would accelerate a wide range of tasks, from software development to legal processes to learning. It would also make articles like this one obsolete.

Continual learning is the ability to pause the model training process, save the model’s current state, and then later resume training on new data. The model should be able to generalise well to new data, while still maintaining its ability to generalise to old data. Refer to this paper for a more formal definition.

Presently, the trend in the industry to augment chatbots with more data is to use RAG, combining queried vectors with prompt engineering to answer questions, rather than continuing to train the LLM with new data. ChatGPT’s zero-shot learning capability, which allows it to answer questions about new, unseen data, makes this approach very appealing. For instance, you could teach it a new programming language and then ask it questions about that language, with just a few prompts, although performance does degrade a bit proportionally to the amount of tokens input. Continually training the model to answer questions based on a new topic like this requires significant computing resources and more importantly, a wide variety of data on the relevant topic. Furthermore, if a topic has very low prevalence in the training set, it will generalise poorly to it. E.g.: take an unpopular public repo and it will know little about it and may hallucinate, despite having seen it at some point during the training process. Context windows (the amount of tokens the model can take as input) are getting increasingly larger very quickly, making RAG even more attractive. Ideally though, do we not want one intelligent all-knowing model, without the need for any external database?

Continual learning is an essential step towards AGI, and some doubt we will even be able to achieve it without significant changes in deep learning network architectures. Jeff Hawkins in his book, “A Thousand Brains”, stated he does not think current ANN’s are capable of effective continual learning, and believes future models will probably need to be architected more similarly to the human brain using his theory on reference frames in the cortical columns of the neocortex.

Continual Learning in the pre-training vs fine-tuning stages of language models

Earlier this year, a research paper called “LIMA: Less Is More for Alignment” was published. It introduced a chatbot that was not trained using Reinforcement Learning from Human Feedback (RLHF), but was instead fine-tuned on just 1,000 carefully annotated question-and-answer samples. Surprisingly, the researchers said that in 43% of cases, “the chatbot’s responses were on par with those of GPT-4”. I did not take an in-depth look at how these were evaluated, but nonetheless, it’s widely acknowledged that a substantial amount of the model’s knowledge and capability is acquired during the pre-training phase, and research like this further proves this.

Models like ChatGPT and Llama-chat have undergone extensive fine-tuning to generate more aligned and effective responses. OpenAI currently offer an API to further fine-tune a model, which takes Q&A data as input to be used for further training. However, this should not be used to teach the model new data, but rather to customise the tone and steerability. Fine-tuning a model in attempt to teach it new data can cause catastrophic forgetting, a problem where the model forgets what it has already learned. This article will go over some techniques that aim to mitigate this problem.

This also leads us to a couple key questions about the feasibility and strategy of continual learning:

  • At which stage of development is it most beneficial and easiest to introduce continual learning?
  • Given that both fine-tuning and RLHF alter the entire model’s parameters, is it even possible to revert to the pre-training stage for further modification?

Note: I provide some PyTorch-like pseudocode for some of the papers discussed below. It has not been tested and may not work, it’s used to break the techniques down step by step and translate any confusing math notation to help the reader understand.

The 5 sub-categories of continual learning techniques

The comprehensive overview of continual learning paper states training strategies for continual learning can be divided into 5 sub categories:

  1. Regularisation-based approach: this approach adds constraints or penalties to the learning process during the training process.
  2. Optimisation-based approach: this technique focuses on modifying the optimisation algorithm.
  3. Representation-based approach: this aims to learn a shared feature representation across different tasks, helping the model generalise better to new but related tasks.
  4. Replay-based approach: this involves storing some data or learned features from previous tasks and replaying them during training on new tasks to maintain performance on earlier learned tasks. In other words, mixing both the old and new datasets when training on new tasks.
  5. Architecture-based approach: in this approach, the network architecture is dynamically adjusted, often by growing or partitioning, delegating different parts of the network to different tasks.

1. Regularisation-based approaches

Soft Masking of Parameters

The following soft-masking techniques mask and adjust the gradients of each parameter during the training process. The optimisation-based approaches coming up also manipulate the gradients for continual learning. Remember the gradients aren’t just temporary numbers that appear and disappear during training; they’re signals that guide the evolution of the weights.

SPG

This paper proposes a technique named SPG (Soft-masking of Parameter-level Gradient flow) which aims to:

  1. Train the model on each task until convergence.
  2. After training, calculate the “importance” of each parameter for the task.
  3. Soft-mask parameters based on their accumulated importance, making important parameters less likely to change during the learning of new tasks.

Let’s break the approach down step by step:

1. Training the First Task

Train the model on the first task’s dataset as normal.

2. Calculate Parameter Importance for the First Task

After the training of the first task is complete, we calculate the importance of each model parameter. The intuition here is simple, we use the gradients of each parameter to compute its importance. A larger gradient implies that a small change in that parameter will result in a larger change in the loss, meaning the model’s performance could vary more significantly, hence that parameter is important.

The gradients are also normalised, because gradients in the first layer could be small, while those in the last layer could be large. If you’re calculating importance based on these raw gradient values, parameters in the last layer would seem more important because of the scale of their gradients, not necessarily because they are genuinely more crucial for the task.

Equations for calculating the importance of the model parameters in SPG (section 3.1 of paper)

Let’s translate this calculation to PyTorch-like pseudocode:

import torch

def compute_final_importance(model, loss_function, data_loader):
# Get a single batch from the data loader
inputs, labels = next(iter(data_loader))

# Forward and backward pass to calculate the gradients for all parameters
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()

importances = []

# Calculate importance based on the gradients
for param in model.parameters():
if param.grad is not None: # Gradients may be None for some unused parameters
normalized_grad = (param.grad - torch.mean(param.grad)) / torch.std(param.grad)
importance = torch.tanh(normalized_grad)
importances.append(importance)

return torch.stack(importances).mean(dim=0)

3. Accumulating Importance Across Tasks

The accumulated importance of each parameter across task is simply calculated by taking the max value at any stage.

4. Training Subsequent Tasks, combined loss and the soft-masking mechanism:

When training on new tasks, the researchers use a combined loss function consisting of two parts. One is the standard loss function which is used as normal on the new task and data, and the second is an additional loss function which involves putting the new data through the old model (the converged model checkpoint after the previous task) and summing up the logits produced. In classification networks the logits are usually the raw non normalised predictions generated by the model in one of the last layers before going through something like a softmax function. This sum of logits serves as a form of loss. The rationale is that if the summed logits are significantly affected when the model parameters change, those parameters are crucial for the performance of the previously learned task.

The gradients generated from this additional loss serve as a guide during backpropagation, nudging the shared parameters to change in a direction that is less likely to harm performance on the first task. It therefore acts as a sort of penalty term to enforce that any updates made to the model do not lead to a significant loss of information related to previous tasks.

Train the model on the next task. Use a standard training loop, but modify the gradients during backpropagation based on their accumulated importance. This is the soft-masking mechanism:

import torch

accumulated_importance = # calculated at the end of each task

for epoch in range(num_epochs):
for x, y in train_loader:

# Forward Pass: Calculate the loss for the current task using the proper loss function
logits = new_model(x)
loss_current_task = nn.CrossEntropyLoss()(logits, y)

# Forward Pass: Calculate the additional losses for previous tasks (CHI mechanism)
loss_previous_tasks = 0
for prev_task_id in range(task_id):
logits_prev = old_model(x, prev_task_id)
loss_previous_tasks += logits_prev.sum()

# Combine the losses
combined_loss = loss_current_task + loss_previous_tasks

# Backward Pass
optimizer.zero_grad()
combined_loss.backward()

# Update the accumulated importance
for param, acc_imp in zip(model.parameters(), accumulated_importance):
grad = param.grad
acc_imp = torch.max(acc_imp, torch.abs(grad))

# Soft-masking the gradients before taking an optimization step
for param, imp in zip(model.parameters(), accumulated_importance):
param.grad *= (1 - importance)

optimizer.step()

5. Soft-Masking Special Cases

  • Feature Extractor: Gradients of parameters in the shared feature extractor are modified based on their specific accumulated importance.
  • Classification Head: For the classification head, gradients are modified based on the average importance of the feature extractor.

Applying this to LLMs

Bear in mind, this paper does not experiment this with a language model, but I assume in a language model you could think of the transformer layers as analogous to the “feature extractor,” and the final classification layer (which predicts the next word or token in the sequence) as the “classification head.”

Soft-masking applied to continual pre-training in a language model

Next we’ll go into a paper which applies similar soft-masking to the pre-training stage in language modelling.

This paper introduces a technique called DAS (Continual DA-pre-training of LMs with Soft-masking) for continual learning in the pre-training stage of a large language model. It applies a soft-masking technique similar to the one just discussed along with a couple other techniques in attempt to continue pre-training of an LLM without running into catastrophic forgetting.

Let’s break it down step by step:

Initial Pre-training Phase

Pre-train the LLM like normal.

Further Pre-training on A New Domain

Prepare New Domain Data:

A new dataset from a different domain is prepared.

Calculating the importance of each neuron

SPG used gradients to determine the importance of each parameter, and then applied the calculated importance value to mask the gradient adjustments of parameters during training. This paper tries to determine the importance of each unit/neuron, rather than parameter, and then uses this in the same way by masking the gradient during training.

This paper uses two different methods to calculate the importance of neurons, depending on the task at hand. One, a gradient-based importance detection method (originally outlined in this paper), and two, a custom “proxy loss function”.

The first introduced is not used in the continual learning of the first new domain. Why? It needs data from the training dataset to work and the authors state that users “don’t have access to the massive original pre-training dataset”, which is a fair assumption. The proxy loss function is used instead for the first phase of continual learning and then for each subsequent phase the other method is used.

The proxy loss function (“Proxy KL-divergence loss”):

I found this term confusing at first, but it’s called this because the original gradient-based importance detection method is defined as a loss function itself, which you can then use to run the network’s outputs through to get the gradients of each neuron, which can then be used to derive importance, just like the SPG technique. It’s calculated by the following:

  • Take a subset of the new domain we’re wanting to train on and feed it twice through the model to get two different representations. These representations will differ a bit due to the existing dropout masks in the Transformer architecture.
  • Compute the KL-divergence between these two representations.

Modified Backpropagation Flow with Proxy and Combined Loss

  1. Forward Pass: Data goes through a forward pass in the neural network.
  2. Backpropagation:

Apply Proxy Loss for Gradient Adjustment: The proxy loss function’s unit-level importance is used to soft-mask the original gradients. This is expressed as:

adjusted_grad *= (1 − unit_level_importance)

Calculate Combined Loss (MLM + Contrastive Loss): Compute the combined loss using both MLM and contrastive loss.

Further Pre-training on More Domains

  1. Direct Importance Calculation: For each new domain, the importance of each unit can now be directly calculated using the data from the new domain via the gradient-based method outlined in equation 3, eliminating the need for the proxy loss function which is only once used after the initial pre-training.
  2. The importance of neurons is updated incrementally as each new task is learned. This update is done using element-wise max. “Element-wise maximum (EMax) operation” refers to comparing two vectors element by element, and taking the maximum value for each corresponding element to create a new vector. E.g.: if you have two vectors A and B of the same length, the element-wise maximum will result in a new vector C where each element C[i] is the maximum between A[i] and B[i].

2. Optimisation-based approaches

We’ll refer to the two techniques outlined in the comprehensive survey paper in section 3.1

Gradient Direction Preservation

The paper talks about manipulating the gradient-based optimisation process to make the gradient directions of new training samples close to those from old training samples. The formula

⟨ ∇θ Lₖ(θ; Dₖ), ∇θ Lₖ(θ; Mₜ) ⟩ ≥ 0

enforces that learning the new task should not increase the loss for the old tasks. Essentially, the gradients of the new task and the old tasks are encouraged to align.

Breaking down the formula, we take the dot product of the gradient of the loss from the new task (∇θ Lₖ(θ; Dₖ)) and the gradient of the loss from the old task (∇θ Lₖ(θ; Mₜ)) should be non-negative. In this context, a positive dot product implies that the gradients for the old task and the new task are generally pointing in the same direction, with the angle between these two vectors is less than or equal to 90 degrees.

Forward/Backward Passes:

Forward Pass:

You would run your input data Dₖ for the new task and Mₜ​ for the old task through the same model to calculate the loss for each.

Backward Pass:

  1. Compute the gradients of the loss with respect to the network parameters for both the old and new task.
  2. Alignment Check: Compute the dot product of the two gradients. You’d then use this information to modify the gradients for the new task in such a way that the dot product is non-negative.
  3. Update Weights: Update the model parameters using these “aligned” gradients.

import torch

# Forward pass for the new task
output_k = model(D_k)
loss_k = criterion(output_k, y_k)

# Forward pass for the old task
output_t = model(M_t)
loss_t = criterion(output_t, y_t)

# Compute gradients for both tasks
loss_k.backward(retain_graph=True) # Compute gradients for new task but keep computation graph
grad_k = torch.cat([p.grad.view(-1) for p in model.parameters()])

optimizer.zero_grad()

loss_t.backward() # Compute gradients for old task
grad_t = torch.cat([p.grad.view(-1) for p in model.parameters()])

# Compute dot product and modify gradients if they don't align
dot_product = torch.dot(grad_k, grad_t)
if dot_product < 0:
# I'm not sure how you modify the gradients here if they don't align, I'm not sure the paper specifies it

# Use the modified gradient to update model parameters
index = 0
for p in model.parameters():
num_params = p.numel()
# Update using modified gradients
p.grad = grad_k[index: index + num_params].view(p.shape)
index += num_params

optimizer.step()

Gradient Direction Preservation without needing old training samples

The text also highlights that gradient projection can be performed even without storing old samples. NCL (Natural continual learning, paper link) is the technique summarised here. Note, this can be categorised as both a regularisation and optimisation based approach.

Training process step by step:

Forward Pass:

You would run your new data through the network and calculate the loss as usual.

Backward Pass:

Objective: The aim is to minimise the task-specific loss ℓk(θ) while adhering to a distance constraint d(θ,θ+δ)≤r.

Algorithm step by step:

  1. As normal, compute the gradient of the loss with respect to the model parameters ∇θ​ℓk​(θ).
  2. The δ is calculated using the update rule. This gives you the “suggested” changes to the model parameters θ based on the new task’s requirements.
  3. Then, you plug this δ into the distance constraint formula: d(θ,θ+δ)=squareroot(δ⊤Λ_k-1​δ)​. The constraint acts like a boundary around the current parameters θ, defined by the distance metric d(θ,θ+δ) and the radius r. I struggled to see why they called it a “radius”, and not just “constraint number” or something. I think it’s because the researchers are visualising the gradients and training process in a high-dimensional space. When you apply a constraint based on the distance metric, you’re essentially defining a “sphere” around your current parameter values in that high-dimensional space. The “radius” r of this sphere sets a limit on how much the parameter can move while learning a new task.
  4. If the proposed δ would move θ too far according to this distance metric, i.e., beyond this boundary, you scale it down so that it stays within the allowable region defined by the radius r.

Let’s look at each bit more in-depth:

Update Rule: The update rule provides a direction in which θ should move.

NCL update rule from section 3.1 in the comprehensive overview of continual learning paper

Breaking it down:

  • ∇θ ℓk(θ) represents the gradients for all parameters (θ) calculated by the loss function.
  • Parameter importance calculation (Λ^(k-1)_(-1)): This term represents a precision matrix and it is yet another way to calculate the importance of parameters in the network. more details below
  • Regularisation Term (θ — μ_(k-1)): This term pulls the updated parameters closer to the optimal parameters μ_(k-1)​ from the previous task. Like the before techniques, it acts as a regulariser to avoid deviation from what was already learned.
  • Learning Rate (λ)

Distance Constraint: Before applying this update, you’d usually check whether this change δ would violate the distance constraint d(θ,θ+δ)≤r. If it does, you’d typically scale down δ so that it satisfies the constraint.

Precision matrix explanation: before in the soft-masking methods we saw the calculation of importance via the output of all neurons or their gradients. In this method a precision matrix is used. This is a bit more complex so I’ll attempt to explain it:

We first calculate the covariance matrix for the networks parameters. In the context of neural networks, the columns in the gradient matrix G correspond to the parameters (weights and biases) of the model. Each row in G represents the gradient vector for a single training example, with respect to all of those parameters.

So, if you have a neural network with P parameters (this includes all the weights and biases from all layers), then each gradient vector will have P elements, one for each parameter. Therefore, G will be a matrix of shape N × P, N representing each batch and therefore each row representing the average gradient vector across all the training examples in a given batch.

When you calculate the covariance matrix Σ from G, the resulting matrix will have dimensions P × P. The diagonal entries Σii​ will indicate the variance of the gradient with respect to the ith parameter, and the off-diagonal entries Σij​ will indicate the covariance between the gradients with respect to the ith and jth parameters. This gives you an idea of how these parameters interact or co-vary during the training process. The inverse of this matrix is the precision matrix, which is what we use to determine importance.

Why the precision matrix over the covariance matrix? While the covariance matrix Σ does capture how parameters interact with each other during training, it doesn’t specifically indicate how crucial each parameter is to the task at hand when all other parameters are considered. In contrast, the precision matrix allows us to assess the conditional independence (this is a concept in probability theory, look it up) of parameters. Large values in the precision matrix indicate that knowing one parameter is highly informative about another, given all the other parameters. I’m not going to go into examples of how this works so get ChatGPT to generate some examples using a very small neural network to see how the values can be interpreted.

Previous methods we saw that calculate importance focus on individual neurons or parameters, ignoring the relationships between them. The precision matrix, on the other hand, can capture these relationships. Like everything in deep learning, whether this is a better way to calculate the importance of a network, is going to be empirical and could differ depending on the task and scale of the network.

Algorithm step by step in PyTorch:

import torch

# Constraint radius
radius = 0.1

for epoch in range(num_epochs):
for batch_idx, (data, target) in enumerate(data_loader):
optimizer.zero_grad()

# Forward pass
output = model(data)
loss = loss_function(output, target)

# Backward pass to get gradients for params
loss.backward()
model_grad = torch.cat([p.grad.data.view(-1) for p in model.parameters()])

# Compute δ using the NCL method
# δ = Λ^(-1) * grad - (θ - µ)
delta = torch.matmul(torch.inverse(covarianceMatrix), model_grad) - (torch.cat([p.data.view(-1) for p in model.parameters()]) - parametersForPrevTask)

# Check constraint
if torch.norm(delta) > radius:
delta = radius * delta / torch.norm(delta)

# Update model parameters (θ) using δ
idx = 0
for p in model.parameters():
length = p.data.numel()
p.data += delta[idx: idx + length].view(p.data.shape)
idx += length

# Update Λ and µ for the next task, probably going to be task-specific and non-trivial

3. Representation-based approach

Firstly, it’s important to note that the pre-training of LLM’s to be further fine-tuned on a downstream task is an example of continual learning in this sub-category. I think ChatGPT’s ability to reason about never-before-seen data is also an example of this approach. Although we technically call it zero-shot learning, and the term “continual learning” requires updating model parameters, it goes beyond anything we’ve seen before. As discussed in the introduction, prompt engineering could be the future of continual learning, instead of continually updating the parameters.

Below we’ll take a look at using knowledge distillation for continual learning. I’m not really sure which sub-category this falls under, but I’d guess it’s probably a mix between representation, architecture and replay approaches. Even though some of the techniques we’re reviewing may seem random and unproven at large scale, breakthroughs in this field are often unpredictable. Therefore, it’s important to maintain a broad perspective.

Knowledge Distillation for continual learning

You can transfer (or “distill”) the knowledge of one network into another network, and the second network does a reasonable job of approximating the function learned by the original network.

The distilled model (the student), is trained to mimic the output of the larger network (the teacher), instead of training it on the raw data directly. For example, say you want to train a smaller student model to mimic a large pre-trained language model (the teacher). Run the original pre-training dataset through the teacher model to generate “soft targets.” These are probability distributions over potential outputs, i.e.: next-word predictions. For instance, for a next-word prediction task, instead of predicting “cat,” the teacher might provide probabilities like 90% for “cat”, 5% for “kitten”, 3% for “feline”, etc.

This is usually done to transfer knowledge to much smaller models, and it yields great results despite the smaller model.

Let’s see how some researchers applied this with success to a NER (named entity recognition) model. The training process is fairly straightforward:

Training process step by step

There are two primary methods outlined in the paper: AddNER and ExtendNER.

AddNER Model

Note, NER models work by taking a sequence of tokens (usually a sentence) as input and then output a probability distribution (for the different types of entities) for each token. IOB tagging is commonly used for NER models, each token can be labeled as ‘O’, or as the beginning (‘B-’) or inside (‘I-’) of an entity of type X. ‘O’ stands for ‘Outside’, it just means the current token doesn’t belong to any entity. Therefore, for n entity types, you will have 2n output neurons in the classification layer: n for the ‘B-’ tags (one for each entity type) and n for the ‘I-’ tags (again, one for each entity type). Add to this the ‘O’ label, which signifies that a token doesn’t belong to any entity, and you end up with 2n + 1 possible labels for each token. The final dimensions can be written as h × (2n + 1), where h is the size of the hidden layer’s output. Bear in mind, this is only for models where tokens can only be one entity. E.g.: “Apple” could be tagged as both “FOOD” and “COMPANY”.

Architecture and teacher-student setup

The student model in this case is a copy of the teacher model, with an additional output classification layer for each new entity type that the model should learn. During training, the new output layer learns from the new annotated data, and the older layers are guided by the teacher model’s outputs to minimise forgetting.

After training, the old output layers are not discarded. It then uses the algorithm and heuristics described in the conflict resolver section (end of section 3.3) to combine these outputs into a single, final prediction for each token in the sequence.

Diagram of the AddNER model from section 3.2 of the paper

Forward Pass

  1. Old Entity Types: The input sentence is passed through the teacher model to obtain probability distributions (the “soft targets” in this context) for the old entity types.
  2. New Entity Types: The same sentence is also passed through the new student model with additional output layers specific to the new entity types​.

Backward Pass

Combined loss function:

  1. KD Loss: calculated by comparing how closely the output probabilities of the old entity types from the new model (student) match those from the old model (teacher). It uses KL-divergence to calculate this. It’s probably calculated token-by-token and then summed or averaged over all tokens in a sentence or batch, but I don’t think the paper goes into this.
  2. Cross-Entropy Loss: This is the usual loss function that compares the model’s predictions for the new entity types against the actual labels from the new dataset.
  3. Combining the two: these two losses are combined into a combined loss by taking a weighted sum of them both. The weights for combining these losses are set by the hyperparameters alpha and beta, which are adjusted like any other hyperparameter to better performance based on experiments.
# Hyperparameters alpha and beta for weighting the two loss functions
alpha = 0.5
beta = 0.5

for epoch in range(num_epochs):
for sentence, labels in D_new:
# Forward pass in teacher model for old entity types
teacher_probs_Ei = teacher_model(sentence)

# Forward pass in student model for old and new entity types
# Note: the new entity types must go through the new output layer (not shown in this pseudocode)
student_probs_Ei, student_probs_Enew = student_model(sentence)

# Compute KD loss
kd_loss = KL_divergence(teacher_probs_Ei, student_probs_Ei)

# Compute CE loss for new entity types
ce_loss = cross_entropy(labels, student_probs_Enew)

# Combined loss
total_loss = alpha * kd_loss + beta * ce_loss

# Backward pass
total_loss.backward()

# Update student model parameters
optimizer.step()

ExtendNER Model

Architecture and teacher-student setup

The ExtendNER model extends the output layer dimensions to accommodate new entity types, instead of adding new output layers. The paper explains quite simply how the dimensions are to be:

“Assuming that Mi was able to recognize n entity types, its final layer can be considered as a matrix with dimension h×(2n+1). The output layer of Mi+1 will then be extended to be a matrix with dimension h × (2n + 2m + 1) in order to accommodate the new entity types.”

Diagram of the ExtendNER model from section 3.4 of the paper

Forward Pass

Same as in AddNER, but with extended dimensions.

Backward Pass

The loss calculation uses either the KL-divergence loss or the cross-entropy loss, depending on the following:

  • When the NER category label y is “O” (from the IOB tagging schema), the KL divergence loss is used.
  • When the category label y is NOT “O”, the Cross-Entropy loss is used.

Final Prediction

Viterbi algorithm is applied to decode the final entity types.

Both AddNER and ExtendNER models performed well for continual learning and the results did not differ between them much

4. Replay-based approach

“Fine-tuned language models are continual learners”

paper link

The model in the paper is not a generic, single-task model like GPT trained just for conversational response. Instead, it’s fine-tuned for a sequence of specialised tasks, ranging from text simplification to Haiku generation. Each of these tasks has unique requirements, evaluation metrics, and specialised training datasets.

The researchers mix parts of the old dataset with the new dataset, and achieve great results by mixing in just 1% of the previous task’s dataset when fine-tuning on a new task. This is done sequentially for many tasks (8). The model also performs well in zero-shot learning settings, meaning it can generalise well to tasks it hasn’t been trained on. For instance, it can generate a Haiku with the correct syllable count when given an unseen topic, showing its ability to generalise. The researchers also mention that their approach is task-order invariant, meaning the sequence in which tasks are learned does not affect the model’s performance. The experiments find that the amount of the old dataset mixed in with the new one doesn’t significantly affect the main task’s performance. However, it does affect the zero-shot learning. At 0% rehearsal, the model tends to forget the zero-shot tasks, while at 1% rehearsal, the model maintains its performance in those tasks very well.

This all seems positive, the fact we can just add 1% of the old dataset and continual learning is solved, but of course, applying it to a chatbot like chatGPT, will be empirical and can be completely different. Even if, hypothetically, chatGPT could be continually trained in the fine-tuning and RLHF stages like this, it would require an immense amount of labeled conversation data.

5. Architecture-based approach

I won’t go into any specific paper or implementation in detail here, but I will provide a brief overview of this approach and a couple different techniques. I recommend reading this section (4.5) of the comprehensive survey paper. It is also easier to read than the other sections.

  1. Parameter Allocation: Here, a subset of the network parameters is dedicated to each task. This can be done either by masking out irrelevant neurons or by explicitly identifying important ones for the current task.
  2. Modular Network: This involves using separate sub-networks or modules for each task.

Sub-networks can be connected in various ways to form an ensemble or a more complex architecture. Below are a few common methods for connecting sub-networks:

Concatenation of Outputs:

In this approach, the outputs of multiple sub-networks are concatenated into a single tensor, which can then be passed through additional layers to produce the final output.

Voting Mechanism:

In some models, each sub-network casts a “vote” on the likely outcome, and the final decision is made by taking the majority vote or a weighted vote. This has biological inspiration as it’s similar to how different cortical columns in the neocortex cast votes.

Skip Connections:

Some architectures allow sub-networks to have skip connections to other parts of the model, allowing information to flow across modules.

Sequential:

In this case, the output of one sub-network serves as the input to the next.

Going back to talking about chatbots, what I find particularly interesting if it were possible to create such an architecture with two sub-networks. The first one is the pre-trained model which holds the general “knowledge”. The second holds knowledge for aligning the model. Once the model is aligned, it would no longer need labeled conversational data. Instead, it could be continually updated by training the pre-trained subnetwork in an unsupervised way.

Conclusion

In conclusion the subfield of continual learning in deep learning is challenging and mostly unknown. This is because we do not fully understand how the neurons in LLMs work, and as outlined in the intro, could also be that current network architectures, or deep learning in general, is just not suited for it.

I noticed last month that ChatGPT (GPT-4 only) had been updated as it now says “Since my training cutoff in January 2022”, so I wonder what the folks at OpenAI did to achieve this.

ChatGPT (GPT-4 variant) telling the user it is trained up until January 2022 (screenshot by author)

--

--