Back to Index

Fine-tune High Performance Sentence Transformers (with Multiple Negatives Ranking)


Chapters

0:0 Intro
1:2 NLI Training Data
2:56 Preprocessing
10:11 SBERT Finetuning Visuals
14:14 MNR Loss Visual
16:37 MNR in PyTorch
23:4 MNR in Sentence Transformers
34:20 Results
36:14 Outro

Transcript

Hi, welcome to the video. We're going to be having a look today at something called a multiple negatives ranking loss, or MNL loss, for training sentence transformers. Now, if you're new to sentence transformers, they're essentially NLP models using transformers that have been trained specifically for building dense vector representations of text.

And the highest performing sentence transformers at the moment are all trained using this method. So they're trained on a natural language inference data set, and they are trained with MNL loss. So in this video, we are going to learn about MNL loss, how it works. We're going to work through a PyTorch example very quickly, and then we're going to look at how we can implement it ourselves using the sentence transformers library.

So let's jump straight into it. So if you saw our previous article and video, you'll remember that we had a NLI data set, which is built from the Stanford NLI data set and the multi-genre NLI data set. We're using the same data set here. So in essence, all it is is a load of sentence pairs, and there is a label 0, 1, or 2, which indicates whether those sentence pairs either infer each other, are not really related, like they could both be true, but not necessarily because of each other, or if they contradict each other.

So there are the three labels. And what we covered in the previous video and article was something called softmax loss. With softmax loss, we use those labels, the 0, 1, or 2, to produce a classification. We optimize on that label. Now, with MNR loss, we don't actually use those labels.

We just use the sentence pairs. And what I'm going to show you is where we use something called a Siamese network, which means we only have an anchor and a positive sentence. Now, what that means is an anchor is a-- you can think of it as like a base sentence.

And a positive to that anchor sentence is just a sentence that indicates that the anchor is true. Now, we could also have negatives. And a negative would indicate that that anchor is not true. And we can extract that from the NMI data set. So let's start pre-processing our data and have a look at what we need to do.

So this code here, we already wrote it in the previous video and article. But we're going to go through it. So if you've never seen any of this before, it's fine. We're going to go through it. I'm going to explain everything. It's not a problem. So basically, up here, all I'm saying is what I just told you.

So these are our labels in our data. We have this 0, entailment, 1, neutral, and 2, contradiction. And then I'm just saying, if M and R, we don't actually need those labels. All we need are anchor-positive pairs, e.g. a pair where the premise suggests the hypothesis, so where the positive indicates the anchor or the other way around.

So essentially, what we need are just rows where we have the label 0. So we're going to do that. But first, we need to actually get our data. So we're using the HugInface dataset library here, which is very good. And we're getting the Stanford Natural Language Inference dataset here.

Now, down here, this format you can see is the format of the dataset. So we have these three features, premise, hypothesis, and label. And if we come down here, we can see what one of those looks like. So we have premise, a person on a horse jumped over a broken-down airplane, and this hypothesis, a person is training his horse for a competition.

And the label for that is 1. Come up here, 1 means neutral. So basically, this here, a person is training his horse for a competition, does not necessarily mean the person on a horse is jumping over a broken-down airplane. And then we come down here, and I think this one's entailment.

So this one is a pair that we want. This is an anchor positive pair. And we have a person's outdoors on a horse, and a person on a horse jumps over a broken-down airplane. If they're jumping over a broken-down airplane, they're probably outside. In fact, they-- well, almost definitely outside.

So this indicates it. So this is an entailment, means it's an anchor positive pair. So I said before, we have two data sets, not just one. So we have the Stanford Natural Language Inference Dataset, and we also have the Multi-Genre Natural Language Inference Dataset. And that's what we're getting here, MLI.

So I'm just loading that with the Load Dataset here. We're loading it from the Glue Dataset. Within Glue, there is this MLI, which is the data that we want. And we're just getting the training data from that. We don't-- because I think there's also-- maybe there's validation data, and there's also test data.

We don't want that. We just want the training data. So we can see the format for this data set, almost exactly the same. We just have this extra ID. So what we can do is we just remove those columns. So we MLI.RemoveColumns, and we specify that we don't want the ID column, because we're going to merge these two data sets.

