The world’s leading publication for data science, AI, and ML professionals.

Retrieval Augmented Generation – Intuitively and Exhaustively Explain

Making language models that can look stuff up

ed

"Data Retriever" By Daniel Warfield using MidJourney. All images by the author unless otherwise specified.
"Data Retriever" By Daniel Warfield using MidJourney. All images by the author unless otherwise specified.

In this post we’ll explore "retrieval augmented generation" (RAG), a strategy which allows us to expose up to date and relevant information to a large language model. We’ll go over the theory, then imagine ourselves as resterauntours; we’ll implement a system allowing our customers to talk with AI about our menu, seasonal events, and general information.

The final result of the practical example, a chat bot which can serve specific information about our restaurant.
The final result of the practical example, a chat bot which can serve specific information about our restaurant.

Who is this useful for? Anyone interested in natural language processing (NLP).

How advanced is this post? This is a very powerful, but very simple concept; great for beginners and experts alike.

Pre-requisites: Some cursory knowledge of large language models (LLMs) would be helpful, but is not required.

The Core of the Issue

LLMs are expensive to train; chat GPT-3 famously cost a cool $3.2M on compute resources alone. If we opened up a new restaurant, and wanted to use an LLM to answer questions about a menu, it’d be cool if we didn’t have to dish out millions of dollars every time we introduced a new seasonal salad. We could do a smaller training step (called fine tuning) to try to get the model to learn a small amount of highly specific information, but this process can still be hundreds to thousands of dollars.

Another problem with LLMs is their confidence; sometimes they say stuff that’s flat out wrong with abject certainty (commonly referred to as haluscinating). As a result it can be difficult to discern where an LLM is getting its information from, and if that information is accurate. If a customer with allergies asks if a dish contains tree-nuts, it’d be cool if we could ensure our LLM uses accurate information so our patrons don’t go into anaphylactic shock.

Attorney Steven A. Schwartz first landed himself in hot water through his use of ChatGPT, which resulted in six fake cases being cited in a legal brief. – A famous example of hallucination in action. source

Both the issue of updating information and using proper sources can be mitigated with RAG.

Retrieval Augmented Generation, In a Nutshell

In-context learning is the ability of an LLM to learn information not through training, but by receiving new information in a carefully formatted prompt. **** For example, say you wanted to ask an LLM for the punchline, and only the punchline, of a joke. Jokes come in a setup-punchline combo extremely often, and because LLMs are statistical models it can be difficult for them to break that prior knowledge.

An example of ChatGPT failing a task because of lack of context
An example of ChatGPT failing a task because of lack of context

One way we can solve this is by giving the model "context"; we can give it a sample in a cleverly formatted prompt such that the LLM gives us the right information.

An example of ChatGPT succeeding at the same task when more context is provided
An example of ChatGPT succeeding at the same task when more context is provided

This trait of LLMs has all sorts of cool uses. I’ve written an article on how this ability can be used to talk with an LLM about images, and how it can be used to extract information from conversations. In this article we’ll be leveraging this ability to inject information into the model via a carefully constructed prompt, based on what the user asked about, to provide the model in-context information.

A conceptual diagram of RAG. The prompt is used to retrieve information in a knowledge base, which is in turn used to augment the prompt. This augmented prompt is then fed into the model for generation.
A conceptual diagram of RAG. The prompt is used to retrieve information in a knowledge base, which is in turn used to augment the prompt. This augmented prompt is then fed into the model for generation.

the RAG process comes in three key parts:

  1. Retrieval: Based on the prompt, retrieve relevant knowledge from a knowledge base.
  2. Augmentation: Combine the retrieved information with the initial prompt.
  3. Generate: pass the augmented prompt to a large language model, generating the final output.
An example of what a RAG prompt might look like.
An example of what a RAG prompt might look like.

Retrieval

The only really conceptually challenging part of RAG is retrieval: How do we know which documents are relevant to a given prompt?

There‘s a lot of ways this could be done. Naively, you could iterate through all your documents and ask an LLM "is this document relevant to the question". You could pass both the document and the prompt to the LLM, ask the LLM if the document is relevant to the prompt, and use some query parser (I talk about those here) to get the LLM to give you a yes or no answer.

Alternatively, for an application as simple as ours, we could just provide all the data. We’ll probably only have a few documents we’ll want to refer to; our restaurant’s menu, events, maybe a document about the restaurants history. we could inject all that data into every prompt, combined with the query from a user.

However, say we don’t just have a restaurant, but a restaurant chain. We’d have a vast amount of information our customers could ask about: dietary restrictions, when the company was founded, where the stores are located, famous people who’ve dined with us. We’d have an entire franchise’s worth of documents; too much data to just put it all in every query, and too much data to ask an LLM to iterate through all documents and tell us which ones are relevant.

