back to indexSPLADE: the first search model to beat BM25
Chapters
0:0 Sparse and dense vector search
0:44 Comparing sparse vs. dense vectors
3:59 Using sparse and dense together
6:46 What is SPLADE?
9:6 Vocabulary mismatch problem
9:51 How SPLADE works (transformers 101)
14:28 Masked language modeling (MLM)
15:57 How SPLADE builds embeddings with MLM
17:35 Where SPLADE doesn't work so well
20:14 Implementing SPLADE in Python
20:38 SPLADE with PyTorch and Hugging Face
24:8 Using the Naver SPLADE library
27:11 What's next for vector search?
00:00:00.000 |
In information retrieval vector embeddings represent documents and queries in a numerical 00:00:08.080 |
vector format. That means that we can take some text which could be web pages from the internet 00:00:15.200 |
in the case of Google or maybe product descriptions in the case of Amazon and we can encode it using 00:00:22.080 |
some sort of embedding method or model and we will get something that looks like this. So we have now 00:00:29.120 |
represented our text in a vector space. Now there are different ways of doing this and sparse and 00:00:37.200 |
dense vectors are two different forms of this representation each with their own pros and their 00:00:43.520 |
own cons. Typically when we think of sparse vectors things like TF-IDF and BM25 they have very high 00:00:51.760 |
dimensionality and they contain very few non-zero values. So the information within those vectors is 00:00:59.680 |
very sparsely located and with these types of vectors we have decades of research looking at 00:01:07.520 |
how they can be used and how they can be represented using compact data structures 00:01:14.000 |
and there are naturally many very very efficient retrieval systems designed specifically for these 00:01:21.840 |
vectors. On the other hand we have dense vectors. Dense vectors are lower dimensional but they are 00:01:28.720 |
very information rich and that's because all of the information is compressed into this much smaller 00:01:35.200 |
dimensional space so we don't have these non-zero values that we would get in a sparse vector 00:01:41.040 |
and hence all of the information is very densely packed and hence why we call them dense vectors. 00:01:46.560 |
These types of vectors are typically built using neural network type architectures like 00:01:51.680 |
transformers and through this they can represent more abstract information like the semantic 00:01:59.680 |
meaning behind some text. When it comes to sparse embeddings the pros are typically faster retrieval, 00:02:06.240 |
a good baseline performance, we don't need to do any model fine-tuning and we also get to do exact 00:02:14.240 |
matching of terms. Whereas on the cons we have a few things as well so the performance cannot really 00:02:20.880 |
be improved significantly over the baseline performance of these algorithms. They also suffer 00:02:26.400 |
from something called the vocabulary mismatch problem and we'll talk about that in more detail 00:02:31.760 |
later and it also doesn't align with the human-like thought of abstract concepts that I described 00:02:38.400 |
earlier. Naturally we have a completely different set of pros and cons when it comes to dense 00:02:42.880 |
vectors. On the pros we know that dense vectors can outperform sparse vectors with fine-tuning, 00:02:49.520 |
we also know that using these we can search with human-like abstract concepts, we have great 00:02:55.840 |
support of multi-modalities so we can search across text, images, audio etc and we can even 00:03:02.960 |
do cross-modal search so we can go from text to image or image to text or whatever you can think 00:03:09.360 |
of. But of course there's also the cons. We know that in order to outperform sparse vector embeddings 00:03:17.120 |
or even get close to sparse vector embeddings in terms of performance we very often require training 00:03:23.520 |
and training requires a lot of data which is very difficult to find when we are find ourselves in 00:03:29.280 |
low resource scenarios. These models also do not generalize very well particularly when we are 00:03:34.880 |
moving from one domain with very specific terminology to another domain with completely 00:03:40.160 |
different terminology. These embeddings also require more compute and memory to build and store 00:03:46.560 |
and search across than sparse methods. We do not get any exact match search and it's kind of hard 00:03:54.000 |
to understand why we're getting results some of the time so it's not very interpretable. Ideally 00:04:00.080 |
we want a way of getting the best of both worlds, we want the pros of dense and the pros of sparse 00:04:06.880 |
and just we don't want any of these cons. But that's very hard to do. There have been some 00:04:13.040 |
band-aid solutions. One of those is to perform two-stage retrieval. In this scenario we have 00:04:20.720 |
two stages to retrieve and rank relevant documents for a given query. In the first stage our system 00:04:28.480 |
would use a sparse retrieval method to search through and return relevant documents from a very 00:04:35.360 |
large set of candidate documents. These are then passed on to the second stage which is a re-ranking 00:04:42.560 |
stage and this uses a dense embedding model to re-rank from that smaller set of candidate 00:04:49.440 |
documents which one it believes is the most relevant using its more human-like semantic 00:04:56.640 |
comprehension of language. There are some benefits to this. First we can apply the sparse method to 00:05:02.480 |
the full set of documents which makes it more efficient to actually search through those and 00:05:07.600 |
then after that we can re-rank everything with our dense model which is naturally much slower but 00:05:13.600 |
we're dealing with a smaller amount of data. Another benefit is that this re-ranking stage 00:05:19.280 |
is detached from the retrieval system so we can modify one of those stages without affecting the 00:05:26.960 |
other and this is particularly useful if we have multiple models that take for example the output 00:05:34.640 |
of the sparse retrieval stage. So that's another thing to consider. However of course this is not 00:05:40.640 |
perfect. Two stages of retrieval and re-ranking can be slower than using a single stage system 00:05:49.280 |
that uses approximate nearest neighbor search algorithms and of course having two stages within 00:05:55.920 |
the system is more complicated and there are naturally going to be many engineering challenges 00:06:02.240 |
that come with that and we're also very reliant on that first stage retriever. If that first stage 00:06:09.120 |
retriever doesn't perform very well then there's nothing we can do with the second stage re-ranking 00:06:16.000 |
model because if it is just being given a load of rubbish results it's just going to re-rank 00:06:21.920 |
rubbish results and the final result will still be rubbish. So they're the main problems with this and 00:06:27.840 |
ideally we want to solve that and we want to do that by improving single-stage systems. Now a lot 00:06:34.640 |
of work has been put into improving single-stage retrieval systems. A big part of that research 00:06:40.240 |
has been in building more robust and learnable sparse embedding models and one of the most 00:06:47.280 |
promising models within this space is known as SPLADE. Now the idea behind the sparse lexical and 00:06:56.080 |
expansion models is that a pre-trained model like BERT can identify connections between 00:07:02.560 |
words and sub-words which we can call word pieces or terms and use that knowledge to enhance 00:07:09.520 |
our sparse vector embeddings. This works in two ways it allows us to measure the relevance of 00:07:17.920 |
different terms so the word 'the' will carry less significance in most cases than a less common word 00:07:26.400 |
like orangutan. The second thing it helps us with is it enables learnable term expansion 00:07:34.080 |
where term expansion is the inclusion of alternative but relevant terms beyond those 00:07:41.040 |
that are found in the original sentence or sequence. Now it's very important to take note of 00:07:47.120 |
the fact that I said learnable term expansion. The big advantage of SPLADE is not they can do 00:07:53.280 |
term expansion that is something that has been done for a while but they can learn term expansions. 00:07:59.600 |
In the past term expansion could be done with more traditional methods but it required 00:08:04.880 |
rule-based logic and rule-based logic someone would have to write that and this is naturally 00:08:10.640 |
time consuming and fundamentally limited because you can't write rules for every single scenario in 00:08:18.240 |
human language. Now by using SPLADE we can simply learn these using a transformer model which is of 00:08:26.480 |
course much more robust and much less time consuming for us. Now another benefit of using 00:08:32.960 |
a context-aware transform model like BERT is that it will modify these term expansions based on the 00:08:40.240 |
context based on the sentence that's being input so it won't just expand the word rainforest to 00:08:47.200 |
three different words it will expand the right word rainforest to many different words that 00:08:54.400 |
entirely depends on the context or the sentence that was fed in with and this is one of the big 00:09:00.400 |
benefits of attention models like transformers that is very context aware. Now term expansion 00:09:07.840 |
is crucial in minimizing a very key problem with sparse embedding methods and that is the vocabulary 00:09:16.320 |
mismatch problem. Now the vocabulary mismatch problem is the very typical lack of overlap 00:09:23.040 |
between a query and the documents that we are searching for. It's because we think of things in 00:09:30.480 |
abstract ideas and concepts and we have many different words in order to explain the same 00:09:35.280 |
thing it's very unlikely that the way that we describe something when we're searching for 00:09:39.600 |
something contains the exact terms the exact words that this relevant information contains 00:09:47.600 |
and this is just a side effect of the complexity of human language. Now let's move on to SPLADE 00:09:53.920 |
and how SPLADE actually builds these sparse embeddings. Now it's actually relatively easy 00:10:00.880 |
to grasp what is happening here. We first start with the transform model like BERT. Now these 00:10:07.200 |
transform models use something called mass language modeling in order to perform their 00:10:12.960 |
pre-training on a ton of text data. Not all transform models use this but most do. Now if 00:10:18.880 |
you're familiar with BERT and mass language modeling that's great if not we're going to just 00:10:24.400 |
quickly break it down. So starting with BERT it's a very popular transform model and like all 00:10:31.280 |
transform models its core functionality is actually to create information rich token embeddings. Now 00:10:39.120 |
what exactly does that mean? Well we start with some text like orangutans are native to the 00:10:46.400 |
forests of Indonesia and Malaysia. With a transform model like BERT we would begin by tokenizing that 00:10:52.800 |
text into BERT specific sub-word or word level tokens and we can see that here. So using the 00:11:00.320 |
HuggingFace transformers library we have this tokenizer object here. This is what is going to 00:11:06.320 |
handle the tokenization of our text. So we have the same sentence I described before orangutans 00:11:12.160 |
are native to the rainforests of Indonesia and Malaysia and we convert it into these tokens which 00:11:18.240 |
is what you can see here. Now these are just the token IDs which are integer numbers but each one 00:11:24.160 |
of these represents something within our text. So here for example this 2030 probably represents 00:11:31.840 |
orangutan and the 5654 here maybe represents the S at the end of orangutan. They can be word level 00:11:40.240 |
or sub-word level like that. Now these are just the numbers let's have a look down here and we can 00:11:45.600 |
actually see how our words are broken up into these token IDs or tokens. So we convert those 00:11:53.440 |
IDs back into human readable tokens and we can see okay we have this this called a classified token 00:11:59.360 |
that is a special token used by BERT. We'll see that at the start of every sequence tokenized by 00:12:04.880 |
BERT tokenizer and then we have orangutans. So it's actually split between four tokens and we 00:12:11.040 |
can see the rest of the sentence there as well. Now why do we create these tokens and these token 00:12:17.840 |
IDs? Well that's because these token IDs are then mapped to what is called an embedding matrix. The 00:12:25.040 |
embedding matrix is the first layer of our transformer model. Now in this embedding matrix 00:12:32.560 |
we will find learned vector representations that literally represent the tokens that we fed in 00:12:39.680 |
within a vector space. So the vector representation for the token rainforest will have a high 00:12:47.040 |
proximity because it has a high semantic similarity to the vector representations for the token 00:12:54.000 |
jungle or the token forest. Whereas it will be further away in that vector space from somewhat 00:13:01.280 |
less related tokens like native or the. Now from here the token representations of our original 00:13:08.320 |
text are going to go through several encoder blobs. These blobs encode more and more contextual 00:13:15.360 |
information into each one of these token embeddings. So as we progress through all of these 00:13:22.160 |
encoder blobs the embeddings are basically going to be moved within that vector space in order to 00:13:28.720 |
consider the meaning within the context of the sentence it appears in rather than just the meaning 00:13:34.160 |
of the token by itself. And after all this progressive iteration of encoding more contextual 00:13:42.000 |
information into our embeddings we arrive at the transformers output layer. Here we have our final 00:13:51.200 |
information rich vector embeddings. Each embedding represents the early token but obviously with that 00:13:56.400 |
context encoded into it. This process is the core of BERT and every other transformer model. However 00:14:02.640 |
the power of transformers comes from the considerable number of things for which these 00:14:09.360 |
information rich embeddings can be used. Typically what will happen is we'll add a task-specific head 00:14:15.520 |
onto the end of the transform model that will transform these information rich embeddings or 00:14:21.840 |
vector embeddings into something else like sentiment predictions or sparse vectors. The 00:14:28.880 |
mass language modeling head is one of the most common of these task-specific heads because it 00:14:36.880 |
is used for pre-training most transformer models. This works by taking a input sentence again let's 00:14:43.440 |
use the orangutans are native to the forests of Indonesia and Malaysia example again. We will 00:14:49.600 |
tokenize this text and then mask a few of those tokens at random. This mask token sequence is 00:14:56.080 |
then passed as input to BERT and at the other end we actually feed in the original unmasked sequence 00:15:02.960 |
to the mass language modeling head and what will happen is BERT and the mass language modeling head 00:15:09.040 |
will have to adjust their internal weights in order to produce accurate predictions for the 00:15:16.240 |
tokens that have been masked. For this to work the mass language modeling head contains 30,522 00:15:25.520 |
output tokens which is the vocabulary size of the BERT base model. So that means we have a output 00:15:33.440 |
for every possible prediction for every possible token prediction and the output as a whole acts 00:15:40.000 |
as a probability distribution over this entire vocabulary and the highest activation across that 00:15:46.080 |
probability distribution represents the token that BERT and the mass language modeling head 00:15:51.440 |
have predicted as being the token behind that masked token position. Now at the same time we 00:15:58.880 |
can think of this probability distribution as a representation of the words or tokens that are 00:16:07.680 |
most relevant to a particular token within the context of the wider sentence. With that what we 00:16:14.560 |
can do with SPLADE is take all of these distributions and aggregate them into a single 00:16:21.440 |
distribution called the importance estimation. The importance estimation is actually the sparse 00:16:28.320 |
vector produced by SPLADE and that is done using this equation here and this allows us to identify 00:16:34.800 |
relevant tokens that do not exist in the original sequence. For example if we masked the word 00:16:40.160 |
rainforest we might return high predictions for the words jungle, land and forest. These words and 00:16:48.800 |
their associated probabilities would then be represented in the SPLADE built sparse vector 00:16:54.640 |
and that doesn't mean we need to mask everything. The predictions will be made relevant to each 00:17:00.400 |
token whether it is masked or not. So in the end all we have to input is the unmasked sequence 00:17:07.920 |
and what we will get is all of these probability distributions for similar words to whatever 00:17:12.560 |
has been input based on the sentence in the context. Now many transform models are trained 00:17:19.600 |
with mass language modeling which means there are a huge number of models that have already got these 00:17:25.280 |
mass language modeling weights and we can actually use that to fine-tune those models as SPLADE 00:17:31.360 |
models and that's something that we will cover in another video. Now let's have a quick look at where 00:17:37.680 |
SPLADE works kind of less well. So as we've seen SPLADE is a really good tool for minimizing the 00:17:44.800 |
vocabulary mismatch problem however there are of course some drawbacks that we should consider. 00:17:49.680 |
Compared to other sparse methods retrieval with SPLADE is very slow. There are three primary 00:17:56.320 |
reasons for this. First the number of non-zero values in SPLADE query and document vectors 00:18:02.400 |
is typically much greater than in traditional sparse vectors because of that term expansion 00:18:08.320 |
and sparse retrieval systems are rarely optimized for this. Second the distribution of these non-zero 00:18:15.760 |
values also deviates from the traditional distribution expected by most sparse retrieval 00:18:22.080 |
systems again causing slowdowns and third SPLADE vectors are not natively supported by most sparse 00:18:31.760 |
retrieval systems meaning that we have to perform multiple pre and post processing steps, 00:18:38.240 |
weight discretization and other things in order to make it work if it works at all and it again 00:18:45.520 |
it's not optimized for that. Fortunately there are some solutions to all of these problems. For one 00:18:51.440 |
the authors of SPLADE actually address this in a later paper that minimizes the number of non-zero 00:18:59.440 |
values in the query vectors and they do that with two steps. First they improved the performance 00:19:06.960 |
of displayed document encodings using max pooling rather than the traditional pooling strategy and 00:19:14.720 |
second they limited the term expansion to the document encodings only so they didn't do the 00:19:21.760 |
query expansions and thanks to the improved document encoding performance dropping those 00:19:28.320 |
query expansions still leaves us with better performance than the original SPLADE model. 00:19:34.800 |
And then if we look at the final two problems so two and three these can both be solved by using 00:19:41.280 |
the Pinecone vector database. Two is solved by Pinecone's retrieval engine being designed to be 00:19:48.000 |
agnostic to data distribution and for number three Pinecone supports real valued sparse vectors 00:19:55.440 |
meaning SPLADE vectors are supported natively without needing to do any of those weird things 00:20:03.040 |
in pre-processing post-processing or discretization. Now with all of that I think we 00:20:10.880 |
have covered everything we could possibly cover in order to understand SPLADE. Now let's have a 00:20:14.960 |
look at how we would actually implement SPLADE in practice. Now we have two options for implementing 00:20:21.040 |
SPLADE we can do directly with Hugging Face Transformers and PyTorch or with a high-level 00:20:27.280 |
abstraction using the official SPLADE library. We'll take a look at doing both starting with 00:20:32.480 |
the Hugging Face and PyTorch implementation just so we can understand how it actually works. 00:20:38.640 |
Okay so first we start by just installing a few prerequisites so we have SPLADE, Transformers, 00:20:44.720 |
and PyTorch and then what we need to do is install this and then what we need to do is initialize 00:20:50.960 |
the tokenizer it's very similar to the BERT tokenizer we initialized earlier and the auto 00:20:57.280 |
model for MastLM so this is Mast Language Modeling. So we're going to be using the 00:21:02.720 |
Naver SPLADE model here and we just initialize all of that. Okay and we have one pretty large 00:21:11.920 |
chunk of text here so this is very domain specific so it has a lot of very specific words in there 00:21:19.200 |
that a typical dense embedding model would probably struggle with unless it has been fine-tuned on 00:21:25.200 |
data containing these exact same terms. So we'll run that and what we do is we tokenize everything 00:21:31.520 |
so that will give us our token IDs that you saw earlier and then we process those through our 00:21:37.200 |
model to create our logits output which is what we will see in a moment this here. Okay so as we saw 00:21:46.080 |
before those logits will be each one of them contains our probability distribution over the 00:21:53.760 |
30.5 thousand possible tokens from the vocabulary and we have 91 of those. Now the reason we have 00:22:01.440 |
91 of those is because from our tokens here we actually had 91 input tokens so if we have a look 00:22:09.840 |
at tokens input IDs dot shape we see that there was 91 input in there so that will change depending 00:22:21.120 |
on how many input tokens we have. Now from here what we're going to do is take these output logits 00:22:27.120 |
and we want to transform them into a sparse vector. Now to do that we're going to be using 00:22:32.080 |
the formula that you saw earlier to create the importance estimation and if we run that 00:22:39.680 |
we'll get a single probability distribution which represents the actual sparse vector from 00:22:45.920 |
SPLATE and we can have a look at that vector and we see there's mostly zeros in there there are a 00:22:50.320 |
few values but very few. So what I'm going to do now is first I want to just ignore this bit we're 00:22:58.400 |
going to come down to here and we're going to create a dictionary format of our sparse vector 00:23:05.280 |
so we run this and there's a few things I want to look at here so number of non-zero values that we 00:23:12.480 |
actually have is 174 and all of them are now contained within this sparse dictionary. Okay so 00:23:20.160 |
these are the token IDs and these are the weights or the relevance of each one of those particular 00:23:28.160 |
tokens. Now we can't read any of these token IDs so similar to before what we're going to do is 00:23:35.200 |
convert those into actual human readable tokens so to do that we'll need to run this 00:23:40.880 |
and then we come down here and we're going to convert them into a more readable format. 00:23:46.960 |
Okay we can see what it believes is important is all of these values so we've sorted everything 00:23:53.920 |
here so that's why the numbers have changed here and we can see that most importantly it's seeing 00:23:59.200 |
like programmed, death, cell, lattice, so a lot of very relevant words within that particular 00:24:08.000 |
domain. Now if we come a little bit further down we can also see how to do that using 00:24:14.080 |
the Naver SPLADE library. So for that we would have to pip install SPLADE we did that at the 00:24:19.760 |
top of the notebook so we don't need to do it again. We're going to be using the max aggregation 00:24:25.280 |
so this is using the max pooling method. Run this again using the same model ID here because it's 00:24:30.800 |
also downloading the model from Hugging Face Transformers and what we do is we set torch to 00:24:36.720 |
no grab so this is saying we don't want to update any of the model weights because we're not doing 00:24:41.440 |
fine tuning we're just performing inference eg prediction here and we just pass into the Naver 00:24:46.880 |
model our tokens which we built using the tokenizer earlier on. From there we need to extract 00:24:53.200 |
the drep tensor and we'll squeeze that to remove one of the dimensions that is unnecessary and we 00:25:00.800 |
can then have a look we have 30.5 thousand dimensions here so this is our probability 00:25:07.040 |
distribution or importance estimation and that is obviously our sparse vector and what we can do 00:25:14.720 |
is actually use what we've done so far in order to compare different documents. So let's take a few 00:25:22.400 |
of these so we have program cell def no no no this is the original text and then the ones below here 00:25:28.880 |
are just me attempting to write something that is either relevant or not relevant that uses a 00:25:35.440 |
similar type of language. So we can run that we'll encode everything we're going to use the PyTorch 00:25:43.840 |
and Hugging Face Transformers method but either way it will both of these will produce the same 00:25:49.680 |
result whether you use that or the actual splayed library and what we'll get is three of these 00:25:56.080 |
importance estimations the splayed vectors and then what we can do is calculate cosine 00:26:02.560 |
similarity between them. So here I'm just going to initialize a zeros array that is just to store 00:26:07.920 |
the similarity scores that we're going to create using this here. So we run that let's have a look 00:26:14.480 |
at the similarity and we can see that obviously these in the diagonal here this is where we're 00:26:21.120 |
comparing each of the vectors to itself so it scores pretty highly because obviously they're 00:26:27.200 |
the same but then the ones that we see as being the most similar other than the you know themselves 00:26:33.040 |
is sentence zero and sentence one so this one here if we come up to here so basically these 00:26:41.920 |
two here are being viewed as the most similar and if we read those we can see that they are in fact 00:26:47.200 |
much more similar they have a lot more overlap in terms of the terms but it's not just about 00:26:53.680 |
the terms that we see here but also the terms that produce from the term expansion as well. 00:26:58.640 |
So that's how we would compare everything that's how we would actually use splayed to create 00:27:05.920 |
embeddings and to actually compare those sparse vectors as well using cosine similarity. 00:27:11.840 |
Now that's it for this introduction to learn sparse embeddings with splayed. Now using splayed 00:27:18.000 |
we can represent text with more efficient sparse vector embeddings that help us at the same time 00:27:26.560 |
deal with the vocabulary mismatch problem whilst enabling exact matching and drawing from some of 00:27:33.120 |
the other benefits of using sparse vectors. But of course there's still a lot to be done 00:27:38.640 |
and there's more research and more efforts looking at how to mix both dense and sparse 00:27:46.560 |
vector embeddings using things like hybrid search as well as things like splayed and using both of 00:27:52.400 |
those together we can actually get really cool results. So I think this is just one step towards 00:27:59.440 |
making vector search and information retrieval way more accessible because we no longer need to 00:28:07.120 |
fine-tune all these really big models in order to get the best possible performance but we can use 00:28:13.840 |
things like hybrid search and things like splayed in order to really just improve our performance 00:28:19.840 |
with very little effort which is a really good thing to see. But that's it for this video I hope 00:28:26.960 |
everything we've been through is interesting and useful but for now that's it so thank you very 00:28:32.240 |
much for watching and I'll see you again in the next one. Bye.