But to merge them, they both need to have the exact same format. So after that, what we do is we perform this Cast function. Now, if we didn't perform this Cast function-- let me show you. So I need to run these anyway, so start running them now. Come down here.

I'm going to Load Dataset, Remove Columns. And OK, let's say I'm not going to do this Cast. I'm just going to concatenate the data sets. I'm going to merge those two data sets together. I do that, and we get this arrow, InvalidError. And the reason for that is because although those two data sets look very similar, they're not actually exactly the same.

The data sets used in one of them includes this NotNull specification, whereas the other one does not include that NotNull specification. So I assume that means that this other data set can include null values, whereas this one cannot. So we can't merge them both because they're not the same data type.

So what we do, we come up here, and we have to do this Cast. So we're Casting the features of the SNLY dataset to be the same as the MNLY dataset features. So we do that, run it. We come down here, we can run it again, and it will work.

And now in Dataset, we have the full data sets. That's both MNLY and SNLY. And you can see that because we have 943,000 rows. If you come up here, we only have 392 in the MNLY dataset. And up here, we have 550 in the SNLY dataset. So we have all the data.

And now in the previous video, an article I mentioned, we have these negative 1 labels. Now, this is an error in the data. Or it's not an error, but it's where whoever labeled the data, they couldn't decide whether this was-- they couldn't decide on the nature of the relationship between the pair of sentences.

So they just put minus 1. There's not very few. I think it's 700 or so sentences in there or pairs in there that are labeled with this minus 1. But that's not a label. We can't do anything with that in our data. Now, what we do is we use this filter method to remove.

So we say false for the row. So the row is false if the label is equal to minus 1, which is saying that row is false, e.g. we do not keep it if its label value is equal to minus 1. Otherwise, we do keep it. Now, things are a bit different now because we actually only want anchor positive pairs, which means we want to remove-- or we only want to keep the rows which have a 0 label.

We want to remove everything else. So we need to modify this to just keep the 0 values. So we'll say, OK, false if x label is not equal to 0. Else it's true. So this is going to keep only 0 values and remove everything else. So that removes those error-less rows and also removes the neutral and contradiction rows as well.

Now, we remove that and let that run. Now, while that's running, let me show you some visuals of how this will work. So what we have here is we have that anchor positive pair from before. So here, our anchor would be, I think, from i's. And the positive would be our hypothesis.

But it would obviously only be rows where the label is equal to 0, which is the entailment label. Now, we have our anchor and positive. And as we usually would with a transform model, we tokenize them. So we tokenize them. We're going to do that with just a tokenizer method, a pre-trained tokenizer from the base transformers library.

Or we do that if we're using PyTorch. If we're using the sentence transformers library, it's a lot easier, and we don't actually need to do that. It will deal with that for us, which is quite nice. And what that produces is just a tokenized version of the anchor and the positive.

So here, we have our-- so I'm just going to put A for anchor, and over here, P for positive. And then what happens next is we have a single BERT model. We actually visualize it as two BERT models, because we're processing the anchor. And then after we've processed the anchor data, we move on to processing the positive data.

So it's like we're using-- for every single training step, we're using the same BERT model twice. So we process both of those through our BERT model, and that produces token embeddings. So token embeddings are 512 dense vectors, which contains 168 dimensions. And then we use something called a mean pooling function.

So mean pooling function is, say we have some vectors here-- 1, 2, and 3. What we're going to do is take the mean value across each dimension. So let's say we have three dimensions here in our-- no, that's a bad idea, because we have three vectors. Let's say we have five dimensions here.

And we take the average across each of those dimensions. And what we produce from that mean pooling operation is a single five-dimensional vector. So we have the mean of each of those-- across each of those dimensions. That's the mean pooling operation. And obviously, from that, we produce what is our sentence embedding, which is a single one-dimensional dense vector.

And we produce that sentence embedding both for A, our anchor, and for P, our positive. And what we have here-- so I don't know if I mentioned it, but we-- so this is a Siamese network, this double network. And this triple network is a triplet network. And it works in the exact same way, but we also have a negative.

So where before we had the contradiction label, I think the label is 2. But I could be wrong. I think it-- no, I think it is 2. That would be a negative sentence because it contradicts the anchor. And what we do is we process that as well. And then we would also get a negative vector-- negative sentence embedding at the end there.