We can use word vector embeddings to deal with this problem. With word vector embeddings we can quickly calculate the similarity of different documents and prompts. The next section will go over word vector embeddings in a nutshell, and the following section will detail how they can be used for retrieval within RAG.

Word Vector Embeddings in a Nutshell

This section is an excerpt from my article on transformers:

Transformers – Intuitively and Exhaustively Explained

In essence, a word vector embedding takes individual words and translates them into a vector which somehow represents its meaning.

The job of a word to vector embedder: turn words into numbers which somehow capture their general meaning.
The job of a word to vector embedder: turn words into numbers which somehow capture their general meaning.

The details can vary from implementation to implementation, but the end result can be thought of as a "space of words", where the space obeys certain convenient relationships. Words are hard to do math on, but vectors which contain information about a word, and how they relate to other words, are significantly easier to do math on. This task of converting words to vectors is often referred to as an "embedding".

Word2Vect, a landmark paper in the natural language processing space, sought to create an embedding which obeyed certain useful characteristics. Essentially, they wanted to be able to do algebra with words, and created an embedding to facilitate that. With Word2Vect, you could embed the word "king", subtract the embedding for "man", add the embedding for "woman", and you would get a vector who’s nearest neighbor was the embedding for "queen".

A conceptual demonstration of doing algebra on word embeddings. If you think of each of the points as a vector from the origin, if you subtracted the vector for "man" from the vector for "king", and added the vector for "woman", the resultant vector would be near the word queen. In actuality these embedding spaces are of much higher dimensions, and the measurement for "closeness" can be a bit less intuitive (like cosine similarity), but the intuition remains the same.
A conceptual demonstration of doing algebra on word embeddings. If you think of each of the points as a vector from the origin, if you subtracted the vector for "man" from the vector for "king", and added the vector for "woman", the resultant vector would be near the word queen. In actuality these embedding spaces are of much higher dimensions, and the measurement for "closeness" can be a bit less intuitive (like cosine similarity), but the intuition remains the same.

I’ll cover word embeddings more exhaustively in a future post, but for the purposes of this article they can be conceptualized as a Machine Learning model which has learned to group words as vectors in a meaningful way. With a word embedding you can start thinking of words in terms of distance. For instance, the distance between a prompt and a document. This idea of distance is what we’ll use to retrieve relevant documents.

Using Word Embeddings For Retrieval

We know how to turn words into a point in some high dimensional space. How can we use those to know which documents are relevant to a given prompt? There’s a lot of ways this can be done, it’s still an active point of research, but we’ll consider a simple yet powerful approach; the manhattan distance of the mean vector embedding.

The Mean Vector Embedding

We have a prompt which can be thought of as a list of words, and we have documents which can also be thought of as lists of words. We can summarize these lists of words by first embedding each word with Word2Vect, then we can calculate the average of all of the embeddings.

A conceptual diagram of calculating the mean vector of all embeddings in a sequence of words. Each index in the resultant vector is simply the average of all the corresponding index in every word.
A conceptual diagram of calculating the mean vector of all embeddings in a sequence of words. Each index in the resultant vector is simply the average of all the corresponding index in every word.

Conceptually, because the word vector encodes the meaning of a word, the mean vector embedding calculates the average meaning of the entire phrase.

A conceptual diagram of calculating the mean vector of a prompt. The patron to our restaurant asks "When does the restaurant have live bands?" Each of these words is passed through an embedder (like word2vect), and then the mean of each of these vectors is calculated. This is done by calculating the average of each index. Conceptually, this can be thought of as calculating the average meaning of the entire phrase.
A conceptual diagram of calculating the mean vector of a prompt. The patron to our restaurant asks "When does the restaurant have live bands?" Each of these words is passed through an embedder (like word2vect), and then the mean of each of these vectors is calculated. This is done by calculating the average of each index. Conceptually, this can be thought of as calculating the average meaning of the entire phrase.

Manhattan Distance

Now that we’ve created a system which can summarize the meaning of a sequence of words down to a single vector, we can use this vector to compare how similar two sequences of words are. In this example, we’ll use the manhattan distance, though many other distance measurements can be used.

A conceptual diagram of the manhattan distance. On the left we find it's namesake: instead of a traditional distance measurement between the two points, the manhattan distance is the sum of the distance along the two axis; the y axis and the x axis. On the right you can see this concept illustrated in terms of comparing vectors. We find the distance between the vectors on an element by element basis, then sum those distances together to get the manhattan distance. Conceptually, this method of distance calculation is best when different axis might represent fundamentally different things, which is a common intuition in vector embeddings.
A conceptual diagram of the manhattan distance. On the left we find it’s namesake: instead of a traditional distance measurement between the two points, the manhattan distance is the sum of the distance along the two axis; the y axis and the x axis. On the right you can see this concept illustrated in terms of comparing vectors. We find the distance between the vectors on an element by element basis, then sum those distances together to get the manhattan distance. Conceptually, this method of distance calculation is best when different axis might represent fundamentally different things, which is a common intuition in vector embeddings.

