back to index

SPLADE: the first search model to beat BM25


Chapters

0:0 Sparse and dense vector search
0:44 Comparing sparse vs. dense vectors
3:59 Using sparse and dense together
6:46 What is SPLADE?
9:6 Vocabulary mismatch problem
9:51 How SPLADE works (transformers 101)
14:28 Masked language modeling (MLM)
15:57 How SPLADE builds embeddings with MLM
17:35 Where SPLADE doesn't work so well
20:14 Implementing SPLADE in Python
20:38 SPLADE with PyTorch and Hugging Face
24:8 Using the Naver SPLADE library
27:11 What's next for vector search?

Whisper Transcript | Transcript Only Page

00:00:00.000 | In information retrieval vector embeddings represent documents and queries in a numerical
00:00:08.080 | vector format. That means that we can take some text which could be web pages from the internet
00:00:15.200 | in the case of Google or maybe product descriptions in the case of Amazon and we can encode it using
00:00:22.080 | some sort of embedding method or model and we will get something that looks like this. So we have now
00:00:29.120 | represented our text in a vector space. Now there are different ways of doing this and sparse and
00:00:37.200 | dense vectors are two different forms of this representation each with their own pros and their
00:00:43.520 | own cons. Typically when we think of sparse vectors things like TF-IDF and BM25 they have very high
00:00:51.760 | dimensionality and they contain very few non-zero values. So the information within those vectors is
00:00:59.680 | very sparsely located and with these types of vectors we have decades of research looking at
00:01:07.520 | how they can be used and how they can be represented using compact data structures
00:01:14.000 | and there are naturally many very very efficient retrieval systems designed specifically for these
00:01:21.840 | vectors. On the other hand we have dense vectors. Dense vectors are lower dimensional but they are
00:01:28.720 | very information rich and that's because all of the information is compressed into this much smaller
00:01:35.200 | dimensional space so we don't have these non-zero values that we would get in a sparse vector
00:01:41.040 | and hence all of the information is very densely packed and hence why we call them dense vectors.
00:01:46.560 | These types of vectors are typically built using neural network type architectures like
00:01:51.680 | transformers and through this they can represent more abstract information like the semantic
00:01:59.680 | meaning behind some text. When it comes to sparse embeddings the pros are typically faster retrieval,
00:02:06.240 | a good baseline performance, we don't need to do any model fine-tuning and we also get to do exact
00:02:14.240 | matching of terms. Whereas on the cons we have a few things as well so the performance cannot really
00:02:20.880 | be improved significantly over the baseline performance of these algorithms. They also suffer
00:02:26.400 | from something called the vocabulary mismatch problem and we'll talk about that in more detail
00:02:31.760 | later and it also doesn't align with the human-like thought of abstract concepts that I described
00:02:38.400 | earlier. Naturally we have a completely different set of pros and cons when it comes to dense
00:02:42.880 | vectors. On the pros we know that dense vectors can outperform sparse vectors with fine-tuning,
00:02:49.520 | we also know that using these we can search with human-like abstract concepts, we have great
00:02:55.840 | support of multi-modalities so we can search across text, images, audio etc and we can even
00:03:02.960 | do cross-modal search so we can go from text to image or image to text or whatever you can think
00:03:09.360 | of. But of course there's also the cons. We know that in order to outperform sparse vector embeddings
00:03:17.120 | or even get close to sparse vector embeddings in terms of performance we very often require training
00:03:23.520 | and training requires a lot of data which is very difficult to find when we are find ourselves in
00:03:29.280 | low resource scenarios. These models also do not generalize very well particularly when we are
00:03:34.880 | moving from one domain with very specific terminology to another domain with completely
00:03:40.160 | different terminology. These embeddings also require more compute and memory to build and store
00:03:46.560 | and search across than sparse methods. We do not get any exact match search and it's kind of hard
00:03:54.000 | to understand why we're getting results some of the time so it's not very interpretable. Ideally
00:04:00.080 | we want a way of getting the best of both worlds, we want the pros of dense and the pros of sparse
00:04:06.880 | and just we don't want any of these cons. But that's very hard to do. There have been some
00:04:13.040 | band-aid solutions. One of those is to perform two-stage retrieval. In this scenario we have
00:04:20.720 | two stages to retrieve and rank relevant documents for a given query. In the first stage our system
00:04:28.480 | would use a sparse retrieval method to search through and return relevant documents from a very
00:04:35.360 | large set of candidate documents. These are then passed on to the second stage which is a re-ranking
00:04:42.560 | stage and this uses a dense embedding model to re-rank from that smaller set of candidate
00:04:49.440 | documents which one it believes is the most relevant using its more human-like semantic
00:04:56.640 | comprehension of language. There are some benefits to this. First we can apply the sparse method to
00:05:02.480 | the full set of documents which makes it more efficient to actually search through those and
00:05:07.600 | then after that we can re-rank everything with our dense model which is naturally much slower but
00:05:13.600 | we're dealing with a smaller amount of data. Another benefit is that this re-ranking stage
00:05:19.280 | is detached from the retrieval system so we can modify one of those stages without affecting the
00:05:26.960 | other and this is particularly useful if we have multiple models that take for example the output
00:05:34.640 | of the sparse retrieval stage. So that's another thing to consider. However of course this is not
00:05:40.640 | perfect. Two stages of retrieval and re-ranking can be slower than using a single stage system
00:05:49.280 | that uses approximate nearest neighbor search algorithms and of course having two stages within
00:05:55.920 | the system is more complicated and there are naturally going to be many engineering challenges
00:06:02.240 | that come with that and we're also very reliant on that first stage retriever. If that first stage
00:06:09.120 | retriever doesn't perform very well then there's nothing we can do with the second stage re-ranking
00:06:16.000 | model because if it is just being given a load of rubbish results it's just going to re-rank
00:06:21.920 | rubbish results and the final result will still be rubbish. So they're the main problems with this and
00:06:27.840 | ideally we want to solve that and we want to do that by improving single-stage systems. Now a lot
00:06:34.640 | of work has been put into improving single-stage retrieval systems. A big part of that research
00:06:40.240 | has been in building more robust and learnable sparse embedding models and one of the most
00:06:47.280 | promising models within this space is known as SPLADE. Now the idea behind the sparse lexical and
00:06:56.080 | expansion models is that a pre-trained model like BERT can identify connections between
00:07:02.560 | words and sub-words which we can call word pieces or terms and use that knowledge to enhance
00:07:09.520 | our sparse vector embeddings. This works in two ways it allows us to measure the relevance of
00:07:17.920 | different terms so the word 'the' will carry less significance in most cases than a less common word
00:07:26.400 | like orangutan. The second thing it helps us with is it enables learnable term expansion
00:07:34.080 | where term expansion is the inclusion of alternative but relevant terms beyond those
00:07:41.040 | that are found in the original sentence or sequence. Now it's very important to take note of
00:07:47.120 | the fact that I said learnable term expansion. The big advantage of SPLADE is not they can do
00:07:53.280 | term expansion that is something that has been done for a while but they can learn term expansions.
00:07:59.600 | In the past term expansion could be done with more traditional methods but it required
00:08:04.880 | rule-based logic and rule-based logic someone would have to write that and this is naturally
00:08:10.640 | time consuming and fundamentally limited because you can't write rules for every single scenario in
00:08:18.240 | human language. Now by using SPLADE we can simply learn these using a transformer model which is of
00:08:26.480 | course much more robust and much less time consuming for us. Now another benefit of using
00:08:32.960 | a context-aware transform model like BERT is that it will modify these term expansions based on the
00:08:40.240 | context based on the sentence that's being input so it won't just expand the word rainforest to
00:08:47.200 | three different words it will expand the right word rainforest to many different words that
00:08:54.400 | entirely depends on the context or the sentence that was fed in with and this is one of the big
00:09:00.400 | benefits of attention models like transformers that is very context aware. Now term expansion
00:09:07.840 | is crucial in minimizing a very key problem with sparse embedding methods and that is the vocabulary
00:09:16.320 | mismatch problem. Now the vocabulary mismatch problem is the very typical lack of overlap
00:09:23.040 | between a query and the documents that we are searching for. It's because we think of things in
00:09:30.480 | abstract ideas and concepts and we have many different words in order to explain the same
00:09:35.280 | thing it's very unlikely that the way that we describe something when we're searching for
00:09:39.600 | something contains the exact terms the exact words that this relevant information contains
00:09:47.600 | and this is just a side effect of the complexity of human language. Now let's move on to SPLADE
00:09:53.920 | and how SPLADE actually builds these sparse embeddings. Now it's actually relatively easy
00:10:00.880 | to grasp what is happening here. We first start with the transform model like BERT. Now these
00:10:07.200 | transform models use something called mass language modeling in order to perform their
00:10:12.960 | pre-training on a ton of text data. Not all transform models use this but most do. Now if
00:10:18.880 | you're familiar with BERT and mass language modeling that's great if not we're going to just
00:10:24.400 | quickly break it down. So starting with BERT it's a very popular transform model and like all
00:10:31.280 | transform models its core functionality is actually to create information rich token embeddings. Now
00:10:39.120 | what exactly does that mean? Well we start with some text like orangutans are native to the
00:10:46.400 | forests of Indonesia and Malaysia. With a transform model like BERT we would begin by tokenizing that
00:10:52.800 | text into BERT specific sub-word or word level tokens and we can see that here. So using the
00:11:00.320 | HuggingFace transformers library we have this tokenizer object here. This is what is going to
00:11:06.320 | handle the tokenization of our text. So we have the same sentence I described before orangutans
00:11:12.160 | are native to the rainforests of Indonesia and Malaysia and we convert it into these tokens which
00:11:18.240 | is what you can see here. Now these are just the token IDs which are integer numbers but each one
00:11:24.160 | of these represents something within our text. So here for example this 2030 probably represents
00:11:31.840 | orangutan and the 5654 here maybe represents the S at the end of orangutan. They can be word level
00:11:40.240 | or sub-word level like that. Now these are just the numbers let's have a look down here and we can
00:11:45.600 | actually see how our words are broken up into these token IDs or tokens. So we convert those
00:11:53.440 | IDs back into human readable tokens and we can see okay we have this this called a classified token
00:11:59.360 | that is a special token used by BERT. We'll see that at the start of every sequence tokenized by
00:12:04.880 | BERT tokenizer and then we have orangutans. So it's actually split between four tokens and we
00:12:11.040 | can see the rest of the sentence there as well. Now why do we create these tokens and these token
00:12:17.840 | IDs? Well that's because these token IDs are then mapped to what is called an embedding matrix. The
00:12:25.040 | embedding matrix is the first layer of our transformer model. Now in this embedding matrix
00:12:32.560 | we will find learned vector representations that literally represent the tokens that we fed in
00:12:39.680 | within a vector space. So the vector representation for the token rainforest will have a high
00:12:47.040 | proximity because it has a high semantic similarity to the vector representations for the token
00:12:54.000 | jungle or the token forest. Whereas it will be further away in that vector space from somewhat
00:13:01.280 | less related tokens like native or the. Now from here the token representations of our original
00:13:08.320 | text are going to go through several encoder blobs. These blobs encode more and more contextual
00:13:15.360 | information into each one of these token embeddings. So as we progress through all of these
00:13:22.160 | encoder blobs the embeddings are basically going to be moved within that vector space in order to
00:13:28.720 | consider the meaning within the context of the sentence it appears in rather than just the meaning
00:13:34.160 | of the token by itself. And after all this progressive iteration of encoding more contextual
00:13:42.000 | information into our embeddings we arrive at the transformers output layer. Here we have our final
00:13:51.200 | information rich vector embeddings. Each embedding represents the early token but obviously with that
00:13:56.400 | context encoded into it. This process is the core of BERT and every other transformer model. However
00:14:02.640 | the power of transformers comes from the considerable number of things for which these
00:14:09.360 | information rich embeddings can be used. Typically what will happen is we'll add a task-specific head
00:14:15.520 | onto the end of the transform model that will transform these information rich embeddings or
00:14:21.840 | vector embeddings into something else like sentiment predictions or sparse vectors. The
00:14:28.880 | mass language modeling head is one of the most common of these task-specific heads because it
00:14:36.880 | is used for pre-training most transformer models. This works by taking a input sentence again let's
00:14:43.440 | use the orangutans are native to the forests of Indonesia and Malaysia example again. We will
00:14:49.600 | tokenize this text and then mask a few of those tokens at random. This mask token sequence is
00:14:56.080 | then passed as input to BERT and at the other end we actually feed in the original unmasked sequence
00:15:02.960 | to the mass language modeling head and what will happen is BERT and the mass language modeling head
00:15:09.040 | will have to adjust their internal weights in order to produce accurate predictions for the
00:15:16.240 | tokens that have been masked. For this to work the mass language modeling head contains 30,522
00:15:25.520 | output tokens which is the vocabulary size of the BERT base model. So that means we have a output
00:15:33.440 | for every possible prediction for every possible token prediction and the output as a whole acts
00:15:40.000 | as a probability distribution over this entire vocabulary and the highest activation across that
00:15:46.080 | probability distribution represents the token that BERT and the mass language modeling head
00:15:51.440 | have predicted as being the token behind that masked token position. Now at the same time we
00:15:58.880 | can think of this probability distribution as a representation of the words or tokens that are
00:16:07.680 | most relevant to a particular token within the context of the wider sentence. With that what we
00:16:14.560 | can do with SPLADE is take all of these distributions and aggregate them into a single
00:16:21.440 | distribution called the importance estimation. The importance estimation is actually the sparse
00:16:28.320 | vector produced by SPLADE and that is done using this equation here and this allows us to identify
00:16:34.800 | relevant tokens that do not exist in the original sequence. For example if we masked the word
00:16:40.160 | rainforest we might return high predictions for the words jungle, land and forest. These words and
00:16:48.800 | their associated probabilities would then be represented in the SPLADE built sparse vector
00:16:54.640 | and that doesn't mean we need to mask everything. The predictions will be made relevant to each
00:17:00.400 | token whether it is masked or not. So in the end all we have to input is the unmasked sequence
00:17:07.920 | and what we will get is all of these probability distributions for similar words to whatever
00:17:12.560 | has been input based on the sentence in the context. Now many transform models are trained
00:17:19.600 | with mass language modeling which means there are a huge number of models that have already got these
00:17:25.280 | mass language modeling weights and we can actually use that to fine-tune those models as SPLADE
00:17:31.360 | models and that's something that we will cover in another video. Now let's have a quick look at where
00:17:37.680 | SPLADE works kind of less well. So as we've seen SPLADE is a really good tool for minimizing the
00:17:44.800 | vocabulary mismatch problem however there are of course some drawbacks that we should consider.
00:17:49.680 | Compared to other sparse methods retrieval with SPLADE is very slow. There are three primary
00:17:56.320 | reasons for this. First the number of non-zero values in SPLADE query and document vectors
00:18:02.400 | is typically much greater than in traditional sparse vectors because of that term expansion
00:18:08.320 | and sparse retrieval systems are rarely optimized for this. Second the distribution of these non-zero
00:18:15.760 | values also deviates from the traditional distribution expected by most sparse retrieval
00:18:22.080 | systems again causing slowdowns and third SPLADE vectors are not natively supported by most sparse
00:18:31.760 | retrieval systems meaning that we have to perform multiple pre and post processing steps,
00:18:38.240 | weight discretization and other things in order to make it work if it works at all and it again
00:18:45.520 | it's not optimized for that. Fortunately there are some solutions to all of these problems. For one
00:18:51.440 | the authors of SPLADE actually address this in a later paper that minimizes the number of non-zero
00:18:59.440 | values in the query vectors and they do that with two steps. First they improved the performance
00:19:06.960 | of displayed document encodings using max pooling rather than the traditional pooling strategy and
00:19:14.720 | second they limited the term expansion to the document encodings only so they didn't do the
00:19:21.760 | query expansions and thanks to the improved document encoding performance dropping those
00:19:28.320 | query expansions still leaves us with better performance than the original SPLADE model.
00:19:34.800 | And then if we look at the final two problems so two and three these can both be solved by using
00:19:41.280 | the Pinecone vector database. Two is solved by Pinecone's retrieval engine being designed to be
00:19:48.000 | agnostic to data distribution and for number three Pinecone supports real valued sparse vectors
00:19:55.440 | meaning SPLADE vectors are supported natively without needing to do any of those weird things
00:20:03.040 | in pre-processing post-processing or discretization. Now with all of that I think we
00:20:10.880 | have covered everything we could possibly cover in order to understand SPLADE. Now let's have a
00:20:14.960 | look at how we would actually implement SPLADE in practice. Now we have two options for implementing
00:20:21.040 | SPLADE we can do directly with Hugging Face Transformers and PyTorch or with a high-level
00:20:27.280 | abstraction using the official SPLADE library. We'll take a look at doing both starting with
00:20:32.480 | the Hugging Face and PyTorch implementation just so we can understand how it actually works.
00:20:38.640 | Okay so first we start by just installing a few prerequisites so we have SPLADE, Transformers,
00:20:44.720 | and PyTorch and then what we need to do is install this and then what we need to do is initialize
00:20:50.960 | the tokenizer it's very similar to the BERT tokenizer we initialized earlier and the auto
00:20:57.280 | model for MastLM so this is Mast Language Modeling. So we're going to be using the
00:21:02.720 | Naver SPLADE model here and we just initialize all of that. Okay and we have one pretty large
00:21:11.920 | chunk of text here so this is very domain specific so it has a lot of very specific words in there
00:21:19.200 | that a typical dense embedding model would probably struggle with unless it has been fine-tuned on
00:21:25.200 | data containing these exact same terms. So we'll run that and what we do is we tokenize everything
00:21:31.520 | so that will give us our token IDs that you saw earlier and then we process those through our
00:21:37.200 | model to create our logits output which is what we will see in a moment this here. Okay so as we saw
00:21:46.080 | before those logits will be each one of them contains our probability distribution over the
00:21:53.760 | 30.5 thousand possible tokens from the vocabulary and we have 91 of those. Now the reason we have
00:22:01.440 | 91 of those is because from our tokens here we actually had 91 input tokens so if we have a look
00:22:09.840 | at tokens input IDs dot shape we see that there was 91 input in there so that will change depending
00:22:21.120 | on how many input tokens we have. Now from here what we're going to do is take these output logits
00:22:27.120 | and we want to transform them into a sparse vector. Now to do that we're going to be using
00:22:32.080 | the formula that you saw earlier to create the importance estimation and if we run that
00:22:39.680 | we'll get a single probability distribution which represents the actual sparse vector from
00:22:45.920 | SPLATE and we can have a look at that vector and we see there's mostly zeros in there there are a
00:22:50.320 | few values but very few. So what I'm going to do now is first I want to just ignore this bit we're
00:22:58.400 | going to come down to here and we're going to create a dictionary format of our sparse vector
00:23:05.280 | so we run this and there's a few things I want to look at here so number of non-zero values that we
00:23:12.480 | actually have is 174 and all of them are now contained within this sparse dictionary. Okay so
00:23:20.160 | these are the token IDs and these are the weights or the relevance of each one of those particular
00:23:28.160 | tokens. Now we can't read any of these token IDs so similar to before what we're going to do is
00:23:35.200 | convert those into actual human readable tokens so to do that we'll need to run this
00:23:40.880 | and then we come down here and we're going to convert them into a more readable format.
00:23:46.960 | Okay we can see what it believes is important is all of these values so we've sorted everything
00:23:53.920 | here so that's why the numbers have changed here and we can see that most importantly it's seeing
00:23:59.200 | like programmed, death, cell, lattice, so a lot of very relevant words within that particular
00:24:08.000 | domain. Now if we come a little bit further down we can also see how to do that using
00:24:14.080 | the Naver SPLADE library. So for that we would have to pip install SPLADE we did that at the
00:24:19.760 | top of the notebook so we don't need to do it again. We're going to be using the max aggregation
00:24:25.280 | so this is using the max pooling method. Run this again using the same model ID here because it's
00:24:30.800 | also downloading the model from Hugging Face Transformers and what we do is we set torch to
00:24:36.720 | no grab so this is saying we don't want to update any of the model weights because we're not doing
00:24:41.440 | fine tuning we're just performing inference eg prediction here and we just pass into the Naver
00:24:46.880 | model our tokens which we built using the tokenizer earlier on. From there we need to extract
00:24:53.200 | the drep tensor and we'll squeeze that to remove one of the dimensions that is unnecessary and we
00:25:00.800 | can then have a look we have 30.5 thousand dimensions here so this is our probability
00:25:07.040 | distribution or importance estimation and that is obviously our sparse vector and what we can do
00:25:14.720 | is actually use what we've done so far in order to compare different documents. So let's take a few
00:25:22.400 | of these so we have program cell def no no no this is the original text and then the ones below here
00:25:28.880 | are just me attempting to write something that is either relevant or not relevant that uses a
00:25:35.440 | similar type of language. So we can run that we'll encode everything we're going to use the PyTorch
00:25:43.840 | and Hugging Face Transformers method but either way it will both of these will produce the same
00:25:49.680 | result whether you use that or the actual splayed library and what we'll get is three of these
00:25:56.080 | importance estimations the splayed vectors and then what we can do is calculate cosine
00:26:02.560 | similarity between them. So here I'm just going to initialize a zeros array that is just to store
00:26:07.920 | the similarity scores that we're going to create using this here. So we run that let's have a look
00:26:14.480 | at the similarity and we can see that obviously these in the diagonal here this is where we're
00:26:21.120 | comparing each of the vectors to itself so it scores pretty highly because obviously they're
00:26:27.200 | the same but then the ones that we see as being the most similar other than the you know themselves
00:26:33.040 | is sentence zero and sentence one so this one here if we come up to here so basically these
00:26:41.920 | two here are being viewed as the most similar and if we read those we can see that they are in fact
00:26:47.200 | much more similar they have a lot more overlap in terms of the terms but it's not just about
00:26:53.680 | the terms that we see here but also the terms that produce from the term expansion as well.
00:26:58.640 | So that's how we would compare everything that's how we would actually use splayed to create
00:27:05.920 | embeddings and to actually compare those sparse vectors as well using cosine similarity.
00:27:11.840 | Now that's it for this introduction to learn sparse embeddings with splayed. Now using splayed
00:27:18.000 | we can represent text with more efficient sparse vector embeddings that help us at the same time
00:27:26.560 | deal with the vocabulary mismatch problem whilst enabling exact matching and drawing from some of
00:27:33.120 | the other benefits of using sparse vectors. But of course there's still a lot to be done
00:27:38.640 | and there's more research and more efforts looking at how to mix both dense and sparse
00:27:46.560 | vector embeddings using things like hybrid search as well as things like splayed and using both of
00:27:52.400 | those together we can actually get really cool results. So I think this is just one step towards
00:27:59.440 | making vector search and information retrieval way more accessible because we no longer need to
00:28:07.120 | fine-tune all these really big models in order to get the best possible performance but we can use
00:28:13.840 | things like hybrid search and things like splayed in order to really just improve our performance
00:28:19.840 | with very little effort which is a really good thing to see. But that's it for this video I hope
00:28:26.960 | everything we've been through is interesting and useful but for now that's it so thank you very
00:28:32.240 | much for watching and I'll see you again in the next one. Bye.