Back to Index

SPLADE: the first search model to beat BM25


Chapters

0:0 Sparse and dense vector search
0:44 Comparing sparse vs. dense vectors
3:59 Using sparse and dense together
6:46 What is SPLADE?
9:6 Vocabulary mismatch problem
9:51 How SPLADE works (transformers 101)
14:28 Masked language modeling (MLM)
15:57 How SPLADE builds embeddings with MLM
17:35 Where SPLADE doesn't work so well
20:14 Implementing SPLADE in Python
20:38 SPLADE with PyTorch and Hugging Face
24:8 Using the Naver SPLADE library
27:11 What's next for vector search?

Transcript

In information retrieval vector embeddings represent documents and queries in a numerical vector format. That means that we can take some text which could be web pages from the internet in the case of Google or maybe product descriptions in the case of Amazon and we can encode it using some sort of embedding method or model and we will get something that looks like this.

So we have now represented our text in a vector space. Now there are different ways of doing this and sparse and dense vectors are two different forms of this representation each with their own pros and their own cons. Typically when we think of sparse vectors things like TF-IDF and BM25 they have very high dimensionality and they contain very few non-zero values.

So the information within those vectors is very sparsely located and with these types of vectors we have decades of research looking at how they can be used and how they can be represented using compact data structures and there are naturally many very very efficient retrieval systems designed specifically for these vectors.

On the other hand we have dense vectors. Dense vectors are lower dimensional but they are very information rich and that's because all of the information is compressed into this much smaller dimensional space so we don't have these non-zero values that we would get in a sparse vector and hence all of the information is very densely packed and hence why we call them dense vectors.

These types of vectors are typically built using neural network type architectures like transformers and through this they can represent more abstract information like the semantic meaning behind some text. When it comes to sparse embeddings the pros are typically faster retrieval, a good baseline performance, we don't need to do any model fine-tuning and we also get to do exact matching of terms.

Whereas on the cons we have a few things as well so the performance cannot really be improved significantly over the baseline performance of these algorithms. They also suffer from something called the vocabulary mismatch problem and we'll talk about that in more detail later and it also doesn't align with the human-like thought of abstract concepts that I described earlier.

Naturally we have a completely different set of pros and cons when it comes to dense vectors. On the pros we know that dense vectors can outperform sparse vectors with fine-tuning, we also know that using these we can search with human-like abstract concepts, we have great support of multi-modalities so we can search across text, images, audio etc and we can even do cross-modal search so we can go from text to image or image to text or whatever you can think of.

But of course there's also the cons. We know that in order to outperform sparse vector embeddings or even get close to sparse vector embeddings in terms of performance we very often require training and training requires a lot of data which is very difficult to find when we are find ourselves in low resource scenarios.

These models also do not generalize very well particularly when we are moving from one domain with very specific terminology to another domain with completely different terminology. These embeddings also require more compute and memory to build and store and search across than sparse methods. We do not get any exact match search and it's kind of hard to understand why we're getting results some of the time so it's not very interpretable.

Ideally we want a way of getting the best of both worlds, we want the pros of dense and the pros of sparse and just we don't want any of these cons. But that's very hard to do. There have been some band-aid solutions. One of those is to perform two-stage retrieval.

In this scenario we have two stages to retrieve and rank relevant documents for a given query. In the first stage our system would use a sparse retrieval method to search through and return relevant documents from a very large set of candidate documents. These are then passed on to the second stage which is a re-ranking stage and this uses a dense embedding model to re-rank from that smaller set of candidate documents which one it believes is the most relevant using its more human-like semantic comprehension of language.

There are some benefits to this. First we can apply the sparse method to the full set of documents which makes it more efficient to actually search through those and then after that we can re-rank everything with our dense model which is naturally much slower but we're dealing with a smaller amount of data.

Another benefit is that this re-ranking stage is detached from the retrieval system so we can modify one of those stages without affecting the other and this is particularly useful if we have multiple models that take for example the output of the sparse retrieval stage. So that's another thing to consider.

However of course this is not perfect. Two stages of retrieval and re-ranking can be slower than using a single stage system that uses approximate nearest neighbor search algorithms and of course having two stages within the system is more complicated and there are naturally going to be many engineering challenges that come with that and we're also very reliant on that first stage retriever.

If that first stage retriever doesn't perform very well then there's nothing we can do with the second stage re-ranking model because if it is just being given a load of rubbish results it's just going to re-rank rubbish results and the final result will still be rubbish. So they're the main problems with this and ideally we want to solve that and we want to do that by improving single-stage systems.

