back to indexIs GPL the Future of Sentence Transformers? | Generative Pseudo-Labeling Deep Dive
Chapters
0:0 Intro
1:8 Semantic Web and Other Uses
4:36 Why GPL?
7:31 How GPL Works
10:37 Query Generation
12:8 CORD-19 Dataset and Download
13:27 Query Generation Code
21:53 Query Generation is Not Perfect
22:39 Negative Mining
26:28 Negative Mining Implementation
27:21 Negative Mining Code
35:19 Pseudo-Labeling
35:55 Pseudo-Labeling Code
37:1 Importance of Pseudo-Labeling
41:20 Margin MSE Loss
43:40 MarginMSE Fine-tune Code
46:30 Choosing Number of Steps
48:54 Fast Evaluation
51:43 What's Next for Sentence Transformers?
00:00:00.720 |
Today we're going to dive into what I think is probably one of the most interesting techniques 00:00:06.720 |
in semantic search for a very long time. That technique is called generative pseudo-labelling 00:00:13.120 |
or GPL. It's from Noah's Rhymers team at the UKP lab and this is primarily the work of Ketsin Wang. 00:00:22.880 |
Now Ketsin Wang also produced the TSE paper maybe around a year or so ago as well and TSE 00:00:34.240 |
is the unsupervised training method for sentence transformers and GPL kind of goes, 00:00:41.440 |
takes that and kind of puts up another notch. So we're going to be using GPL to take unstructured 00:00:50.880 |
data and generate structured data from it for training sentence transformers. Now before we 00:00:59.520 |
really dive deep into GPL, I want to give a little bit of background and set the scene of where we 00:01:06.640 |
might want to use this sort of technique. So in 1999 there was a concept known as semantic web 00:01:16.080 |
described by the creator of the World Wide Web, Tim Berners-Lee. Now Tim Berners-Lee had this 00:01:22.960 |
sort of dream of the web in the way that we know it today but where you have machines roaming the 00:01:32.240 |
web and being able to just understand everything. But when you start to become a bit more niche, 00:01:43.520 |
it gets hard to actually find that sort of data set and pretty much impossible most of the time. 00:01:48.720 |
So for example let's say you're in the finance industry, you have these internal documents, 00:01:56.000 |
fairly technical financial documents and you want to build a question answering system 00:02:01.840 |
where people or staff can ask a question in natural language and return answers from those 00:02:08.080 |
documents. It's going to be hard for you to find a model, you do have financial Q&A models but 00:02:15.760 |
they've been fine-tuned on like personal finance questions on Reddit and okay kind of similar but 00:02:22.800 |
not really. Like personal finance and technical finance documentation in a finance company, 00:02:29.600 |
very different. So it's very hard to actually find a data set that is going to satisfy what you need 00:02:37.120 |
in that scenario because it's fairly niche. And it's not even that niche, 00:02:40.560 |
there's a lot of companies that need to do that sort of thing. Another example is a project I 00:02:45.760 |
worked on was for the Devay language. We wanted to create a set of language models for the Devay 00:02:52.400 |
language which is the national language of the Maldives. Now there's not that many Devay speakers 00:02:58.480 |
worldwide so it's pretty niche. So it's very hard to find labeled data and I think what we end up 00:03:06.400 |
doing is actually using unsupervised methods. We use TSAE for the actual sentence transformer model. 00:03:13.840 |
What we have there is another use case where it's fairly niche and it's hard to find labeled 00:03:20.560 |
data for. But there is unlabeled data, unstructured data in both of those scenarios. In your finance 00:03:27.680 |
company you have the documents there, you just don't have them, they're not labeled in a way that 00:03:33.680 |
you can train a model with. Or in the Devay language example we have, there is many web 00:03:40.160 |
pages in Devay that we can scrape data from but it's not labeled. So for these more niche topics 00:03:49.360 |
which covered a vast majority of use cases, you really either need to spend a lot of money 00:03:57.200 |
labeling your data or we need to find something else that allows us to either synthetically 00:04:02.880 |
generate data or just train on the unlabeled data. Now training on the unlabeled data 00:04:10.320 |
doesn't really work. It works to an extent, TSAE showed that it can but you're looking at very 00:04:19.120 |
generic semantic similarity and it's not perfect. GPL again, it's not perfect but it's definitely 00:04:30.160 |
much better. So back to the semantic web example, in the case of the semantic web, the majority of 00:04:39.200 |
the semantic web is going to be these niche topics that we don't have labeled data for. 00:04:45.200 |
GPL is not a solution, it's not going to create the semantic web or anything like that. 00:04:50.880 |
But it does allow us to actually target those previously inaccessible domains that actually 00:04:59.520 |
begin producing models that can intelligently comprehend the meaning behind the language in 00:05:07.840 |
those particular domains which I think is super fascinating and the fact that you can actually do 00:05:13.520 |
this is so cool. Now there is a lot to it. This video I'm sure will be quite long but we really 00:05:21.760 |
are going to go into the details and I think by the end of this video you should know everything 00:05:26.960 |
you need to know about GPL and you will certainly be able to apply it in practice. 00:05:31.680 |
So we've introduced where we might want to use it. I just want to now have a look at sort of 00:05:40.400 |
like an overview of GPL and where we are going to use it in this video. So GPL is, you can use 00:05:47.840 |
it in two ways. You can use it to fine tune a pre-trained model. So when I say pre-trained, 00:05:55.200 |
I mean a transformer model that hasn't been fine-tuned specifically for semantic search. 00:06:00.400 |
So a BERT-based case, for example, would be a pre-trained model from Hockey Base Hub. You take 00:06:05.440 |
that and you can use GPL to fine-tune it for semantic search. You'll get okay results with 00:06:10.640 |
that. The better way of using GPL is for domain adaption. So domain adaption is where you already 00:06:19.120 |
have a semantic search model. So for example, what we are going to look at, we have a model that has 00:06:25.040 |
been trained on MS Marco from data from 2018. This model, if you give it some questions about COVID-19, 00:06:36.400 |
because 2018 was pre-COVID, it really kind of struggles with even simple questions. You ask 00:06:44.560 |
it a simple question about COVID-19 and it will start returning you answers from about like Ebola 00:06:51.040 |
and flu rather than COVID-19 because it's confused. It's never seen that before. 00:06:57.280 |
So this is a typical issue with sentence transformers. They're very brittle 00:07:04.960 |
and they don't adapt new domains very well whatsoever. 00:07:08.800 |
So GPL is best used in these sorts of scenarios where we want to adapt a model to a particular 00:07:17.440 |
domain. Now, as I said, our example is going to be COVID. We're going to teach this model 00:07:24.640 |
to understand COVID even though it's never seen anything about COVID before. 00:07:30.880 |
So let's have a look at how GPL actually works. So at a very high level, GPL consists of three data 00:07:38.800 |
preparation steps followed by one fine-tuning step. So the data preparation steps are the steps 00:07:46.240 |
that produce the synthetic data set that we then use in that fine-tuning step. So that fine-tuning 00:07:53.440 |
step is actually a common supervised learning method. The unsupervised, if you want to call it 00:07:59.920 |
that, part of this method is that we are actually generating the text or the data that we'll be 00:08:09.360 |
using automatically. We don't need to label it ourselves. So these three data preparation steps 00:08:16.160 |
are the key to GPL. They are query generation, which is where we create queries from passages, 00:08:26.400 |
so like paragraphs of text. The second is negative mining. So that is an information 00:08:34.160 |
retrieval step where we retrieve similar passages to our positive passage, which is a passage from 00:08:41.920 |
before that do not actually answer our query, or we assume they do not answer our query. 00:08:47.600 |
And then the pseudo-labeling step kind of cleans up the data from those two steps, 00:08:55.280 |
from those two earlier parts, and uses a cross-encoded model to identify the similarity 00:09:01.200 |
scores to assign to our passages and query pairs that we've generated already. 00:09:09.200 |
So you can see that process here. We start with P+, which means positive passage at the top. 00:09:15.600 |
We pass that through a query generation model. We use T5 for this, and that generates queries 00:09:22.720 |
or questions. That's passed through a dense retrieval step, the negative mining step, 00:09:27.600 |
which will hopefully return things that either partially answer or share many of the same 00:09:34.320 |
words with our query or positive passage, and therefore are similar but do not actually match 00:09:42.240 |
to our query. And then we pass all of those, the query, the positive and negative, 00:09:46.640 |
through to our pseudo-labeling step, which is the cross-encoded step. Now the cross-encoded step 00:09:53.120 |
is going to produce a margin score. So that's almost like the difference in similarity between 00:09:59.360 |
our query positive pair and our query negative pair. We'll explain all this in much more depth 00:10:04.000 |
later on. I assume right now it's probably quite confusing. So no worries, we will go through all 00:10:11.440 |
of this. So as you've probably noticed, each of these steps requires a user of a pre-existing 00:10:16.560 |
model that has been fine-tuned for each one of these parts. Now we don't need to have fine-tuned 00:10:22.480 |
this. They're general models. They don't need to have been specially trained for that particular 00:10:28.960 |
task with this particular domain, okay? These models are generally quite good at adapting to 00:10:34.080 |
new domains, which is why we're using them. So let's dive deeper into what I've just described 00:10:41.840 |
and begin looking at the query generation step of GPL. So as I mentioned, GPL is perfect 00:10:51.200 |
for those scenarios where we have no labeled data, but we do need a lot of unstructured or 00:10:59.120 |
unlabeled data, okay? So we need a lot of text data that is not labeled. That could be text 00:11:04.880 |
scraped from web pages, from PDF documents, from anywhere where you can find text data. 00:11:11.680 |
The only requirement is that this text data is actually relevant to your particular use case, 00:11:16.720 |
e.g. is in-domain. Now, if we consider the case of maybe we have a use case where we need to build 00:11:25.360 |
a semantic search retrieval model for German finance documents. For the in-domain of that, 00:11:32.640 |
we might consider German finance news articles or German finance regulations, okay? They would 00:11:38.640 |
be in-domain, but other documents like PyTorch documentation or English financial documents, 00:11:44.560 |
they are not in-domain at all. It's almost like if you imagine you're studying for an exam in 00:11:50.560 |
one topic, like biology, and you start studying some chemistry papers for that biology exam. 00:11:58.000 |
Maybe some of that chemistry might be relevant, a crossover of your biology, but not much. You're 00:12:04.160 |
never actually going to do that if you want to pass your exam at least. Now, in our examples 00:12:09.760 |
in this video, we're going to be using the CORD-19 dataset, which is a set of papers that, 00:12:16.720 |
I don't think all of them are specifically about COVID-19, but I think most of them at least 00:12:21.440 |
mention COVID-19. There are a lot of files here. Let's have a look at the code we use to actually 00:12:28.240 |
download that data. Okay, so we need to find the CORD-19 dataset from the Allen Institute of AI. 00:12:38.240 |
It's this website here. So it's not available, at least when I checked, 00:12:43.200 |
it wasn't available on HuggingFace. Datasets are anywhere like this, so we need to 00:12:46.800 |
pull it manually. So I'm just going to download that. We create this tar file, and essentially, 00:12:54.320 |
I'm not going to go explain all of this. I want to be quick through this a little bit, 00:12:57.600 |
but you will be able to find the link to this in the description. So you can just run this script, 00:13:03.280 |
but everything will be stored within this document parses PDF JSON file, which is just going to be a 00:13:09.280 |
load of JSON files. There's a lot of them, 300, just under 340,000 of them, and they're going to 00:13:17.200 |
be named like this, and they're going to look like this. So they each have these paragraphs in there, 00:13:23.840 |
which we're going to pull in and use in our examples. So once we've actually downloaded 00:13:29.040 |
all of the test data, we're going to move on to the query generation code. So we have our 00:13:35.360 |
Core 19 data. We're going to read them from here. I use a generate function here. So we're going 00:13:41.760 |
through, getting the text, and we're just yielding the passage. So this is one passage at a time. So 00:13:47.600 |
one paragraph takes that time that I'm pulling in, and I'm using a generate function here to do that. 00:13:55.440 |
Now, just be aware there are a lot of duplicates in this data set. So what I've done is create 00:14:01.120 |
this passage dupes set. So just check every time we have a new passage, we just check if we already 00:14:08.480 |
pulled it in. If so, we skip, move on to the next one. Otherwise, it's a unique passage, and we can 00:14:14.880 |
pull it in and add it to that duplication check, and then yield it. So using yield here, because 00:14:22.560 |
this is a generate function that we loop through one at a time. So basically, we're going to iterate 00:14:27.520 |
through that function, and it will return as one passage at a time, which is what we're doing here. 00:14:32.560 |
I returned two passages here. You can see they're pretty long. So there is a lot of text data in 00:14:38.320 |
there. We're probably going to cut a lot of it out when we're feeding it into our models. That's fine. 00:14:43.280 |
We just want to see how this method works. We want to make everything perfect. So what we're 00:14:51.040 |
going to do is use this model here. So this is a T5 model trained on data from MS Marko 00:14:58.560 |
from 2018. So it's pre-COVID. It doesn't know anything about COVID. We're using Hug and Face 00:15:05.520 |
Transformers here, Auto Tokenizer, and then Automotive Sequence-to-Sequence Language Modeling. 00:15:10.480 |
And one thing, just to be very aware, if you have a CUDA-enabled GPU, definitely try and use it. So 00:15:16.480 |
here we're just moving that model to CUDA, to our GPU. It will make things much quicker. This step, 00:15:22.400 |
otherwise, on CPU, will take a very long time. So it's best not to. If you can, it's best not to do 00:15:28.560 |
that. I think for me, on the Tesla V100, I think this took maybe one or two hours to do for 200,000 00:15:39.680 |
examples, which is a fair few. So here's just an example of what we're going to do here. So I'm 00:15:47.120 |
taking one passage. We are going to tokenize the passages to create our inputs. And now, 00:15:53.040 |
inputs will include the input IDs and attention masks. We need both of those. And we feed them 00:15:57.600 |
into our model. And this generates three queries from that. One other thing that you really should 00:16:05.680 |
make note of here is, again, I'm moving those tensors to CUDA, if I can. If I move my model 00:16:12.800 |
to CUDA already, these also need to be on CUDA. And let's have a look at what we produce from that. 00:16:17.760 |
So this is, sometimes it's good, sometimes it's not good. It really depends. I think this is 00:16:23.920 |
kind of a relevant example. But it's not perfect, because the queries are about SARS here, 00:16:30.080 |
rather than COVID, even though we're talking about COVID. Obviously, the model doesn't know that. 00:16:34.400 |
So this generative step, or query generation step, is very good for out-of-domain topics. Because, 00:16:43.840 |
for example, if this said COVID-19, I think there's a good chance that this would say, 00:16:49.200 |
where is COVID-19 infection found? Because it does a lot of word replacement. So it's seeing 00:16:59.520 |
your answer, basically. And it's just kind of reformatting that answer into a question 00:17:04.720 |
about the answer. That's a lot of what this T5 model is doing. So it can be quite good. And 00:17:11.840 |
in some cases, it can be less good. But that's fine, because it's better than nothing. We really 00:17:19.600 |
do need some sort of labeled data here. And in most cases, it works fairly well. 00:17:24.800 |
So here, we get some OK examples. We have SARS rather than COVID. But otherwise, it's fine. 00:17:30.960 |
So what I'm doing here is I'm creating a new directory called data, where I'm going to store 00:17:37.040 |
all of the synthetic data we create. Importing TQDM for the progress bar, 00:17:43.040 |
setting a target of 200,000 pairs that I want to create. Batch size, increase as much as you can. 00:17:49.360 |
That's going to make it quicker. But obviously, it's restricted, depending on your hardware. 00:17:53.760 |
And just specify a number of queries. And we're going to do this all in batches to make things 00:17:59.840 |
a little bit faster. And then we reinitialize that generator. That's going to return as our 00:18:04.880 |
passages. Because we've already run through three of them already. So I just want to 00:18:08.960 |
restart it, essentially. So we're going to go through. I'm using this with TQDM as progress 00:18:14.880 |
here. Because this is a generator, we don't know how long it is. So it's not going to give us a 00:18:20.080 |
good progress bar if we just run TQDM on passages. So I'm specifying the total as a target. So total 00:18:27.280 |
number of steps is the target that I have up here, the 200,000. So I'm going to increase the count by 00:18:32.400 |
one with every new passage that we generate queries for. In fact, I'm going to increase it by three, 00:18:39.360 |
I think. Yeah, because we create three queries per passage. So we do do that. And once our count 00:18:47.840 |
is greater than or equal to the target, 200,000, we stop the whole process. 00:18:51.840 |
Here, for the passage batch, I'm appending a passage. So the passage that we have from here, 00:18:59.520 |
I'm just replacing any tab and newline characters to keep it clean for when we're writing the data 00:19:05.760 |
later on, more than anything. Because we're going to be using tab-separated files. 00:19:09.200 |
We're going to encode everything in batches. So once the length of the batch is equal to 00:19:16.480 |
the batch size that we specify up here, so 256, we're going to begin encoding everything. 00:19:21.360 |
So we tokenize everything in the batch, and then we generate all of our queries. 00:19:26.400 |
Again, generating three for each passage that we have. And then another thing is that we need to 00:19:34.640 |
decode the queries that we've generated to human readable text. Because the next model we use is 00:19:40.320 |
going to be a different model. So it cannot read the same tokens that this model outputs. So we 00:19:46.960 |
need to decode that back to human text. And then we're going to loop through all of the decoded 00:19:53.520 |
outputs, which is one query per line. And we have to consider there's actually three queries per 00:20:00.800 |
passage. So we have maybe five passages here, and that means we have 15 queries on the other side. 00:20:06.160 |
So we need to consider that when we're pairing those back together. So we use this passage 00:20:11.840 |
IDX, which is the integer of I divided by number of queries. So imagine you have, 00:20:18.320 |
so we have for passage 0, we are going to have queries 0, 1, and 2. 0, 1, and 2, you divide those 00:20:28.800 |
by 3. And all of them are going to be less than 1. So 2 is the highest one. 2 divided by 1, 0.66. 00:20:38.880 |
If you take the integer value of that, it's going to become 0, which maps to the passage on this 00:20:46.720 |
side. So that maps to passage number 0. So that means 0, 1, 2, all mapped to passage 0. And then 00:20:53.680 |
we do the next one, so 3, 4, 5. And that's going to map to 1. So that's how we're mapping our 00:20:59.920 |
generated queries back to our passages. And then for each one of those, we just append it to line 00:21:07.840 |
here. We're using a tab to separate them because we're just going to write them to file. And we 00:21:12.400 |
increase the count. Refresh the batch every time we've been through a batch. And update the progress 00:21:19.520 |
bar. So updating it, actually updating it by 3 here. No, sorry. Updating it by the size of the 00:21:28.480 |
decoded output. So if we had five passages, that would be 15 because we have three queries per 00:21:34.720 |
passage. OK, now we want to write those query passage pairs to file, which we do like that. 00:21:44.560 |
OK, so that's the query generation set. There's a lot in there. But from this, we now have our 00:21:50.800 |
query passage pairs. Now, it is worth noting that, like we saw before, query generation is not 00:21:58.640 |
perfect. It can generate noisy, imperfect, super nonsensical queries. And this is where GPL improved 00:22:10.560 |
upon GENQ. If you know GENQ, we covered it in the last chapter. And GENQ just relies on this query 00:22:17.040 |
generation step. GPL, later on, we have this pseudo-labeling step, which kind of cleans up 00:22:24.160 |
any noisy data that we might have, or at least cleans up to an extent, which is really useful. 00:22:31.280 |
So that's great. We've finished with our query generation step. And now we're ready to really 00:22:38.000 |
move on to the next step of negative mining, which I think is probably one of the most interesting 00:22:44.320 |
steps. So now, if you think about the data we have, we have our queries and the assumed positive 00:22:50.480 |
passages, or the positively paired query passages that we have in the moment. Now, suppose we 00:22:58.400 |
fine-tune our sentence transformer or bind coder with just those. Our sentence transformer is just 00:23:04.720 |
going to learn how to put things together. So you're going to learn how to place these queries 00:23:09.440 |
and passages in the same vector space. And yes, OK, that will work to an extent. But the performance 00:23:16.960 |
is not great. And what we need is a way to actually find negative passages. So there was a 00:23:24.640 |
paper on Rocket QA. And they found that when-- or they had this really cool chart that I liked, 00:23:31.440 |
where you have your performance, or model performance, when they've trained it with 00:23:37.680 |
hard negatives. So negative samples are quite hard for the model to figure out that it's negative, 00:23:43.680 |
because it's very similar to your positive. And we'll explain more about hard negatives later on. 00:23:48.560 |
And where they train the model without any hard negatives, so just the positives. 00:23:52.480 |
And it's very clear that models perform better when there are hard negatives. It's almost like 00:24:00.320 |
you can think of it as making your exams harder. The people that pass the exams 00:24:06.480 |
now will have studied even harder than they would have in the past. Because you've made the exam 00:24:13.280 |
harder. They need to study harder. And as a result of that, they will be better because of that. 00:24:19.680 |
Their knowledge will be better because the exam was harder. And they had to try harder to pass it. 00:24:25.440 |
It's very similar in why we include negatives. It makes the training process harder for our model. 00:24:33.360 |
But at the same time, the model that we output from that is much better. So to actually get 00:24:40.880 |
these negatives, we perform what's called a negative mining step. Now, this negative mining 00:24:46.960 |
step is used to find highly similar passages to our positive passages that are obviously very 00:24:55.200 |
similar but are not the same. So we are assuming that these very similar passages are-- maybe they 00:25:02.400 |
talk about the same topic, but they don't really answer our query. We're assuming something like 00:25:08.560 |
this. We don't actually know, but we're assuming this. So we're performing a semantic search, 00:25:15.920 |
information retrieval step to actually return these possible negatives. 00:25:22.080 |
And then we can use them as negative training examples. So now what we will have is our query, 00:25:28.960 |
positive passage P plus, and assumed negative passage P minus. And the result of this is that 00:25:36.160 |
our model is going to have to learn very nuanced differences between that negative or assumed 00:25:40.960 |
negative and that positive in order to be able to separate them out and understand that, OK, 00:25:46.480 |
the positive is the pair for our query, but this negative, even though it's super similar, is not 00:25:53.840 |
the pair for our query. So it's just giving our model an extra task to perform here. 00:26:01.360 |
Now, with all that in mind, we also need to understand that we're assuming that all the 00:26:06.320 |
passages we're returning are negatives. We don't actually know that they are. And again, 00:26:13.840 |
this is something that we're going to handle later on in the pseudo-labeling step. But for now, 00:26:17.920 |
we just assume that they are, in fact, negative passages to our query. 00:26:24.000 |
So let's move on to the actual implementation of negative mining. 00:26:28.800 |
So the first thing we need to consider for this negative mining step is that there are 00:26:33.040 |
two parts to it. There is an embedding or retrieval model. Again, we're going to use 00:26:38.800 |
something that has been trained on data from pre-COVID times. And what that is going to do 00:26:46.800 |
is taking our passages and our queries and translate them into vectors. OK, but we need 00:26:53.280 |
somewhere to store those vectors. So we need a vector database to do that. So we're going to 00:26:58.560 |
use a vector database to store all of those vectors. And then we're going to perform a 00:27:02.400 |
search through all of those with our queries in order to identify the most similar passages. 00:27:09.600 |
And then we're going to say, OK, is this a positive passage for this query? If not, 00:27:14.800 |
then great. We're going to assume it's a negative passage for our query. We're just 00:27:19.280 |
going to go through and do that. So the first thing we're going to do is load our model. 00:27:24.000 |
So we're using this MS Markov Silver Base TASB. Again, that is a pre-COVID data that has been 00:27:31.920 |
trained on. And we're loading that into Sentence Transformers and just setting the max sequence 00:27:36.480 |
length to 256 here. OK, and then we're going to initialize a Pinecone index. So this is going to 00:27:45.120 |
be our vector database. We're going to store all of our vectors. So for that, we do need an API 00:27:50.720 |
key. It's all free. You don't need to pay anything for this. All the infrastructure and stuff is 00:27:54.720 |
handled for us, which is great. So we go to @pinecone.io, make an account if you need to. 00:28:01.520 |
And then you'll get an API key. Now, for me, it's not the best way to do it. It's just easy. 00:28:07.760 |
I'm just storing my API key in a file called secret in the same directory as my code. I'm 00:28:14.160 |
just reading that in here. OK, so my API key. Oh, as well, another thing you need to do here 00:28:22.640 |
is to install that, you need to pip install pinecone-client. OK, not pinecone, pinecone-client. 00:28:30.240 |
So we initialize pinecone with our API key. This is just a default environment. There are other 00:28:38.720 |
environments available if you want to use them, but I don't think you really need to. 00:28:42.880 |
And then we create the index. OK, so we have a native mining index. So here, I'm just saying 00:28:49.600 |
you can check your currently running indexes with pinecone list indexes. If native mine is not in 00:28:58.560 |
there, then I'm going to create native mine index. And there's a few things you need to pass, so 00:29:02.880 |
dimension. The dimension is what we have here, so that the embedding dimension that our model, 00:29:09.040 |
embedding model is outputting. So you specify that, 768. The metric, this is important. We 00:29:16.240 |
need to be using dot product metric here for GPL. And also, the number of pods. So by default, 00:29:24.160 |
this is one. What you can do is increase this. So I think I increased it to 70, which is probably 00:29:31.600 |
like massively overkill for this. But that shortened. So with one, I had a runtime, 00:29:39.200 |
I think, of one hour 30 or one hour 40 with 70, which again, it's probably overkill. Maybe 40 00:29:46.240 |
would do the same. I don't know. With that, it was, I think, 40 minutes. So a lot faster, 00:29:54.400 |
but you do have to pay for anything that's more than one pod. So I know if you're in a rush, 00:30:00.000 |
then fair enough, fine. Go with that. Otherwise, you just sit with one pod. It's free. 00:30:07.680 |
Okay. And then you connect. So you, this creates your index, and then you connect to your index. 00:30:15.040 |
So pine cone index, negative mine. Okay. And then we want to go through, we're going to use 00:30:21.280 |
our model. Actually, this bit here, we're creating our file reading generator. So it's just going to 00:30:27.120 |
yield the query and passage pairs. I include this in here because I got a value error because there's 00:30:33.120 |
some weird, literally one row of weird data in there. So I just added that in there to skip that 00:30:40.400 |
one. And yeah, initializing that generator. And then here, I'm going to go through and actually 00:30:47.200 |
encode the passages. So you see, we're creating a passage batch here. We're doing it in batches 00:30:54.880 |
again, make it faster. So I'm adding my passages to the passage batch, also adding IDs to the ID 00:31:03.520 |
batch. And then once we reach the batch size, which is 64 in this example, we encode our passage 00:31:11.600 |
batch. We also convert them to a list because we need it to be a list when we're upsetting 00:31:16.640 |
everything. So upset just means, it means update or insert. It's just like database lingo. We need 00:31:24.480 |
it to be a list when we're pushing this because it's pushing it through an API call here, so 00:31:29.280 |
JSON request. So we need to convert that to a list rather than NumPy array, which is, I think, 00:31:34.640 |
the default. And then we just create a list of ID and vectors and upload that to PyCone, our index. 00:31:45.600 |
And then we just refresh those batches. So we start again, and then we do 64 and then do the 00:31:51.280 |
encode and upset again. And then at the end here, I just wanted to check the number of vectors we 00:31:56.480 |
have in the index. So we see the vector, the dimensionality of the index, so as the index 00:32:02.960 |
fullness, this will tell you pretty much how quick it's going to run. At zero, it's perfect. It means 00:32:07.760 |
it's basically empty. It'll run pretty quick. And then you have the vector count here. So remember, 00:32:15.280 |
we have 200,000 examples or pairs, but not all of them are being used. So they are all being used, 00:32:25.840 |
but not all of them are unique because for each passage that we have, we have three queries. 00:32:30.480 |
So three unique queries, but three duplicated passages. So obviously, out of those 200,000 00:32:39.520 |
passages that we have, we need to divide that by three to get the actual number of unique passages 00:32:44.720 |
that we have in there, which is something like this 76840. So the database is now set up for 00:32:54.480 |
us to begin native mindset. It's full of all of our passages. So what we're going to do is loop 00:32:59.520 |
through each of the queries in pairs, which we created here. So we're going to loop through all 00:33:05.840 |
of those here in batches of batch size again, so 100. And we're going to initialize this triplets 00:33:13.840 |
list where we're going to store our query, positive and negative. At the moment, 00:33:18.560 |
we just have query positive, remember? So we're going to go through there, get our queries, 00:33:24.000 |
get our positives, and we're going to create the query embeddings. So that's in batches again. 00:33:32.560 |
Then we search for the top 10 most similar matches to our query, and then we loop through all of 00:33:40.640 |
those. So for query, positive passage, and query response, this query response will actually have 00:33:48.880 |
the 10 possible or the 10 high similarity passages for that query inside there. So extract those, 00:33:56.320 |
remember? 10 in there. And I'm just going to shuffle those. So I'm going to shuffle them, 00:34:00.560 |
and then we're going to loop through them. Now, we do this so that we're not just returning the 00:34:04.800 |
most similar one all the time, but we're returning one of the most similar top 10 instead. And then 00:34:12.160 |
we extract native passage from that one record that hit. And then one thing we really need to 00:34:20.080 |
consider here is, OK, if we've got all of our passages in there, we're also going to have the 00:34:24.320 |
positive passage for our query as well in there. So it's pretty likely that we're going to return 00:34:30.960 |
that, at least for a few of our queries for sure. So we need to check that negative passage we're 00:34:37.280 |
looking at does not match to the positive passage for that query. And then if not, 00:34:43.600 |
then that means it's a negative passage or assumed negative passage, and we can append 00:34:48.240 |
it to our triplets. So we have query, tab, positive, tab, negative. And then we say then file. 00:34:55.120 |
OK? Now, one last thing. Before we move on, we should delete our index. So if you're on the 00:35:04.640 |
free tier, you're just using that one pod, it's fine. You're not going to pay anything anyway. 00:35:08.800 |
But I guess it's also good practice to remove it after you're done with it. But if you are paying, 00:35:13.520 |
of course, you want to do this so you're not spending any more money than you need to. So 00:35:17.680 |
that's the negative mining step. Now we can move on to what is the final data preparation step, 00:35:24.400 |
which is the pseudo-labeling. Now, pseudo-labeling is essentially where we use a cross-encoder model 00:35:31.760 |
to kind of clean up the data from the previous two steps. So what this cross-encoder is going 00:35:39.920 |
to do is generate similarity scores for both the query positive and the query negative pairs. OK? 00:35:47.040 |
So we pass both those into cross-encoder model, and it will output predicted similarity scores 00:35:53.440 |
for the pairs. So first thing you need to do is actually initialize a cross-encoder model. Again, 00:35:59.360 |
this should have been trained on pre-COVID data. And then we're going to use a generate function 00:36:06.880 |
as we have been doing throughout this entire thing, just to read the data or the pairs, 00:36:11.840 |
triplets in this case, from file, and then yield them. And what we're going to do is use that 00:36:17.920 |
cross-encoder, using this function here, to calculate both the similarity of the positive 00:36:24.800 |
and the negative. And then we subtract those to get the margin between them, so the separation 00:36:29.680 |
between them, which is ideal for when we're performing margin MSC loss, which we will be 00:36:35.920 |
very soon for fine-tuning our model. So we go through using our generator here. We get a line, 00:36:43.840 |
a time, a query, positive, negative. We get a positive score, negative score, and then we 00:36:49.440 |
calculate the margin between them. And then we're going to append those to label lines and save it 00:36:55.120 |
all to file. So it's a pretty quick step, actually, for the pseudo-labeling step. But this final step 00:37:03.920 |
is also very important in ensuring that we have high-quality training data. Without it, we would 00:37:12.400 |
need to assume that all passages returning negative mining step are actually negatives, 00:37:18.080 |
and they're irrelevant for our particular query. In reality, this is obviously never the case, 00:37:25.200 |
because some negative passages are obviously going to be more negative or less negative 00:37:30.960 |
than others. And maybe some of them are not even negative in the first place. 00:37:37.040 |
So there was this chart or example from the GPL paper that I really liked. I've just adapted it 00:37:43.840 |
for our particular use case. So the query is, what are the symptoms of COVID-19? The positive, 00:37:50.960 |
so the first row there, is COVID-19 symptoms include fever, coughing, loss of sense of smell 00:37:56.160 |
or taste. So that positive passage is the passage that we started with. We generated that query in 00:38:03.360 |
the generative or query generation stage. And then we retrieved these negatives in the negative mining 00:38:10.000 |
stage. Now, just have a look at each one of these. So the first one, fever, coughing, and a loss of 00:38:16.160 |
sense of smell are common COVID symptoms. That is not actually a negative. It's a positive. But this 00:38:22.320 |
is very likely to happen in the negative mining stage, because we're returning the most similar 00:38:26.800 |
other passages. And those other very similar passages could very easily be actual genuine 00:38:33.840 |
positives that are just not marked as being the pair for our query. We don't actually know that 00:38:42.080 |
they're negatives. So in this case, this is a false negative, because it should be a positive. 00:38:47.040 |
Now, the next one, we have these easy negatives. The next one, symptoms are physical or mental 00:38:54.080 |
features that indicate a condition of disease. Now, this is a pretty typical easy negative, 00:39:00.240 |
because it has some sort of crossover. So the question is, what are the symptoms of COVID-19? 00:39:06.400 |
And this is just defining what a symptom is. But it's not about COVID-19 or the symptoms of COVID-19. 00:39:12.640 |
This is an easy negative, because it mentions one of those keywords. But other than that, 00:39:19.040 |
it's not really similar at all. So it will be quite easy for our model to separate that. 00:39:23.680 |
So it will be easy for it to look at this easy negative and say, OK, it's not relevant for our 00:39:28.160 |
query. It's very different to the positive that I have here. That's why it's an easy negative. 00:39:34.160 |
Next one, another easy negative is COVID-19 is believed to have spread from animals to 00:39:39.600 |
humans in late 2019. Again, it's talking about COVID-19, but it's not the symptoms of COVID-19. 00:39:45.680 |
So maybe it's slightly harder for our model to separate this one. But it's still a super easy 00:39:50.560 |
negative. And then we have a final one. So this is a hard negative. Coughs are a symptom of many 00:39:56.000 |
illnesses, including flu, COVID-19, and asthma. In this case, it's actually a partial answer, 00:40:03.280 |
because it tells us, OK, cough is one of the symptoms. But it's still not an answer. We're 00:40:08.240 |
asking about the symptoms. We want to know about multiple symptoms. This is kind of partially 00:40:13.840 |
answers, but it doesn't really. And what I really like about this is we look at the scores on the 00:40:19.600 |
right here. We have the scores when you're using something like pseudo-labeling from GPL, 00:40:24.320 |
and scores when you're not using pseudo-labeling, like if you're using GenQ. Now, GenQ, even so, 00:40:31.360 |
like for the false negative, GenQ is seeing this as a full-on negative. So you're just going to 00:40:36.800 |
confuse your model if you're doing this. If it has two things that are talking about the same thing, 00:40:42.320 |
and then you're telling your model, actually, one of them is not relevant, your model is just like, 00:40:47.200 |
you're really going to damage the performance of your model from doing that. And it's the same 00:40:51.760 |
with the other ones as well. There's almost like a sliding scale of relevance here. It's not just 00:40:58.000 |
black or white. There's all these shades of gray in the middle. And when using GPL and pseudo-labeling, 00:41:04.800 |
we can fill in those shades of gray, and we can see that there is a sliding scale of relevance. 00:41:11.120 |
Without pseudo-labeling, we really miss that. So it's a very simple step, but really important 00:41:18.480 |
for the performance of our model. Okay, so we now have the fully prepared data, and we can actually 00:41:25.840 |
move on to actually fine-tuning our sentence transformer using margin MSC loss. Now, this 00:41:34.800 |
fine-tuning portion is not anything new or unique. It's actually very common. Margin MSC loss is used 00:41:44.560 |
for a lot of sentence transformer models. So looking back at the generated data, we have the 00:41:50.400 |
format of Q, positive passage, negative passage, and margin. Let's have a look at how those fit 00:41:58.640 |
into the margin MSC function. So we have the query, positive, negative. We pass those through. 00:42:07.360 |
We get the similarity here. So this is our similarity of the query and the positive. 00:42:15.200 |
Okay, and this is this bi-encoder here is the one we're training, by the way. So this bi-encoder is 00:42:22.240 |
the final model that we're fine-tuning, that we're trying to adapt to our new domain. 00:42:28.320 |
And then we also calculate using the query and negative vectors here. We calculate the similarity 00:42:35.360 |
between those to get the similarity of Q and negative. We subtract those to get this here, 00:42:45.040 |
which delta hat is our predicted margin. With our predicted margin, we can compare that against 00:42:53.280 |
our true margin, or we assume it's a true margin from our cross-encoder. And we feed those into 00:43:00.640 |
the margin MSC over here, loss function. So we have the predicted minus the true for each sample. 00:43:10.560 |
We square all of that to get the squared error. And then we're going over for all of our samples, 00:43:19.440 |
taking the average. And that is our margin MSC loss. Now, what is really nice is that we can 00:43:27.840 |
actually use the default training methods in the Sentence Transformers library for this. 00:43:34.880 |
Because margin MSC loss is just a normal method. There's nothing new here at all. 00:43:40.960 |
OK. So from Sentence Transformers, I'm going to import the input example, which is just a 00:43:46.240 |
data format we always use in our training in here. So we're opening up the triplets margin here. 00:43:54.080 |
I'm just reading all the lines here. I'm not going to use a generator. This time, 00:43:58.720 |
I'm going to just create our training data, which is just a list of these input examples. 00:44:03.440 |
We pass the query positive and negative as a text in here, so as our triplet. 00:44:10.480 |
And then the label, we take the float of margin, because margin is a string from our TSV. 00:44:17.520 |
And from this, we get this 200,000 training examples. And we can load these pairs into a 00:44:24.400 |
generator data loader here. So we're just using normal torch data loader, nothing special here. 00:44:31.920 |
We use empty cache here just to clear our GPU in case we have anything on there. So if you keep 00:44:38.400 |
running this script again and again, trying to get it working or just modifying different 00:44:44.000 |
parameters to see how it goes, you'll want to include this in here. The batch size. Batch size 00:44:50.320 |
is important with margin MSC loss. The larger you can get, the better. I did see, I think, 00:44:56.000 |
in the sentence transformers documentation, it used something like 64. And that's hard. You're 00:45:02.000 |
going to need a pretty good GPU for that. And to be fair, even 32, you need a good GPU. This is using 00:45:09.040 |
one Tesla V100 for this. And it works. It's fine. So as high as you can get, 00:45:17.680 |
32, I think, is good. Even better if you can get 64. 00:45:23.200 |
And, yeah, we just initialize data loader. We set the batch size and we shuffle to true. 00:45:27.600 |
And so the next step we want to do here is initialize the buy encoder or sentence 00:45:35.760 |
transformer model that we're going to be fine tuning using domain adaption. So this is the 00:45:42.320 |
same one we used earlier in the retrieval step. We're actually going to be taking that and fine 00:45:47.280 |
tuning it. So we set the model match sequence length here, again, as we did before. And then 00:45:56.240 |
we initialize the margin MSC loss function as we would with normal sentence transformers. 00:46:01.280 |
And then we just run like model fit. We train for one epoch or actually, it depends. You can 00:46:08.800 |
get better performance if you train for more. I'll show you. And I just do 10% of the warm steps here 00:46:14.640 |
like we usually do. And I'm going to save it. Same model name, but I've added COVID onto the end 00:46:19.600 |
there. And this doesn't take long. I think this was-- maybe it was 40 minutes per epoch on the 00:46:26.960 |
test of V100. So it's not bad. Now, in terms of the amount of training steps you want to actually 00:46:35.280 |
put your model through, with GPL, they did fine. So in the paper, there's a really nice chart 00:46:43.360 |
that shows that after around 10,000 steps, for that particular use case, it leveled off. The 00:46:49.520 |
performance didn't really improve that much more from there. However, I did find-- so I'm using 00:46:55.600 |
200,000 examples. So one epoch for me is already 200,000 training steps, which is more. And I found 00:47:04.960 |
that actually training for 10 epochs-- I didn't test in between. I just went from one. And I also 00:47:11.040 |
tested 10. I found that 10, actually, the performance seemed to be better. I didn't really 00:47:17.600 |
quantitatively assess this. I just looked at it qualitatively. And it was performing or is 00:47:23.280 |
returning better results than the model that had been trained for one epoch. So for rather than 00:47:29.200 |
200,000 steps, it had been trained for 2 million steps, which is a fair bit more than 100,000 00:47:35.520 |
mentioned here. But that's just what I found. I don't know why. I suppose it will depend very much 00:47:43.920 |
on the data set you're using in particular. But I would definitely recommend you just try and test 00:47:50.480 |
a few different models. The training doesn't take too long for this step. So you have the luxury of 00:47:56.560 |
actually being able to do that and testing these different numbers of training steps. 00:48:02.160 |
So I mean, once training is complete, we can actually test that model as we usually would. 00:48:07.920 |
We have the model as well. So I can show you. So we have the model over on-- if we go to 00:48:14.720 |
Hugging Face CO Models, maybe. So if you search in here, you go Pinecone. And one of these is 00:48:23.680 |
the model that has been trained on COVID. So hit this one here. So you can actually copy this. And 00:48:30.240 |
you can just write-- I think it is-- it should maybe have an example. Ah, so you write this. So 00:48:36.720 |
you from Sentence Transformers, import Sentence Transformer. And you just do model equals Sentence 00:48:42.960 |
Transformer in your model name, which in this case is what you can see up here, Pinecone MS Marco. 00:48:48.240 |
And you can test the model that I've actually trained here. Again, this was trained on 10 epochs, 00:48:53.680 |
not one. OK, so in this evaluation here, you can see what I was looking at. So load the old model. 00:49:00.400 |
It's just the MS Marco, Silbert, Base, TUSB. And then here, I've got the model trained for one 00:49:07.440 |
epoch. And here, the model trained for 10 epochs, which is equivalent to the model I just showed 00:49:12.960 |
you on Hugging Face Hub. So I've taken these tests. So we have these queries and three answers 00:49:22.400 |
or three passages for each one of those. Now, these are not even really perfect matches to the 00:49:28.240 |
sort of data we were training on. But they're kind of similar in that they just include COVID-19. 00:49:34.240 |
OK, so it's not a perfect match. It's not even perfectly in the domain. But it's good enough, 00:49:40.240 |
I think. But I think that really shows how useful this technique is because it does work even with 00:49:46.640 |
this kind of not perfect match between what I'm testing on and what I'm actually training on. 00:49:52.080 |
So here, I'm just going through those. I'm getting the dot product score between those 00:49:56.560 |
and sorting based on that. So the highest rated passage for each query is returned. 00:50:02.400 |
So for the old model, how is COVID-19 transmitted? We get Ebola, HIV. And then at the very bottom, 00:50:09.680 |
we have Corona, OK? Again, what is the latest name variant of Corona? And then it's these here. 00:50:17.920 |
Again, COVID-19 is right at the bottom. We have Corona lager right at the top. 00:50:23.440 |
I don't know if the rest of the world calls it lager or beer, but I put lager there anyway. 00:50:30.400 |
And then we say, OK, what are the symptoms of COVID-19? Then we get flu. Corona comes second 00:50:34.880 |
here. And then we have symptoms at definition again. And then how will most people identify 00:50:40.160 |
that they have contracted the coronavirus? And then after drinking too many bottles of Corona 00:50:44.480 |
beer, most people are hungover, right? So obviously, it's not working very well. 00:50:48.160 |
And then we do the GPL model. So this is the training on one epoch. We see slightly better 00:50:53.200 |
results. So Corona is number one here. And then here, it's number two. So it's not great, but fine. 00:51:02.960 |
What are the symptoms of COVID-19? This one, I guess, right? So that's good. And then here, 00:51:10.880 |
how will most people identify that they have contracted the coronavirus? Again, it's doing 00:51:14.400 |
the drinking too many bottles of Corona. Again, then you have COVID-19 second place. 00:51:20.480 |
Then we have the model that's trained for 10 epochs. This one, every time. So number one, 00:51:25.360 |
number one, and number one. So it's pretty cool to see that. And yeah, I think it's super 00:51:35.120 |
interesting that this technique can actually do that using nothing more than just unstructured 00:51:41.840 |
text data. So that's it for this video on generative pseudo-labeling, or GPL. 00:51:48.480 |
Now, I think it's already very impressive what it can do. And more importantly, I think, 00:51:54.080 |
or what is more interesting for me is where this technique will go in the future. Is this going to 00:52:00.000 |
become a new standard in training sentence transformers, where you're building the 00:52:04.080 |
synthetic data? I feel like there's really a lot that we covered in this video and from the GPL 00:52:11.520 |
paper to unpack and think about. We have negative mining. We have pseudo-labeling. We have all 00:52:16.640 |
these different techniques that go together and produce this really cool, in my opinion, technique 00:52:22.960 |
that we can use super powerful. For me, it's just super impressive that you can actually do this. 00:52:30.800 |
So I hope there is more in this field in GPL or some new form of GPL in the future. 00:52:39.120 |
I'm very excited to see that. But for now, of course, that's it. I hope this video has been 00:52:45.760 |
useful. I hope you're excited about this as I am. So thank you very much for watching, 00:52:50.800 |
and I will see you again in the next one. Bye.