Back to Index

Medical Search Engine with SPLADE + Sentence Transformers in Python


Chapters

0:0 Hybrid search for medical field
0:18 Hybrid search process
2:42 Prerequisites and Installs
3:26 Pubmed QA data preprocessing step
8:25 Creating dense vectors with sentence-transformers
10:30 Creating sparse vector embeddings with SPLADE
18:12 Preparing sparse-dense format for Pinecone
21:2 Creating the Pinecone sparse-dense index
24:25 Making hybrid search queries
29:59 Final thoughts on sparse-dense with SPLADE

Transcript

Today we're going to take a look at how to implement a hybrid search engine using both Splayed and a sentence transformer. This is going to be very hands on, so I'm going to outline the architecture of this thing and then we'll jump straight into building it. So as for the process that we're going to be going through here, we're going to start off with data, obviously.

So we have our data over here and it's just going to be paragraphs of text, like loads and loads of text. What we're going to have to do is we're going to have to create a chunking mechanism. So when we are encoding these chunks of text or just general information, we can't encode too much information in any one go, depending on which models you're using.

So if you're using Cohere or OpenAI, the chunks that you can use will be much larger. We're going to be using a sentence transformer on one side and then Splayed on the other side. Splayed and the sentence transformer are somewhat limited. You're definitely not going to go over around 500 tokens.

And with the Dennis Vector Embedding model, the sentence transformer, I believe it's like 384 tokens. So you think that's probably around a thousand characters at most. So you're thinking a few sentences. So we're going to take that through a chunking mechanism to just break it apart into chunks of text.

Right. And then those chunks of text, we're going to take them on one side, we're going to feed them into Splayed, which is going to create our sparse vectors. And then on the other side, we have our sentence transformer. Can't remember which one we're using. So I'll just put ST for now.

And that will give us our Dennis Vectors. Right. With both of those, we're going to take them both and we're going to create with both of them a sparse Dennis Vector, which we then take into Pinecone. Okay. So that's our vector database where we're going to store everything. And then what we can do is we can ask like a question, we ask a question up here.

That gets encoded again by Splayed, so since this is Splayed, this is our dense embedding model. Put those both together to create a sparse Dennis Vector. And we actually take it over here, feed it into Pinecone and get a ton of responses based on both the sparse and the dense information.

So that is what we're going to be building. Let's actually go ahead and build it. All right. So we're going to start. The first thing, we're using transformers, sentence transformers here. That means we're going to be creating these embeddings on our local machine. In this case, it's actually Colab, but I'll call it local rather than the alternative, which would be called an API like OpenAI.

So what we'll do is we go to runtime. We need to change runtime type and we need to make sure that we're using a GPU here. Okay. So save that, run. And these are just the libraries that we're going to be using. So HungFace datasets for our data, obviously transformers, sentence transformers for our encoder.

This is a dense encoder, our database, and also our SPLADE model. Okay, cool. So we're going to first load the PubMedQA dataset. So this is a medical question answering dataset. So with medical things, you'll find that there's a lot of kind of specific terminology and it's within that sort of domain that models like SPLADE or just general sparse embedding models will perform better.

However, if you are able to train your sentence transformer, your dense embedding model on the dataset, then in that case, you can actually improve the performance to beyond that of a sparse embedding model, usually. So let's have a look at what we have. So we just have our ID here, and then we have this data, this context, and we have all these paragraphs.

And what we're going to need to do with these paragraphs is put them all together and then chunk them into the smaller chunks that will fit into our models. So I think we mentioned here, we're going to, yeah, into digestible chunks for our models. We are going to be using BERT, which has, you know, the default BERT model has this max sequence length, 512 tokens, which is fairly big, but your typical sentence transformer actually limits this to 128, right?

So we're going to be pretty naive in our assumptions here, but we're going to just assume this 128 token limit, and we're actually going to assume that the average token length is three characters. In reality, it will vary. We should realistically actually create our tokens and then count the number of tokens in there, but it's just more complicated logic, and I want this to be as simple as possible.