Combining these two concepts together, we can find the mean vector embedding of the prompt, and all documents, and use the manhattan distance to sort the documents in terms of distance, a proxy for relatedness.

How the most relevant documents are found. The mean vector embedding is calculated for the prompt, as well as all documents. A distance is calculated between the prompt and all documents, allowing the retrieval system to prioritize which documents to include in augmentation.
How the most relevant documents are found. The mean vector embedding is calculated for the prompt, as well as all documents. A distance is calculated between the prompt and all documents, allowing the retrieval system to prioritize which documents to include in augmentation.

And that’s the essence of retrieval; you calculate a word vector embedding for all words in all pieces of text, then compute an mean vector which represents each piece of text. We can then use the manhattan distance as a proxy for similarity.

In terms of actually deciding which documents to use, there’s a lot of options. You could set a maximum distance threshold, in which any larger distance would count as irrelevant, or you could always include the document with the minimum distance. The exact details depend on the needs of the application. To keep things simple we’ll always retrieve the document with the lowest distance to the prompt.

A Note on Vector Databases

Before I move onto augmentation and generation, a note.

In this article I wanted to focus on the concepts of RAG without going through the specifics of vector data bases. They’re a fascinating and incredibly powerful technology which I’ll be building from scratch in a future post. If you’re implementing RAG in a project, you’ll probably want to use a vector database to achieve better query performance when calculating the distance between a prompt and large number of documents. Here’s a few options you might be interested in:

typically RAG is achieved by hooking up one of these databases with LangChain, a workflow I’m planning on tackling in another future post.

Augmentation and Generation

Cool, so we’re able to retrieve which documents are relevant to a users prompt. How do we actually use them? This can be done with a prompt formatted to the specific application. For instance, we can declare the following format:

"Answer the customers prompt based on the folowing context:
==== context: {document title} ====
{document content}

...

prompt: {prompt}"

This format can then be used, along with whichever document was deemed useful, to augment the prompt. This augmented prompt can then be passed directly to the LLM to generate the final output.

RAG From Scratch

We’ve covered the theory; retrieval, augmentation, and generation. In order to further our understanding, we’ll implement RAG more or less from scratch. We’ll use a pre-trained word vector embedder and LLM, but we’ll do distance calculation and augmentation ourselves.

you can find the full code here:

MLWritingAndResearch/RAGFromScratch.ipynb at main · DanielWarfield1/MLWritingAndResearch

Downloading the Word to Vector Encoder

First we need to download a pre-trained encoder, which has learned the relationships between words and, as a result, knows which words belong in certain regions of space.

"""Downloading a word encoder.
I was going to use word2vect, but glove downloads way faster. For our purposes
they're conceptually identical
"""

import gensim.downloader

#doenloading encoder
word_encoder = gensim.downloader.load('glove-twitter-25')

#getting the embedding for a word
word_encoder['apple']
The embedding for the word "apple"
The embedding for the word "apple"

Embedding a Document or Prompt

Now that we have an encoder, we can calculate the mean of all embeddings in a given word to embed an entire sequence of text, like a prompt or document.

"""defining a function for embedding an entire document to a single mean vector
"""

import numpy as np

def embed_sequence(sequence):
    vects = word_encoder[sequence.split(' ')]
    return np.mean(vects, axis=0)

embed_sequence('its a sunny day today')
The mean vector of all embeddings in "it's a sunny day today"
The mean vector of all embeddings in "it’s a sunny day today"

Calculating Distance

We can use scipy’s cdist function to calculate the manhattan distance, which is used as a proxy for similarity.

"""Calculating distance between two embedding vectors
uses manhattan distance
"""

from scipy.spatial.distance import cdist

def calc_distance(embedding1, embedding2):
    return cdist(np.expand_dims(embedding1, axis=0), np.expand_dims(embedding2, axis=0), metric='cityblock')[0][0]

print('similar phrases:')
print(calc_distance(embed_sequence('sunny day today')
                  , embed_sequence('rainy morning presently')))

print('different phrases:')
print(calc_distance(embed_sequence('sunny day today')
                  , embed_sequence('perhaps reality is painful')))
The distance between similar and different phrases. Notice how the similar phrases don't actually have any of the same words, but have similar general meaning. Also, the last quote is from a book I'm reading called "Beyond Good and Evil". I'm not trying to be edgy. There's a part where Nietzsche talks about, perhaps, reality is painful in nature, and the strength of one's will is the capacity to observe it undiluted.
The distance between similar and different phrases. Notice how the similar phrases don’t actually have any of the same words, but have similar general meaning. Also, the last quote is from a book I’m reading called "Beyond Good and Evil". I’m not trying to be edgy. There’s a part where Nietzsche talks about, perhaps, reality is painful in nature, and the strength of one’s will is the capacity to observe it undiluted.