And what we do with that during training-- so the whole MNR loss thing, multiple negative ranking thing-- is we take all of those vectors that we produce, the A, which is the anchors, the P, and the N for the positive and negatives, if we're using negatives. If not, we just basically blank out the dark blue part of this visual.

All we do is we calculate the cosine similarity between our anchor and our positive, and maybe our neutral negative as well. But I'm just going to say I'm going to go through the anchor and positive version of it for now. So we calculate the cosine similarity between the anchor and positive.

And we do that for every anchor. So anchor 0 for the first one, we would calculate the cosine similarity between anchor 0 and positive 0, 1, 2, 3, and 4, and so on, up until the batch size. So what we have there is the actual-- so the number of sentence embeddings that we have there is equal to the batch size that we are using.

And obviously, what we would expect with our anchor and positive pair is that we would expect the cosine similarity between A0 and P0 to be greater than the cosine similarity between, let's say, A0 and P1, or P2, or P3, or P4. And likewise, if we had A3, we would expect the cosine similarity between A3 and P3 to be greater than that between A3 and P0, or P1, or P2, or P4.

So that's how we optimize. We say, OK, the target label for A0 is just going to have the greatest argmax value with the value of P0, or with the pair P0. And for A3, it's going to be with P3. So the labels for this are actually just 0, 1, 2, 3, 4, up until the batch size, which we will see.

I'm going to show you that. OK, so here we have what is our involved MNR loss training notebook. So I'm just going to take you through this. We're just using PyTorch rather than the Sentence Transformers library here. Now, I'm going to go through, and we're going to come to where we actually start training.

So we have a mean pooling operation here. So I mentioned before we had that mean pooling function where we're getting the average across dimensions. That's what this is dealing with. The reason it looks more complicated than just taking the average is because we need to consider where we have padding values, where the mask value is 0, because we don't want to consider those in our average function, because then it's going to obviously bring down our average a lot just for having more padding tokens, which doesn't make sense, obviously.

So that's why it looks more complicated than you probably expect for an averaging function. And then so we're using PyTorch here. So we're moving some set model. It's just a BERT model that we're using here, plain, straight BERT model, nothing special about it. I think it's BERT based on case.

And just moving it to a CUDA GPU if we have one. So it is available, checking if it's available. And then here we're defining some layers that we're going to be using in MNR loss. So we have the cosine similarity. I said before we're doing that. We're checking and calculating similarity between pairs.

That's what we're doing here, just initializing that function. And we also have a loss function, of course. So here we're using a categorical cross-entropy loss. And we'll see how that works. We can also use a scale function. So I think I thought I did use it a bit later on, I think.

But it's fine. It's not really too important. But we just use that. We multiply our similarity score value by the scale value later on. Down here, so we're using Transforms Optimization. So what we're doing here is we're saying, for the first 10% of the training data, I want to warm up to this learning rate of 2e to minus 5.

So we're not going to start at that learning rate initially. We're going to slowly build up to it. And yeah, that's all we're doing there. And we can see our batch here. So we have attention mask. So sorry, we have the anchor. And in the anchor, we have our attention mask and we have our input IDs.

And then in positive, we also have the same. We have attention mask and input IDs. So input IDs and attention mask, if you use the Transformers library before, you probably recognize these. It's just the input tensors that we use when we're feeding text into a model. So that's how our Transformer understands our text.

Yeah, I'll just mention here what is in there. And let's come down. Is there anything important? I don't think so. OK, let's come down to the training loop. So yeah, using a scale value of 20 there. And we come to here. So here, we have our anchor embeddings. So these are sentence embeddings.

So we have the anchor sentence embeddings and we have positive sentence embeddings, which we've output from our BERT model. And what we do, we do the mean pooling. So when I said sentence embeddings here, they are the token embeddings, sorry. So we have 512 token embeddings. And then we do the mean pooling to get the sentence embeddings.

And then what we do is we calculate the cosine similarity between each of our anchors and all of the positive values. So we create that array of values of cosine similarities. And on each row, obviously, we would expect the true pair to have the highest cosine similarity. So that is what we will do a little further down here.

