back to index

RAG But Better: Rerankers with Cohere AI


Chapters

0:0 RAG and Rerankers
1:25 Problems of Retrieval Only
4:32 How Embedding Models Work
6:34 How Rerankers Work
8:20 Implementing Reranking in Python
13:11 Testing Retrieval without Reranking
15:21 Retrieval with Cohere Reranking
21:54 Tips for Reranking

Whisper Transcript | Transcript Only Page

00:00:00.000 | Driva Augmented Generation, or RAG, has become a little bit of an overloaded term.
00:00:06.240 | It promises quite a lot, but when we actually start implementing it, especially when we're
00:00:12.000 | new to doing this stuff, the results are sometimes amazing, but more often than not, kind of
00:00:18.320 | not as good as what we were expecting.
00:00:21.140 | And that is because RAG, as with most tools, is very easy to get started with, but then
00:00:27.380 | it's very hard to actually get good at implementing.
00:00:30.640 | The truth is that there is a lot more to RAG than just putting documents into a vector
00:00:36.320 | database and then retrieving documents from that vector database and putting them into
00:00:40.560 | an LLM.
00:00:41.560 | In order to make the most out of RAG, you have to do a lot of other things as well.
00:00:46.080 | So that's why we're starting this series on how to do RAG better.
00:00:52.160 | In this first video, we're going to be looking at how to do re-ranking, which is probably
00:00:58.420 | the easiest and fastest way to make a RAG pipeline better.
00:01:03.320 | Now I'm going to be talking throughout this entire series within the context of RAG and
00:01:08.040 | LLMs, but in reality, this can be applied to retrieval as a whole.
00:01:12.620 | If you have a semantic search application, or maybe even recommendation systems, you
00:01:17.880 | can actually apply not all, but a lot of what we're going to be talking about throughout
00:01:22.260 | the series, including re-ranking, which we'll go through today.
00:01:25.740 | So before jumping into the solution of re-ranking, I'm going to talk a little bit about the problem
00:01:31.320 | that we face with just retrieval as a whole, and then specific to LLMs.
00:01:36.980 | So to begin with retrieval, to ensure fast search times, we use something called vector
00:01:42.540 | search.
00:01:43.540 | That is, we transform our text into vectors, place them all into a vector space, and then
00:01:47.780 | compare their proximity to what we call a query vector, which is just a vector version
00:01:52.680 | of some sort of query, and see which ones are the closest together, and we return them.
00:01:58.060 | Now for vector search to work, we need vectors, which are essentially just compressed representations
00:02:04.660 | of semantic meaning behind that text.
00:02:08.380 | Because we're compressing that information into a single vector, we will naturally lose
00:02:12.740 | some information, but that is the cost of vector search, and for the most part, it's
00:02:18.220 | definitely worth paying.
00:02:19.900 | Vector search can give us very good results.
00:02:22.420 | But what I tend to find with vector search and RAG with LLMs is that, okay, I get some
00:02:28.560 | good results at the top, but there's actually another result in, let's say, position 17,
00:02:33.260 | for example, that actually provides some very relevant context for the question that I have
00:02:38.300 | asked.
00:02:39.740 | So in this example, let's say this is position 17 down here.
00:02:44.540 | We have that relevant item, but what we would typically do when we're doing RAG with LLMs
00:02:50.320 | is we're returning the top three items.
00:02:54.140 | So we're missing out on these other relevant records down here.
00:02:58.260 | So what can we do?
00:03:00.860 | The simplest is simply to just return everything, and send all of these into our LLM.
00:03:08.800 | So over here, we have our LLM.
00:03:11.780 | Now that's okay, but LLMs have limited context windows.
00:03:17.260 | So we're going to end up filling that context window very quickly if we just start returning
00:03:23.100 | everything.
00:03:24.100 | So we want to return all of this, so we want to return a lot of records so that we have
00:03:28.300 | high retrieval recall, but then we want to limit the number of records we actually send
00:03:35.060 | to our LLM.
00:03:36.900 | And that's where re-ranking would come in.
00:03:40.840 | So by adding a re-ranker, we can still use all of those records, we still get to return
00:03:46.300 | all of these from our retrieval component, but then the records that we actually send
00:03:51.360 | to our LLM are just these here, these top three.
00:03:55.560 | And the re-ranker has gone ahead and handled the reordering of our records to get the most
00:04:01.640 | relevant items at the top, so we can then send all of that to our LLM.
00:04:08.400 | Now the question here is, is a re-ranker really going to help us here?
00:04:12.360 | Can we not just use a better retrieval model?
00:04:16.440 | And yes, we can use a better retrieval model, and that's something we'll be talking about
00:04:20.200 | in a future video.
00:04:22.480 | But there is a very good reason as to why a re-ranker can generally perform better than
00:04:29.280 | a encoder model or retrieval model.
00:04:32.560 | So let's talk about that very quickly.
00:04:35.320 | This is what an encoder model is doing.
00:04:37.080 | So this is an encoder/retriever.
00:04:39.680 | So this is like your ARDA002.
00:04:43.760 | Now what it's doing is we have a transformer model.
00:04:48.600 | So these are the same transformer model.
00:04:52.160 | The reason that I've got two of them on the screen right now is because you use your first
00:04:57.680 | iteration or inference step of the transformer model to create your embedding for document
00:05:03.200 | A. And from that you get your vector A. So that is the compressed information that we
00:05:09.880 | can then take across to our vector database, which would kind of be like this point here.
00:05:15.920 | That's in our vector space.
00:05:17.880 | And then in another inference step, we're going to do the same for document B. We get
00:05:22.200 | vector B, and there we go.
00:05:24.560 | We have that in our vector search, and we can then compare the proximity of those two
00:05:28.960 | records to get the similarity.
00:05:32.480 | The metric that we'd be using here, the computation, would be either dot product or cosine in the
00:05:40.240 | case of ARDA002.
00:05:42.400 | Now you have to consider that the computational complexity of something like cosine similarity
00:05:49.360 | is much simpler than one of these transformer inference steps.
00:05:55.240 | So the reason that we use this encoder architecture is that we can do all of the transformer inferences
00:06:03.160 | at the start, when we're building our index, that takes a long time because transformers
00:06:08.220 | are big, heavy things.
00:06:09.600 | They take a lot of computation.
00:06:11.320 | Whereas the cosine similarity step at the end, which we can run at the time when our
00:06:16.500 | user is making a query, is very fast.
00:06:19.640 | So it's kind of like we're doing the heavy part of the computation to compare documents
00:06:25.160 | at the very start of building the index.
00:06:28.880 | And that means we can do very quick, simple computations at user query time.
00:06:34.320 | And that is different to what we do re-ranking.
00:06:37.840 | So here, this transformer is our re-ranker.
00:06:41.160 | And at query time, right, so let's say document A here, maybe that's our query.
00:06:49.940 | And document B is one of the documents in the database.
00:06:53.380 | We're saying to the transformer, okay, how similar are these two items?
00:06:59.520 | So to compare the similarity in this case, we're running an entire transformer inference
00:07:05.780 | step.
00:07:06.780 | And notice, because we're doing everything in a single transformer step, we're not losing
00:07:11.680 | as much information as we are with this one, where we're compressing everything into vectors.
00:07:16.880 | That means that theoretically, we lose less information, so we can get a more accurate
00:07:21.380 | similarity score here.
00:07:23.060 | But at the same time, it's way slower.
00:07:25.760 | So it's kind of like, you know, on one side, you have fast and, you know, relatively accurate.
00:07:33.320 | And then on this side, you have slow, but super accurate.
00:07:37.240 | So the idea with the sort of re-ranking approach to retrieval is that we use our retrieval
00:07:46.100 | encoder step to basically filter down the total number of documents to just, you know,
00:07:53.580 | in this example, let's say there's like 25 documents there.
00:07:57.680 | 25 documents is not too much.
00:07:59.880 | So feeding them into our re-ranker is actually going to be very fast.
00:08:03.200 | Whereas if we fed all documents into our re-ranker, we'd be waiting, I don't know, like a really
00:08:10.260 | long time, which we don't want to do.
00:08:12.600 | So instead, we filter down the encoder, feed them into the re-ranker, and then we'll get
00:08:17.920 | like three amazing results super quickly.
00:08:21.080 | So that is how the re-ranking approach works.
00:08:27.180 | Let's see how we'd actually implement that in Python.
00:08:29.920 | Okay.
00:08:30.920 | So we're going to be working through this notebook here.
00:08:34.800 | We need HookingFace datasets, that's going to be where we get our dataset from, OpenAI
00:08:41.240 | for creating our embeddings, Pinecone for storing those embeddings, and Cohere for our
00:08:46.920 | re-ranker.
00:08:47.920 | We're going to start by downloading our dataset, which is this AI archive.
00:08:52.160 | It's pre-chunked, so I've already chunked it into like tokens of 300, I think, something
00:08:57.440 | like that.
00:08:59.400 | And it's basically just a dataset of archive papers.
00:09:04.180 | You can kind of see a few of them here that are related to LLMs.
00:09:07.660 | Essentially, I gathered it by taking some recent papers that are well-known, like LLAMA
00:09:13.820 | 2 paper, GPT-4 paper, GPT-Q, and so on, and just extracting that, extracting what that
00:09:20.960 | was referencing, and extracting those papers, and kind of just going in a loop through that.
00:09:27.460 | So yeah, we have a fair few records in there.
00:09:31.300 | It's not huge, but it's not small either, so 41.5,000 chunks, where each chunk is roughly
00:09:39.260 | this size.
00:09:42.020 | So I'm just going to reformat the data into the format we need.
00:09:45.160 | This is basically like a pinecone format.
00:09:48.020 | You have ID, text, which we're going to convert into embeddings, and metadata.
00:09:53.280 | We're not going to use metadata in this example, but it can be useful, and maybe it's something
00:09:58.460 | that we'll look at in a future video in this series as well.
00:10:03.120 | So we need to define our embedding function.
00:10:07.020 | So we need to define that encoder model that we're going to be using.
00:10:10.900 | For that, I'm going to be using OpenAI.
00:10:13.100 | It's easy, R002, fairly good performance, although there are better models, and that's
00:10:18.140 | something we will also be talking about in the future.
00:10:21.280 | So I'm going to just run that, and I will need to enter my OpenAI API key.
00:10:27.740 | To get that, you need to head on over to platform.openai.com, and get your API key.
00:10:36.300 | I'm going to enter mine in here, and yeah.
00:10:40.460 | So with that, we should be able to initialize our embedding model, which we are doing here.
00:10:46.640 | I'm not going to go through all these functions, because I've done it a million times before.
00:10:52.280 | I think people are probably getting bored of that part of these videos.
00:10:57.620 | So I'm just going to run through those bits very quickly.
00:11:01.640 | I'm going to get my pinecone credentials, again, app.pinecone.io for those, and I will
00:11:09.000 | run that, enter my API key first, and then I want my PyCone environment, which I find
00:11:17.240 | next to my API key in the console.
00:11:19.860 | So mine was this.
00:11:22.080 | Yours would probably be like gcpsarter or something along those lines.
00:11:26.280 | OK, cool.
00:11:27.280 | So here, I'm going to create an index if it doesn't already exist.
00:11:30.520 | My index does actually already exist, and I'm not going to recreate it, because it takes
00:11:35.160 | a little bit of time, or at least it did the other day when creating this.
00:11:39.120 | So you can see that I already have like the 41,000 records in there.
00:11:45.320 | If you're looking at that, you should probably see nothing in yours, unless you've just run
00:11:50.600 | this or you're connecting to an existing index.
00:11:55.320 | So this is the code I use to create my index, right?
00:11:59.720 | It's pretty straightforward.
00:12:01.060 | The one thing that is maybe a little more complicated, but it's not that complicated,
00:12:06.480 | is we're actually creating the embeddings here.
00:12:10.200 | So I think I defined an embedding function up here, actually, and I ended up not using
00:12:14.040 | it for some reason.
00:12:16.220 | Just ignore that.
00:12:17.400 | So in here, this is where we're doing our embeddings, but we're wrapping it within an
00:12:22.620 | exponential backoff function to avoid rate lump errors, which I was hitting a lot the
00:12:27.680 | other day.
00:12:28.760 | So essentially, it's going to try and embed.
00:12:32.020 | If it gets a rate limit error, it's going to wait.
00:12:36.440 | And it's going to keep doing that for a maximum of five retries.
00:12:39.520 | Hopefully, you shouldn't be hitting five retries.
00:12:43.200 | If so, there's probably something wrong.
00:12:45.620 | So yeah, you should be OK there.
00:12:49.200 | But if you are hitting those rate limit errors, you might be waiting a little bit of time
00:12:52.940 | for this to finish.
00:12:54.860 | If not, it should finish quite quickly.
00:12:56.880 | I was hitting tons of rate limit errors the other day, and I ended up-- this took like
00:13:02.400 | 40 minutes, I think.
00:13:04.200 | So yeah, just be aware of that.
00:13:07.240 | It's going to depend on the rate limits you have set on your OpenAI account.
00:13:11.240 | Now we want to test retrieval without Cohere's re-ranking model first.
00:13:16.340 | So I'm going to ask this question.
00:13:18.920 | So get docs.
00:13:19.920 | Yeah, I'm just querying.
00:13:20.920 | Again, I'm not going to go through everything.
00:13:23.560 | I'm just going to return, for now, the top three records.
00:13:29.480 | So my question is, can you explain why we would want to do reinforcement learning with
00:13:33.800 | human feedback?
00:13:35.160 | That's what this is here.
00:13:36.880 | It's like a training method that is kind of like why ChatGPT was so good when it was released.
00:13:45.080 | So I kind of want-- OK, why would I want to do that?
00:13:48.240 | I think the first answer here-- and there's some-- the scraping that I did is not perfect,
00:13:55.640 | so I apologize for that.
00:13:57.920 | But for the most part, I think we can read it.
00:14:00.200 | So it's a powerful strategy for fine-tuning large language models, enabling significant
00:14:04.200 | improvements in their performance, iteratively aligning the model's responses more closely
00:14:08.760 | with human expectations and preferences.
00:14:14.040 | It can help fix issues of factuality, toxicity, and helpfulness that cannot be remedied by
00:14:19.720 | simply scaling up LMs.
00:14:22.680 | OK, so I think that's a good answer, like number one there.
00:14:28.720 | And then let's have a look at the second one-- increasingly popular technique for reducing
00:14:33.240 | harmful behaviors, OK, can significantly change metrics-- doesn't necessarily tell me any
00:14:42.320 | benefits there, OK?
00:14:44.360 | So the only relevant bit of information in this second sentence is increasingly popular
00:14:50.400 | technique for reducing harmful behaviors, OK?
00:14:53.840 | So just one little bit there.
00:14:56.480 | And then number three, I think-- like, I don't see anything in this that tells me why I should
00:15:03.220 | use RLHF.
00:15:06.560 | It's telling me about RLHF, but isn't telling me why I'd actually want to use it.
00:15:11.100 | So these results could be better, all right?
00:15:14.520 | So number one, good.
00:15:15.920 | Number two, it's kind of relevant.
00:15:18.480 | Number three, not so much.
00:15:20.480 | So can we get better than that?
00:15:23.520 | Yes, we can.
00:15:24.520 | We just need to use reranking.
00:15:25.940 | So I'm going to come down to here, and we're going to initialize our reranking model.
00:15:30.720 | So for that, we need another API key, which is Cohere's API key.
00:15:36.320 | This should be free.
00:15:37.320 | Like, the Pinecone and Cohere ones will be free.
00:15:39.800 | The OpenAI one, I think, you need to pay a little bit.
00:15:42.560 | So yeah, just be aware of that.
00:15:45.080 | But again, we'll be-- like I said, later on in this series, we'll be talking about other
00:15:51.240 | alternatives to OpenAI for embedding models, which may actually be a fair bit better.
00:15:56.800 | So I'm going to go to this website here, dashboard.cohere.com/api-keys.
00:16:03.860 | You will probably need to sign up, make an account, and do all of that.
00:16:09.140 | And then you will get to your Cohere dashboard, new trial key.
00:16:13.080 | I'm going to call it something-- I don't know-- demo generate trial key.
00:16:19.800 | OK, and I'm going to put it into here.
00:16:23.760 | Cool.
00:16:25.140 | So we now want to rerank stuff.
00:16:28.240 | Let's try.
00:16:29.240 | So I'm just going to rerun the last results, because I only got three here.
00:16:34.640 | I'm going to rerun it with 25.
00:16:38.240 | So yeah, we have many more now.
00:16:41.240 | And I'm just going to re-rank those 25.
00:16:42.960 | And I want to see what was re-ranked.
00:16:46.320 | I just want to compare those results.
00:16:48.920 | So when we re-rank stuff, we're going to return this Cohere responses re-rank result object.
00:16:54.840 | And we can access the text from those like this.
00:16:57.080 | OK, so you can see we kind of get this output there.
00:17:01.560 | And the way that I've set up the docs object that I returned from the last item here, you
00:17:09.760 | can see it's a dictionary, where the text maps to the position.
00:17:14.160 | The reason I've done that is so that I can just very quickly see what the reordered position
00:17:20.960 | after re-ranking is.
00:17:22.400 | So you can see that, OK, it's kept the zero position, like the top result.
00:17:27.440 | But then it's swapped out one and two for these two items here, OK?
00:17:33.080 | So I'm going to define this function here.
00:17:35.600 | It's basically just going to do everything we've just gone through.
00:17:38.680 | It's going to query, get those results.
00:17:41.200 | It's going to then re-rank everything.
00:17:42.680 | And it's just going to compare the results for us.
00:17:45.520 | So I'm going to set a top k of 25, so returning 25 records from our retrieval step.
00:17:52.720 | And then we're just going to return the top three from our re-ranking step.
00:17:57.600 | So I'm going to compare that query, so the RLHF query, OK?
00:18:02.980 | So zero has remained the same.
00:18:06.000 | One has been swapped for 23, and two has been swapped for 14.
00:18:10.640 | So this won't show us the first results here, because they haven't changed.
00:18:14.960 | So we're looking at these results first.
00:18:17.700 | So the original is what we went through before, where it has the one kind of useful bit of
00:18:24.300 | information, increasingly popular technique for reducing harmful behaviors in large language
00:18:28.660 | models.
00:18:30.000 | And then the rest wasn't really that relevant to our specific question, which is basically
00:18:34.640 | why would I want to use RLHF?
00:18:38.240 | Now having a look at 23, we've shown it's possible to use RLHF to train LLMs that acts
00:18:47.400 | as helpful and harmless assistants, OK?
00:18:49.440 | So that's useful, OK?
00:18:51.240 | That's why we might want to use it.
00:18:53.040 | RLHF training also improves honesty, OK?
00:18:57.240 | That's another reason to use it.
00:18:58.880 | In other words, associated with aligning LLMs, RLHF improves helpfulness and harmlessness
00:19:04.360 | by a huge margin, OK?
00:19:06.860 | Another reason why we might want to use it.
00:19:08.760 | So OK, three good reasons already.
00:19:11.840 | Our alignment interventions actually enhance the capabilities of large models.
00:19:15.440 | And yes, I think that's another reason, combined with training for specialized skills without
00:19:20.280 | degradation in alignment or performance.
00:19:23.000 | Another reason why we should use it, right?
00:19:24.440 | So this here is talking about RLHF like the previous number two ranked context, but it's
00:19:32.600 | way more relevant to our specific question, which is that's why we use re-ranking models.
00:19:40.160 | Now let's have another look.
00:19:43.440 | So this is-- yeah, this one, there was nothing relevant, right?
00:19:47.560 | So this is the original.
00:19:49.140 | For our specific question, there wasn't anything relevant in here.
00:19:52.440 | The re-ranked one has this.
00:19:54.560 | Just one thing here is like the LLMs are actually reading all of this text, which is kind of
00:19:59.920 | impressive.
00:20:00.920 | I really struggle to, but anyway.
00:20:04.680 | So the model outputs are output safe responses.
00:20:07.080 | I think that's-- assuming it's talking about RLHF is a good-- it's helpful.
00:20:12.520 | We switch entirely to RLHF to teach the model how to write more nuanced responses.
00:20:18.100 | So that's a good reason.
00:20:21.200 | Comprehensive tuning with RLHF has added the benefit that it may make the model more robust
00:20:27.360 | to jailbreak attempts.
00:20:28.680 | Another benefit.
00:20:29.880 | We can do it to RLHF by first collecting human preferences-- it's not relevant-- annotators,
00:20:37.280 | write a prompt they believe can elicit safe behavior, and then compare multiple model
00:20:43.800 | responses to the prompts, selecting the responses that are safest according to a set of guidelines.
00:20:50.600 | We use the human preference data to train a safety reward model, and-- OK.
00:20:55.740 | So I think the relevant bits here are make the model more robust to jailbreak attempts,
00:21:02.960 | and teach the model how to write more nuanced responses.
00:21:05.820 | So those two are good.
00:21:07.120 | The rest of it isn't as relevant, but it's far more relevant than this one where it didn't
00:21:11.280 | tell us any benefits using RLHF.
00:21:14.120 | Cool.
00:21:15.360 | Now let's try one more.
00:21:17.760 | So what is red teaming?
00:21:19.160 | It's like a safety or security testing thing that they apply to LLMs now.
00:21:24.080 | It's like stress testing for LLMs.
00:21:26.280 | You can see that it hasn't changed the top one again.
00:21:31.360 | And I think the responses here were generally not quite as obviously better with re-ranking,
00:21:37.520 | but still slightly better.
00:21:39.560 | What I will do is just kind of let you read those.
00:21:41.600 | So you have this one here.
00:21:43.360 | You can pause and read through if you want.
00:21:46.360 | And also this one as well.
00:21:48.960 | So again, you can pause and read through if you like.
00:21:53.280 | I'm not going to go through all those again.
00:21:55.140 | So that is re-ranking.
00:21:56.960 | I think it's pretty clear.
00:21:59.120 | It can help a lot.
00:22:00.120 | At least I have found it just, you know, I don't have any specific metrics on how much
00:22:05.240 | it helps, but just from using it in actual use cases, it helps quite a bit.
00:22:13.640 | So I hope this is something that you can also use to sort of improve your retrieval pipelines,
00:22:20.520 | particularly when you're using RAG and sending everything to LLMs.
00:22:24.040 | But you should also test it and make sure it is actually helping.
00:22:28.160 | So for example, if you're using a, maybe you're using kind of like an older re-ranking model,
00:22:34.920 | the chances are it won't actually be quite as good as some of the more recent and better
00:22:40.800 | encoder models.
00:22:42.120 | So you could actually degrade performance if you do that.
00:22:45.420 | So you always want to make sure that you're using kind of like state-of-the-art re-rankers
00:22:50.000 | alongside state-of-the-art encoders.
00:22:52.200 | And you should see an impact kind of similar to what we saw here with the RLHF question.
00:22:57.960 | But anyway, as I mentioned, this is like the first method I would use when trying to optimize
00:23:03.880 | an existing retrieval pipeline.
00:23:07.120 | And as you can see, super easy to implement, it's, you know, you don't really need to modify
00:23:12.120 | other parts of the pipeline.
00:23:13.440 | You just need to put this into the middle.
00:23:15.400 | So I'll leave it there for now.
00:23:17.720 | I hope this walkthrough has been useful and interesting.
00:23:22.240 | Thank you very much for watching, and I will see you again in the next one.
00:23:25.780 | [MUSIC PLAYING]
00:23:29.140 | [MUSIC PLAYING]
00:23:33.060 | [MUSIC PLAYING]
00:23:36.580 | [MUSIC PLAYING]