Defining Retrieval and Augmentation

Now that we’re calculating relevance, it might be useful to define a few documents.

"""Defining documents
for simplicities sake I only included words the embedder knows. You could just
parse out all the words the embedder doesn't know, though. After all, the retreival
is done on a mean of all embeddings, so a missing word or two is of little consequence
"""
documents = {"menu": "ratatouille is a stew thats twelve dollars and fifty cents also gazpacho is a salad thats thirteen dollars and ninety eight cents also hummus is a dip thats eight dollars and seventy five cents also meat sauce is a pasta dish thats twelve dollars also penne marinera is a pasta dish thats eleven dollars also shrimp and linguini is a pasta dish thats fifteen dollars",
             "events": "on thursday we have karaoke and on tuesdays we have trivia",
             "allergins": "the only item on the menu common allergen is hummus which contain pine nuts",
             "info": "the resteraunt was founded by two brothers in two thousand and three"}

Now we can define a function which uses our previous distance calculation to define which documents are relevant to a given prompt

"""defining a function that retreives the most relevent document
"""

def retreive_relevent(prompt, documents=documents):
    min_dist = 1000000000
    r_docname = ""
    r_doc = ""

    for docname, doc in documents.items():
        dist = calc_distance(embed_sequence(prompt)
                           , embed_sequence(doc))

        if dist < min_dist:
            min_dist = dist
            r_docname = docname
            r_doc = doc

    return r_docname, r_doc

prompt = 'what pasta dishes do you have'
print(f'finding relevent doc for "{prompt}"')
print(retreive_relevent(prompt))
print('----')
prompt = 'what events do you guys do'
print(f'finding relevent doc for "{prompt}"')
print(retreive_relevent(prompt))
Note, this is just a proof of concept. One of the issue I faced was when the term "guys" showed up in the prompt, i.e. "what pasta dishes do you guys have". The info states that the restaurant was founded by "two brothers", and the info would come up instead of the menu. These types of quirks are the reality of the art.
Note, this is just a proof of concept. One of the issue I faced was when the term "guys" showed up in the prompt, i.e. "what pasta dishes do you guys have". The info states that the restaurant was founded by "two brothers", and the info would come up instead of the menu. These types of quirks are the reality of the art.

Augmenting and Generating

Now we can put it all together. Get a query from a user, retrieve relevant documents, augment the prompt, and pass it to an LLM.

Augmentation might look something like this:

"""Defining retreival and augmentation
creating a function that does retreival and augmentation,
this can be passed straight to the model
"""
def retreive_and_agument(prompt, documents=documents):
    docname, doc = retreive_relevent(prompt, documents)
    return f"Answer the customers prompt based on the folowing documents:n==== document: {docname} ====n{doc}n====nnprompt: {prompt}nresponse:"

prompt = 'what events do you guys do'
print(f'prompt for "{prompt}":n')
print(retreive_and_agument(prompt))

And generation might look something like this:

"""Using RAG with OpenAI's gpt model
"""

import openai
openai.api_key = OPENAI_API_TOKEN

prompts = ['what pasta dishes do you have', 'what events do you guys do', 'oh cool what is karaoke']

for prompt in prompts:

    ra_prompt = retreive_and_agument(prompt)
    response = openai.Completion.create(model="gpt-3.5-turbo-instruct", prompt=ra_prompt, max_tokens=80).choices[0].text

    print(f'prompt: "{prompt}"')
    print(f'response: {response}')
our custom RAG enabled chat bot in action.
our custom RAG enabled chat bot in action.

Conclusion

And that’s it! In this post we went over how word vector embeddings play a key part in RAG and how embeddings can be manipulated to summarize a sequence of words. We went over using distance to get relevant information, then tied it all together with augmentation to query an LLM. In the end we created a chatbot that can leverage up-to-date information.

Follow For More!

I describe papers and concepts in the ML space, with an emphasis on practical and intuitive explanations. I plan on creating more posts on best practice RAG implementation techniques, and implementing a vector database from scratch. Stay tuned!

Get an email whenever Daniel Warfield publishes

Never expected, always appreciated. By donating you allow me to allocate more time and resources towards more frequent and higher quality articles. Link
Never expected, always appreciated. By donating you allow me to allocate more time and resources towards more frequent and higher quality articles. Link

Attribution: All of the images in this document were created by Daniel Warfield, unless a source is otherwise provided. You can use any images in this post for your own non-commercial purposes, so long as you reference this article, https://danielwarfield.dev, or both.


Related Articles