What is Retrieval-Augmented Generation (RAG)? -- Part 1: Context
This is the first part of a three-part blog:
Imagine you meet a person who is super talented in languages. This person can have a conversation with you, summarise a very long document in seconds, translate between many languages, and even entertain you with original but not-so-funny jokes. However this person does not have all the facts right all the time. Therefore you might hear some “confident bullshit” from this person when you ask about certain things.
Now, you want to have a meaningful discussion with this talky person around a topic which is explained in a book or a wikipedia page. What would you do?
One idea is to throw this person a document with all the info you need for the discussion, and ask this person to read the document in a few seconds and then discuss with you only based on the documented information.
This is the idea of retrieval-augmented generation (RAG). RAG is a popular technique for mitigating the hallucination problem of large language models (LLMs). As you read in the above, the idea is actually quite simple. But there are some questions to be answered for understanding RAG a little bit deeper:
- How does AI use documented information to refine its answer? (context)
- How does AI “read” those documented information? (embedding)
- How does AI store and find the relevant information to answer a question? (retrieval)
In this blog, I will try to explain the theory behind RAG, how to implement RAG, and some example applications. To manage this under the constraint of a baby’s nap time, please allow me to break the blog into three parts.
Providing context for LLMs
First let’s talk a bit about how LLMs work. Essentially AI or machine learning models are approximations of the following conditional distribution:
\[p_{\boldsymbol{\theta}}(\ \text{target to be predicted}\ |\ \text{observed data} ),\]where ${\boldsymbol{\theta}}$ is a set of parameters or hyper-parameters. For deep neural networks such as LLMs, ${\boldsymbol{\theta}}$ could consist of multiple billions parameters. The optimal parameters ${\boldsymbol{\theta}}^*$ is obtained by minimize a loss function $\mathcal{l}(\boldsymbol{\theta}\ ,\mathcal{D} )$, where $\mathcal{D}$ is the training data.
In case of LLMs, the “observed data” is the text given to a model (prompt) and the “target to be predict” would be the next most likely words (response):
\[p_{\boldsymbol{\theta}^*}(\ \text{response}\ |\ \underbrace{\text{question}}_\text{a simple prompt} ).\]So in a very simplified way, a LLM is just predicting the next words based on your input and has no guarantee of factual correctness (or uncertainty measure). That is why sometimes we feel a LLM is confidently making up stuff, also known as the hallucination problem.
One way to mitigate hallucination is to add context in your prompt, such that the above conditional distribution becomes
\[p_{\boldsymbol{\theta}^*}(\ \text{response}\ |\ \underbrace{\text{context},\ \text{system message},\ \text{question}}_\text{a more sophisticated prompt}\ ),\]where
- Context is the documented info such as a wikipedia page;
- System message is something like “Answering the questions based only on the provided context”;
- Question is a user input such as “How many nations are there at Uppsala University?”.
Using the first principles, the intuitive explanation of providing context for a LLM to refine its output is equivalent to providing more information (hence reduce the uncertainty) such that the output is drawn from a probability distribution with a smaller variance. We should also note that a LLM is still predicting next words, conditioning on the provided context. At least I have not read any theoretical explanation of how LLM comprehends the logic or the actual “meaning” of the texts. But hey, maybe human intelligence isn’t that much more different than just “predicting the next words” based on the past training and the current stimulation?
RAG and fine-tuning
Before going to the example, let’s clarify some doubts about RAG and fine-tuning.
Fine-tuning is essentially further optimising ${\boldsymbol{\theta}^*}$ to ${\boldsymbol{\theta}’}$ by using a tailored and typically smaller dataset $\mathcal{D}’$ for a specific domain. From a Bayesian point of view, fine-tuning is equivalent to update the model parameters with a new dataset, while RAG is equivalent to update the posterior belief by conditionalizing on the context. And we can of course combine both techniques by firstly fine-tuning the model and then applying RAG to update the posterior belief.
To summarise, here are the Bayesian interpretation of fine-tuning and RAG:
- Fine-tuning:
- RAG:
- Fine-tuning and RAG:
Example: How many nations are there at Uppsala University?
This question might appear strange to you. A nation is a student society at Uppsala University. This is actually a good example of the importance of context.
If we ask this question to Llama 2 (an open-source LLM developed by Meta) without any context, here is the answer:
>>> How many nations are there at Uppsala University?
Uppsala University is an international university with students and staff from all over the
world. According to the university's website, the following countries are represented at the
university:
1. Sweden: The majority of students and staff at Uppsala University are from Sweden.
2. Finland: Finland is the second-largest source country for international students at Uppsala
University.
3. Other European countries: Students from other European countries, such as Norway, Denmark,
Germany, France, and the UK, also attend Uppsala University.
4. Rest of the world: Students from countries outside Europe, including Africa, Asia,
Australia, and South America, are also represented at Uppsala University.
The exact number of nations represented at Uppsala University can vary depending on the
academic year and the specific program or department involved.[1][2
Clearly the LLM didn’t understand what a nation means in this question. It interpreted the word “nation” as a country. It even had a fake and incomplete reference numbers (“[1][2”) at the end – I suspect that a lot of the LLM’s training data for this type of questions have reference numbers in the text. So the AI is basically generating some stuff which looks like a good answer.
Now let’s apply RAG to the same Q&A application. We will ask AI to answer the same question based on the Wikipedia page about Uppsala University (https://en.wikipedia.org/wiki/Uppsala_University).
For example, using LangChain, the context can be added as a part of the prompt template:
from langchain_core.prompts import ChatPromptTemplate
prompt = ChatPromptTemplate.from_template("""Answer the following question based only on the provided context:
<context>
{context}
</context>
Question: {input}""")
We will explain later how the context is populated with the right information from the Wikipage about UU. For now, let’s just look at the AI’s response after adding the right context to the prompt:
There are 13 nations at Uppsala University in Sweden. The nations are:
1. Stockholms nation
2. Uplands nation
3. Gästrike-Hälsinge nation
4. Östgöta nation
5. Västgöta nation
6. Södermanlands-Nerikes nation
7. Västmanlands-Dala nation
8. Smålands nation
9. Göteborgs nation
10. Kalmar nation
11. Värmlands nation
12. Norrlands nation
13. Gotlands nation
Now the answer is much more to the point and accurate! We can verify that this answer is indeed coherent with the information from the Wikipage about Uppsala University (see Section Student Life - Nations and student union).
References
[1] Lewis, Patrick, et al. “Retrieval-augmented generation for knowledge-intensive nlp tasks.” Advances in Neural Information Processing Systems 33 (2020): 9459-9474.
[2] LangChain Cookbook: RAG: https://python.langchain.com/docs/expression_language/cookbook/retrieval