The world’s leading publication for data science, AI, and ML professionals.

Reshaping the Model’s Memory without the Need for Retraining

Erasing any echo of problematic content a large language model has learned

| AI | LARGE LANGUAGE MODELS| MACHINE UNLEARNING|

Photo by Drew Saurus on Unsplash
Photo by Drew Saurus on Unsplash

"To forgive is wisdom, to forget is genius. " ― Joyce Cary

Large language models (LLMs) have taken the world by storm. In less than a year they are ubiquitous and are now used by millions of users. These models are often trained with huge amounts of text (including problematic material and sensitive data). How do you make a model forget? The same that could store the entirety of human knowledge?

To learn how to forget

Photo by Paul Pastourmatzis on Unsplash
Photo by Paul Pastourmatzis on Unsplash

LLMs stand as a testament to both our accomplishments and the challenges that lie ahead – source

LLMs have surprised both users and researchers with their ability to learn from huge amounts of text and identify language patterns and cultural nuances. While they could be the basis for a new application and scientific revolution, they have a dark side.

Huge corpora must be used to train these patterns. While it is true that the greater the amount of data used the better the performance of an LLM, collecting this data is expensive. To limit costs, indiscriminate scraping of data from the Internet is often used. These corpora therefore also contain extremely problematic data: copyrighted texts, toxic or malicious data, inaccurate or fake content, personal data, and more.

image source: here
image source: here

Machine unlearning: The duty of forgetting

LLMs have the ability to store all this information and to leak them once they are queried. This opens up enormous ethical and even legal risks. In addition, this has led to lawsuits, public pressure, and the focus of legislative discussions.

To date, through fine-tuning, we know that we can reinforce the specific knowledge of a model. However, if we wanted a model to forget specific information, we would have to retrain the model. The problem is that training an LLM costs millions of dollars and is time-intensive.

How do you get an LLM to forget?

In general, machine unlearning is an active field of research. Most studies focus on classification tasks and only a few studies are on generative AI or LLMs. LLMs are particularly problematic because it is difficult to understand from where personal data (chat history or training data) were acquired and in what parameters they are stored. Removing data from a trained model is extremely complex, as model weights are a complex integration of the whole collection of training data.

An interesting approach that has recently been proposed is that we fine-tune the model with the text we want to forget. In this case, we negate the loss function, in other words, we penalize the model when it predicts as the next word in the text what we want to forget.

image source: here
image source: here

As simple and effective as this approach seems, it actually has limitations. For example, if the text we want to forget is my bio: "My name is Salvatore…" the model will forget not only "Salvatore" but also "my name is." In other words, this model forgets general knowledge about language.

Thus, we are interested in looking for an approach that instead of penalizing some text, shifts the model from predicting personal data to giving a generic answer (as if it had never encountered) personal data.

So we want a model that is able to effectively forget about the problem text, but at the same time retain its skills and the rest of its knowledge.

How to forget Harry Potter

Photo by Dollar Gill on Unsplash
Photo by Dollar Gill on Unsplash

"It does not do to dwell on dreams and forget to live" – Albus Dumbledore in the Sorcerer’s Stone

Recently an article dealt with how you can make a model forget an entire book without impacting LLM performance. The authors show how a model can forget the complex plot of Harry Potter and at the same time manage to maintain performance in benchmark datasets.

Who’s Harry Potter? Approximate Unlearning in LLMs

We can consider that an LLM is trained on text dataset X and we want it to forget text subset Y. Through fine tuning we can obtain a model that has enhanced knowledge of Y. This model will be an expert on subject Y. The traditional method would be to retrain the LLM on X-Y but this would require a lot of time and computational resources.

We want a model that preserves its general knowledge and understanding of the language. Therefore, the authors decided to exploit the expert model to help an LLM to forget.

The first step is to understand what a generic prediction would be. For the authors of this paper, a generic prediction for a sentence such as "_He looks at the scar on his___" is the difference between an expert model (which has a thorough understanding of what we want to forget) and the baseline model.

In simple words, the authors took an LLM (LLaMA-7B) as a baseline and fine-tuned it to Harry Potter (expert model). After that, a prompt is given to the two models ("_He looks at the scar on his___"), and a vector v of predictions (logit) is obtained for each, the generic prediction is:

v_generic = v_baseline − α*ReLU*(v_expert − v_baseline)         eq.(1)

Using ReLU and a constant α allows us to extract only the predictions specific to the expert model. This is to prevent the model from forgetting "He looks at the scar on his" but only "forehead" (i.e., where Harry Potter has his scar).

Is it enough?

No, because forgetting a book is not just forgetting the name of a protagonist or a specific term (also because by varying the prompts one could still access this knowledge). The idea is that our model forgets in a deeper way. For authors, this can be achieved by destroying links between entities in the text.

For this reason, the authors extracted the various entities in the book with GPT-4 and translated them with names or entities that are idiosyncratic to the text. These are terms that are consistent but not specific to the book, as you can see from the example:

image source: here
image source: here

This serves to steer the model away conceptually from predicting Harry Potter-related content toward more general texts that are at the same time consistent with the textual input.

Combining these two elements together, the process is divided into four steps:

  • We need to create a dictionary where we map specific elements of the text to generic translations.
  • We get blocks of text (depending on the context length of the chosen LLM). We do the block mapping with our dictionary we get the prediction with the expert model for the original text and the model baseline prediction for the mapped text.
  • We combine the predictions of these two models with the equation described above (eq .1) and thus obtain the generic predictions.
  • In the last step we conduct fine-tuning of the baseline model using the original test as input and the generic labels as target tokens.
image source: here
image source: here

Did our model forget about magic?