So labels, what this value does here is just outputs a tensor, which is 0, 1, 2, 3, 4, up until the batch size of the data. So that's where you'd expect the argmax value to be. So for A3, you'd expect the maximum cosine similarity to be in the third index, where P3, where it's been compared to P3.

And then here, we are calculating the loss. So we're calculating between the scores, which we have up here, and the labels. So this is taking the-- we're basically looking for this value here to be the maximum value in a specific row at the index equal to the current label.

So that's the A3, P3 pair that I'm talking about. Now we're just optimizing on that. And that's it. That's the order is to mnr loss in PyTorch. But when I say that's order, it's actually quite a lot. I mean, there's a lot of code going into this. And it's very confusing.

Oh, here is the-- that's the labels tensor that I just mentioned, by the way. So you can see it's just counting up to our batch size, which is 32. So that's the PyTorch implementation. But it's complicated. And to be honest, if you do the same with sentence transformers, you're probably going to get better results.

So I'm going to show you how to do that with sentence transformers. It's a lot easier and a lot more effective. OK. So in sentence transformers, so we're back in the first notebook now. So we have our data set, 314 rows. And it is just anchor positive pairs. So with sentence transformers, we use something called an input example.

So we have a big list of input examples, which is the data format that our sentence transformers library training methods would expect. So we want to do from sentence transformers import input example. And what we'll do is we'll just initialize the list here, so samples. And this is very simple.

We're just going to go for sample in-- or for row, let's say, for row in data set. We want to say samples.append. And in here, we have our input example. So the thing we just import, the object, the special sentence transformers object. And in the previous video and article, this accepts two different parameters.

It accepts the text and label. Now, like I said, we don't have labels. We just generate them as we're performing the training. So we don't need the label here, so we just write text. And in here, we-- so list. And in there, we want our row. We want the premise, which is our anchor.

And we also want the hypothesis, which is our positive. OK, and that's all we need, OK? So that will take a while. So what you can do is if I just go from tqdm.auto, import tqdm. So we can add this to our loop here so that we have a nice progress bar, so we can see how far along we are, how long it's going to take.

And I mean, it's very quick anyway, but I think it's nice to be able to see that, especially for the longer data sets. Now, here we need to initialize what is a data loader. Usually, we use the PyTorch data loader. This time, we're not going to. This time, we're going to use a special data loader from the sentence transformers library.

And I'll explain why in just a moment. So from sentence transformers, I want to import data sets. We set the batch size, as we usually do. We're going to set that equal to 32. And we're creating a data loader, so we're just going to call that loader again, as we usually do.

And in here, we want to write data sets, and we want the no duplicates data loader. So you can see that there. And what this is going to do is, unlike a normal data loader in PyTorch, which would just feed you-- it would just feed you 32, in this case, samples at once, it wouldn't check the data in there.

Or no duplicates data loader checks that you don't have any duplicates within the same batch. Now, realistically, with this data set, probably not going to get that anyway. But if you do have data sets which might have duplicates, you can use this to avoid that issue. Because if you think, OK, if you have a duplicate, and our labels are saying that pair A1 and P1 should be the same, but in reality, over here, we have A7 and P7, which are exactly the same, it's going to confuse the model.

And it's going to say that A1 and P7 should be matching. But in reality, it should just be A1 and P1. So that's why we use this no duplicates data loader to remove any possibility of that happening. Now, as well, if that happens occasionally, it's not really an issue anyway.

But it's nice to just be careful with it. So we have our samples, and we want the batch size, which is equal to our batch size. So that's our data loader. Now, what we need to do is initialize a model. So in Sentence Transformers, we can do the same thing as we do with Hug and Face Transformers, where we load a pre-trained model.

But we can also initialize a new model using what are called modules. So in this case, we're going to use two modules, which is going to be the Transformer module, so the actual bear itself. And that's going to be followed by a pooling module for the mean pooling that we do.

So we're going to write, from Sentence Transformers, import models, and also Sentence Transformer. And we say, OK, BERT is equal to models, and it's a Transformer model. So we initialize that, and we're using a BERT base on case model. And we also have our pooling module, and that is models pooling.

And in here, the pooling approach needs to know the dimensionality of the vectors that it's going to be dealing with. And we can get that from BERT. We can say, get word embedding dimension, which is the 768. And as well as that, we also need to know which type of pooling we want to use.

