Ideal generative AI versus reality
Foundational LLMs have read every byte of text they could find and their chatbot counterparts can be prompted to have intelligent conversations and be asked to perform specific tasks. Access to comprehensive information is democratized; No more figuring out the right keywords to search or picking sites to read from. However, LLMs are prone to rambling and generally respond with the statistically most probable response you’d want to hear (sycophancy) an inherent result of the transformer model. Extracting 100% accurate information out of an LLM’s knowledge base doesn’t always yield trustworthy results.
Chat LLMs are infamous for making up citations to scientific papers or court cases that don’t exist. Lawyers filing a suit against an airline included citations to court cases that never actually happened. A 2023 study reported, that when ChatGPT is prompted to include citations, it had only provided references that exist only 14% of the time. Falsifying sources, rambling, and delivering inaccuracies to appease the prompt are dubbed hallucination, a huge obstacle to overcome before AI is fully adopted and trusted by the masses.
One counter to LLMs making up bogus sources or coming up with inaccuracies is retrieval-augmented generation or RAG. Not only can RAG decrease the tendency of LLMs to hallucinate but several other advantages as well.
These advantages include access to an updated knowledge base, specialization (e.g. by providing private data sources), empowering models with information beyond what is stored in the parametric memory (allowing for smaller models), and the potential to follow up with more data from legitimate references.
What is RAG (Retrieval Augmented Generation)?
Retrieval-Augmented Generation (RAG) is a deep learning architecture implemented in LLMs and transformer networks that retrieves relevant documents or other snippets and adds them to the context window to provide additional information, aiding an LLM to generate useful responses. A typical RAG system would have two main modules: retrieval and generation.
The main reference for RAG is a paper by Lewis et al. from Facebook. In the paper, the authors use a pair of BERT-based document encoders to transform queries and documents by embedding the text in a vector format. These embeddings are then used to identify the top-k (typically 5 or 10) documents via a maximum inner product search (MIPS). As the name suggests, MIPS is based on the inner (or dot) product of the encoded vector representations of the query and those in a vector database pre-computed for the documents used as external, non-parametric memory.
As described in the piece by Lewis et al., RAG was designed to make LLMs better at knowledge-intensive tasks which “humans could not reasonably be expected to perform without access to an external knowledge source”. Consider taking an open book and non-open book exam and you’ll have a good indication of how RAG might supplement LLM-based systems.
RAG with the Hugging Face 🤗 Library
Lewis et al. open-sourced their RAG models on the Hugging Face Hub, thus we can experiment with the same models used in the paper. A new Python 3.8 virtual environment with virtualenv is recommended.
virtualenv my_env --python=python3.8
source my_env/bin/activate
After activating the environment, we can install dependencies using pip: transformers and datasets from Hugging Face, the FAISS library from Facebook that RAG uses for vector search, and PyTorch for use as a backend.
pip install transformers
pip install datasets
pip install faiss-cpu==1.8.0
#https://pytorch.org/get-started/locally/ to
#match the pytorch version to your system
pip install torch
Lewis et al. implemented two different versions of RAG: rag-sequence and rag-token. Rag-sequence uses the same retrieved document to augment the generation of an entire sequence whereas rag-token can use different snippets for each token. Both versions use the same Hugging Face classes for tokenization and retrieval, and the API is much the same, but each version has a unique class for generation. These classes are imported from the transformers library.
from transformers import RagTokenizer, RagRetriever
from transformers import RagTokenForGeneration
from transformers import RagSequenceForGeneration
The first time the RagRetriever model with the default “wiki_dpr” dataset is instantiated it will initiate a substantial download (about 300 GB). If you have a large data drive and want Hugging Face to use it (instead of the default cache folder in your home drive), you can set a shell variable, HF_DATASETS_CACHE.
# in the shell:
export HF_DATASETS_CACHE="/path/to/data/drive"
# ^^ add to your ~/.bashrc file if you want to set the variable
Ensure the code is working before downloading the full wiki_dpr dataset. To avoid the big download until you’re ready, you can pass use_dummy_dataset=True when instantiating the retriever. You’ll also instantiate a tokenizer to convert strings to integer indices (corresponding to tokens in a vocabulary) and vice-versa. Sequence and token versions of RAG use the same tokenizer. RAG sequence (rag-sequence) and RAG token (rag-token) each have fine-tuned (e.g. rag-token-nq) and base versions (e.g. rag-token-base).
tokenizer = RagTokenizer.from_pretrained(\
"facebook/rag-token-nq")
token_retriever = RagRetriever.from_pretrained(\
"facebook/rag-token-nq", \
index_name="compressed", \
use_dummy_dataset=False)
sequence_retriever = RagRetriever.from_pretrained(\
"facebook/rag-sequence-nq", \
index_name="compressed", \
use_dummy_dataset=False)
dummy_retriever = RagRetriever.from_pretrained(\
"facebook/rag-sequence-nq", \
index_name="exact", \
use_dummy_dataset=True)
token_model = RagTokenForGeneration.from_pretrained(\
"facebook/rag-token-nq", \
retriever=token_retriever)
seq_model = RagTokenForGeneration.from_pretrained(\
"facebook/rag-sequence-nq", \
retriever=seq_retriever)
dummy_model = RagTokenForGeneration.from_pretrained(\
"facebook/rag-sequence-nq", \
retriever=dummy_retriever)
Once your models are instantiated, you can provide a query, tokenize it, and pass it to the “generate” function of the model. We’ll compare results from rag-sequence, rag-token, and RAG using a retriever with the dummy version of the wiki_dpr dataset. Note that these rag-models are case-insensitive
query = "what is the name of the oldest tree on Earth?"
input_dict = tokenizer.prepare_seq2seq_batch(\
query, return_tensors="pt")
token_generated = token_model.generate(**input_dict) token_decoded = token_tokenizer.batch_decode(\
token_generated, skip_special_tokens=True)
seq_generated = seq_model.generate(**input_dict)
seq_decoded = seq_tokenizer.batch_decode(\
seq_generated, skip_special_tokens=True)
dummy_generated = dummy_model.generate(**input_dict)
dummy_decoded = seq_tokenizer.batch_decode(\
dummy_generated, skip_special_tokens=True)
print(f"answers to query '{query}': ")
print(f"\t rag-sequence-nq: {seq_decoded[0]},"\
f" rag-token-nq: {token_decoded[0]},"\
f" rag (dummy): {dummy_decoded[0]}")
>> answers to query ‘What is the name of the oldest tree on Earth?’: Prometheus was the oldest tree discovered until 2012, with its innermost, extant rings exceeding 4862 years of age.
>> rag-sequence-nq: prometheus, rag-token-nq: prometheus, rag (dummy): 4862
In general, rag-token is correct more often than rag-sequence, (though both are often correct), and rag-sequence is more often right than RAG using a retriever with a dummy dataset.
“What sort of context does the retriever provide?” You may wonder. To find out, we can deconstruct the generation process. Using the seq_retriever and seq_model instantiated as above, we query “What is the name of the oldest tree on Earth”
query = "what is the name of the oldest tree on Earth?"
inputs = tokenizer(query, return_tensors="pt")
input_ids = inputs["input_ids"]
question_hidden_states = seq_model.question_encoder(input_ids)[0]
docs_dict = seq_retriever(input_ids.numpy(),\
question_hidden_states.detach().numpy(),\
return_tensors="pt")
doc_scores = torch.bmm(\
question_hidden_states.unsqueeze(1),\
docs_dict["retrieved_doc_embeds"]\
.float().transpose(1, 2)).squeeze(1)
generated = model.generate(\
context_input_ids=docs_dict["context_input_ids"],\
context_attention_mask=\
docs_dict["context_attention_mask"],\
doc_scores=doc_scores)
generated_string = tokenizer.batch_decode(\
generated,\
skip_special_tokens=True)
contexts = tokenizer.batch_decode(\
docs_dict["context_input_ids"],\
attention_mask=docs_dict["context_attention_mask"],\
skip_special_tokens=True)
best_context = contexts[doc_scores.argmax()]
We can code our model to print the variable “best context” to see what was captured
print(f" based on the retrieved context"\
f":\n\n\t {best_context}: \n")
based on the retrieved context:
Prometheus (tree) / In a clonal organism, however, the individual clonal stems are not nearly so old, and no part of the organism is particularly old at any given time. Until 2012, Prometheus was thus the oldest “non-clonal” organism yet discovered, with its innermost, extant rings exceeding 4862 years of age. In the 1950s dendrochronologists were making active efforts to find the oldest living tree species in order to use the analysis of the rings for various research purposes, such as the evaluation of former climates, the dating of archaeological ruins, and addressing the basic scientific question of maximum potential lifespan. Bristlecone pines // what is the name of the oldest tree on earth?
print(f" rag-sequence-nq answers '{query}'"\
f" with '{generated_string[0]}'")
We can also print the answer by calling the “generated_string” variable. The rag-sequence-nq answers ‘what is the name of the oldest tree on Earth?’ with ‘ Prometheus’.
What Can You Do with RAG?
In the last year and a half, there has been a veritable explosion in LLMs and LLM tools. The BART base model used in Lewis et al. was only 400 million parameters, a far cry from the current crop of LLMs, which typically start in the billion parameter range for “lite” variants. Also, many models being trained, merged, and fine-tuned today are multimodal, combining text inputs and outputs with images or other tokenized data sources. Combining RAG with other tools can build complex capabilities, but the underlying models won’t be immune to common LLM shortcomings. The problems of sycophancy, hallucination, and reliability in LLMs all remain and run the risk of growing just as LLM use grows.
The most obvious applications for RAG are variations on conversational semantic search, but perhaps they also include incorporating multimodal inputs or image generation as part of the output. For example, RAG in LLMs with domain knowledge can make software documentation you can chat with. Or RAG could be used to keep interactive notes in a literature review for a research project or thesis.
Incorporating a ‘chain-of-thought’ reasoning capability, you could take a more agentic approach to empower your models to query RAG system and assemble more complex lines of inquiry or reasoning.
It is also very important to keep in mind that RAG does not solve the common LLM pitfalls (hallucination, sycophancy, etc.) and serves only as a means to alleviate or guide your LLM to a more niche response. The endpoints that ultimately matter, are specific to your use case, the information you feed your model, and how the model is finetuned.