Now a lot of work has been put into improving single-stage retrieval systems. A big part of that research has been in building more robust and learnable sparse embedding models and one of the most promising models within this space is known as SPLADE. Now the idea behind the sparse lexical and expansion models is that a pre-trained model like BERT can identify connections between words and sub-words which we can call word pieces or terms and use that knowledge to enhance our sparse vector embeddings.

This works in two ways it allows us to measure the relevance of different terms so the word 'the' will carry less significance in most cases than a less common word like orangutan. The second thing it helps us with is it enables learnable term expansion where term expansion is the inclusion of alternative but relevant terms beyond those that are found in the original sentence or sequence.

Now it's very important to take note of the fact that I said learnable term expansion. The big advantage of SPLADE is not they can do term expansion that is something that has been done for a while but they can learn term expansions. In the past term expansion could be done with more traditional methods but it required rule-based logic and rule-based logic someone would have to write that and this is naturally time consuming and fundamentally limited because you can't write rules for every single scenario in human language.

Now by using SPLADE we can simply learn these using a transformer model which is of course much more robust and much less time consuming for us. Now another benefit of using a context-aware transform model like BERT is that it will modify these term expansions based on the context based on the sentence that's being input so it won't just expand the word rainforest to three different words it will expand the right word rainforest to many different words that entirely depends on the context or the sentence that was fed in with and this is one of the big benefits of attention models like transformers that is very context aware.

Now term expansion is crucial in minimizing a very key problem with sparse embedding methods and that is the vocabulary mismatch problem. Now the vocabulary mismatch problem is the very typical lack of overlap between a query and the documents that we are searching for. It's because we think of things in abstract ideas and concepts and we have many different words in order to explain the same thing it's very unlikely that the way that we describe something when we're searching for something contains the exact terms the exact words that this relevant information contains and this is just a side effect of the complexity of human language.

Now let's move on to SPLADE and how SPLADE actually builds these sparse embeddings. Now it's actually relatively easy to grasp what is happening here. We first start with the transform model like BERT. Now these transform models use something called mass language modeling in order to perform their pre-training on a ton of text data.

Not all transform models use this but most do. Now if you're familiar with BERT and mass language modeling that's great if not we're going to just quickly break it down. So starting with BERT it's a very popular transform model and like all transform models its core functionality is actually to create information rich token embeddings.

Now what exactly does that mean? Well we start with some text like orangutans are native to the forests of Indonesia and Malaysia. With a transform model like BERT we would begin by tokenizing that text into BERT specific sub-word or word level tokens and we can see that here. So using the HuggingFace transformers library we have this tokenizer object here.

This is what is going to handle the tokenization of our text. So we have the same sentence I described before orangutans are native to the rainforests of Indonesia and Malaysia and we convert it into these tokens which is what you can see here. Now these are just the token IDs which are integer numbers but each one of these represents something within our text.

So here for example this 2030 probably represents orangutan and the 5654 here maybe represents the S at the end of orangutan. They can be word level or sub-word level like that. Now these are just the numbers let's have a look down here and we can actually see how our words are broken up into these token IDs or tokens.

So we convert those IDs back into human readable tokens and we can see okay we have this this called a classified token that is a special token used by BERT. We'll see that at the start of every sequence tokenized by BERT tokenizer and then we have orangutans. So it's actually split between four tokens and we can see the rest of the sentence there as well.

Now why do we create these tokens and these token IDs? Well that's because these token IDs are then mapped to what is called an embedding matrix. The embedding matrix is the first layer of our transformer model. Now in this embedding matrix we will find learned vector representations that literally represent the tokens that we fed in within a vector space.

So the vector representation for the token rainforest will have a high proximity because it has a high semantic similarity to the vector representations for the token jungle or the token forest. Whereas it will be further away in that vector space from somewhat less related tokens like native or the.

Now from here the token representations of our original text are going to go through several encoder blobs. These blobs encode more and more contextual information into each one of these token embeddings. So as we progress through all of these encoder blobs the embeddings are basically going to be moved within that vector space in order to consider the meaning within the context of the sentence it appears in rather than just the meaning of the token by itself.

And after all this progressive iteration of encoding more contextual information into our embeddings we arrive at the transformers output layer. Here we have our final information rich vector embeddings. Each embedding represents the early token but obviously with that context encoded into it. This process is the core of BERT and every other transformer model.

