back to indexMedical Search Engine with SPLADE + Sentence Transformers in Python
Chapters
0:0 Hybrid search for medical field
0:18 Hybrid search process
2:42 Prerequisites and Installs
3:26 Pubmed QA data preprocessing step
8:25 Creating dense vectors with sentence-transformers
10:30 Creating sparse vector embeddings with SPLADE
18:12 Preparing sparse-dense format for Pinecone
21:2 Creating the Pinecone sparse-dense index
24:25 Making hybrid search queries
29:59 Final thoughts on sparse-dense with SPLADE
00:00:00.000 |
Today we're going to take a look at how to implement a hybrid search engine using both 00:00:11.180 |
This is going to be very hands on, so I'm going to outline the architecture of this 00:00:15.820 |
thing and then we'll jump straight into building it. 00:00:18.360 |
So as for the process that we're going to be going through here, we're going to start 00:00:25.420 |
So we have our data over here and it's just going to be paragraphs of text, like loads 00:00:32.760 |
What we're going to have to do is we're going to have to create a chunking mechanism. 00:00:37.400 |
So when we are encoding these chunks of text or just general information, we can't encode 00:00:44.980 |
too much information in any one go, depending on which models you're using. 00:00:50.900 |
So if you're using Cohere or OpenAI, the chunks that you can use will be much larger. 00:00:56.220 |
We're going to be using a sentence transformer on one side and then Splayed on the other 00:01:01.900 |
Splayed and the sentence transformer are somewhat limited. 00:01:04.520 |
You're definitely not going to go over around 500 tokens. 00:01:08.620 |
And with the Dennis Vector Embedding model, the sentence transformer, I believe it's like 00:01:15.100 |
So you think that's probably around a thousand characters at most. 00:01:21.980 |
So we're going to take that through a chunking mechanism to just break it apart into chunks 00:01:28.900 |
And then those chunks of text, we're going to take them on one side, we're going to feed 00:01:34.380 |
them into Splayed, which is going to create our sparse vectors. 00:01:42.020 |
And then on the other side, we have our sentence transformer. 00:01:55.780 |
With both of those, we're going to take them both and we're going to create with both of 00:02:00.260 |
them a sparse Dennis Vector, which we then take into Pinecone. 00:02:05.380 |
So that's our vector database where we're going to store everything. 00:02:08.460 |
And then what we can do is we can ask like a question, we ask a question up here. 00:02:13.820 |
That gets encoded again by Splayed, so since this is Splayed, this is our dense embedding 00:02:21.660 |
Put those both together to create a sparse Dennis Vector. 00:02:27.340 |
And we actually take it over here, feed it into Pinecone and get a ton of responses based 00:02:34.460 |
on both the sparse and the dense information. 00:02:44.380 |
The first thing, we're using transformers, sentence transformers here. 00:02:47.620 |
That means we're going to be creating these embeddings on our local machine. 00:02:52.820 |
In this case, it's actually Colab, but I'll call it local rather than the alternative, 00:03:02.860 |
We need to change runtime type and we need to make sure that we're using a GPU here. 00:03:10.420 |
And these are just the libraries that we're going to be using. 00:03:12.940 |
So HungFace datasets for our data, obviously transformers, sentence transformers for our 00:03:20.060 |
This is a dense encoder, our database, and also our SPLADE model. 00:03:27.060 |
So we're going to first load the PubMedQA dataset. 00:03:31.420 |
So this is a medical question answering dataset. 00:03:34.060 |
So with medical things, you'll find that there's a lot of kind of specific terminology and it's 00:03:40.700 |
within that sort of domain that models like SPLADE or just general sparse embedding models 00:03:49.660 |
However, if you are able to train your sentence transformer, your dense embedding model on 00:03:54.420 |
the dataset, then in that case, you can actually improve the performance to beyond that of 00:04:06.900 |
So we just have our ID here, and then we have this data, this context, and we have all these 00:04:14.140 |
And what we're going to need to do with these paragraphs is put them all together and then 00:04:17.820 |
chunk them into the smaller chunks that will fit into our models. 00:04:23.020 |
So I think we mentioned here, we're going to, yeah, into digestible chunks for our models. 00:04:28.980 |
We are going to be using BERT, which has, you know, the default BERT model has this 00:04:33.720 |
max sequence length, 512 tokens, which is fairly big, but your typical sentence transformer 00:04:43.320 |
So we're going to be pretty naive in our assumptions here, but we're going to just assume this 00:04:48.780 |
128 token limit, and we're actually going to assume that the average token length is 00:04:55.820 |
We should realistically actually create our tokens and then count the number of tokens 00:05:01.380 |
in there, but it's just more complicated logic, and I want this to be as simple as possible. 00:05:06.220 |
So we're just doing this for now, but if you're interested in that, let me know and I can 00:05:09.560 |
send you a link to some code that does actually do that. 00:05:13.300 |
So to create these chunks, we're going to create a processing function called chunker, 00:05:18.180 |
We're just going to feed in that list of context that we've got up here. 00:05:21.660 |
So this literally, this list here, and what it's going to do is join them and then split 00:05:29.220 |
So we're going to create our chunks at a sentence level. 00:05:33.300 |
So what we do is we loop through each sentence in here, and we say, okay, if we add it to 00:05:40.060 |
the chunk here, and once the length of that exceeds our limit, which is here, the limit 00:05:48.060 |
is 384 tokens, we will say, okay, we've got enough here, we're not going to go any higher, 00:05:58.580 |
So here, what we're doing is, let's say we have four sentences, or no, five sentences 00:06:06.340 |
What we're going to do is so that we're not cutting off between sentences that are relevant, 00:06:12.780 |
like have some continuum logic between them, what we're going to do is between each of 00:06:16.820 |
our chunks, we're actually going to have some overlap. 00:06:19.500 |
So let's say we take chunks zero to four, and then what we're going to do for the next 00:06:24.980 |
chunk is take chunks two to seven, or something like that. 00:06:30.900 |
So there's always a bit of overlap between the chunks. 00:06:34.340 |
So once we are done, and we get to the end of our sentences, we might still have a smaller 00:06:40.340 |
chunk that's left over, so we just append that to our chunks list. 00:06:48.180 |
So let's run that, and we'll apply it to our first context. 00:06:52.500 |
All right, and then we get these smaller chunks now, okay? 00:06:59.060 |
We're not just splitting in the middle of a sentence. 00:07:01.220 |
But one thing you will also notice is, like here, it says, "The leaf server plan consists 00:07:06.340 |
of a lattice work of," and we also have that here, right? 00:07:11.680 |
So we always have, like, basically half the chunk is overlapped. 00:07:19.300 |
Depending on what you're doing, you can minimize that. 00:07:22.740 |
You should realistically have some overlap there, so you're not cutting between sentences 00:07:36.540 |
So that's why we have those overlaps in there. 00:07:38.900 |
This is probably a more reasonable one, so you have all this, and then the overlap starts 00:07:47.180 |
So what we want to do is give each chunk a unique ID. 00:07:52.020 |
So we're using the pub_id here, followed by the chunk number, okay? 00:08:06.820 |
So we're going to go through the entire PubMed dataset here, we're going to get the context, 00:08:14.700 |
Again, we're using that PubMed ID and the chunk number. 00:08:26.060 |
Now what I want to do is move on to creating our vectors. 00:08:29.540 |
All right, so the first one I'm going to do is the dense vectors. 00:08:36.740 |
And the first thing we want to do is make sure that we're using CUDA if it's available, 00:08:41.500 |
otherwise you can use CPU, it's just going to be slower. 00:08:44.220 |
It's not going to be too slow, it's not a huge dataset that we're processing here, but 00:08:50.220 |
And the model that we're using is this base model that has been trained on MS Marco, which 00:08:59.580 |
And specifically, so this is important, it has been trained to use dot product similarity. 00:09:05.300 |
And we need that for it to function with the sparse dense vectors that we are putting into 00:09:13.460 |
So they're basically, they're compared in a dot product similarity space. 00:09:24.820 |
So we see the sentence transformer details here, and we can actually see here that the 00:09:29.140 |
max sequence length for this sentence transformer is 512 tokens. 00:09:33.180 |
So early on when we went for the 128 token limit, with this one, we can actually do 512. 00:09:43.020 |
So I think we set like 380 something for the character limit. 00:09:47.540 |
With this, we could actually set like 1,500, which is quite a bit more. 00:09:55.100 |
But anyway, we'll stick with what we have because with a lot of sentence transformers, 00:10:05.340 |
So we have our dense model, we encode, and then we pass in our data. 00:10:10.900 |
And we'll get a, we'll see in a moment, 768 dimensional dense vector. 00:10:19.980 |
You can also see that in the model, get sentence embedding dimension here as well. 00:10:26.100 |
We'll need this when we're actually initializing our vector index later. 00:10:30.060 |
So moving on to the sparse vectors, we're using the splayed co-condenser assembled distil. 00:10:37.600 |
So it's basically like an efficient splayed model. 00:10:51.840 |
So it's basically how it's creating its single vectors from the many vectors I initially 00:10:59.900 |
And I created a video on splayed so you can go and take a look at that if you're interested. 00:11:05.100 |
There'll be a link to that in the video at the top somewhere. 00:11:09.140 |
So it takes tokenized inputs that need to be built with a tokenizer initialized with 00:11:26.780 |
And then to create our sparse vectors, we do this. 00:11:29.140 |
So we're saying torch no grad, which basically means like, don't calculate the gradients 00:11:37.380 |
And we only need that for training the model. 00:11:39.660 |
Right now we're just performing inference or prediction. 00:11:44.620 |
And what we do is we move the tokens to CUDA if we're using it. 00:11:51.280 |
So the reason we move to CUDA is because if we don't, the tokens feeding into the model 00:11:57.580 |
are on CPU and the model is on GPU, we're going to see an error. 00:12:01.780 |
So we need to make sure we include that in there. 00:12:04.420 |
And then here is the splayed vector representations output by the model. 00:12:10.580 |
And we use squeeze to reduce the dimensionality of that vector. 00:12:14.640 |
So initially it's like, I think it's like 30,000 comma one, the shape. 00:12:26.020 |
So that gives us this dimensional vector, which is huge, right? 00:12:34.660 |
So that is actually the vocab size of the BERT model. 00:12:39.100 |
So every token that BERT recognizes is represented by one of these values. 00:12:45.540 |
And essentially we're creating a score for each one of those tokens through splayed, 00:12:58.060 |
Now to create the data format that we'll be feeding into Pinecone, it's essentially going 00:13:05.460 |
to be like a dictionary of the position of the nonzero values to the scores that they 00:13:20.940 |
So here we can see we have 174 nonzero values here, should say that as well. 00:13:37.500 |
So we come up to here and we have our indices. 00:13:40.940 |
So at position number 1,000, the score of that token is this, right? 00:13:49.620 |
And I think I have a little example of what that actually means here. 00:13:53.880 |
We don't need to do this for processing things by Pinecone. 00:13:57.480 |
We are just doing this so that we can understand what this actually means. 00:14:05.820 |
So like I said, all of those 30.5,000 values in that vector that was output by splayed, 00:14:18.580 |
And in this, these tokens are just numbers because that's what the transform model splayed 00:14:30.540 |
So this is mapping those positions, those integer values to the actual text tokens that 00:14:41.300 |
And we process the dictionary that we just created up here through that. 00:15:03.180 |
Let's just have a look at what this is and then we'll see if it makes sense. 00:15:10.700 |
So program cell death is a regular death of cells of an organism. 00:15:20.080 |
Lattice work of longitudinal and transverse veins, including areoles. 00:15:26.920 |
You know, I don't know what any of that means, but we can at least see that in this sparse 00:15:32.160 |
dictionary we have, so we have PC, which is, I think this is like, it's coming from here. 00:15:44.540 |
We have this up here, Madagascar, I don't know where that's coming from. 00:15:55.120 |
And then I think we should also have some other words in here that are not actually 00:16:00.240 |
from this, because what SPLADE does is actually identifies the words that are in the vector 00:16:05.680 |
already, or within this, it identifies the most important words, okay? 00:16:11.040 |
So I would say it's probably got that right with, like, lace programmed, the PC, and the 00:16:19.400 |
And death, lattice, cell, all those are probably the most important words in here. 00:16:24.400 |
It's not giving us the word the, or the word within, right? 00:16:28.300 |
Because it doesn't view those as being what are important. 00:16:32.120 |
But if we go down, we'll probably see, we'll probably see some words that are not actually 00:16:42.160 |
Because part of what SPLADE does is it expands, it does term expansion, which basically means 00:16:48.640 |
based on the words it sees, it adds other words that it doesn't see, but that we might 00:16:54.760 |
expect a document that we're searching for to contain. 00:16:59.640 |
So I think the word, okay, so the word die, I don't think is in here, right? 00:17:13.100 |
Lacy, it's probably not, so we have lace plant, all right, so lacy is in there, I don't know 00:17:19.920 |
if that is actually relevant, I don't understand any of what this says. 00:17:24.780 |
We have plant and plants, I wonder if both of those are in there. 00:17:29.540 |
So we've got plant, plant, okay, we don't have plants, right? 00:17:35.840 |
So imagine in your document that this, well, actually, this is a document. 00:17:40.980 |
Let's say in the query, the user is searching for program cell death in plants, or how do 00:17:50.660 |
They would have the term die and plants in there, but they wouldn't have the term death 00:17:57.220 |
So that's why the term expansion is really useful, because then you do have that term 00:18:02.220 |
overlap, which is what traditional sparse vector methods kind of lack, so like BM25. 00:18:09.820 |
They don't have that automatic term expansion. 00:18:12.740 |
So we create our sparse vectors, or we have seen how to create our dense vectors and seen 00:18:19.860 |
Now let's have a look at how we do this for everything. 00:18:23.820 |
So we're going to create a help function called builder, which is first going to transform 00:18:29.500 |
a list of records from our data, so the context, into this format here. 00:18:35.580 |
So this is the format that we're going to be feeding into PyCone, right? 00:18:38.940 |
So we have our ID, we have our dense vector here, we have our sparse vector in the dictionary 00:18:44.840 |
format that we saw already, and then we have this metadata. 00:18:48.300 |
Metadata is just additional information that we can attach to our vectors. 00:18:52.360 |
In this case, I'm going to include the text, like a human readable text. 00:19:02.460 |
This is just going to go through everything, right? 00:19:07.520 |
So we get our IDs from the records that we have there, so we have our IDs. 00:19:12.820 |
So records is just everything, I believe, yeah. 00:19:16.600 |
So records is everything nowadays, so it's going to extract the IDs for everything, and 00:19:20.620 |
then it's going to extract the context, right? 00:19:24.200 |
So that's why we have the pub ID followed by the chunk number. 00:19:30.580 |
And then we have those kind of smaller sentence, couple sentences, chunks of text. 00:19:36.480 |
And then from those chunks of text, what we're going to do is we're going to encode everything. 00:19:43.580 |
That creates our dense vectors, then we're going to create our sparse vectors, so we 00:19:47.620 |
get our, what is this bit, so input IDs, that's creating our tokens, and then we process our 00:19:54.120 |
tokens through the sparse or the splayed model, okay? 00:19:59.380 |
Then what we do is we initialize an empty list, which is where we're going to store 00:20:05.500 |
And what we do is we go through the IDs, the dense vectors, the sparse vectors, and the 00:20:09.960 |
context that we've just created, and we create this format here, all right? 00:20:15.680 |
So this is for every record, we have this format, the ID, values, sparse values, and 00:20:26.400 |
So with that, we'll run this cell, and let's try it with the first three records first, 00:20:35.680 |
So we'll just kind of loop through, there we go. 00:20:39.040 |
So we get these, there's a lot of numbers in there, but we have the metadata, we have, 00:20:43.240 |
if I come up to here, we have the, these are the values and the indices for our splayed 00:20:49.800 |
vector, right, indices for the sparse values. 00:20:53.480 |
We have our dense values, our dense vector, which is very big. 00:21:02.760 |
So now what we want to do is initialize our connection to Pinecone using free API keys. 00:21:07.520 |
So for that, you will go here, it's actually app.pinecone.io, and you will end up on this 00:21:16.120 |
Initially, you go to API keys, and you will have your API key here, it will probably say 00:21:22.280 |
You click copy, say that over here, and you just put it into your API key. 00:21:27.560 |
I've stored mine in a variable called your API key. 00:21:31.440 |
And then for your environment, you go back over to your console, and you just copy whatever 00:21:37.680 |
So for me, us-east1-gcp, yours, there's a good chance it'll be the same, but it may 00:21:48.040 |
So that just initializes our connection with Pinecone, and then what we want to do is actually 00:21:54.340 |
So we run this, there's a few things that are important here. 00:21:57.640 |
So the index name is not so important, you can kind of use whatever you want there, but 00:22:04.160 |
Dimensionality, so that is the 768 dimensions of the dense vector embedding model. 00:22:13.520 |
We have to use the dot product metric to use the sparse dense vectors. 00:22:18.800 |
And for the pod type, we must use either S1 or P1. 00:22:23.760 |
So that will just create the index, and we can actually go to the console, we go to indexes, 00:22:33.520 |
All right, so we have this PubMed displayed, one in there now, go to here. 00:22:42.040 |
And what we then need to do is initialize the connection to our index. 00:22:47.360 |
For this, we can use either index or we can use gRPC index, which is just essentially 00:22:54.300 |
faster and also a little bit more reliable in terms of your connection. 00:23:02.120 |
The index one is still very stable and still very fast, but just not as good. 00:23:09.560 |
That will just give us some index statistics, of course, our index is completely empty right 00:23:14.400 |
now, and the dimensionality is what we set before, the 768. 00:23:20.480 |
Now to add some vectors, we just do this, so index upsert, and we pass in what we created 00:23:25.880 |
with Builder, because Builder is outputting the format that we need to add things to Pinecone. 00:23:31.440 |
Okay, so we can see that we upsert three items if we do that. 00:23:37.040 |
Upsert just means like insert, like three items. 00:23:45.520 |
So you can also increase the batch size depending on what hardware you're using. 00:23:50.720 |
We'll stick with 64, which is pretty low, just to be safe depending on what you're using. 00:23:56.900 |
And with this, it's not going to take long, so we've got like a minute 20 here, so I'll 00:24:03.360 |
Okay, so that is complete, it took one and a half minutes. 00:24:07.440 |
And then what we want to do is we're just going to check that the number of upserted 00:24:11.440 |
records aligns with the length of our original data. 00:24:14.820 |
Okay, so here is our original data, and here's a number of items that are inside our index 00:24:23.200 |
So it looks like everything is in there, and we can move on to querying. 00:24:28.180 |
So our queries will need to contain both sparse and dense vectors, so we're going to use this 00:24:37.160 |
And what that will allow us to do is, it's just going to handle everything for us. 00:24:42.380 |
So we create our dense vectors, we then create our sparse dictionary and we just return those. 00:24:49.440 |
So we're going to start with, can clinicians use the PHQ-9 to assess depression in people 00:24:57.600 |
So we run this, and we say, straight away, I think to investigate, we have a PHQ-9, the 00:25:08.560 |
essential psychometric characteristics to measure depressive symptoms in people with 00:25:18.320 |
So you see that we have depressive symptoms, depression, vision loss, and visual impairment. 00:25:25.760 |
So it's not, the words don't align perfectly, right? 00:25:33.080 |
So my question here would be, what is doing this? 00:25:36.540 |
Is it the dense component, or is it the sparse component? 00:25:40.580 |
And actually, we'll see that it's kind of both. 00:25:43.760 |
But what I want to show you is that we can actually scale the dense versus sparse components. 00:25:51.720 |
So the way that we do this is that we use this hybrid scale function. 00:25:56.160 |
And what it's going to do is it's going to take an alpha value, where the alpha, when 00:26:00.600 |
it is equal to 1, it will maximize the dense vector, but it will basically make the sparse 00:26:10.780 |
If we use an alpha value of 0, it means the sparse vector is the only thing being used, 00:26:16.940 |
and the dense vector is completely irrelevant. 00:26:19.220 |
And if we just want an equal blend between the two of them, we use 0.5. 00:26:23.980 |
So let's first try a pure dense search and see what happens. 00:26:34.400 |
And you see that we actually get the right answer up here straight away. 00:26:43.220 |
It's not that much different, but it's different. 00:26:46.940 |
So does that mean it's only the dense vector doing this? 00:26:50.780 |
Let's try an alpha value of 0.0, and we actually get the same answer at the top again. 00:27:01.340 |
Yeah, so with the dense embedding, I'm not sure if the performance on that is better 00:27:08.020 |
or not, but we do get slightly different results. 00:27:11.580 |
So when we have a mix of both, we actually get the star result there. 00:27:15.400 |
So let's try some other questions that maybe will help us get slightly different responses. 00:27:23.400 |
What is going on here is that both models are actually very good for this data set. 00:27:28.460 |
So we don't see that much difference when we try and vary them. 00:27:32.320 |
So does ibuprofen increase perioperative blood loss during hip arthroplasty? 00:27:42.140 |
This is a sparse search, and when we run it, we get to the term where the prior exposure 00:27:47.900 |
of non-okay, this is ibuprofen from what I understand, anti-inflammatory drugs, increases 00:27:56.540 |
this thing here, perioperative blood loss, associated with major orthopaedic surgery. 00:28:05.420 |
So I checked what this means, and this basically means a hip replacement, or sorry, no, this 00:28:11.620 |
means hip replacement, and the words, I think both of them. 00:28:18.340 |
So this is like major surgery, and this is a hip replacement, which is major surgery. 00:28:23.540 |
That's what I understood, it could be completely wrong, but I'm not sure. 00:28:27.980 |
This one, and then they mentioned hip replacement here. 00:28:30.820 |
So I think this one is relevant, and this is using the pure sparse method, right? 00:28:37.400 |
And then we get this, and this actually does talk about ibuprofen and this sort of stuff, 00:28:43.380 |
but I don't know if that is, it doesn't mention the arthroplasty thing. 00:28:53.840 |
If we'd go pure dense, okay, we actually get the best answer at position number two, which 00:29:02.540 |
It's not that it's not performing well, that is a good performance, but it's not quite 00:29:07.540 |
as good as when we have the pure sparse, right? 00:29:11.380 |
So what we'll find, and I put a ton of example questions in here from this PubMed QA paper. 00:29:19.240 |
So you can try a few of these, but what we find is that some of them perform better with 00:29:24.520 |
sparse, some of them perform better with dense. 00:29:27.460 |
So what is a good approach to use here is to use a mix of both using the hybrid search. 00:29:32.620 |
So we set like alpha to 0.3, 0.5, whatever seems to work best overall, depending on your 00:29:41.140 |
particular use case, and overall we're going to get generally better performance. 00:29:46.960 |
Now, once you're done with all of this, if you've asked a couple more questions and so 00:29:50.700 |
on, what you need to do is just delete your index at the end, save resources so that you're 00:30:01.380 |
We've just kind of quickly been through an example of actually using hybrid search in 00:30:07.300 |
Pinecone with Splayed and a dense vector sentence transform model. 00:30:16.500 |
What we'll find is that the performance of hybrid search versus just pure dense or pure 00:30:27.440 |
So if you're able to implement this in your search applications, it's 100% worth doing. 00:30:37.440 |
So I hope this video has been interesting and useful. 00:30:42.600 |
So thank you very much for watching, and I will see you again in the next one.