Back to Index

Fine-tune Sentence Transformers the OG Way (with NLI Softmax loss)


Chapters

0:0 Intro
0:42 NLI Fine-tuning
1:44 Softmax Loss Training Overview
5:47 Preprocessing NLI Data
12:48 PyTorch Process
19:48 Using Sentence-Transformers
30:45 Results
35:49 Outro

Transcript

Hi, welcome to the video. We're going to be covering how we can train a SBERT model, or a Sentence Transformer, or SentenceBERT model, using what is kind of like the original way of training these models or fine-tuning these models, which is using Softmax Loss. So let's start with just a quick overview of the training approach.

Now, using the Softmax training approach is part of what we could call the natural language inference approach to fine-tuning these models. And within that sort of category of training, we have two approaches. We have Softmax Loss or Softmax Classification Loss, which we're going to cover. And then we also have something called a Multiple Negatives Ranking Loss.

Now, in reality, you probably wouldn't use Softmax Loss, because it's just nowhere near as good as using the other form of Loss, the Multiple Negatives Ranking. I'm going to call it MNR from now on. So MNR is more effective, but Softmax Loss is sort of the original, and that's why we're covering it here.

So we're just going to go through it. We're not going to go into too much depth. I'm going to just kind of go through it very quickly. So when we're training these models, we can either use what is called a Siamese network or a triplet network. Now, what you can see right now is a Siamese network.

So we have almost like two copies of the same BERT, so they're like Siamese twins. And the idea is we would have two sentences, sentence A and sentence B, and we would feed both of those through our BERT model and produce the token embeddings that we usually get from BERT.

And then we use a pooling operation, so like a mean average pooling, and then from that, we get our sentence embedding in U and V. And what we would be doing is optimizing to try and get those sentence embeddings as close as possible for similar sentences, and then for dissimilar sentences, we want them to be as far away from each other as possible.

So that's kind of like the start of the model, but it's not the full model. We continue, and what we're going to do is concatenate those two together, so U plus V here. And then we're also going to do this other operation here. So we're going to take U and V, and we're going to get difference between them.

So this is just a positive number here. We're taking the magnitude here, these bars. So we're just getting a positive number, which is a difference between the two vectors, and we also concatenate that. And we create this big vector, which is U, V, and then we have U minus V at the end.

And we take that vector, so that's what you can see over here, and we feed it into a very simple feed-forward neural network. So to feed-forward neural network, each one of these sentence embeddings, we're going to have the dimensionality of 768. So obviously, the dimensionality or input dimension of our feed-forward neural network is 768 with pipeline 3.

And then the output are our three output activations here. Now, if you watched the last video or you read the last article, you will remember that we had three labels in our training data. So in our NLI training data, we had its entailment, neutral, and contradiction. So those sentence pairs, and we're trying to classify, are they relevant to each other or not?

And then we have the true label over here, and then we just optimize using cross-entropy loss, which is what you can see over here. The reason why this is called softmax loss is because there is a softmax function within the cross-entropy loss function. So that's the process at high level.

So let's jump in. We're going to first-- we're mainly going to focus on how we form our data to put it into here, and then we're going to move on to how we actually train all of this using the sentence transformers library. That's going to be our main focus.

But we'll just very quickly run through the code in PyTorch, so you can just see how it works. And if you really want to dive into it, you can obviously just take a look at the code and figure out how it is. It's not hard to read. So let's jump into it.

OK, so the first thing that we want to do-- I've added a little note up here. This is just an information of what is in our data or the labels of our data. We're going to have a look at our data anyway now. So I'm going to use the HuggingFace data sets library.

So we're just going to import data sets. And we're actually using two different data sets. We're using the Stanford Natural Language Inference, or SNLI data set, and also the multi-genre NLI data set as well. So SNLI, we just put that in. So it's data sets load data set. And then it's called SNLI.

And we want the training subset of that, so train. It's also-- sorry, not subset-- split. And then we can have a look. What do we have inside the data? OK, so we have these three features. So these are columns, or you can call them columns if you want. So we have the premise, hypothesis, and label.

