back to indexRAG But Better: Rerankers with Cohere AI
Chapters
0:0 RAG and Rerankers
1:25 Problems of Retrieval Only
4:32 How Embedding Models Work
6:34 How Rerankers Work
8:20 Implementing Reranking in Python
13:11 Testing Retrieval without Reranking
15:21 Retrieval with Cohere Reranking
21:54 Tips for Reranking
00:00:00.000 |
Driva Augmented Generation, or RAG, has become a little bit of an overloaded term. 00:00:06.240 |
It promises quite a lot, but when we actually start implementing it, especially when we're 00:00:12.000 |
new to doing this stuff, the results are sometimes amazing, but more often than not, kind of 00:00:21.140 |
And that is because RAG, as with most tools, is very easy to get started with, but then 00:00:27.380 |
it's very hard to actually get good at implementing. 00:00:30.640 |
The truth is that there is a lot more to RAG than just putting documents into a vector 00:00:36.320 |
database and then retrieving documents from that vector database and putting them into 00:00:41.560 |
In order to make the most out of RAG, you have to do a lot of other things as well. 00:00:46.080 |
So that's why we're starting this series on how to do RAG better. 00:00:52.160 |
In this first video, we're going to be looking at how to do re-ranking, which is probably 00:00:58.420 |
the easiest and fastest way to make a RAG pipeline better. 00:01:03.320 |
Now I'm going to be talking throughout this entire series within the context of RAG and 00:01:08.040 |
LLMs, but in reality, this can be applied to retrieval as a whole. 00:01:12.620 |
If you have a semantic search application, or maybe even recommendation systems, you 00:01:17.880 |
can actually apply not all, but a lot of what we're going to be talking about throughout 00:01:22.260 |
the series, including re-ranking, which we'll go through today. 00:01:25.740 |
So before jumping into the solution of re-ranking, I'm going to talk a little bit about the problem 00:01:31.320 |
that we face with just retrieval as a whole, and then specific to LLMs. 00:01:36.980 |
So to begin with retrieval, to ensure fast search times, we use something called vector 00:01:43.540 |
That is, we transform our text into vectors, place them all into a vector space, and then 00:01:47.780 |
compare their proximity to what we call a query vector, which is just a vector version 00:01:52.680 |
of some sort of query, and see which ones are the closest together, and we return them. 00:01:58.060 |
Now for vector search to work, we need vectors, which are essentially just compressed representations 00:02:08.380 |
Because we're compressing that information into a single vector, we will naturally lose 00:02:12.740 |
some information, but that is the cost of vector search, and for the most part, it's 00:02:22.420 |
But what I tend to find with vector search and RAG with LLMs is that, okay, I get some 00:02:28.560 |
good results at the top, but there's actually another result in, let's say, position 17, 00:02:33.260 |
for example, that actually provides some very relevant context for the question that I have 00:02:39.740 |
So in this example, let's say this is position 17 down here. 00:02:44.540 |
We have that relevant item, but what we would typically do when we're doing RAG with LLMs 00:02:54.140 |
So we're missing out on these other relevant records down here. 00:03:00.860 |
The simplest is simply to just return everything, and send all of these into our LLM. 00:03:11.780 |
Now that's okay, but LLMs have limited context windows. 00:03:17.260 |
So we're going to end up filling that context window very quickly if we just start returning 00:03:24.100 |
So we want to return all of this, so we want to return a lot of records so that we have 00:03:28.300 |
high retrieval recall, but then we want to limit the number of records we actually send 00:03:40.840 |
So by adding a re-ranker, we can still use all of those records, we still get to return 00:03:46.300 |
all of these from our retrieval component, but then the records that we actually send 00:03:51.360 |
to our LLM are just these here, these top three. 00:03:55.560 |
And the re-ranker has gone ahead and handled the reordering of our records to get the most 00:04:01.640 |
relevant items at the top, so we can then send all of that to our LLM. 00:04:08.400 |
Now the question here is, is a re-ranker really going to help us here? 00:04:12.360 |
Can we not just use a better retrieval model? 00:04:16.440 |
And yes, we can use a better retrieval model, and that's something we'll be talking about 00:04:22.480 |
But there is a very good reason as to why a re-ranker can generally perform better than 00:04:43.760 |
Now what it's doing is we have a transformer model. 00:04:52.160 |
The reason that I've got two of them on the screen right now is because you use your first 00:04:57.680 |
iteration or inference step of the transformer model to create your embedding for document 00:05:03.200 |
A. And from that you get your vector A. So that is the compressed information that we 00:05:09.880 |
can then take across to our vector database, which would kind of be like this point here. 00:05:17.880 |
And then in another inference step, we're going to do the same for document B. We get 00:05:24.560 |
We have that in our vector search, and we can then compare the proximity of those two 00:05:32.480 |
The metric that we'd be using here, the computation, would be either dot product or cosine in the 00:05:42.400 |
Now you have to consider that the computational complexity of something like cosine similarity 00:05:49.360 |
is much simpler than one of these transformer inference steps. 00:05:55.240 |
So the reason that we use this encoder architecture is that we can do all of the transformer inferences 00:06:03.160 |
at the start, when we're building our index, that takes a long time because transformers 00:06:11.320 |
Whereas the cosine similarity step at the end, which we can run at the time when our 00:06:19.640 |
So it's kind of like we're doing the heavy part of the computation to compare documents 00:06:28.880 |
And that means we can do very quick, simple computations at user query time. 00:06:34.320 |
And that is different to what we do re-ranking. 00:06:41.160 |
And at query time, right, so let's say document A here, maybe that's our query. 00:06:49.940 |
And document B is one of the documents in the database. 00:06:53.380 |
We're saying to the transformer, okay, how similar are these two items? 00:06:59.520 |
So to compare the similarity in this case, we're running an entire transformer inference 00:07:06.780 |
And notice, because we're doing everything in a single transformer step, we're not losing 00:07:11.680 |
as much information as we are with this one, where we're compressing everything into vectors. 00:07:16.880 |
That means that theoretically, we lose less information, so we can get a more accurate 00:07:25.760 |
So it's kind of like, you know, on one side, you have fast and, you know, relatively accurate. 00:07:33.320 |
And then on this side, you have slow, but super accurate. 00:07:37.240 |
So the idea with the sort of re-ranking approach to retrieval is that we use our retrieval 00:07:46.100 |
encoder step to basically filter down the total number of documents to just, you know, 00:07:53.580 |
in this example, let's say there's like 25 documents there. 00:07:59.880 |
So feeding them into our re-ranker is actually going to be very fast. 00:08:03.200 |
Whereas if we fed all documents into our re-ranker, we'd be waiting, I don't know, like a really 00:08:12.600 |
So instead, we filter down the encoder, feed them into the re-ranker, and then we'll get 00:08:21.080 |
So that is how the re-ranking approach works. 00:08:27.180 |
Let's see how we'd actually implement that in Python. 00:08:30.920 |
So we're going to be working through this notebook here. 00:08:34.800 |
We need HookingFace datasets, that's going to be where we get our dataset from, OpenAI 00:08:41.240 |
for creating our embeddings, Pinecone for storing those embeddings, and Cohere for our 00:08:47.920 |
We're going to start by downloading our dataset, which is this AI archive. 00:08:52.160 |
It's pre-chunked, so I've already chunked it into like tokens of 300, I think, something 00:08:59.400 |
And it's basically just a dataset of archive papers. 00:09:04.180 |
You can kind of see a few of them here that are related to LLMs. 00:09:07.660 |
Essentially, I gathered it by taking some recent papers that are well-known, like LLAMA 00:09:13.820 |
2 paper, GPT-4 paper, GPT-Q, and so on, and just extracting that, extracting what that 00:09:20.960 |
was referencing, and extracting those papers, and kind of just going in a loop through that. 00:09:27.460 |
So yeah, we have a fair few records in there. 00:09:31.300 |
It's not huge, but it's not small either, so 41.5,000 chunks, where each chunk is roughly 00:09:42.020 |
So I'm just going to reformat the data into the format we need. 00:09:48.020 |
You have ID, text, which we're going to convert into embeddings, and metadata. 00:09:53.280 |
We're not going to use metadata in this example, but it can be useful, and maybe it's something 00:09:58.460 |
that we'll look at in a future video in this series as well. 00:10:07.020 |
So we need to define that encoder model that we're going to be using. 00:10:13.100 |
It's easy, R002, fairly good performance, although there are better models, and that's 00:10:18.140 |
something we will also be talking about in the future. 00:10:21.280 |
So I'm going to just run that, and I will need to enter my OpenAI API key. 00:10:27.740 |
To get that, you need to head on over to platform.openai.com, and get your API key. 00:10:40.460 |
So with that, we should be able to initialize our embedding model, which we are doing here. 00:10:46.640 |
I'm not going to go through all these functions, because I've done it a million times before. 00:10:52.280 |
I think people are probably getting bored of that part of these videos. 00:10:57.620 |
So I'm just going to run through those bits very quickly. 00:11:01.640 |
I'm going to get my pinecone credentials, again, app.pinecone.io for those, and I will 00:11:09.000 |
run that, enter my API key first, and then I want my PyCone environment, which I find 00:11:22.080 |
Yours would probably be like gcpsarter or something along those lines. 00:11:27.280 |
So here, I'm going to create an index if it doesn't already exist. 00:11:30.520 |
My index does actually already exist, and I'm not going to recreate it, because it takes 00:11:35.160 |
a little bit of time, or at least it did the other day when creating this. 00:11:39.120 |
So you can see that I already have like the 41,000 records in there. 00:11:45.320 |
If you're looking at that, you should probably see nothing in yours, unless you've just run 00:11:50.600 |
this or you're connecting to an existing index. 00:11:55.320 |
So this is the code I use to create my index, right? 00:12:01.060 |
The one thing that is maybe a little more complicated, but it's not that complicated, 00:12:06.480 |
is we're actually creating the embeddings here. 00:12:10.200 |
So I think I defined an embedding function up here, actually, and I ended up not using 00:12:17.400 |
So in here, this is where we're doing our embeddings, but we're wrapping it within an 00:12:22.620 |
exponential backoff function to avoid rate lump errors, which I was hitting a lot the 00:12:32.020 |
If it gets a rate limit error, it's going to wait. 00:12:36.440 |
And it's going to keep doing that for a maximum of five retries. 00:12:39.520 |
Hopefully, you shouldn't be hitting five retries. 00:12:49.200 |
But if you are hitting those rate limit errors, you might be waiting a little bit of time 00:12:56.880 |
I was hitting tons of rate limit errors the other day, and I ended up-- this took like 00:13:07.240 |
It's going to depend on the rate limits you have set on your OpenAI account. 00:13:11.240 |
Now we want to test retrieval without Cohere's re-ranking model first. 00:13:20.920 |
Again, I'm not going to go through everything. 00:13:23.560 |
I'm just going to return, for now, the top three records. 00:13:29.480 |
So my question is, can you explain why we would want to do reinforcement learning with 00:13:36.880 |
It's like a training method that is kind of like why ChatGPT was so good when it was released. 00:13:45.080 |
So I kind of want-- OK, why would I want to do that? 00:13:48.240 |
I think the first answer here-- and there's some-- the scraping that I did is not perfect, 00:13:57.920 |
But for the most part, I think we can read it. 00:14:00.200 |
So it's a powerful strategy for fine-tuning large language models, enabling significant 00:14:04.200 |
improvements in their performance, iteratively aligning the model's responses more closely 00:14:14.040 |
It can help fix issues of factuality, toxicity, and helpfulness that cannot be remedied by 00:14:22.680 |
OK, so I think that's a good answer, like number one there. 00:14:28.720 |
And then let's have a look at the second one-- increasingly popular technique for reducing 00:14:33.240 |
harmful behaviors, OK, can significantly change metrics-- doesn't necessarily tell me any 00:14:44.360 |
So the only relevant bit of information in this second sentence is increasingly popular 00:14:50.400 |
technique for reducing harmful behaviors, OK? 00:14:56.480 |
And then number three, I think-- like, I don't see anything in this that tells me why I should 00:15:06.560 |
It's telling me about RLHF, but isn't telling me why I'd actually want to use it. 00:15:25.940 |
So I'm going to come down to here, and we're going to initialize our reranking model. 00:15:30.720 |
So for that, we need another API key, which is Cohere's API key. 00:15:37.320 |
Like, the Pinecone and Cohere ones will be free. 00:15:39.800 |
The OpenAI one, I think, you need to pay a little bit. 00:15:45.080 |
But again, we'll be-- like I said, later on in this series, we'll be talking about other 00:15:51.240 |
alternatives to OpenAI for embedding models, which may actually be a fair bit better. 00:15:56.800 |
So I'm going to go to this website here, dashboard.cohere.com/api-keys. 00:16:03.860 |
You will probably need to sign up, make an account, and do all of that. 00:16:09.140 |
And then you will get to your Cohere dashboard, new trial key. 00:16:13.080 |
I'm going to call it something-- I don't know-- demo generate trial key. 00:16:29.240 |
So I'm just going to rerun the last results, because I only got three here. 00:16:48.920 |
So when we re-rank stuff, we're going to return this Cohere responses re-rank result object. 00:16:54.840 |
And we can access the text from those like this. 00:16:57.080 |
OK, so you can see we kind of get this output there. 00:17:01.560 |
And the way that I've set up the docs object that I returned from the last item here, you 00:17:09.760 |
can see it's a dictionary, where the text maps to the position. 00:17:14.160 |
The reason I've done that is so that I can just very quickly see what the reordered position 00:17:22.400 |
So you can see that, OK, it's kept the zero position, like the top result. 00:17:27.440 |
But then it's swapped out one and two for these two items here, OK? 00:17:35.600 |
It's basically just going to do everything we've just gone through. 00:17:42.680 |
And it's just going to compare the results for us. 00:17:45.520 |
So I'm going to set a top k of 25, so returning 25 records from our retrieval step. 00:17:52.720 |
And then we're just going to return the top three from our re-ranking step. 00:17:57.600 |
So I'm going to compare that query, so the RLHF query, OK? 00:18:06.000 |
One has been swapped for 23, and two has been swapped for 14. 00:18:10.640 |
So this won't show us the first results here, because they haven't changed. 00:18:17.700 |
So the original is what we went through before, where it has the one kind of useful bit of 00:18:24.300 |
information, increasingly popular technique for reducing harmful behaviors in large language 00:18:30.000 |
And then the rest wasn't really that relevant to our specific question, which is basically 00:18:38.240 |
Now having a look at 23, we've shown it's possible to use RLHF to train LLMs that acts 00:18:58.880 |
In other words, associated with aligning LLMs, RLHF improves helpfulness and harmlessness 00:19:11.840 |
Our alignment interventions actually enhance the capabilities of large models. 00:19:15.440 |
And yes, I think that's another reason, combined with training for specialized skills without 00:19:24.440 |
So this here is talking about RLHF like the previous number two ranked context, but it's 00:19:32.600 |
way more relevant to our specific question, which is that's why we use re-ranking models. 00:19:43.440 |
So this is-- yeah, this one, there was nothing relevant, right? 00:19:49.140 |
For our specific question, there wasn't anything relevant in here. 00:19:54.560 |
Just one thing here is like the LLMs are actually reading all of this text, which is kind of 00:20:04.680 |
So the model outputs are output safe responses. 00:20:07.080 |
I think that's-- assuming it's talking about RLHF is a good-- it's helpful. 00:20:12.520 |
We switch entirely to RLHF to teach the model how to write more nuanced responses. 00:20:21.200 |
Comprehensive tuning with RLHF has added the benefit that it may make the model more robust 00:20:29.880 |
We can do it to RLHF by first collecting human preferences-- it's not relevant-- annotators, 00:20:37.280 |
write a prompt they believe can elicit safe behavior, and then compare multiple model 00:20:43.800 |
responses to the prompts, selecting the responses that are safest according to a set of guidelines. 00:20:50.600 |
We use the human preference data to train a safety reward model, and-- OK. 00:20:55.740 |
So I think the relevant bits here are make the model more robust to jailbreak attempts, 00:21:02.960 |
and teach the model how to write more nuanced responses. 00:21:07.120 |
The rest of it isn't as relevant, but it's far more relevant than this one where it didn't 00:21:19.160 |
It's like a safety or security testing thing that they apply to LLMs now. 00:21:26.280 |
You can see that it hasn't changed the top one again. 00:21:31.360 |
And I think the responses here were generally not quite as obviously better with re-ranking, 00:21:39.560 |
What I will do is just kind of let you read those. 00:21:48.960 |
So again, you can pause and read through if you like. 00:22:00.120 |
At least I have found it just, you know, I don't have any specific metrics on how much 00:22:05.240 |
it helps, but just from using it in actual use cases, it helps quite a bit. 00:22:13.640 |
So I hope this is something that you can also use to sort of improve your retrieval pipelines, 00:22:20.520 |
particularly when you're using RAG and sending everything to LLMs. 00:22:24.040 |
But you should also test it and make sure it is actually helping. 00:22:28.160 |
So for example, if you're using a, maybe you're using kind of like an older re-ranking model, 00:22:34.920 |
the chances are it won't actually be quite as good as some of the more recent and better 00:22:42.120 |
So you could actually degrade performance if you do that. 00:22:45.420 |
So you always want to make sure that you're using kind of like state-of-the-art re-rankers 00:22:52.200 |
And you should see an impact kind of similar to what we saw here with the RLHF question. 00:22:57.960 |
But anyway, as I mentioned, this is like the first method I would use when trying to optimize 00:23:07.120 |
And as you can see, super easy to implement, it's, you know, you don't really need to modify 00:23:17.720 |
I hope this walkthrough has been useful and interesting. 00:23:22.240 |
Thank you very much for watching, and I will see you again in the next one.