back to indexFine-tune High Performance Sentence Transformers (with Multiple Negatives Ranking)
Chapters
0:0 Intro
1:2 NLI Training Data
2:56 Preprocessing
10:11 SBERT Finetuning Visuals
14:14 MNR Loss Visual
16:37 MNR in PyTorch
23:4 MNR in Sentence Transformers
34:20 Results
36:14 Outro
00:00:03.560 |
at something called a multiple negatives ranking 00:00:07.000 |
loss, or MNL loss, for training sentence transformers. 00:00:18.440 |
transformers that have been trained specifically 00:00:20.480 |
for building dense vector representations of text. 00:00:26.040 |
And the highest performing sentence transformers 00:00:30.040 |
at the moment are all trained using this method. 00:00:32.600 |
So they're trained on a natural language inference data set, 00:00:40.640 |
So in this video, we are going to learn about MNL loss, 00:00:46.120 |
We're going to work through a PyTorch example very quickly, 00:00:48.720 |
and then we're going to look at how we can implement it 00:00:51.080 |
ourselves using the sentence transformers library. 00:00:56.240 |
So if you saw our previous article and video, 00:01:07.440 |
you'll remember that we had a NLI data set, which 00:01:12.080 |
is built from the Stanford NLI data set and the multi-genre NLI 00:01:20.280 |
So in essence, all it is is a load of sentence pairs, 00:01:28.680 |
indicates whether those sentence pairs either infer each other, 00:01:34.240 |
are not really related, like they could both be true, 00:01:45.640 |
And what we covered in the previous video and article 00:01:51.080 |
With softmax loss, we use those labels, the 0, 1, or 2, 00:02:04.760 |
Now, with MNR loss, we don't actually use those labels. 00:02:14.160 |
where we use something called a Siamese network, which 00:02:17.680 |
means we only have an anchor and a positive sentence. 00:02:42.320 |
And a negative would indicate that that anchor is not true. 00:02:46.840 |
And we can extract that from the NMI data set. 00:03:04.400 |
So if you've never seen any of this before, it's fine. 00:03:18.280 |
We have this 0, entailment, 1, neutral, and 2, contradiction. 00:03:30.040 |
e.g. a pair where the premise suggests the hypothesis, 00:03:34.800 |
so where the positive indicates the anchor or the other way 00:03:53.440 |
So we're using the HugInface dataset library here, 00:03:59.000 |
And we're getting the Stanford Natural Language Inference 00:04:10.840 |
So we have these three features, premise, hypothesis, 00:04:15.160 |
And if we come down here, we can see what one of those 00:04:35.440 |
does not necessarily mean the person on a horse 00:04:53.200 |
and a person on a horse jumps over a broken-down airplane. 00:04:55.880 |
If they're jumping over a broken-down airplane, 00:04:58.680 |
In fact, they-- well, almost definitely outside. 00:05:06.560 |
So this is an entailment, means it's an anchor positive pair. 00:05:11.280 |
So I said before, we have two data sets, not just one. 00:05:16.320 |
So we have the Stanford Natural Language Inference Dataset, 00:05:20.360 |
and we also have the Multi-Genre Natural Language 00:05:26.800 |
So I'm just loading that with the Load Dataset here. 00:05:33.480 |
Within Glue, there is this MLI, which is the data that we want. 00:05:37.440 |
And we're just getting the training data from that. 00:05:42.040 |
maybe there's validation data, and there's also test data. 00:05:59.000 |
So what we can do is we just remove those columns. 00:06:07.640 |
because we're going to merge these two data sets. 00:06:13.020 |
So after that, what we do is we perform this Cast function. 00:06:20.660 |
Now, if we didn't perform this Cast function-- 00:06:24.620 |
So I need to run these anyway, so start running them now. 00:06:34.700 |
And OK, let's say I'm not going to do this Cast. 00:06:40.060 |
I'm going to merge those two data sets together. 00:06:43.500 |
I do that, and we get this arrow, InvalidError. 00:06:47.440 |
And the reason for that is because although those two 00:06:50.220 |
data sets look very similar, they're not actually 00:07:01.700 |
whereas the other one does not include that NotNull 00:07:05.180 |
So I assume that means that this other data set 00:07:09.020 |
can include null values, whereas this one cannot. 00:07:12.020 |
So we can't merge them both because they're not 00:07:17.180 |
So what we do, we come up here, and we have to do this Cast. 00:07:20.340 |
So we're Casting the features of the SNLY dataset 00:07:29.900 |
We come down here, we can run it again, and it will work. 00:07:35.740 |
And now in Dataset, we have the full data sets. 00:07:44.660 |
And you can see that because we have 943,000 rows. 00:07:49.940 |
If you come up here, we only have 392 in the MNLY dataset. 00:07:54.260 |
And up here, we have 550 in the SNLY dataset. 00:08:03.780 |
And now in the previous video, an article I mentioned, 00:08:18.620 |
Or it's not an error, but it's where whoever labeled the data, 00:08:26.980 |
they couldn't decide on the nature of the relationship 00:08:34.500 |
I think it's 700 or so sentences in there or pairs in there 00:08:48.060 |
Now, what we do is we use this filter method to remove. 00:08:58.500 |
equal to minus 1, which is saying that row is false, 00:09:02.020 |
e.g. we do not keep it if its label value is equal to minus 1. 00:09:12.100 |
because we actually only want anchor positive pairs, which 00:09:19.660 |
or we only want to keep the rows which have a 0 label. 00:09:27.220 |
So we need to modify this to just keep the 0 values. 00:09:35.060 |
So we'll say, OK, false if x label is not equal to 0. 00:09:52.980 |
and also removes the neutral and contradiction rows as well. 00:10:11.900 |
So what we have here is we have that anchor positive pair 00:10:19.420 |
So here, our anchor would be, I think, from i's. 00:10:44.980 |
And as we usually would with a transform model, 00:10:52.740 |
We're going to do that with just a tokenizer method, 00:10:55.700 |
a pre-trained tokenizer from the base transformers library. 00:11:04.780 |
If we're using the sentence transformers library, 00:11:07.700 |
it's a lot easier, and we don't actually need to do that. 00:11:10.060 |
It will deal with that for us, which is quite nice. 00:11:13.060 |
And what that produces is just a tokenized version 00:11:23.700 |
so I'm just going to put A for anchor, and over here, 00:11:29.140 |
And then what happens next is we have a single BERT model. 00:11:42.020 |
And then after we've processed the anchor data, 00:11:49.060 |
So it's like we're using-- for every single training step, 00:11:55.380 |
So we process both of those through our BERT model, 00:12:10.060 |
And then we use something called a mean pooling function. 00:12:12.940 |
So mean pooling function is, say we have some vectors here-- 00:12:23.300 |
What we're going to do is take the mean value 00:12:32.180 |
So let's say we have three dimensions here in our-- 00:12:36.260 |
no, that's a bad idea, because we have three vectors. 00:12:43.780 |
And we take the average across each of those dimensions. 00:12:46.660 |
And what we produce from that mean pooling operation 00:13:13.700 |
And we produce that sentence embedding both for A, our anchor, 00:13:30.860 |
so this is a Siamese network, this double network. 00:13:35.860 |
And this triple network is a triplet network. 00:13:44.660 |
So where before we had the contradiction label, 00:14:08.060 |
And then we would also get a negative vector-- 00:14:11.900 |
negative sentence embedding at the end there. 00:14:17.300 |
so the whole MNR loss thing, multiple negative ranking 00:14:23.500 |
is we take all of those vectors that we produce, 00:14:27.540 |
the A, which is the anchors, the P, and the N 00:14:32.260 |
for the positive and negatives, if we're using negatives. 00:14:36.260 |
If not, we just basically blank out the dark blue part 00:14:43.340 |
All we do is we calculate the cosine similarity 00:14:55.540 |
to go through the anchor and positive version of it for now. 00:15:13.180 |
between anchor 0 and positive 0, 1, 2, 3, and 4, and so on, 00:15:24.460 |
so the number of sentence embeddings that we have there 00:15:29.020 |
is equal to the batch size that we are using. 00:15:36.660 |
is that we would expect the cosine similarity between A0 00:15:40.780 |
and P0 to be greater than the cosine similarity between, 00:15:54.500 |
expect the cosine similarity between A3 and P3 00:15:58.420 |
to be greater than that between A3 and P0, or P1, or P2, or P4. 00:16:14.900 |
is just going to have the greatest argmax value 00:16:28.540 |
So the labels for this are actually just 0, 1, 2, 3, 4, 00:16:38.620 |
OK, so here we have what is our involved MNR loss training 00:16:48.980 |
We're just using PyTorch rather than the Sentence Transformers 00:16:55.740 |
going to come to where we actually start training. 00:17:05.660 |
So I mentioned before we had that mean pooling function 00:17:08.140 |
where we're getting the average across dimensions. 00:17:15.100 |
need to consider where we have padding values, where 00:17:24.340 |
want to consider those in our average function, 00:17:26.700 |
because then it's going to obviously bring down 00:17:28.780 |
our average a lot just for having more padding tokens, 00:17:34.820 |
So that's why it looks more complicated than you 00:17:46.620 |
It's just a BERT model that we're using here, 00:17:49.020 |
plain, straight BERT model, nothing special about it. 00:17:56.700 |
And just moving it to a CUDA GPU if we have one. 00:18:01.020 |
So it is available, checking if it's available. 00:18:14.020 |
We're checking and calculating similarity between pairs. 00:18:17.420 |
That's what we're doing here, just initializing that function. 00:18:23.220 |
So here we're using a categorical cross-entropy loss. 00:18:32.460 |
So I think I thought I did use it a bit later on, I think. 00:18:44.180 |
We multiply our similarity score value by the scale value 00:18:52.980 |
Down here, so we're using Transforms Optimization. 00:18:59.340 |
saying, for the first 10% of the training data, 00:19:03.940 |
I want to warm up to this learning rate of 2e to minus 5. 00:19:08.780 |
So we're not going to start at that learning rate initially. 00:19:29.540 |
And in the anchor, we have our attention mask 00:19:49.820 |
So that's how our Transformer understands our text. 00:19:57.540 |
Yeah, I'll just mention here what is in there. 00:20:23.500 |
and we have positive sentence embeddings, which 00:20:37.860 |
And then we do the mean pooling to get the sentence embeddings. 00:20:44.940 |
the cosine similarity between each of our anchors 00:20:52.860 |
So we create that array of values of cosine similarities. 00:21:01.900 |
expect the true pair to have the highest cosine similarity. 00:21:05.940 |
So that is what we will do a little further down here. 00:21:14.940 |
is just outputs a tensor, which is 0, 1, 2, 3, 4, 00:21:25.340 |
So that's where you'd expect the argmax value to be. 00:21:31.100 |
So for A3, you'd expect the maximum cosine similarity 00:21:48.620 |
So we're calculating between the scores, which 00:22:00.700 |
here to be the maximum value in a specific row at the index 00:22:10.360 |
So that's the A3, P3 pair that I'm talking about. 00:22:24.140 |
But when I say that's order, it's actually quite a lot. 00:22:26.780 |
I mean, there's a lot of code going into this. 00:22:33.620 |
that's the labels tensor that I just mentioned, by the way. 00:22:49.780 |
And to be honest, if you do the same with sentence 00:22:53.760 |
transformers, you're probably going to get better results. 00:22:56.420 |
So I'm going to show you how to do that with sentence 00:23:27.220 |
which is the data format that our sentence transformers 00:23:37.300 |
So we want to do from sentence transformers import input 00:23:49.020 |
And what we'll do is we'll just initialize the list here, 00:24:32.460 |
We just generate them as we're performing the training. 00:24:36.580 |
So we don't need the label here, so we just write text. 00:24:52.780 |
And we also want the hypothesis, which is our positive. 00:25:10.900 |
So what you can do is if I just go from tqdm.auto, 00:25:33.060 |
but I think it's nice to be able to see that, especially 00:25:37.580 |
Now, here we need to initialize what is a data loader. 00:25:48.180 |
loader from the sentence transformers library. 00:25:54.340 |
So from sentence transformers, I want to import data sets. 00:26:12.180 |
so we're just going to call that loader again, as we usually do. 00:26:28.020 |
And what this is going to do is, unlike a normal data 00:26:31.420 |
loader in PyTorch, which would just feed you-- 00:26:35.460 |
it would just feed you 32, in this case, samples at once, 00:26:56.060 |
But if you do have data sets which might have duplicates, 00:27:03.020 |
Because if you think, OK, if you have a duplicate, 00:27:06.020 |
and our labels are saying that pair A1 and P1 00:27:11.660 |
should be the same, but in reality, over here, 00:27:15.060 |
we have A7 and P7, which are exactly the same, 00:27:20.420 |
And it's going to say that A1 and P7 should be matching. 00:27:29.820 |
So that's why we use this no duplicates data loader 00:27:48.540 |
the batch size, which is equal to our batch size. 00:27:56.980 |
Now, what we need to do is initialize a model. 00:28:00.180 |
So in Sentence Transformers, we can do the same thing 00:28:03.620 |
as we do with Hug and Face Transformers, where 00:28:15.980 |
two modules, which is going to be the Transformer module, 00:28:21.780 |
And that's going to be followed by a pooling module 00:28:28.340 |
So we're going to write, from Sentence Transformers, 00:28:33.380 |
import models, and also Sentence Transformer. 00:28:48.180 |
So we initialize that, and we're using a BERT base on case 00:29:09.060 |
needs to know the dimensionality of the vectors 00:29:15.860 |
We can say, get word embedding dimension, which is the 768. 00:29:26.420 |
to know which type of pooling we want to use. 00:29:33.940 |
And we can see we have all these different methods in here. 00:29:43.140 |
We can take the mean and consider the square root 00:29:50.780 |
I don't know this one, so I could be completely wrong. 00:29:55.900 |
So this is a mean pooling method that I mentioned before. 00:29:59.140 |
We're going to be using that one, so we say true. 00:30:02.780 |
And then to actually initialize the model using 00:30:11.140 |
are going to write model equals sentence transformer. 00:30:15.460 |
And like I said before, this is how you would usually 00:30:19.660 |
So if you wanted to load a pre-trained model, 00:30:21.580 |
you'd be like BERT base on case, like in here. 00:30:25.620 |
We are initializing a new model using these two modules. 00:30:33.260 |
And we have BERT followed by the pooling function or module. 00:30:41.900 |
And then we can have a look at what we have there. 00:30:53.820 |
And then here, we have the structure of our model. 00:30:57.220 |
So we can see with transformer, we're using the BERT model. 00:31:07.820 |
And we see that the only one of these values that is true 00:31:13.860 |
And the rest of them are false because we're not 00:31:26.300 |
So from Sentence Transformers, import losses. 00:31:32.940 |
And our loss function is going to be equal to losses. 00:31:36.660 |
And we have the multiple negatives ranking loss, 00:31:50.020 |
it knows the model parameters that it's dealing with. 00:31:54.580 |
And with that, we're ready to actually start training 00:32:09.220 |
so like I said before, we have that 10% of warm-up steps 00:32:16.220 |
But it isn't as complicated to set that this time. 00:32:27.620 |
Well, it would be the length of the loader, OK? 00:32:43.900 |
And then we can call model.fit, so like tennis flow. 00:32:50.820 |
now, we pass our model configuration and setup 00:32:57.780 |
So the first thing we want to do is set our train objectives, 00:33:03.820 |
which is just a list containing train objective pairs, 00:33:11.340 |
And that is loader followed by loss, so our objectives. 00:33:26.940 |
So we write warm-up steps is equal to warm-up. 00:33:36.780 |
So I think I use that MNR2 or something for the final one 00:33:45.940 |
And then after that, final thing, if you want. 00:34:00.940 |
So you can set show progress bar equal to false, if you want. 00:34:08.180 |
We run this, and this will fine-tune our model. 00:34:11.380 |
Now, I've already run it, so I'm not going to run it again. 00:34:17.140 |
So what I'm going to show you are just the results from that. 00:34:20.300 |
OK, so this is my other notebook where I already 00:34:30.700 |
So it's a load of random sentences, but a couple of them 00:34:44.740 |
We've got knit noodles and weaving spaghetti. 00:34:50.340 |
with construction materials and dentists with trim bricks. 00:35:14.220 |
of these different sentences, or the encodings produced 00:35:18.700 |
by our model, the sentence embeddings produced 00:35:21.220 |
by our model, and all of the equivalent embeddings 00:35:29.100 |
And we just use matplotlib and seaborn to actually plot that. 00:35:35.060 |
And we see that we get this nice visualization. 00:35:45.460 |
And those are the three pairs I mentioned before 00:35:52.060 |
And the rest of these values are all very low, 00:35:56.060 |
which is obviously very good because those pairs are not 00:36:09.020 |
where the true pairs are, or the true semantic pairs are there. 00:36:15.020 |
OK, so I think that's pretty much it for this video. 00:36:19.980 |
We've kind of gone through multiple negatives, ranking 00:36:29.460 |
if you're going to train a sentence transformer, 00:36:38.060 |
go with, depending on the data that you have. 00:36:41.580 |
Now, I mean, that's everything for the video.