Now, in those previous diagrams we saw, we saw sentence A, which is the premise, and sentence B, which is the hypothesis. And then we saw labels at the end, so it's just the same. Now, if you want to have a quick look at one of those, we just have this.

So we just get what you can see here. So what label do we have here? We have label 1, so we can appear. That's neutral, so the premise and the hypothesis could both be true, but they are not necessarily related. And then here, we see a person's jumping over a broken down airplane.

The person is training his horse for the competition. So they could both be about the same topic, but it's not necessarily about the same. They don't infer each other. So if we maybe try and find one that is a contradiction or something else, why did I spawn again? OK, so this one is a contradiction.

So a person on a horse jumps over a broken down airplane. A person's at a diner ordering an omelet. So those two things aren't about the same topic, so they're a contradiction. And then the other one, we have just, I think, if I do, we should find one. We have this one.

So a person on a horse jumps over a broken down airplane. A person is outdoors on a horse. So this here would infer-- sorry, this here, this premise infers this hypothesis. So that's the data. And like I said, we have two of those data sets. We have SMLI and MNLI.

So MNLI, we load it the same way. So data sets, load data sets. It's from the glue data set. And then the subset is MNLI. And again, we want to split to be equal to train. OK, and if we just have a look, we'll see a very similar format, but not exactly the same.

So see, we have premise, hypothesis, label. Then we also have this index. We need to merge these two data sets. We need to reformat our MNLI data set a little bit. So first thing we do, we write MNLI. And we want to remove that column. So MNLI, remove columns.

And we are doing IDX. OK, and let's make sure it works. And we see now we don't have that. And if we try to merge these, we still get an error, which is annoying, but it's fine. So I'm going to call it data set, because data sets adopt to concatenate data sets.

And we just pass them both as a list. So SMI, MNI, and we're going to get this error. OK, so the schema, so the format of the data set is different. Even though they both contain the same columns, I think one of them has a slightly different format. Like one of them allows you to have nulls.

In fact, it does say here, right? So they both have slightly different formats. So to fix that, we just want to change the schema of one of those data sets. And all we do for that is we're going to change the SMI data set and say SMI cast features, just cast maybe, and MNI.features here, right?

Yeah, and then we can actually do this now. OK, so now we have our data set. We can look and see, OK, we now have 943, basically, 1,000 rows there. Now, inside this data, we actually have some rows that we don't want. So we should have the labels 0, 1, and 2, which we have up at the top here, 0, 1, and 2.

But there's actually some rows that have the label minus 1. And all these are just erroneous rows. We don't actually want those in there. It's where someone couldn't figure out what to actually rate that sentence there. So what we're going to do is just remove those. So we write data set equals data set.

And we're going to use filter. And then we just write lambda function. This lambda function is going to select rows where the label value is not minus 1. So we're going to say false if the label value, so label, is equal to minus 1. So we're going to filter those out.

And then obviously, we want to put else true to keep the other columns. So let me just print out. And we can see. So we have 942.8 there. And here, we have 942.0. So removed, I think it's like 700 or so rows. So if we're using the sentence transformers way of training the models, this is pretty much all we have to do.

There's one more step that we have to take, which is to convert the data into input examples or a list of input examples, which we'll move on to in a moment. We won't cover it now. I'm going to quickly just cover the other training approach using PyTorch. But I mean, it's quite complicated.

And at least when I was training, using that approach, the model was nowhere near as good as when I trained it using sentence transformers. So I wouldn't recommend it. But if you're interested, this is how we do it. So let me switch over to that notebook. So if I come over here.

And OK, we're going to see it's basically doing the same thing. We're loading the data set. And we come down. Now, is there a difference? So the difference here, so we're importing mainly the BERT tokenizer is what we're focusing on here. We come down, and then here, so we're filtering, which is what we did before.

Nothing new there. But here, we're doing something different. So here, we're actually tokenizing our text. So we're using this map function here. We're tokenizing both the premise sentences and also the hypothesis sentences. And we get the input IDs and the attention mask out of those. And if we have a look down here, we'll see that this is what we end up with at the end there.

So I've removed all the other features. And all we have are the labels. And then we also have the input IDs and attention mask both our premise, or sentence A, and also our hypothesis, or sentence B. And then after that, we need to do this as well. It's dataset.setformat.