So we're just doing this for now, but if you're interested in that, let me know and I can send you a link to some code that does actually do that. Okay. So to create these chunks, we're going to create a processing function called chunker, which is here. We're just going to feed in that list of context that we've got up here.

So this literally, this list here, and what it's going to do is join them and then split based on sentences. So we're going to create our chunks at a sentence level. So what we do is we loop through each sentence in here, and we say, okay, if we add it to the chunk here, and once the length of that exceeds our limit, which is here, the limit is 384 tokens, we will say, okay, we've got enough here, we're not going to go any higher, so we then add that to the chunk.

Okay. So here, what we're doing is, let's say we have four sentences, or no, five sentences in a single chunk. What we're going to do is so that we're not cutting off between sentences that are relevant, like have some continuum logic between them, what we're going to do is between each of our chunks, we're actually going to have some overlap.

So let's say we take chunks zero to four, and then what we're going to do for the next chunk is take chunks two to seven, or something like that. So there's always a bit of overlap between the chunks. Okay. So once we are done, and we get to the end of our sentences, we might still have a smaller chunk that's left over, so we just append that to our chunks list.

So that's our chunking function. So let's run that, and we'll apply it to our first context. All right, and then we get these smaller chunks now, okay? And they've been split between sentences. We're not just splitting in the middle of a sentence. But one thing you will also notice is, like here, it says, "The leaf server plan consists of a lattice work of," and we also have that here, right?

So we always have, like, basically half the chunk is overlapped. So we have a lot of repetition in there. Depending on what you're doing, you can minimize that. Just this example is fine. You should realistically have some overlap there, so you're not cutting between sentences that have some logical continuation.

We basically don't want to lose information. So that's why we have those overlaps in there. This is probably a more reasonable one, so you have all this, and then the overlap starts from around here. Okay? Cool. So what we want to do is give each chunk a unique ID.

So we're using the pub_id here, followed by the chunk number, okay? And what we do is we create the full... So this is for the full dataset, I think. Let me... Okay, yeah. So we're going to go through the entire PubMed dataset here, we're going to get the context, and we're going to create our chunks, okay?

Again, we're using that PubMed ID and the chunk number. So we're in that. All right? And we get a list, okay? So that's good. Now what I want to do is move on to creating our vectors. All right, so the first one I'm going to do is the dense vectors.

We're using a sentence transformer for this. All right? And the first thing we want to do is make sure that we're using CUDA if it's available, otherwise you can use CPU, it's just going to be slower. It's not going to be too slow, it's not a huge dataset that we're processing here, but you know, just be aware of that.

And the model that we're using is this base model that has been trained on MS Marco, which is like an information retrieval dataset. And specifically, so this is important, it has been trained to use dot product similarity. And we need that for it to function with the sparse dense vectors that we are putting into Panko.

Okay? So they're basically, they're compared in a dot product similarity space. So that is important. And we initialize it on CUDA if we can. Right? Cool. So we see the sentence transformer details here, and we can actually see here that the max sequence length for this sentence transformer is 512 tokens.

So early on when we went for the 128 token limit, with this one, we can actually do 512. So we could increase that quite a bit. So I think we set like 380 something for the character limit. With this, we could actually set like 1,500, which is quite a bit more.

But anyway, we'll stick with what we have because with a lot of sentence transformers, they are restricted to that smaller size. And then we create a embedding like this. So we have our dense model, we encode, and then we pass in our data. Right? And we'll get a, we'll see in a moment, 768 dimensional dense vector.

Cool. You can also see that in the model, get sentence embedding dimension here as well. This is important. We'll need this when we're actually initializing our vector index later. So moving on to the sparse vectors, we're using the splayed co-condenser assembled distil. So it's basically like an efficient splayed model.

We do we want, I think this all looks good. So one thing, we move it to CUDA if we can. The aggregation here is max. So it's basically how it's creating its single vectors from the many vectors I initially create. And I created a video on splayed so you can go and take a look at that if you're interested.