And for that, we can write pooling. And we can see we have all these different methods in here. So we have the pooling mode. We can use the CLS token. We can take the maximum value. We can take the mean and consider the square root of the length of the tokens.

I don't know this one, so I could be completely wrong. And we also have the mean. So this is a mean pooling method that I mentioned before. We're going to be using that one, so we say true. And then to actually initialize the model using those two parts or two modules, we are going to write model equals sentence transformer.

And like I said before, this is how you would usually load a pre-trained model. So if you wanted to load a pre-trained model, you'd be like BERT base on case, like in here. We're not doing that. We are initializing a new model using these two modules. So we write modules equals.

And we have BERT followed by the pooling function or module. And then we can have a look at what we have there. So we have the model. So hopefully, it doesn't take too long. We get our list. This is fine. It's just coming from, I think, Honey Face. And then here, we have the structure of our model.

So we can see with transformer, we're using the BERT model. And in here, the pooling-- we have the embedding dimension, 7, 6, 8. And we see that the only one of these values that is true is the pooling mode mean tokens. And the rest of them are false because we're not using those methods.

Now, all we need to do is do the-- we need to initialize the loss function. So from Sentence Transformers, import losses. And our loss function is going to be equal to losses. And we have the multiple negatives ranking loss, which we can see there. And all we need to pass to that is a model.

So it knows what model it is-- it knows the model parameters that it's dealing with. So let's run that. And with that, we're ready to actually start training or fine-tuning the model. So we'll say, OK, we want to use one epoch. And the number of warm-up steps-- so like I said before, we have that 10% of warm-up steps that we want to use.

But it isn't as complicated to set that this time. We just want to say, OK, 10% 0.1 multiplied by the total number of steps. What total number of steps? Well, it would be the length of the loader, OK? And we write int there, OK? And that was our warm-up steps.

And then we can call model.fit, so like tennis flow. And then in here, we just pass a few-- now, we pass our model configuration and setup and so on. So the first thing we want to do is set our train objectives, which is just a list containing train objective pairs, you could say.

For us, we just have one of those. And that is loader followed by loss, so our objectives. And then we have the epochs. So epochs, just put one there. It's probably easier. We have number of warm-up steps. So we write warm-up steps is equal to warm-up. And then we have the output path, so where we're going to save our model.

So I think I use that MNR2 or something for the final one that I put together. And then after that, final thing, if you want. So this will come up with a progress bar like we saw before. But the progress bar, for me, it will just print every single step update.

So it's quite annoying. So you can set show progress bar equal to false, if you want. And then you just run this, OK? We run this, and this will fine-tune our model. Now, I've already run it, so I'm not going to run it again. It takes a while. So what I'm going to show you are just the results from that.

OK, so this is my other notebook where I already ran all of what you just saw. We have our sentences here. So it's a load of random sentences, but a couple of them do match up, right? So we have this sushi one here, and there's another sushi one here.

But what I've done is not use the same words in both sentences. We've got knit noodles and weaving spaghetti. And we also have dental specialists with construction materials and dentists with trim bricks. So similar in concept, but they don't share any of the same words. And it's quite abstract as well.

So down here, so we have our model. We're just encoding those sentences. And then what I'm doing is I'm just creating a similarity matrix between all of these different sentences, or the encodings produced by our model, the sentence embeddings produced by our model, and all of the equivalent embeddings from this list of sentences.

And we just use matplotlib and seaborn to actually plot that. And we see that we get this nice visualization. And we see that we have 4 and 3 that align, and also 9 and 1, and 7 and 5. And those are the three pairs I mentioned before that are very similar.

And the rest of these values are all very low, which is obviously very good because those pairs are not similar. Maybe they have some sort of similarity, but they're not similar. So we can obviously see straight away where the true pairs are, or the true semantic pairs are there.

OK, so I think that's pretty much it for this video. We've kind of gone through multiple negatives, ranking loss, and how to implement it. And like I said before, this is really if you're going to train a sentence transformer, or fine-tune a sentence transformer, this is the approach I would probably go with, depending on the data that you have.

Now, I mean, that's everything for the video. So thank you very much for watching. I hope you've enjoyed it. And I will see you in the next one.