And we use, because we're using PyTorch, we set it into a Torch format, OK? From there, typical PyTorch stuff here. So we're setting up a data loader using batch size 16. That's what we use in the SBERT paper. And then if we come down, this is all just examples.

So I'm actually going to go a little further down. So to here, OK? So here, I'm defining the-- you remember before in that graph, we had the-- we just passed sentence A, sentence B. They both went into the BERT, or the Siamese BERT. And then there was this pooling method, which took our token embeddings, which are 512, 768 dimensional vectors, and compressed them into just a single 768 dimensional vector.

That's what this function here does. When we're using Sentence Transformer, we don't need to worry about this. Sentence Transformers, the library, by the way, the framework, that's probably a bit confusing. But I mean, when I say Sentence Transformers, or using Sentence Transformers, I mean the framework or library, which we're going to cover soon.

But obviously, they're also the models. That's the name of the models. So here, I'm taking the mean pooling. It's taking the average of those values and excluding values that are padded, which is why we're not just taking the average straight. We are removing those attention mass values. We go down.

We move our device. So we check if we have a CUDA-enabled GPU and move our model to it, if we can. And then here, these are the layers we use. So I told you before, we had that-- well, we concatenate our U and V vectors, the sentence embeddings. And then we pass them to a Feedforward Neural Network.

And that Feedforward Neural Network is the size of our sentence embeddings multiplied by 3. And it outputs three labels or classes. And then we also use a cross-entropy loss function between what the Feedforward Neural Network outputs and our actual labels. So after that, this is what I mean. There's quite a bit of code when you go to the very manual PyTorch approach, rather than using the Sentence Transformers library.

So here, we're getting this get linear schedule with warmup. So that's just saying, for the first 10% of our sets, we're going to warm up the learning rate. So I'm not going to go full-on training at 1e to the minus 5. So we're going to warm it up a little bit.

Now, in the SPET paper, they used 2e to the minus 5. For me, it just bounced around a lot. So I halved it. But if you can get it working with 2e to the minus 5, that's what they use in the paper. So it's probably better. And then they only train for 1e park as well.

And then also here, I'm using the add and move weighted k. And then this is the training loop. So TQDM is just a progress bar. We do 1e park. We make sure our model is on training mode. We initialize a loop, which is going to get all the batches from our data loader.

And then we're just getting all the data out. This is just PyTorch stuff. Getting our U and V, sentence embeddings. Then here, we're getting the-- actually, sorry, so U and V here are actually token embeddings. Here, we're converting them into sentence embeddings. And then we're getting the U, the absolute value or the difference vector.

Here, concatenate it all together. So we're creating that concatenated vector that we then feed into the feedforward neural network. And then we optimize based on the loss here. And that's pretty much it. And then we're saving the model down here. So yeah, that's how we train it in PyTorch.

You can see here I was messing around, seeing if I could see what happened if I did two EPUBs. It's better to just stick with one. Even though the loss was lower, in the end, the performance wasn't any different. So I would train for one EPUB. OK, so let's go back to the code and we'll work through the actual training with sentence transformers, which is what I would recommend doing.

OK, so I said before we had the list of input examples. So input example is just a data format that sentence transformers library uses. So we just want to write from sentence transformers, I'm going to import input example. And then all I'm going to do here is write from tqdm or-- from tqdm.auto, I want to import tqdm.

So this is just for our progress bar, so we can see what is happening. And then in here, we just want to-- actually, we want to create our training examples first, or training samples, whichever you want to-- whatever you want to call it. It's going to be empty list.

And then we literally just for loop through all of our training data, through our data set, and extracting what we need from it, which is just sentence A, sentence B, and the label. So write for row in-- I'm going to put tqdm training samples. So just adding tqdm in there so we have a progress bar, so we can see where we are.

All we need to do is write train samples, append input example. And before you get confused, this should be data set, not train samples. So data set is where we're looping through our data set, not the empty list. And then inside our input example, we have two variables, text and labels.

So you have to pass your text, which is the input text that you're going to process into your model. So we go row, premise, and also row hypothesis. So they're just our two text features from our data set. And then here, we also want a label. So label is just row label.