There'll be a link to that in the video at the top somewhere. Okay. So it takes tokenized inputs that need to be built with a tokenizer initialized with the same model ID. Okay. So this model here. Right? So we create our tokens like this. We make sure to return PyTorch tensors.

And then to create our sparse vectors, we do this. So we're saying torch no grad, which basically means like, don't calculate the gradients of the model because it takes more time. And we only need that for training the model. Right now we're just performing inference or prediction. So it's not needed.

Okay. And what we do is we move the tokens to CUDA if we're using it. And then we feed them into the model. So the reason we move to CUDA is because if we don't, the tokens feeding into the model are on CPU and the model is on GPU, we're going to see an error.

So we need to make sure we include that in there. And then here is the splayed vector representations output by the model. And we use squeeze to reduce the dimensionality of that vector. So initially it's like, I think it's like 30,000 comma one, the shape. We don't need that one.

So we just remove it like that. All right. So that gives us this dimensional vector, which is huge, right? So 30.5 thousand items, right? So that is actually the vocab size of the BERT model. So every token that BERT recognizes is represented by one of these values. And essentially we're creating a score for each one of those tokens through splayed, right?

Most of them are going to be zero, right? That's what makes it a sparse vector. Now to create the data format that we'll be feeding into Pinecone, it's essentially going to be like a dictionary of the position of the nonzero values to the scores that they were assigned. So what's that look like?

Let me show you. So here we can see we have 174 nonzero values here, should say that as well. And we create this, okay? So let me show you what that is. This is a kind of bad example. So we come up to here and we have our indices.

So at position number 1,000, the score of that token is this, right? And I think I have a little example of what that actually means here. We don't need to do this for processing things by Pinecone. We are just doing this so that we can understand what this actually means.

So I'm going to create this. This is an index to token. So like I said, all of those 30.5,000 values in that vector that was output by splayed, they all refer to a particular token, right? And in this, these tokens are just numbers because that's what the transform model splayed will read, which we can't read.

We don't understand that, right? We need actual text. So this is mapping those positions, those integer values to the actual text tokens that they represent. And we process the dictionary that we just created up here through that. And then we get this, right? So let's see. So this is for, can I see what this is for?

It's for this here, right? Let's just have a look at what this is and then we'll see if it makes sense. All right. So program cell death is a regular death of cells of an organism. The lace plant produces so on and so on. Lattice work of longitudinal and transverse veins, including areoles.

You know, I don't know what any of that means, but we can at least see that in this sparse dictionary we have, so we have PC, which is, I think this is like, it's coming from here. It's not ideal, but it's fine. Lace, which is mentioned here, programmed. We have this up here, Madagascar, I don't know where that's coming from.

Death D is the D at the end there, right? So we have all of these. And then I think we should also have some other words in here that are not actually from this, because what SPLADE does is actually identifies the words that are in the vector already, or within this, it identifies the most important words, okay?

So I would say it's probably got that right with, like, lace programmed, the PC, and the D here, right? And death, lattice, cell, all those are probably the most important words in here. It's not giving us the word the, or the word within, right? Because it doesn't view those as being what are important.

But if we go down, we'll probably see, we'll probably see some words that are not actually in here, but are similar to words in here. Because part of what SPLADE does is it expands, it does term expansion, which basically means based on the words it sees, it adds other words that it doesn't see, but that we might expect a document that we're searching for to contain.

So I think the word, okay, so the word die, I don't think is in here, right? But you come down here and it is here. Regulated, okay, regulated is in there. Lacy, it's probably not, so we have lace plant, all right, so lacy is in there, I don't know if that is actually relevant, I don't understand any of what this says.

We have plant and plants, I wonder if both of those are in there. So we've got plant, plant, okay, we don't have plants, right? But that might be useful, right? So imagine in your document that this, well, actually, this is a document. Let's say in the query, the user is searching for program cell death in plants, or how do plants die from PCD, right?