However the power of transformers comes from the considerable number of things for which these information rich embeddings can be used. Typically what will happen is we'll add a task-specific head onto the end of the transform model that will transform these information rich embeddings or vector embeddings into something else like sentiment predictions or sparse vectors.

The mass language modeling head is one of the most common of these task-specific heads because it is used for pre-training most transformer models. This works by taking a input sentence again let's use the orangutans are native to the forests of Indonesia and Malaysia example again. We will tokenize this text and then mask a few of those tokens at random.

This mask token sequence is then passed as input to BERT and at the other end we actually feed in the original unmasked sequence to the mass language modeling head and what will happen is BERT and the mass language modeling head will have to adjust their internal weights in order to produce accurate predictions for the tokens that have been masked.

For this to work the mass language modeling head contains 30,522 output tokens which is the vocabulary size of the BERT base model. So that means we have a output for every possible prediction for every possible token prediction and the output as a whole acts as a probability distribution over this entire vocabulary and the highest activation across that probability distribution represents the token that BERT and the mass language modeling head have predicted as being the token behind that masked token position.

Now at the same time we can think of this probability distribution as a representation of the words or tokens that are most relevant to a particular token within the context of the wider sentence. With that what we can do with SPLADE is take all of these distributions and aggregate them into a single distribution called the importance estimation.

The importance estimation is actually the sparse vector produced by SPLADE and that is done using this equation here and this allows us to identify relevant tokens that do not exist in the original sequence. For example if we masked the word rainforest we might return high predictions for the words jungle, land and forest.

These words and their associated probabilities would then be represented in the SPLADE built sparse vector and that doesn't mean we need to mask everything. The predictions will be made relevant to each token whether it is masked or not. So in the end all we have to input is the unmasked sequence and what we will get is all of these probability distributions for similar words to whatever has been input based on the sentence in the context.

Now many transform models are trained with mass language modeling which means there are a huge number of models that have already got these mass language modeling weights and we can actually use that to fine-tune those models as SPLADE models and that's something that we will cover in another video.

Now let's have a quick look at where SPLADE works kind of less well. So as we've seen SPLADE is a really good tool for minimizing the vocabulary mismatch problem however there are of course some drawbacks that we should consider. Compared to other sparse methods retrieval with SPLADE is very slow.

There are three primary reasons for this. First the number of non-zero values in SPLADE query and document vectors is typically much greater than in traditional sparse vectors because of that term expansion and sparse retrieval systems are rarely optimized for this. Second the distribution of these non-zero values also deviates from the traditional distribution expected by most sparse retrieval systems again causing slowdowns and third SPLADE vectors are not natively supported by most sparse retrieval systems meaning that we have to perform multiple pre and post processing steps, weight discretization and other things in order to make it work if it works at all and it again it's not optimized for that.

Fortunately there are some solutions to all of these problems. For one the authors of SPLADE actually address this in a later paper that minimizes the number of non-zero values in the query vectors and they do that with two steps. First they improved the performance of displayed document encodings using max pooling rather than the traditional pooling strategy and second they limited the term expansion to the document encodings only so they didn't do the query expansions and thanks to the improved document encoding performance dropping those query expansions still leaves us with better performance than the original SPLADE model.

And then if we look at the final two problems so two and three these can both be solved by using the Pinecone vector database. Two is solved by Pinecone's retrieval engine being designed to be agnostic to data distribution and for number three Pinecone supports real valued sparse vectors meaning SPLADE vectors are supported natively without needing to do any of those weird things in pre-processing post-processing or discretization.

Now with all of that I think we have covered everything we could possibly cover in order to understand SPLADE. Now let's have a look at how we would actually implement SPLADE in practice. Now we have two options for implementing SPLADE we can do directly with Hugging Face Transformers and PyTorch or with a high-level abstraction using the official SPLADE library.

We'll take a look at doing both starting with the Hugging Face and PyTorch implementation just so we can understand how it actually works. Okay so first we start by just installing a few prerequisites so we have SPLADE, Transformers, and PyTorch and then what we need to do is install this and then what we need to do is initialize the tokenizer it's very similar to the BERT tokenizer we initialized earlier and the auto model for MastLM so this is Mast Language Modeling.