It's just the feature names from our data set, which you can still see up here. Now, we process that. It can take a little bit of time. So I won't take too long, unfortunately. And then from there, we need to-- you remember before, or when we very quickly went through, we had the PyTorch data loader.

We also need a data loader here as well. Sometimes you can use special data loaders from the Sentence Transformers library, which are quite good. But for this, we're just using a normal PyTorch data loader. So we need to import torch for that. So we can just write from torch utils data, import data loader.

Same as the paper, we're using batch size of 16. And the data loader or loader is just data loader. We pass in those train samples, specify our batch size. And if you'd like to shuffle, which in this case we will, you also put shuffle. So shuffle equals true. And that should work.

So now we have our data loader. And what we do now is initialize our model using Sentence Transformers. Now, Sentence Transformers uses modules to set up the model. So we're going to have a transformer module, which I'll just leave out a bit. And then we're also going to have a pooling module, which is for our mean pooling layer.

So from Sentence Transformers, again, we're going to import models. And what is this one? Sentence Transformer, yeah. We initialize those two modules. So we have a BERT module. So models, transformer. And then here, it's using the Hugging Face models. So we can put anything from Hugging Face on here.

I'm going to use BERT base on case. And then we also want our pooler. So our pooler models again. Pooler, pooling. And then we have BERT. And we want to get the word embedding dimension. So get word embedding dimension, which is the 768 of our token embeddings. And then, of course, of our sentence embedding as well.

And then we also want to set the type of pooling that we are going to do. So pooling mode, you can see that we have these different ways of pooling. So we have CLS token, maximum. This one, I've never actually seen use. Square root, the length, it's interesting. And then we also have this one.

This is the mean pooling, and we're going to use that. So that is-- there are two modules. And then we just want to initialize our model. So we write sentence transformer. And what you can do, by the way, is this is how you would actually-- say you have a sentence transform model that you want to load.

You'd write the sentence transformer name in here. So like all MP net, whatever it's called. You'd do that as well. But you can also load or initialize the model using the modules that we just initialized. So write BERT followed by the pooler. And then keep details of that model in there.

OK, so ignore that. So this bit here. So this is our sentence transformer structure. You can think of it as a structure. So our transformer, we're using BERT model, the maximum sequence length. And then in here, we have a pooler. And we have the word embedding dimension that we'll expect, 768.

And then you see here, we have those different pooling modes. And we are using pooling mode mean tokens, which is true. The rest of them are false. And then from there, so we also need to initialize our loss function, which is pretty straightforward as well. So again, from sentence transformers, I want to import losses.

And there are plenty of different losses that you can use. You can just look on that documentation. But we're using softmax loss. So what we want to do is write loss equals losses. And we write softmax loss. And then in here, so you think, OK, our loss function, what does it need?

So we pass in the model. So it can get the model parameters from that. So firstly, model equals model. And then it also needs the embedding dimension. So it's 768 again. So it needs on this sentence embedding dimension. And there, we just want-- what do we-- what is it?

So the model get embedding dimension again, get sentence embedding dimension. And then we also need to pass-- OK, how many labels are we going to have in our model? We already know it's the number of labels. We already know it's three. So I'm sure you can get that dynamically from the data set if you want as well.

But I'm just going to put three. And I think that's it. So we have our loss model and our data. So I think we should be OK to start training. So I'm going to say we go for, OK, one epoch. We want to say how many warm-up steps do we want.

So again, it's the 10% warming up that we use. So we just want 0.1 multiplied by the length of our data set. So length of the data set. Yeah, and I'm just going to-- we need that to be an integer value. So I'm just-- it's quite rough, rounding very roughly there.

But that's fine. And then we want to just start training our model. So we write model.fit, it's like TensorFlow. And we use our train objectives. So in here, we need to pass a list which contains a single tuple, which is our loader. So the data loader and the loss.

So I think with this, you can, if you have multiple train objects, you can put another loader, another loss, and keep going through that. So that's why we have a tuple within a list. Then you have your epochs, the number of warm-up steps, which is just warm-up steps again.