They would have the term die and plants in there, but they wouldn't have the term death or plant, right? So that's why the term expansion is really useful, because then you do have that term overlap, which is what traditional sparse vector methods kind of lack, so like BM25. They don't have that automatic term expansion.

So we create our sparse vectors, or we have seen how to create our dense vectors and seen how to create our sparse vectors. Now let's have a look at how we do this for everything. So we're going to create a help function called builder, which is first going to transform a list of records from our data, so the context, into this format here.

So this is the format that we're going to be feeding into PyCone, right? So we have our ID, we have our dense vector here, we have our sparse vector in the dictionary format that we saw already, and then we have this metadata. Metadata is just additional information that we can attach to our vectors.

In this case, I'm going to include the text, like a human readable text. So what we'll do is we create builder. This is just going to go through everything, right? So let me go through everything here. So we get our IDs from the records that we have there, so we have our IDs.

So records is just everything, I believe, yeah. So records is everything nowadays, so it's going to extract the IDs for everything, and then it's going to extract the context, right? So that's why we have the pub ID followed by the chunk number. That's the ID. And then we have those kind of smaller sentence, couple sentences, chunks of text.

And then from those chunks of text, what we're going to do is we're going to encode everything. That creates our dense vectors, then we're going to create our sparse vectors, so we get our, what is this bit, so input IDs, that's creating our tokens, and then we process our tokens through the sparse or the splayed model, okay?

Then what we do is we initialize an empty list, which is where we're going to store everything to add to Pinecone. And what we do is we go through the IDs, the dense vectors, the sparse vectors, and the context that we've just created, and we create this format here, all right?

So this is for every record, we have this format, the ID, values, sparse values, and metadata, okay? Which is what I showed you just here, right? Cool. So with that, we'll run this cell, and let's try it with the first three records first, okay? So we'll just kind of loop through, there we go.

So we get these, there's a lot of numbers in there, but we have the metadata, we have, if I come up to here, we have the, these are the values and the indices for our splayed vector, right, indices for the sparse values. We have our dense values, our dense vector, which is very big.

And then we have the ID. All right, cool. So now what we want to do is initialize our connection to Pinecone using free API keys. So for that, you will go here, it's actually app.pinecone.io, and you will end up on this page. Initially, you go to API keys, and you will have your API key here, it will probably say default.

You click copy, say that over here, and you just put it into your API key. I've stored mine in a variable called your API key. And then for your environment, you go back over to your console, and you just copy whatever is in here. So for me, us-east1-gcp, yours, there's a good chance it'll be the same, but it may vary.

All right, cool. So we run that. So that just initializes our connection with Pinecone, and then what we want to do is actually create a index. So we run this, there's a few things that are important here. So the index name is not so important, you can kind of use whatever you want there, but you do need to pass an index name.

Dimensionality, so that is the 768 dimensions of the dense vector embedding model. Not the displayed model, the dense model. We have to use the dot product metric to use the sparse dense vectors. And for the pod type, we must use either S1 or P1. So that will just create the index, and we can actually go to the console, we go to indexes, and we should see it if we refresh.

All right, so we have this PubMed displayed, one in there now, go to here. And what we then need to do is initialize the connection to our index. For this, we can use either index or we can use gRPC index, which is just essentially faster and also a little bit more reliable in terms of your connection.

It holds a stable connection to Pinecone. The index one is still very stable and still very fast, but just not as good. So we run those. Okay, cool. That will just give us some index statistics, of course, our index is completely empty right now, and the dimensionality is what we set before, the 768.

Now to add some vectors, we just do this, so index upsert, and we pass in what we created with Builder, because Builder is outputting the format that we need to add things to Pinecone. Okay, so we can see that we upsert three items if we do that. Upsert just means like insert, like three items.

All right, so cool. We can repeat that for the full index. So you can also increase the batch size depending on what hardware you're using. We'll stick with 64, which is pretty low, just to be safe depending on what you're using. And with this, it's not going to take long, so we've got like a minute 20 here, so I'll skip ahead.

