back to index

Is GPL the Future of Sentence Transformers? | Generative Pseudo-Labeling Deep Dive


Chapters

0:0 Intro
1:8 Semantic Web and Other Uses
4:36 Why GPL?
7:31 How GPL Works
10:37 Query Generation
12:8 CORD-19 Dataset and Download
13:27 Query Generation Code
21:53 Query Generation is Not Perfect
22:39 Negative Mining
26:28 Negative Mining Implementation
27:21 Negative Mining Code
35:19 Pseudo-Labeling
35:55 Pseudo-Labeling Code
37:1 Importance of Pseudo-Labeling
41:20 Margin MSE Loss
43:40 MarginMSE Fine-tune Code
46:30 Choosing Number of Steps
48:54 Fast Evaluation
51:43 What's Next for Sentence Transformers?

Whisper Transcript | Transcript Only Page

00:00:00.720 | Today we're going to dive into what I think is probably one of the most interesting techniques
00:00:06.720 | in semantic search for a very long time. That technique is called generative pseudo-labelling
00:00:13.120 | or GPL. It's from Noah's Rhymers team at the UKP lab and this is primarily the work of Ketsin Wang.
00:00:22.880 | Now Ketsin Wang also produced the TSE paper maybe around a year or so ago as well and TSE
00:00:34.240 | is the unsupervised training method for sentence transformers and GPL kind of goes,
00:00:41.440 | takes that and kind of puts up another notch. So we're going to be using GPL to take unstructured
00:00:50.880 | data and generate structured data from it for training sentence transformers. Now before we
00:00:59.520 | really dive deep into GPL, I want to give a little bit of background and set the scene of where we
00:01:06.640 | might want to use this sort of technique. So in 1999 there was a concept known as semantic web
00:01:16.080 | described by the creator of the World Wide Web, Tim Berners-Lee. Now Tim Berners-Lee had this
00:01:22.960 | sort of dream of the web in the way that we know it today but where you have machines roaming the
00:01:32.240 | web and being able to just understand everything. But when you start to become a bit more niche,
00:01:43.520 | it gets hard to actually find that sort of data set and pretty much impossible most of the time.
00:01:48.720 | So for example let's say you're in the finance industry, you have these internal documents,
00:01:56.000 | fairly technical financial documents and you want to build a question answering system
00:02:01.840 | where people or staff can ask a question in natural language and return answers from those
00:02:08.080 | documents. It's going to be hard for you to find a model, you do have financial Q&A models but
00:02:15.760 | they've been fine-tuned on like personal finance questions on Reddit and okay kind of similar but
00:02:22.800 | not really. Like personal finance and technical finance documentation in a finance company,
00:02:29.600 | very different. So it's very hard to actually find a data set that is going to satisfy what you need
00:02:37.120 | in that scenario because it's fairly niche. And it's not even that niche,
00:02:40.560 | there's a lot of companies that need to do that sort of thing. Another example is a project I
00:02:45.760 | worked on was for the Devay language. We wanted to create a set of language models for the Devay
00:02:52.400 | language which is the national language of the Maldives. Now there's not that many Devay speakers
00:02:58.480 | worldwide so it's pretty niche. So it's very hard to find labeled data and I think what we end up
00:03:06.400 | doing is actually using unsupervised methods. We use TSAE for the actual sentence transformer model.
00:03:13.840 | What we have there is another use case where it's fairly niche and it's hard to find labeled
00:03:20.560 | data for. But there is unlabeled data, unstructured data in both of those scenarios. In your finance
00:03:27.680 | company you have the documents there, you just don't have them, they're not labeled in a way that
00:03:33.680 | you can train a model with. Or in the Devay language example we have, there is many web
00:03:40.160 | pages in Devay that we can scrape data from but it's not labeled. So for these more niche topics
00:03:49.360 | which covered a vast majority of use cases, you really either need to spend a lot of money
00:03:57.200 | labeling your data or we need to find something else that allows us to either synthetically
00:04:02.880 | generate data or just train on the unlabeled data. Now training on the unlabeled data
00:04:10.320 | doesn't really work. It works to an extent, TSAE showed that it can but you're looking at very
00:04:19.120 | generic semantic similarity and it's not perfect. GPL again, it's not perfect but it's definitely
00:04:30.160 | much better. So back to the semantic web example, in the case of the semantic web, the majority of
00:04:39.200 | the semantic web is going to be these niche topics that we don't have labeled data for.
00:04:45.200 | GPL is not a solution, it's not going to create the semantic web or anything like that.
00:04:50.880 | But it does allow us to actually target those previously inaccessible domains that actually
00:04:59.520 | begin producing models that can intelligently comprehend the meaning behind the language in
00:05:07.840 | those particular domains which I think is super fascinating and the fact that you can actually do
00:05:13.520 | this is so cool. Now there is a lot to it. This video I'm sure will be quite long but we really
00:05:21.760 | are going to go into the details and I think by the end of this video you should know everything
00:05:26.960 | you need to know about GPL and you will certainly be able to apply it in practice.
00:05:31.680 | So we've introduced where we might want to use it. I just want to now have a look at sort of
00:05:40.400 | like an overview of GPL and where we are going to use it in this video. So GPL is, you can use
00:05:47.840 | it in two ways. You can use it to fine tune a pre-trained model. So when I say pre-trained,
00:05:55.200 | I mean a transformer model that hasn't been fine-tuned specifically for semantic search.
00:06:00.400 | So a BERT-based case, for example, would be a pre-trained model from Hockey Base Hub. You take
00:06:05.440 | that and you can use GPL to fine-tune it for semantic search. You'll get okay results with
00:06:10.640 | that. The better way of using GPL is for domain adaption. So domain adaption is where you already
00:06:19.120 | have a semantic search model. So for example, what we are going to look at, we have a model that has
00:06:25.040 | been trained on MS Marco from data from 2018. This model, if you give it some questions about COVID-19,
00:06:36.400 | because 2018 was pre-COVID, it really kind of struggles with even simple questions. You ask
00:06:44.560 | it a simple question about COVID-19 and it will start returning you answers from about like Ebola
00:06:51.040 | and flu rather than COVID-19 because it's confused. It's never seen that before.
00:06:57.280 | So this is a typical issue with sentence transformers. They're very brittle
00:07:04.960 | and they don't adapt new domains very well whatsoever.
00:07:08.800 | So GPL is best used in these sorts of scenarios where we want to adapt a model to a particular
00:07:17.440 | domain. Now, as I said, our example is going to be COVID. We're going to teach this model
00:07:24.640 | to understand COVID even though it's never seen anything about COVID before.
00:07:30.880 | So let's have a look at how GPL actually works. So at a very high level, GPL consists of three data
00:07:38.800 | preparation steps followed by one fine-tuning step. So the data preparation steps are the steps
00:07:46.240 | that produce the synthetic data set that we then use in that fine-tuning step. So that fine-tuning
00:07:53.440 | step is actually a common supervised learning method. The unsupervised, if you want to call it
00:07:59.920 | that, part of this method is that we are actually generating the text or the data that we'll be
00:08:09.360 | using automatically. We don't need to label it ourselves. So these three data preparation steps
00:08:16.160 | are the key to GPL. They are query generation, which is where we create queries from passages,
00:08:26.400 | so like paragraphs of text. The second is negative mining. So that is an information
00:08:34.160 | retrieval step where we retrieve similar passages to our positive passage, which is a passage from
00:08:41.920 | before that do not actually answer our query, or we assume they do not answer our query.
00:08:47.600 | And then the pseudo-labeling step kind of cleans up the data from those two steps,
00:08:55.280 | from those two earlier parts, and uses a cross-encoded model to identify the similarity
00:09:01.200 | scores to assign to our passages and query pairs that we've generated already.
00:09:09.200 | So you can see that process here. We start with P+, which means positive passage at the top.
00:09:15.600 | We pass that through a query generation model. We use T5 for this, and that generates queries
00:09:22.720 | or questions. That's passed through a dense retrieval step, the negative mining step,
00:09:27.600 | which will hopefully return things that either partially answer or share many of the same
00:09:34.320 | words with our query or positive passage, and therefore are similar but do not actually match
00:09:42.240 | to our query. And then we pass all of those, the query, the positive and negative,
00:09:46.640 | through to our pseudo-labeling step, which is the cross-encoded step. Now the cross-encoded step
00:09:53.120 | is going to produce a margin score. So that's almost like the difference in similarity between
00:09:59.360 | our query positive pair and our query negative pair. We'll explain all this in much more depth
00:10:04.000 | later on. I assume right now it's probably quite confusing. So no worries, we will go through all
00:10:11.440 | of this. So as you've probably noticed, each of these steps requires a user of a pre-existing
00:10:16.560 | model that has been fine-tuned for each one of these parts. Now we don't need to have fine-tuned
00:10:22.480 | this. They're general models. They don't need to have been specially trained for that particular
00:10:28.960 | task with this particular domain, okay? These models are generally quite good at adapting to
00:10:34.080 | new domains, which is why we're using them. So let's dive deeper into what I've just described
00:10:41.840 | and begin looking at the query generation step of GPL. So as I mentioned, GPL is perfect
00:10:51.200 | for those scenarios where we have no labeled data, but we do need a lot of unstructured or
00:10:59.120 | unlabeled data, okay? So we need a lot of text data that is not labeled. That could be text
00:11:04.880 | scraped from web pages, from PDF documents, from anywhere where you can find text data.
00:11:11.680 | The only requirement is that this text data is actually relevant to your particular use case,
00:11:16.720 | e.g. is in-domain. Now, if we consider the case of maybe we have a use case where we need to build
00:11:25.360 | a semantic search retrieval model for German finance documents. For the in-domain of that,
00:11:32.640 | we might consider German finance news articles or German finance regulations, okay? They would
00:11:38.640 | be in-domain, but other documents like PyTorch documentation or English financial documents,
00:11:44.560 | they are not in-domain at all. It's almost like if you imagine you're studying for an exam in
00:11:50.560 | one topic, like biology, and you start studying some chemistry papers for that biology exam.
00:11:58.000 | Maybe some of that chemistry might be relevant, a crossover of your biology, but not much. You're
00:12:04.160 | never actually going to do that if you want to pass your exam at least. Now, in our examples
00:12:09.760 | in this video, we're going to be using the CORD-19 dataset, which is a set of papers that,
00:12:16.720 | I don't think all of them are specifically about COVID-19, but I think most of them at least
00:12:21.440 | mention COVID-19. There are a lot of files here. Let's have a look at the code we use to actually
00:12:28.240 | download that data. Okay, so we need to find the CORD-19 dataset from the Allen Institute of AI.
00:12:38.240 | It's this website here. So it's not available, at least when I checked,
00:12:43.200 | it wasn't available on HuggingFace. Datasets are anywhere like this, so we need to
00:12:46.800 | pull it manually. So I'm just going to download that. We create this tar file, and essentially,
00:12:54.320 | I'm not going to go explain all of this. I want to be quick through this a little bit,
00:12:57.600 | but you will be able to find the link to this in the description. So you can just run this script,
00:13:03.280 | but everything will be stored within this document parses PDF JSON file, which is just going to be a
00:13:09.280 | load of JSON files. There's a lot of them, 300, just under 340,000 of them, and they're going to
00:13:17.200 | be named like this, and they're going to look like this. So they each have these paragraphs in there,
00:13:23.840 | which we're going to pull in and use in our examples. So once we've actually downloaded
00:13:29.040 | all of the test data, we're going to move on to the query generation code. So we have our
00:13:35.360 | Core 19 data. We're going to read them from here. I use a generate function here. So we're going
00:13:41.760 | through, getting the text, and we're just yielding the passage. So this is one passage at a time. So
00:13:47.600 | one paragraph takes that time that I'm pulling in, and I'm using a generate function here to do that.
00:13:55.440 | Now, just be aware there are a lot of duplicates in this data set. So what I've done is create
00:14:01.120 | this passage dupes set. So just check every time we have a new passage, we just check if we already
00:14:08.480 | pulled it in. If so, we skip, move on to the next one. Otherwise, it's a unique passage, and we can
00:14:14.880 | pull it in and add it to that duplication check, and then yield it. So using yield here, because
00:14:22.560 | this is a generate function that we loop through one at a time. So basically, we're going to iterate
00:14:27.520 | through that function, and it will return as one passage at a time, which is what we're doing here.
00:14:32.560 | I returned two passages here. You can see they're pretty long. So there is a lot of text data in
00:14:38.320 | there. We're probably going to cut a lot of it out when we're feeding it into our models. That's fine.
00:14:43.280 | We just want to see how this method works. We want to make everything perfect. So what we're
00:14:51.040 | going to do is use this model here. So this is a T5 model trained on data from MS Marko
00:14:58.560 | from 2018. So it's pre-COVID. It doesn't know anything about COVID. We're using Hug and Face
00:15:05.520 | Transformers here, Auto Tokenizer, and then Automotive Sequence-to-Sequence Language Modeling.
00:15:10.480 | And one thing, just to be very aware, if you have a CUDA-enabled GPU, definitely try and use it. So
00:15:16.480 | here we're just moving that model to CUDA, to our GPU. It will make things much quicker. This step,
00:15:22.400 | otherwise, on CPU, will take a very long time. So it's best not to. If you can, it's best not to do
00:15:28.560 | that. I think for me, on the Tesla V100, I think this took maybe one or two hours to do for 200,000
00:15:39.680 | examples, which is a fair few. So here's just an example of what we're going to do here. So I'm
00:15:47.120 | taking one passage. We are going to tokenize the passages to create our inputs. And now,
00:15:53.040 | inputs will include the input IDs and attention masks. We need both of those. And we feed them
00:15:57.600 | into our model. And this generates three queries from that. One other thing that you really should
00:16:05.680 | make note of here is, again, I'm moving those tensors to CUDA, if I can. If I move my model
00:16:12.800 | to CUDA already, these also need to be on CUDA. And let's have a look at what we produce from that.
00:16:17.760 | So this is, sometimes it's good, sometimes it's not good. It really depends. I think this is
00:16:23.920 | kind of a relevant example. But it's not perfect, because the queries are about SARS here,
00:16:30.080 | rather than COVID, even though we're talking about COVID. Obviously, the model doesn't know that.
00:16:34.400 | So this generative step, or query generation step, is very good for out-of-domain topics. Because,
00:16:43.840 | for example, if this said COVID-19, I think there's a good chance that this would say,
00:16:49.200 | where is COVID-19 infection found? Because it does a lot of word replacement. So it's seeing
00:16:59.520 | your answer, basically. And it's just kind of reformatting that answer into a question
00:17:04.720 | about the answer. That's a lot of what this T5 model is doing. So it can be quite good. And
00:17:11.840 | in some cases, it can be less good. But that's fine, because it's better than nothing. We really
00:17:19.600 | do need some sort of labeled data here. And in most cases, it works fairly well.
00:17:24.800 | So here, we get some OK examples. We have SARS rather than COVID. But otherwise, it's fine.
00:17:30.960 | So what I'm doing here is I'm creating a new directory called data, where I'm going to store
00:17:37.040 | all of the synthetic data we create. Importing TQDM for the progress bar,
00:17:43.040 | setting a target of 200,000 pairs that I want to create. Batch size, increase as much as you can.
00:17:49.360 | That's going to make it quicker. But obviously, it's restricted, depending on your hardware.
00:17:53.760 | And just specify a number of queries. And we're going to do this all in batches to make things
00:17:59.840 | a little bit faster. And then we reinitialize that generator. That's going to return as our
00:18:04.880 | passages. Because we've already run through three of them already. So I just want to
00:18:08.960 | restart it, essentially. So we're going to go through. I'm using this with TQDM as progress
00:18:14.880 | here. Because this is a generator, we don't know how long it is. So it's not going to give us a
00:18:20.080 | good progress bar if we just run TQDM on passages. So I'm specifying the total as a target. So total
00:18:27.280 | number of steps is the target that I have up here, the 200,000. So I'm going to increase the count by
00:18:32.400 | one with every new passage that we generate queries for. In fact, I'm going to increase it by three,
00:18:39.360 | I think. Yeah, because we create three queries per passage. So we do do that. And once our count
00:18:47.840 | is greater than or equal to the target, 200,000, we stop the whole process.
00:18:51.840 | Here, for the passage batch, I'm appending a passage. So the passage that we have from here,
00:18:59.520 | I'm just replacing any tab and newline characters to keep it clean for when we're writing the data
00:19:05.760 | later on, more than anything. Because we're going to be using tab-separated files.
00:19:09.200 | We're going to encode everything in batches. So once the length of the batch is equal to
00:19:16.480 | the batch size that we specify up here, so 256, we're going to begin encoding everything.
00:19:21.360 | So we tokenize everything in the batch, and then we generate all of our queries.
00:19:26.400 | Again, generating three for each passage that we have. And then another thing is that we need to
00:19:34.640 | decode the queries that we've generated to human readable text. Because the next model we use is
00:19:40.320 | going to be a different model. So it cannot read the same tokens that this model outputs. So we
00:19:46.960 | need to decode that back to human text. And then we're going to loop through all of the decoded
00:19:53.520 | outputs, which is one query per line. And we have to consider there's actually three queries per
00:20:00.800 | passage. So we have maybe five passages here, and that means we have 15 queries on the other side.
00:20:06.160 | So we need to consider that when we're pairing those back together. So we use this passage
00:20:11.840 | IDX, which is the integer of I divided by number of queries. So imagine you have,
00:20:18.320 | so we have for passage 0, we are going to have queries 0, 1, and 2. 0, 1, and 2, you divide those
00:20:28.800 | by 3. And all of them are going to be less than 1. So 2 is the highest one. 2 divided by 1, 0.66.
00:20:38.880 | If you take the integer value of that, it's going to become 0, which maps to the passage on this
00:20:46.720 | side. So that maps to passage number 0. So that means 0, 1, 2, all mapped to passage 0. And then
00:20:53.680 | we do the next one, so 3, 4, 5. And that's going to map to 1. So that's how we're mapping our
00:20:59.920 | generated queries back to our passages. And then for each one of those, we just append it to line
00:21:07.840 | here. We're using a tab to separate them because we're just going to write them to file. And we
00:21:12.400 | increase the count. Refresh the batch every time we've been through a batch. And update the progress
00:21:19.520 | bar. So updating it, actually updating it by 3 here. No, sorry. Updating it by the size of the
00:21:28.480 | decoded output. So if we had five passages, that would be 15 because we have three queries per
00:21:34.720 | passage. OK, now we want to write those query passage pairs to file, which we do like that.
00:21:44.560 | OK, so that's the query generation set. There's a lot in there. But from this, we now have our
00:21:50.800 | query passage pairs. Now, it is worth noting that, like we saw before, query generation is not
00:21:58.640 | perfect. It can generate noisy, imperfect, super nonsensical queries. And this is where GPL improved
00:22:10.560 | upon GENQ. If you know GENQ, we covered it in the last chapter. And GENQ just relies on this query
00:22:17.040 | generation step. GPL, later on, we have this pseudo-labeling step, which kind of cleans up
00:22:24.160 | any noisy data that we might have, or at least cleans up to an extent, which is really useful.
00:22:31.280 | So that's great. We've finished with our query generation step. And now we're ready to really
00:22:38.000 | move on to the next step of negative mining, which I think is probably one of the most interesting
00:22:44.320 | steps. So now, if you think about the data we have, we have our queries and the assumed positive
00:22:50.480 | passages, or the positively paired query passages that we have in the moment. Now, suppose we
00:22:58.400 | fine-tune our sentence transformer or bind coder with just those. Our sentence transformer is just
00:23:04.720 | going to learn how to put things together. So you're going to learn how to place these queries
00:23:09.440 | and passages in the same vector space. And yes, OK, that will work to an extent. But the performance
00:23:16.960 | is not great. And what we need is a way to actually find negative passages. So there was a
00:23:24.640 | paper on Rocket QA. And they found that when-- or they had this really cool chart that I liked,
00:23:31.440 | where you have your performance, or model performance, when they've trained it with
00:23:37.680 | hard negatives. So negative samples are quite hard for the model to figure out that it's negative,
00:23:43.680 | because it's very similar to your positive. And we'll explain more about hard negatives later on.
00:23:48.560 | And where they train the model without any hard negatives, so just the positives.
00:23:52.480 | And it's very clear that models perform better when there are hard negatives. It's almost like
00:24:00.320 | you can think of it as making your exams harder. The people that pass the exams
00:24:06.480 | now will have studied even harder than they would have in the past. Because you've made the exam
00:24:13.280 | harder. They need to study harder. And as a result of that, they will be better because of that.
00:24:19.680 | Their knowledge will be better because the exam was harder. And they had to try harder to pass it.
00:24:25.440 | It's very similar in why we include negatives. It makes the training process harder for our model.
00:24:33.360 | But at the same time, the model that we output from that is much better. So to actually get
00:24:40.880 | these negatives, we perform what's called a negative mining step. Now, this negative mining
00:24:46.960 | step is used to find highly similar passages to our positive passages that are obviously very
00:24:55.200 | similar but are not the same. So we are assuming that these very similar passages are-- maybe they
00:25:02.400 | talk about the same topic, but they don't really answer our query. We're assuming something like
00:25:08.560 | this. We don't actually know, but we're assuming this. So we're performing a semantic search,
00:25:15.920 | information retrieval step to actually return these possible negatives.
00:25:22.080 | And then we can use them as negative training examples. So now what we will have is our query,
00:25:28.960 | positive passage P plus, and assumed negative passage P minus. And the result of this is that
00:25:36.160 | our model is going to have to learn very nuanced differences between that negative or assumed
00:25:40.960 | negative and that positive in order to be able to separate them out and understand that, OK,
00:25:46.480 | the positive is the pair for our query, but this negative, even though it's super similar, is not
00:25:53.840 | the pair for our query. So it's just giving our model an extra task to perform here.
00:26:01.360 | Now, with all that in mind, we also need to understand that we're assuming that all the
00:26:06.320 | passages we're returning are negatives. We don't actually know that they are. And again,
00:26:13.840 | this is something that we're going to handle later on in the pseudo-labeling step. But for now,
00:26:17.920 | we just assume that they are, in fact, negative passages to our query.
00:26:24.000 | So let's move on to the actual implementation of negative mining.
00:26:28.800 | So the first thing we need to consider for this negative mining step is that there are
00:26:33.040 | two parts to it. There is an embedding or retrieval model. Again, we're going to use
00:26:38.800 | something that has been trained on data from pre-COVID times. And what that is going to do
00:26:46.800 | is taking our passages and our queries and translate them into vectors. OK, but we need
00:26:53.280 | somewhere to store those vectors. So we need a vector database to do that. So we're going to
00:26:58.560 | use a vector database to store all of those vectors. And then we're going to perform a
00:27:02.400 | search through all of those with our queries in order to identify the most similar passages.
00:27:09.600 | And then we're going to say, OK, is this a positive passage for this query? If not,
00:27:14.800 | then great. We're going to assume it's a negative passage for our query. We're just
00:27:19.280 | going to go through and do that. So the first thing we're going to do is load our model.
00:27:24.000 | So we're using this MS Markov Silver Base TASB. Again, that is a pre-COVID data that has been
00:27:31.920 | trained on. And we're loading that into Sentence Transformers and just setting the max sequence
00:27:36.480 | length to 256 here. OK, and then we're going to initialize a Pinecone index. So this is going to
00:27:45.120 | be our vector database. We're going to store all of our vectors. So for that, we do need an API
00:27:50.720 | key. It's all free. You don't need to pay anything for this. All the infrastructure and stuff is
00:27:54.720 | handled for us, which is great. So we go to @pinecone.io, make an account if you need to.
00:28:01.520 | And then you'll get an API key. Now, for me, it's not the best way to do it. It's just easy.
00:28:07.760 | I'm just storing my API key in a file called secret in the same directory as my code. I'm
00:28:14.160 | just reading that in here. OK, so my API key. Oh, as well, another thing you need to do here
00:28:22.640 | is to install that, you need to pip install pinecone-client. OK, not pinecone, pinecone-client.
00:28:30.240 | So we initialize pinecone with our API key. This is just a default environment. There are other
00:28:38.720 | environments available if you want to use them, but I don't think you really need to.
00:28:42.880 | And then we create the index. OK, so we have a native mining index. So here, I'm just saying
00:28:49.600 | you can check your currently running indexes with pinecone list indexes. If native mine is not in
00:28:58.560 | there, then I'm going to create native mine index. And there's a few things you need to pass, so
00:29:02.880 | dimension. The dimension is what we have here, so that the embedding dimension that our model,
00:29:09.040 | embedding model is outputting. So you specify that, 768. The metric, this is important. We
00:29:16.240 | need to be using dot product metric here for GPL. And also, the number of pods. So by default,
00:29:24.160 | this is one. What you can do is increase this. So I think I increased it to 70, which is probably
00:29:31.600 | like massively overkill for this. But that shortened. So with one, I had a runtime,
00:29:39.200 | I think, of one hour 30 or one hour 40 with 70, which again, it's probably overkill. Maybe 40
00:29:46.240 | would do the same. I don't know. With that, it was, I think, 40 minutes. So a lot faster,
00:29:54.400 | but you do have to pay for anything that's more than one pod. So I know if you're in a rush,
00:30:00.000 | then fair enough, fine. Go with that. Otherwise, you just sit with one pod. It's free.
00:30:07.680 | Okay. And then you connect. So you, this creates your index, and then you connect to your index.
00:30:15.040 | So pine cone index, negative mine. Okay. And then we want to go through, we're going to use
00:30:21.280 | our model. Actually, this bit here, we're creating our file reading generator. So it's just going to
00:30:27.120 | yield the query and passage pairs. I include this in here because I got a value error because there's
00:30:33.120 | some weird, literally one row of weird data in there. So I just added that in there to skip that
00:30:40.400 | one. And yeah, initializing that generator. And then here, I'm going to go through and actually
00:30:47.200 | encode the passages. So you see, we're creating a passage batch here. We're doing it in batches
00:30:54.880 | again, make it faster. So I'm adding my passages to the passage batch, also adding IDs to the ID
00:31:03.520 | batch. And then once we reach the batch size, which is 64 in this example, we encode our passage
00:31:11.600 | batch. We also convert them to a list because we need it to be a list when we're upsetting
00:31:16.640 | everything. So upset just means, it means update or insert. It's just like database lingo. We need
00:31:24.480 | it to be a list when we're pushing this because it's pushing it through an API call here, so
00:31:29.280 | JSON request. So we need to convert that to a list rather than NumPy array, which is, I think,
00:31:34.640 | the default. And then we just create a list of ID and vectors and upload that to PyCone, our index.
00:31:45.600 | And then we just refresh those batches. So we start again, and then we do 64 and then do the
00:31:51.280 | encode and upset again. And then at the end here, I just wanted to check the number of vectors we
00:31:56.480 | have in the index. So we see the vector, the dimensionality of the index, so as the index
00:32:02.960 | fullness, this will tell you pretty much how quick it's going to run. At zero, it's perfect. It means
00:32:07.760 | it's basically empty. It'll run pretty quick. And then you have the vector count here. So remember,
00:32:15.280 | we have 200,000 examples or pairs, but not all of them are being used. So they are all being used,
00:32:25.840 | but not all of them are unique because for each passage that we have, we have three queries.
00:32:30.480 | So three unique queries, but three duplicated passages. So obviously, out of those 200,000
00:32:39.520 | passages that we have, we need to divide that by three to get the actual number of unique passages
00:32:44.720 | that we have in there, which is something like this 76840. So the database is now set up for
00:32:54.480 | us to begin native mindset. It's full of all of our passages. So what we're going to do is loop
00:32:59.520 | through each of the queries in pairs, which we created here. So we're going to loop through all
00:33:05.840 | of those here in batches of batch size again, so 100. And we're going to initialize this triplets
00:33:13.840 | list where we're going to store our query, positive and negative. At the moment,
00:33:18.560 | we just have query positive, remember? So we're going to go through there, get our queries,
00:33:24.000 | get our positives, and we're going to create the query embeddings. So that's in batches again.
00:33:32.560 | Then we search for the top 10 most similar matches to our query, and then we loop through all of
00:33:40.640 | those. So for query, positive passage, and query response, this query response will actually have
00:33:48.880 | the 10 possible or the 10 high similarity passages for that query inside there. So extract those,
00:33:56.320 | remember? 10 in there. And I'm just going to shuffle those. So I'm going to shuffle them,
00:34:00.560 | and then we're going to loop through them. Now, we do this so that we're not just returning the
00:34:04.800 | most similar one all the time, but we're returning one of the most similar top 10 instead. And then
00:34:12.160 | we extract native passage from that one record that hit. And then one thing we really need to
00:34:20.080 | consider here is, OK, if we've got all of our passages in there, we're also going to have the
00:34:24.320 | positive passage for our query as well in there. So it's pretty likely that we're going to return
00:34:30.960 | that, at least for a few of our queries for sure. So we need to check that negative passage we're
00:34:37.280 | looking at does not match to the positive passage for that query. And then if not,
00:34:43.600 | then that means it's a negative passage or assumed negative passage, and we can append
00:34:48.240 | it to our triplets. So we have query, tab, positive, tab, negative. And then we say then file.
00:34:55.120 | OK? Now, one last thing. Before we move on, we should delete our index. So if you're on the
00:35:04.640 | free tier, you're just using that one pod, it's fine. You're not going to pay anything anyway.
00:35:08.800 | But I guess it's also good practice to remove it after you're done with it. But if you are paying,
00:35:13.520 | of course, you want to do this so you're not spending any more money than you need to. So
00:35:17.680 | that's the negative mining step. Now we can move on to what is the final data preparation step,
00:35:24.400 | which is the pseudo-labeling. Now, pseudo-labeling is essentially where we use a cross-encoder model
00:35:31.760 | to kind of clean up the data from the previous two steps. So what this cross-encoder is going
00:35:39.920 | to do is generate similarity scores for both the query positive and the query negative pairs. OK?
00:35:47.040 | So we pass both those into cross-encoder model, and it will output predicted similarity scores
00:35:53.440 | for the pairs. So first thing you need to do is actually initialize a cross-encoder model. Again,
00:35:59.360 | this should have been trained on pre-COVID data. And then we're going to use a generate function
00:36:06.880 | as we have been doing throughout this entire thing, just to read the data or the pairs,
00:36:11.840 | triplets in this case, from file, and then yield them. And what we're going to do is use that
00:36:17.920 | cross-encoder, using this function here, to calculate both the similarity of the positive
00:36:24.800 | and the negative. And then we subtract those to get the margin between them, so the separation
00:36:29.680 | between them, which is ideal for when we're performing margin MSC loss, which we will be
00:36:35.920 | very soon for fine-tuning our model. So we go through using our generator here. We get a line,
00:36:43.840 | a time, a query, positive, negative. We get a positive score, negative score, and then we
00:36:49.440 | calculate the margin between them. And then we're going to append those to label lines and save it
00:36:55.120 | all to file. So it's a pretty quick step, actually, for the pseudo-labeling step. But this final step
00:37:03.920 | is also very important in ensuring that we have high-quality training data. Without it, we would
00:37:12.400 | need to assume that all passages returning negative mining step are actually negatives,
00:37:18.080 | and they're irrelevant for our particular query. In reality, this is obviously never the case,
00:37:25.200 | because some negative passages are obviously going to be more negative or less negative
00:37:30.960 | than others. And maybe some of them are not even negative in the first place.
00:37:37.040 | So there was this chart or example from the GPL paper that I really liked. I've just adapted it
00:37:43.840 | for our particular use case. So the query is, what are the symptoms of COVID-19? The positive,
00:37:50.960 | so the first row there, is COVID-19 symptoms include fever, coughing, loss of sense of smell
00:37:56.160 | or taste. So that positive passage is the passage that we started with. We generated that query in
00:38:03.360 | the generative or query generation stage. And then we retrieved these negatives in the negative mining
00:38:10.000 | stage. Now, just have a look at each one of these. So the first one, fever, coughing, and a loss of
00:38:16.160 | sense of smell are common COVID symptoms. That is not actually a negative. It's a positive. But this
00:38:22.320 | is very likely to happen in the negative mining stage, because we're returning the most similar
00:38:26.800 | other passages. And those other very similar passages could very easily be actual genuine
00:38:33.840 | positives that are just not marked as being the pair for our query. We don't actually know that
00:38:42.080 | they're negatives. So in this case, this is a false negative, because it should be a positive.
00:38:47.040 | Now, the next one, we have these easy negatives. The next one, symptoms are physical or mental
00:38:54.080 | features that indicate a condition of disease. Now, this is a pretty typical easy negative,
00:39:00.240 | because it has some sort of crossover. So the question is, what are the symptoms of COVID-19?
00:39:06.400 | And this is just defining what a symptom is. But it's not about COVID-19 or the symptoms of COVID-19.
00:39:12.640 | This is an easy negative, because it mentions one of those keywords. But other than that,
00:39:19.040 | it's not really similar at all. So it will be quite easy for our model to separate that.
00:39:23.680 | So it will be easy for it to look at this easy negative and say, OK, it's not relevant for our
00:39:28.160 | query. It's very different to the positive that I have here. That's why it's an easy negative.
00:39:34.160 | Next one, another easy negative is COVID-19 is believed to have spread from animals to
00:39:39.600 | humans in late 2019. Again, it's talking about COVID-19, but it's not the symptoms of COVID-19.
00:39:45.680 | So maybe it's slightly harder for our model to separate this one. But it's still a super easy
00:39:50.560 | negative. And then we have a final one. So this is a hard negative. Coughs are a symptom of many
00:39:56.000 | illnesses, including flu, COVID-19, and asthma. In this case, it's actually a partial answer,
00:40:03.280 | because it tells us, OK, cough is one of the symptoms. But it's still not an answer. We're
00:40:08.240 | asking about the symptoms. We want to know about multiple symptoms. This is kind of partially
00:40:13.840 | answers, but it doesn't really. And what I really like about this is we look at the scores on the
00:40:19.600 | right here. We have the scores when you're using something like pseudo-labeling from GPL,
00:40:24.320 | and scores when you're not using pseudo-labeling, like if you're using GenQ. Now, GenQ, even so,
00:40:31.360 | like for the false negative, GenQ is seeing this as a full-on negative. So you're just going to
00:40:36.800 | confuse your model if you're doing this. If it has two things that are talking about the same thing,
00:40:42.320 | and then you're telling your model, actually, one of them is not relevant, your model is just like,
00:40:47.200 | you're really going to damage the performance of your model from doing that. And it's the same
00:40:51.760 | with the other ones as well. There's almost like a sliding scale of relevance here. It's not just
00:40:58.000 | black or white. There's all these shades of gray in the middle. And when using GPL and pseudo-labeling,
00:41:04.800 | we can fill in those shades of gray, and we can see that there is a sliding scale of relevance.
00:41:11.120 | Without pseudo-labeling, we really miss that. So it's a very simple step, but really important
00:41:18.480 | for the performance of our model. Okay, so we now have the fully prepared data, and we can actually
00:41:25.840 | move on to actually fine-tuning our sentence transformer using margin MSC loss. Now, this
00:41:34.800 | fine-tuning portion is not anything new or unique. It's actually very common. Margin MSC loss is used
00:41:44.560 | for a lot of sentence transformer models. So looking back at the generated data, we have the
00:41:50.400 | format of Q, positive passage, negative passage, and margin. Let's have a look at how those fit
00:41:58.640 | into the margin MSC function. So we have the query, positive, negative. We pass those through.
00:42:07.360 | We get the similarity here. So this is our similarity of the query and the positive.
00:42:15.200 | Okay, and this is this bi-encoder here is the one we're training, by the way. So this bi-encoder is
00:42:22.240 | the final model that we're fine-tuning, that we're trying to adapt to our new domain.
00:42:28.320 | And then we also calculate using the query and negative vectors here. We calculate the similarity
00:42:35.360 | between those to get the similarity of Q and negative. We subtract those to get this here,
00:42:45.040 | which delta hat is our predicted margin. With our predicted margin, we can compare that against
00:42:53.280 | our true margin, or we assume it's a true margin from our cross-encoder. And we feed those into
00:43:00.640 | the margin MSC over here, loss function. So we have the predicted minus the true for each sample.
00:43:10.560 | We square all of that to get the squared error. And then we're going over for all of our samples,
00:43:19.440 | taking the average. And that is our margin MSC loss. Now, what is really nice is that we can
00:43:27.840 | actually use the default training methods in the Sentence Transformers library for this.
00:43:34.880 | Because margin MSC loss is just a normal method. There's nothing new here at all.
00:43:40.960 | OK. So from Sentence Transformers, I'm going to import the input example, which is just a
00:43:46.240 | data format we always use in our training in here. So we're opening up the triplets margin here.
00:43:54.080 | I'm just reading all the lines here. I'm not going to use a generator. This time,
00:43:58.720 | I'm going to just create our training data, which is just a list of these input examples.
00:44:03.440 | We pass the query positive and negative as a text in here, so as our triplet.
00:44:10.480 | And then the label, we take the float of margin, because margin is a string from our TSV.
00:44:17.520 | And from this, we get this 200,000 training examples. And we can load these pairs into a
00:44:24.400 | generator data loader here. So we're just using normal torch data loader, nothing special here.
00:44:31.920 | We use empty cache here just to clear our GPU in case we have anything on there. So if you keep
00:44:38.400 | running this script again and again, trying to get it working or just modifying different
00:44:44.000 | parameters to see how it goes, you'll want to include this in here. The batch size. Batch size
00:44:50.320 | is important with margin MSC loss. The larger you can get, the better. I did see, I think,
00:44:56.000 | in the sentence transformers documentation, it used something like 64. And that's hard. You're
00:45:02.000 | going to need a pretty good GPU for that. And to be fair, even 32, you need a good GPU. This is using
00:45:09.040 | one Tesla V100 for this. And it works. It's fine. So as high as you can get,
00:45:17.680 | 32, I think, is good. Even better if you can get 64.
00:45:23.200 | And, yeah, we just initialize data loader. We set the batch size and we shuffle to true.
00:45:27.600 | And so the next step we want to do here is initialize the buy encoder or sentence
00:45:35.760 | transformer model that we're going to be fine tuning using domain adaption. So this is the
00:45:42.320 | same one we used earlier in the retrieval step. We're actually going to be taking that and fine
00:45:47.280 | tuning it. So we set the model match sequence length here, again, as we did before. And then
00:45:56.240 | we initialize the margin MSC loss function as we would with normal sentence transformers.
00:46:01.280 | And then we just run like model fit. We train for one epoch or actually, it depends. You can
00:46:08.800 | get better performance if you train for more. I'll show you. And I just do 10% of the warm steps here
00:46:14.640 | like we usually do. And I'm going to save it. Same model name, but I've added COVID onto the end
00:46:19.600 | there. And this doesn't take long. I think this was-- maybe it was 40 minutes per epoch on the
00:46:26.960 | test of V100. So it's not bad. Now, in terms of the amount of training steps you want to actually
00:46:35.280 | put your model through, with GPL, they did fine. So in the paper, there's a really nice chart
00:46:43.360 | that shows that after around 10,000 steps, for that particular use case, it leveled off. The
00:46:49.520 | performance didn't really improve that much more from there. However, I did find-- so I'm using
00:46:55.600 | 200,000 examples. So one epoch for me is already 200,000 training steps, which is more. And I found
00:47:04.960 | that actually training for 10 epochs-- I didn't test in between. I just went from one. And I also
00:47:11.040 | tested 10. I found that 10, actually, the performance seemed to be better. I didn't really
00:47:17.600 | quantitatively assess this. I just looked at it qualitatively. And it was performing or is
00:47:23.280 | returning better results than the model that had been trained for one epoch. So for rather than
00:47:29.200 | 200,000 steps, it had been trained for 2 million steps, which is a fair bit more than 100,000
00:47:35.520 | mentioned here. But that's just what I found. I don't know why. I suppose it will depend very much
00:47:43.920 | on the data set you're using in particular. But I would definitely recommend you just try and test
00:47:50.480 | a few different models. The training doesn't take too long for this step. So you have the luxury of
00:47:56.560 | actually being able to do that and testing these different numbers of training steps.
00:48:02.160 | So I mean, once training is complete, we can actually test that model as we usually would.
00:48:07.920 | We have the model as well. So I can show you. So we have the model over on-- if we go to
00:48:14.720 | Hugging Face CO Models, maybe. So if you search in here, you go Pinecone. And one of these is
00:48:23.680 | the model that has been trained on COVID. So hit this one here. So you can actually copy this. And
00:48:30.240 | you can just write-- I think it is-- it should maybe have an example. Ah, so you write this. So
00:48:36.720 | you from Sentence Transformers, import Sentence Transformer. And you just do model equals Sentence
00:48:42.960 | Transformer in your model name, which in this case is what you can see up here, Pinecone MS Marco.
00:48:48.240 | And you can test the model that I've actually trained here. Again, this was trained on 10 epochs,
00:48:53.680 | not one. OK, so in this evaluation here, you can see what I was looking at. So load the old model.
00:49:00.400 | It's just the MS Marco, Silbert, Base, TUSB. And then here, I've got the model trained for one
00:49:07.440 | epoch. And here, the model trained for 10 epochs, which is equivalent to the model I just showed
00:49:12.960 | you on Hugging Face Hub. So I've taken these tests. So we have these queries and three answers
00:49:22.400 | or three passages for each one of those. Now, these are not even really perfect matches to the
00:49:28.240 | sort of data we were training on. But they're kind of similar in that they just include COVID-19.
00:49:34.240 | OK, so it's not a perfect match. It's not even perfectly in the domain. But it's good enough,
00:49:40.240 | I think. But I think that really shows how useful this technique is because it does work even with
00:49:46.640 | this kind of not perfect match between what I'm testing on and what I'm actually training on.
00:49:52.080 | So here, I'm just going through those. I'm getting the dot product score between those
00:49:56.560 | and sorting based on that. So the highest rated passage for each query is returned.
00:50:02.400 | So for the old model, how is COVID-19 transmitted? We get Ebola, HIV. And then at the very bottom,
00:50:09.680 | we have Corona, OK? Again, what is the latest name variant of Corona? And then it's these here.
00:50:17.920 | Again, COVID-19 is right at the bottom. We have Corona lager right at the top.
00:50:23.440 | I don't know if the rest of the world calls it lager or beer, but I put lager there anyway.
00:50:30.400 | And then we say, OK, what are the symptoms of COVID-19? Then we get flu. Corona comes second
00:50:34.880 | here. And then we have symptoms at definition again. And then how will most people identify
00:50:40.160 | that they have contracted the coronavirus? And then after drinking too many bottles of Corona
00:50:44.480 | beer, most people are hungover, right? So obviously, it's not working very well.
00:50:48.160 | And then we do the GPL model. So this is the training on one epoch. We see slightly better
00:50:53.200 | results. So Corona is number one here. And then here, it's number two. So it's not great, but fine.
00:51:02.960 | What are the symptoms of COVID-19? This one, I guess, right? So that's good. And then here,
00:51:10.880 | how will most people identify that they have contracted the coronavirus? Again, it's doing
00:51:14.400 | the drinking too many bottles of Corona. Again, then you have COVID-19 second place.
00:51:20.480 | Then we have the model that's trained for 10 epochs. This one, every time. So number one,
00:51:25.360 | number one, and number one. So it's pretty cool to see that. And yeah, I think it's super
00:51:35.120 | interesting that this technique can actually do that using nothing more than just unstructured
00:51:41.840 | text data. So that's it for this video on generative pseudo-labeling, or GPL.
00:51:48.480 | Now, I think it's already very impressive what it can do. And more importantly, I think,
00:51:54.080 | or what is more interesting for me is where this technique will go in the future. Is this going to
00:52:00.000 | become a new standard in training sentence transformers, where you're building the
00:52:04.080 | synthetic data? I feel like there's really a lot that we covered in this video and from the GPL
00:52:11.520 | paper to unpack and think about. We have negative mining. We have pseudo-labeling. We have all
00:52:16.640 | these different techniques that go together and produce this really cool, in my opinion, technique
00:52:22.960 | that we can use super powerful. For me, it's just super impressive that you can actually do this.
00:52:30.800 | So I hope there is more in this field in GPL or some new form of GPL in the future.
00:52:39.120 | I'm very excited to see that. But for now, of course, that's it. I hope this video has been
00:52:45.760 | useful. I hope you're excited about this as I am. So thank you very much for watching,
00:52:50.800 | and I will see you again in the next one. Bye.