So warm-up steps. We also need the-- what do we need? Output path. So where are we going to save the model? So I'd just put like SBIRT. Now, what did I call it? I think testB is what I've called it later on. And oh, last one is show progress bar.

Now, this is automatically true. But when I zoom that, it just printed out loads of lines that it's printing to a new line every single update. So I'll just set to false. So I wouldn't do that because that's obviously quite annoying. So yeah, that's how you train the model.

I'm not going to do it again. I've already done it. I've already trained this SBIRT testB model. So what I'm going to do is switch over to that notebook where I trained it and show you those results. So this is a notebook, pretty much just covered. I'm going through that again.

And then here, so we have the training, the training time as well, something I didn't mention just now, is one hour and 15 minutes for me on an RTX 3090. So reasonably fast. It depends on what you're training on. Yeah, so yeah, it's quick. So I define these sentences just below random sentences, like complete nonsense.

But some of them do align. So see this one, one thinks she saw her raw fish and rice change position. And this one, seeing her sushi move, weaving with spaghetti, and where is the other, knit with noodles, and dental specialist with construction materials, and same again, dentist with chewing bricks.

So there's some that are kind of similar, but they don't share any of the same descriptive words. But they kind of mean the same thing, roundabouts. So with our model, so we have loaded the model here, which we can just do. So if you've saved the model, which it does automatically here, you just take this, you take that, you come down here, and you write sentence transformer.

And then in here, you do that. And then you would put that in the model variable. So that's all I've done there. So remove that. So loading the model, I'm going to use it to encode those sentences, which is just in the list, and create our embedding. So the sentence embeddings.

And then from there, I'm getting the cosine similarity, again, using sentence transformers for that. This makes it a bit easier. And I'm just comparing all those sentences. So I'm just a very qualitative view on how these embeddings are doing. And you see that we get these results here. So it's pretty good, actually.

It's getting the right ones. So this 7 and 5, 9 and 1, and I think 4 and 3 are the ones that we wanted to get. And they are, in fact, the highest three-rated scores. But a lot of these other non-pairs are still rated kind of high. And like I said before, softmax loss is not the best way of training your model anymore, or fine-tuning your model anymore.

There's other ways, like MNR loss. So let me show you some of the charts from MNR loss, and you'll see the difference. So we have-- this one is just for BERT. You see they're all very flat. It does actually get-- almost, it gets two of the correct answers within its top three, but not all of them.

But it's very flat. Like, all the values are very near the same value, which makes it hard to differentiate between similar and not similar, which is not what we want. But obviously, BERT hasn't been trained for this, so you can't expect it. And this is my PyTorch model. It's getting better than the BERT model, but the performance still isn't there compared to the Sentence Transformer's trained model, which is not here.

This is the actual-- the one that they trained, Sentence Transformer's themselves. You can see there's better differentiation between general values here. But then if we compare that to-- so this is an MNR model that I have trained using the same Sentence Transformer's. We use Sentence Transformer's here. We come down, we see a big difference.

So these are our similar pairs. And they stick out so much more than they did with the other models. Everything else is rated-- is scored very lowly. But these stick out a lot. And that's really the difference between models. This is much better because it separates those similar and dissimilar pairs very well.

It's just a lot more accurate. So that is-- I mean, that's my MNR model as well. The actual Sentence Transformer's MNR model is much better than this. But yeah, I think that's pretty much it. That's it for this video. So in the next one, we are going to-- as you probably have guessed, we're going to have a look at how we can use MNR loss or multiple negative ranking loss to build a model, which I think, personally, is a lot more interesting.

Sentence-- sorry, Softmax loss is pretty interesting. But it's not particularly intuitive. And even the Sentence Transformer's authors, the expert authors, said the same thing. That's actually where I got it from. It isn't very intuitive when you think about it. It's kind of hard to understand why it works. Because we have that weird concatenation at the end.

It's classifying, and it seems strange. MNR loss is much more intuitive, and it makes a lot more sense. And I think it's more interesting, and you get way better results with it. So we're going to cover that in the next video. So I think that should be pretty interesting.

But for now, that's it on Softmax loss. Thank you very much for watching. I hope it's been useful. And I will see you in the next one.