Okay, so that is complete, it took one and a half minutes. And then what we want to do is we're just going to check that the number of upserted records aligns with the length of our original data. Okay, so here is our original data, and here's a number of items that are inside our index now.

So it looks like everything is in there, and we can move on to querying. So our queries will need to contain both sparse and dense vectors, so we're going to use this function here called encode. And what that will allow us to do is, it's just going to handle everything for us.

So we create our dense vectors, we then create our sparse dictionary and we just return those. So we're going to start with, can clinicians use the PHQ-9 to assess depression in people with vision loss? So we run this, and we say, straight away, I think to investigate, we have a PHQ-9, the essential psychometric characteristics to measure depressive symptoms in people with visual impairment.

So I would say that is probably correct. So you see that we have depressive symptoms, depression, vision loss, and visual impairment. So it's not, the words don't align perfectly, right? But they have the same meaning. So my question here would be, what is doing this? Is it the dense component, or is it the sparse component?

And actually, we'll see that it's kind of both. But what I want to show you is that we can actually scale the dense versus sparse components. So the way that we do this is that we use this hybrid scale function. And what it's going to do is it's going to take an alpha value, where the alpha, when it is equal to 1, it will maximize the dense vector, but it will basically make the sparse vector completely irrelevant.

If we use an alpha value of 0, it means the sparse vector is the only thing being used, and the dense vector is completely irrelevant. And if we just want an equal blend between the two of them, we use 0.5. So let's first try a pure dense search and see what happens.

I need to run this. And you see that we actually get the right answer up here straight away. The score is different. This is 181, whereas up here it is 203. It's not that much different, but it's different. So does that mean it's only the dense vector doing this?

Let's try an alpha value of 0.0, and we actually get the same answer at the top again. So I think there is some variation. I think that maybe this changes. Yeah, so with the dense embedding, I'm not sure if the performance on that is better or not, but we do get slightly different results.

So when we have a mix of both, we actually get the star result there. So let's try some other questions that maybe will help us get slightly different responses. What is going on here is that both models are actually very good for this data set. So we don't see that much difference when we try and vary them.

So does ibuprofen increase perioperative blood loss during hip arthroplasty? This is a sparse search, and when we run it, we get to the term where the prior exposure of non-okay, this is ibuprofen from what I understand, anti-inflammatory drugs, increases this thing here, perioperative blood loss, associated with major orthopaedic surgery.

So I checked what this means, and this basically means a hip replacement, or sorry, no, this means hip replacement, and the words, I think both of them. So this is like major surgery, and this is a hip replacement, which is major surgery. That's what I understood, it could be completely wrong, but I'm not sure.

This one, and then they mentioned hip replacement here. So I think this one is relevant, and this is using the pure sparse method, right? And then we get this, and this actually does talk about ibuprofen and this sort of stuff, but I don't know if that is, it doesn't mention the arthroplasty thing.

So I just assume it's not as relevant. If we'd go pure dense, okay, we actually get the best answer at position number two, which is still good, right? It's not that it's not performing well, that is a good performance, but it's not quite as good as when we have the pure sparse, right?

So what we'll find, and I put a ton of example questions in here from this PubMed QA paper. So you can try a few of these, but what we find is that some of them perform better with sparse, some of them perform better with dense. So what is a good approach to use here is to use a mix of both using the hybrid search.

So we set like alpha to 0.3, 0.5, whatever seems to work best overall, depending on your particular use case, and overall we're going to get generally better performance. Now, once you're done with all of this, if you've asked a couple more questions and so on, what you need to do is just delete your index at the end, save resources so that you're not using more than what is needed.

So that's it for this video. We've just kind of quickly been through an example of actually using hybrid search in Pinecone with Splayed and a dense vector sentence transform model. And I think the results are pretty good. Now this is just one example. What we'll find is that the performance of hybrid search versus just pure dense or pure sparse search generally is a lot better.

So if you're able to implement this in your search applications, it's 100% worth doing. But anyway, for now, we'll leave it there. So I hope this video has been interesting and useful. So thank you very much for watching, and I will see you again in the next one. Bye.