So we're going to be using the Naver SPLADE model here and we just initialize all of that. Okay and we have one pretty large chunk of text here so this is very domain specific so it has a lot of very specific words in there that a typical dense embedding model would probably struggle with unless it has been fine-tuned on data containing these exact same terms.

So we'll run that and what we do is we tokenize everything so that will give us our token IDs that you saw earlier and then we process those through our model to create our logits output which is what we will see in a moment this here. Okay so as we saw before those logits will be each one of them contains our probability distribution over the 30.5 thousand possible tokens from the vocabulary and we have 91 of those.

Now the reason we have 91 of those is because from our tokens here we actually had 91 input tokens so if we have a look at tokens input IDs dot shape we see that there was 91 input in there so that will change depending on how many input tokens we have.

Now from here what we're going to do is take these output logits and we want to transform them into a sparse vector. Now to do that we're going to be using the formula that you saw earlier to create the importance estimation and if we run that we'll get a single probability distribution which represents the actual sparse vector from SPLATE and we can have a look at that vector and we see there's mostly zeros in there there are a few values but very few.

So what I'm going to do now is first I want to just ignore this bit we're going to come down to here and we're going to create a dictionary format of our sparse vector so we run this and there's a few things I want to look at here so number of non-zero values that we actually have is 174 and all of them are now contained within this sparse dictionary.

Okay so these are the token IDs and these are the weights or the relevance of each one of those particular tokens. Now we can't read any of these token IDs so similar to before what we're going to do is convert those into actual human readable tokens so to do that we'll need to run this and then we come down here and we're going to convert them into a more readable format.

Okay we can see what it believes is important is all of these values so we've sorted everything here so that's why the numbers have changed here and we can see that most importantly it's seeing like programmed, death, cell, lattice, so a lot of very relevant words within that particular domain.

Now if we come a little bit further down we can also see how to do that using the Naver SPLADE library. So for that we would have to pip install SPLADE we did that at the top of the notebook so we don't need to do it again. We're going to be using the max aggregation so this is using the max pooling method.

Run this again using the same model ID here because it's also downloading the model from Hugging Face Transformers and what we do is we set torch to no grab so this is saying we don't want to update any of the model weights because we're not doing fine tuning we're just performing inference eg prediction here and we just pass into the Naver model our tokens which we built using the tokenizer earlier on.

From there we need to extract the drep tensor and we'll squeeze that to remove one of the dimensions that is unnecessary and we can then have a look we have 30.5 thousand dimensions here so this is our probability distribution or importance estimation and that is obviously our sparse vector and what we can do is actually use what we've done so far in order to compare different documents.

So let's take a few of these so we have program cell def no no no this is the original text and then the ones below here are just me attempting to write something that is either relevant or not relevant that uses a similar type of language. So we can run that we'll encode everything we're going to use the PyTorch and Hugging Face Transformers method but either way it will both of these will produce the same result whether you use that or the actual splayed library and what we'll get is three of these importance estimations the splayed vectors and then what we can do is calculate cosine similarity between them.

So here I'm just going to initialize a zeros array that is just to store the similarity scores that we're going to create using this here. So we run that let's have a look at the similarity and we can see that obviously these in the diagonal here this is where we're comparing each of the vectors to itself so it scores pretty highly because obviously they're the same but then the ones that we see as being the most similar other than the you know themselves is sentence zero and sentence one so this one here if we come up to here so basically these two here are being viewed as the most similar and if we read those we can see that they are in fact much more similar they have a lot more overlap in terms of the terms but it's not just about the terms that we see here but also the terms that produce from the term expansion as well.

So that's how we would compare everything that's how we would actually use splayed to create embeddings and to actually compare those sparse vectors as well using cosine similarity. Now that's it for this introduction to learn sparse embeddings with splayed. Now using splayed we can represent text with more efficient sparse vector embeddings that help us at the same time deal with the vocabulary mismatch problem whilst enabling exact matching and drawing from some of the other benefits of using sparse vectors.

But of course there's still a lot to be done and there's more research and more efforts looking at how to mix both dense and sparse vector embeddings using things like hybrid search as well as things like splayed and using both of those together we can actually get really cool results.

So I think this is just one step towards making vector search and information retrieval way more accessible because we no longer need to fine-tune all these really big models in order to get the best possible performance but we can use things like hybrid search and things like splayed in order to really just improve our performance with very little effort which is a really good thing to see.

But that's it for this video I hope everything we've been through is interesting and useful but for now that's it so thank you very much for watching and I'll see you again in the next one. Bye. you you you