
Large language models (LLM) can store an impressive amount of factual data, but their capabilities are limited by the number of parameters. Furthermore, frequently updating LLM is expensive, while old training data can make LLM produce out-of-date responses.
To tackle the problem above, we can augment LLM with external tools. In this article, I will share how to integrate LLM with retrieval components to enhance performance.
Retrieval-Augmented (RA)
A retrieval component can provide the LLM with more up-to-date and precise knowledge. Given input x, we want to predict output p(y|x). From an external data source R, we retrieve a list of contexts z=(_z_1, z_2,..,z_n) relevant to x. We can join x and z together and make full use of z‘s rich information to predict p(y|x,z)._ Besides, maintaining R up-to-date is also much cheaper.

QA Demo Using Wikipedia data + ChatGPT
In this demo, for a given question, we do the following steps:
- Retrieve Wikipedia documents related to the question.
- Provide both the question and the Wikipedia to ChatGPT.
We want to compare and see how the extra context affects ChatGPT’s responses.
Dataset
For the Wikipedia dataset, we can extract it from here. I use "20220301.simple" subset with more than 200k documents. Due to the context length limit, I only use the title and abstract parts. For each document, I also add a doc id for the retrieval purpose later. So the data examples look like this.
{"title": "April", "doc": "April is the fourth month of the year in the Julian and Gregorian calendars, and comes between March and May. It is one of four months to have 30 days.", "id": 0}
{"title": "August", "doc": "August (Aug.) is the eighth month of the year in the Gregorian calendar, coming between July and September. It has 31 days. It is named after the Roman emperor Augustus Caesar.", "id": 1}
We combine the title and the abstract passage and prepare them for encoding.
with open(input_file, "r") as f:
for line in f.readlines():
try:
example = json.loads(line.strip("n"))
self.id2text[example["id"]] = example.get("title", "") + self.tokenizer.sep_token + example.get("doc", "")
except Exception as _:
continue
if len(self.id2text) >= self.max_index_count:
break
Encoding
Afterwards, we need a reliable embedding model to build our retrieval index. In this demo, I use the pre-trained multilingual-e5-large with dim=1024 to encode the docs. For faster indexing and storage efficiency, you can choose other small-dimension embedding models.
My first embedding model choice was the pre-trained ALBERT, but the results’ quality was poor. You should do a few test cases to make sure your index works reasonably well before moving to the next step. To pick a good embedding for retrieval, you can check out this leaderboard.
@torch.no_grad()
def _get_batch_embedding(self, x: List[str]):
'''
Get embedding of a single batch
Parameters
----------
x: List of text to encode
'''
batch_dict = self.tokenizer(x, max_length=512, padding=True, truncation=True, return_tensors='pt')
outputs = self.model(**batch_dict)
embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
# normalize embeddings
return F.normalize(embeddings, p=2, dim=1)
def _get_all_embeddings(self):
'''
Get embedding of all data points
'''
data_loader = DataLoader(list(self.id2text.values()), batch_size=8)
embeddings = []
for bs in tqdm(data_loader):
embeddings += self._get_batch_embedding(bs)
embeddings = [e.tolist() for e in embeddings]
return embeddings
ANN Index
We have got our documents’ embeddings and ids list ready. The next step is to index them nicely for retrieval. I use the HNSW index with cosine distance measurement.
self.index = hnswlib.Index(space = 'cosine', dim=self.dim)
self.index.init_index(max_elements =self.max_index_count, ef_construction = 200, M = 16)
self.index.add_items(embeddings, ids)
self.index.set_ef(50)
print(f"Finish building ann index, took {time()-start:.2f}s")
self.index.save_index(self.index_file) # so we don't need to do everything once again
Given a question, we can first go to the retrieval index and look for some relevant information. To avoid the wrong context, you can set a threshold for distance here. In this way, only relevant documents will be used:
def get_nn(self, text: List[str], topk:int=1):
embeddings = self._get_batch_embedding(text)
labels, distances = self.index.knn_query(embeddings.detach().numpy(), k=topk)
# map id back to wiki passage
nb_texts = [self.map_id_to_text(label) for label in labels]
if self.debug:
for i in range(len(text)):
print(f"Query={text[i]}, neighbor_id={labels[i]}, neighbor={nb_texts[i]}, distances={distances[i]}")
return nb_texts, labels, distances
ChatGPT API
So now our retrieval pipeline is ready! For the next step, let’s prepare a prompt to ask ChatGPT. We prepare 2 prompt formats as below, one only has the question and one with both the question and relevant Wikipedia text.
The ‘question’ placeholder here is the target question we want to ask ChatGPT, while ‘info’ is the Wikipedia document retrieved from our HNSW index.
NO_WIKI_PROMPT_TEMPLATE = """
Answer the following question:
Question: ```{question}```
Your response:
"""
HAVE_WIKI_PROMPT_TEMPLATE = """
You will be provided with the following information:
1. A question delimited with triple backticks.
2. Addition information that is related to the question.
Perform the following tasks:
1. Understand the provided information.
2. Use the provided information and answer the question.
Question: ```{question}```
Addition information: ```{info}```
Your response:
"""
Filling the placeholder with the question and corresponding Wikipedia doc, we can send out the request to openai ChatCompletion API.
def get_chat_completion(
messages: dict,
model: str = "gpt-3.5-turbo",
max_retries: int = 3,
debug: bool = False
):
'''
Gets a chat completion from the OpenAI API.
Parameters
----------
messages : dict
input messages to use, e.g: {"user", "what is the meaning of BERT"}
model : str, optional
The OPEN AI model to use. Here we set default value to "gpt-3.5-turbo".
max_retries : int, optional
The maximum number of retries to use. Defaults to 3.
debug: bool
If we want to debug or not
'''
model_dict = {"model": model}
error_msg = None
error_type = None
if debug:
logging.warning(f"Sending chat with message={messages}, model_dict={model_dict}...")
for _ in range(max_retries):
try:
completion = openai.ChatCompletion.create(
temperature=0.0, messages=messages, **model_dict
)
return completion
except Exception as e:
error_msg = str(e)
error_type = type(e).__name__
sleep(3)
print(
f"Could not obtain the completion after {max_retries} retries: `{error_type} ::"
f" {error_msg}`"
I have tried sending a few requests with and without Wikipedia documents and see the difference.
Example 1
Question: what does the cost of living include?
ChatGPT: The cost of living typically includes expenses such as housing, food, transportation, healthcare, utilities, and other basic necessities required to maintain a certain standard of living in a particular location. It may also include taxes, insurance, and other miscellaneous expenses.
Wiki info: [[‘Cost of livingCost of living is the amount of money it costs just to live in a certain place. It includes food, housing, etc.’]]
ChatGPT + wiki info: The cost of living includes the amount of money it costs to live in a certain place, which encompasses expenses such as food, housing, and other necessary expenses.
Example 2
Question: how to make cheese?
ChatGPT: Cheese can be made by curdling milk with the help of an acid or rennet, separating the curds from the whey, and then pressing and aging the curds to develop flavor and texture. The specific process and ingredients used can vary depending on the type of cheese being made.
Wiki info: [[‘CheeseCheese is a dairy product that is made from milk. There are many types of cheese, such as cheddar, Swiss, and provolone. ‘]]
ChatGPT + wiki info: To make cheese, you need milk as the main ingredient. Cheese is a dairy product that is made from milk. There are many types of cheese, such as cheddar, Swiss, and provolone.
Example 3
Question: can you explain the seq2seq model?
ChatGPT: The seq2seq (sequence-to-sequence) model is a type of neural network architecture used for tasks that involve generating a sequence of output values from a sequence of input values. It consists of two main components: an encoder and a decoder. The encoder takes in the input sequence and generates a fixed-length vector representation of it, which is then passed to the decoder. The decoder uses this vector to generate the output sequence, one element at a time. The model is commonly used in natural language processing tasks such as machine translation, text summarization, and speech recognition.
Wiki info: [[‘Computability theoryComputability theory is part of computer science. Scientists want to know what can be computed, and what can not. ‘]]
ChatGPT + wiki info: The seq2seq model, or sequence-to-sequence model, is a type of neural network architecture used for natural language processing tasks such as machine translation, text summarization, and conversational modeling. It consists of two recurrent neural networks (RNNs), an encoder and a decoder, that work together to transform an input sequence into an output sequence. The encoder processes the input sequence and produces a fixed-length vector representation, which is then fed into the decoder to generate the output sequence. This model has been successful in many NLP tasks and has become a popular choice for language generation. The provided additional information about computability theory is not directly related to the question about seq2seq model.
At first sight, we can see that the ANN index works quite well, but our dataset coverage is quite small. A better method is to expand your retrieval using Google Search results. You can also update your retrieval index with relevant documents when doing QA on a specific domain.
In examples 1 and 2, when the Wikipedia page seems related, ChatGPT’s answer seems to heavily depend on additional information. But the model may not be able to distinguish if the additional information is actually helpful (example 2). In order to avoid bad cases like this, you can add a few examples in the prompt to let the model know when and when not to use the extra information.
Another case is example 3, where the Wikipedia text is completely irrelevant. Fortunately, the answer seems not to be affected by the extra context.
You can find the code here. Hope you enjoy the reading 🙂