Photo by Artem Maltsev on Unsplash
Photo by Artem Maltsev on Unsplash

"The trick to forgetting the big picture is to look at everything close-up." – Chuck Palahniuk

The authors chose LLaMA-2 in the 7B version as the model, both because it was open-source and because it showed excellent capabilities despite its limited size. The training of the original model (pretraining phase with a huge corpus of text) required 184K GPU hours, while the forgetting process proposed by the authors requires only 1 GPU hour (thus definitely inexpensive in terms of resources and affordable for anyone).

META LLaMA 2.0: the most disruptive AInimal

The first step is to assess whether the model actually retained information about the Harry Potter book (e.g., "_When Harry returned to class, he observed his best friends___"). To be sure of this, the authors created a series of textual prompts that the model had to complete based on its internal knowledge. In addition, they created prompts to check whether the model was familiar with what was described in the books (e.g., "Draft a brief narrative in the style of Harry Potter. Short story:"). As can be seen, the model that went through the forgetting process seems to no longer be able to recall elements from the book:

image source: here
image source: here

The authors manually evaluated not only how the model had completed the sentences but also the probabilities associated with a given token. For example, considering the sentence "_Harry Potter studies_" the authors saw whether the words "magic" or "wizardry" were among the highest probability tokens.

The results show that the probability of the next token decreases significantly with each fine-tuning step. The lower the probability of a token the less likely it will be selected, even by changing prompts. According to the authors, only 120 gradient descent fine-tuning steps are needed for optimal results.

image source: here
image source: here

The model seems to have forgotten the book and provided generic answers. The question remains: has the forgetting process impacted the model’s general skills and knowledge?

For this, the authors used three benchmark datasets:

  • WinoGrande is a benchmark for commonsense reasoning (273 expert-crafted resolution problems).
  • HellaSwag is a dataset of sentences to complete that are trivial to humans but not to computers.
  • PIQA, a dataset for commonsense reasoning created to investigate the physical knowledge of existing LLMs.
  • BoolQ is a large question-answering dataset (yes/no) where the model is provided a question, and the context, and has to provide an answer.
  • OpenBookQA, a question-answering dataset modeled after open book exams for assessing human understanding of a subject.
  • ARC, multiple-choice question-answering dataset containing questions from science exams.
example of questions in the datasets. adapted from the original articles ([[here](https://arxiv.org/abs/1905.07830)](https://arxiv.org/abs/1809.02789), here, and here)
example of questions in the datasets. adapted from the original articles ([[here](https://arxiv.org/abs/1905.07830)](https://arxiv.org/abs/1809.02789), here, and here)

The results show that show that performance is minimally impacted by the unlearning process. Obviously, a greater number of gradient steps decreases familiarity with the topic but also impacts performance more.

image source: here
image source: here

However, this study has limitations:

  • There are occasional leaks (if you ask the model for the names of magic schools it suggests Hogwarts). Since the authors used the books as text (but there are also movies and theme parks dedicated to the world of Harry Potter) this could simply mean Wikipedia-level knowledge rather than actual leaks.
  • Second, more sophisticated prompting techniques could lead the model to reveal information. It should therefore be tested with adversarial attacks or other prompting techniques.
  • The method uses GPT-4 and thus its knowledge of Harry Potter, but in other cases, this is not possible.
  • The Harry Potter books have a universe rich in characters, peculiar expressions, and precise themes. While the method seems to work well with a fictional topic, other topics do not have such rich lexical content or are much more abstruse.

The authors aware of the limitations invite the community to try and test the model:

Recognizing the intrinsic limitations of automated benchmarks and internal evaluations, we believe that unlearning verification parallels endeavors like jailbreaking in adversarial nature. Therefore, we open-sourced the model, encouraging the broader community to challenge it, providing a more diverse and extensive set of tests to discern if any remnants of the targeted knowledge persist. (source)

The model is stored on HuggingFace and is available here:

microsoft/Llama2-7b-WhoIsHarryPotter · Hugging Face

Parting thoughts

Photo by Saif71.com on Unsplash
Photo by Saif71.com on Unsplash

"The advantage of a bad memory is that one enjoys several times the same good things for the first time." ― Friedrich Nietzsche

Forgetting something intentionally is a difficult challenge even for humans. This is also difficult for LLMs. As the study of grokking showed, there is a difference between memorizing and learning.

Grokking: Learning Is Generalization and Not Memorization

Initial studies tried to make the model forget by eliminating what it memorized. This impacted his general knowledge and his understanding of language itself. This new study shows how it is not enough to focus on the key terms of a concept (for example, the main characters in Harry Potter) but also on the concept itself (the plot for example).

The authors show how the model loses familiarity with Harry Potter and at the same time maintains its performance in reasoning benchmarks. Although this method is not perfect and has only been tested on a limited case, it opens up some very interesting perspectives. Indeed, pretraining datasets are full of toxic comments, stereotypes, biases, and hateful speech. This is the first step in being able to allow a model to unlearn this content without re-training.

What do think? Let me know in the comments


If you have found this interesting:

You can look for my other articles, you can also subscribe to get notified when I publish articles, you can become a Medium member to access all its stories (affiliate links of the platform for which I get small revenues without cost to you) and you can also connect or reach me on LinkedIn.

Here is the link to my GitHub repository, where I am planning to collect code and many resources related to machine learning, Artificial Intelligence, and more.

GitHub – SalvatoreRa/tutorial: Tutorials on machine learning, artificial intelligence, data science…

or you may be interested in one of my recent articles:

Scaling Data, Scaling Bias: A Deep Dive into Hateful Content and Racial Bias in Generative AI

Tabula Rasa: Why Do Tree-Based Algorithms Outperform Neural Networks


Related Articles