Back to Index

Training BERT #3 - Next Sentence Prediction (NSP)


Chapters

0:0
7:23 Tokenization
10:32 Create a Labels Tensor
11:20 Calculate Our Loss

Transcript

Hi, welcome to the video. Here we're going to have a look at using next sentence prediction or NSP for fine-tuning our BERT models. Now a few of the previous videos we covered mass language modeling and how we use mass language modeling to fine-tune our models. NSP is like the other half of fine-tuning for BERT.

So both of those techniques during the actual training of BERT, so when Google train BERT initially, they use both of these methods. And whereas MLM is identifying or almost training on the relationships between words, next sentence prediction is training on more long-term relationships between sentences rather than words. And in the original BERT paper it was found that without NSP, because they tried training BERT without NSP as well, BERT performed worse on every single metric.

So it is pretty important and obviously if we take this approach, we take mass language modeling and NSP and apply both those to training our models, fine-tuning our models, we're going to get better results than if we just use MLM. So what is NSP? NSP consists of giving BERT two sentences, sentence A and sentence B, and saying, "Hey BERT, does sentence B come after sentence A?" And then BERT will say, "Okay, sentence B is the next sentence after sentence A, or it is not the next sentence after sentence A." So if we took these three sentences that are on the screen, we have one, two, and three, right?

One and two, if you ask BERT, "Does sentence two come after sentence one?" Then we'd kind of want BERT to say no, right? Because clearly they're talking about completely different topics, and the type of language and everything in there just doesn't really match up. But then if we have a look at sentence three and sentence one, they do match up.

So sentence three is quite possibly the follow-on sentence after sentence one. So in that case, we would expect BERT to say, "This is the next sentence." So let's have a look at how NSP looks within BERT itself. So here we have the core BERT model, and during fine-tuning or pre-training, we add this other head on top of BERT.

So this is the BERT for pre-training head. And the BERT for pre-training head contains two different heads inside it. And that is our NSP head and our mass language modeling head. Now, we just want to focus on the NSP head for now. And as well, we don't need to fine-tune or train our models with both of these heads.

We can actually do it one by one. We could use mass language modeling only, or we could use NSP only. But the full approach to pre-training BERT is using both. So if we have a look inside our NSP head, we'll find that we have a feed-forward neural network, and that will output two different values.

Now, these two values are our "is not the next sequence" there, and our "is the next sequence," which is there. Okay, so value zero is the next sentence. Value one is not the next sentence. Now, we have the final outputs from our final encoder in BERT at the bottom here.

And we don't actually use all of these activations. We only use the CLS token activation, which is over at the left here. So this here is our CLS token. Okay, and when I say this is our CLS token, I mean more that this is not our CLS token. The CLS token is down here.

So we input the CLS token, and this output is the subsequent output after being processed by 12 or so encoders within BERT itself. So this is the output representation of that CLS token. Now, the activations from that get fed into our feed-forward neural network, and the dimensionality that we have here is 768 for that single token.

This is in the BERT base model, by the way. And that gets translated into our dimensionality here, which is just the two outputs. So that's essentially how NSP works. Once we have our two outputs here, we just take the argmax of both of those. So we take both over here, and we just take an argmax function of that, and that will output us either 0 or 1, where 0 is the isNext class, and 1 is the notNext class.

And that's how NSP works. So let's dive into the code and see how all this works in Python. Okay, so we're going to be using HuggingFace's transformers and PyTorch. So we'll import both of those. And from transformers, we just need the BERT tokenizer class and the BERT for next sentence prediction class.

And BERT for next sentence prediction. Then we also want to import Torch. And we're going to use two sentences here. So both of these are from the Wikipedia page on the American Civil War. And these are both consecutive sentences. So going back to what we looked at before, we would be hoping that BERT would output a 0 label for both of these, because sentence B is the next sentence after sentence A.

This one being sentence B, this one being sentence A. So execute that. And we now have three different steps that we need to take. And that is tokenization, create a classification label, so the 0 or the 1, so that we can train the model. And then from that, we calculate the loss.

So the first step there is tokenization. So we tokenize. It's pretty easy. All we do is inputs, tokenizer, and then we pass text and text2. And we are using PyTorch here. So I want to return a PyTorch tensor. And make sure that's PT. Now we need to also initialize those.

So tokenizer equals BERT tokenizer from pre-trained. And we'll just use BERT base and case for now. Obviously, you can use another BERT model if you want. And I'm just going to copy that and initialize our model as well. OK, now rerun that. And we'll get this warning. That's because we're using these models that are used for training or for fine-tuning.

So it's just telling us that we shouldn't really use this for inference. You need to train it first. And that's fine, because that's our intention. Now from these inputs, we'll get a few different tensors. So we have input IDs, token type IDs, and attention mask. Now for next sentence prediction, we do need all of these.

So this is a little bit different to masked language modeling. With masked language modeling, we don't actually need token type IDs. But for next sentence prediction, we do. So let's have a look at what we have inside these. So input IDs is just our tokenized text. And you see that we pass these two sentences here.

And they're actually both within the same sentence or the same tensor here, input IDs. And they're separated by this 102 in the middle, which is a separated token. So before that, all these tokens, that is our text variable or sentence A. And then afterwards, we have our text 2 variable, which is sentence B.

And we can see this mirrored in the token type IDs tensor as well. So all the way along here up to here, that's our sentence A. So we have zeros for sentence A. And then following that, we have ones representing sentence B. And then we have our attention mask, which is just ones because the attention mask is a one where it's a real token and a zero where we have padding token.

So we don't need to really worry about that tensor at all. Now, the next step here is that we need to create a labels tensor. So to do that, we just write labels. And we just need to make sure that when we do this, we use a long tensor.

Okay, so we use a long tensor. And in here, we need to pass a list containing a single value, which is either our zero for is the next sentence, or one for is not the next sentence. In our case, our two sentences are supposed to be together. So we will pass a zero in here.

And run that. And if we just have a look at what we get from there, we see that we get this integer tensor. So now we're ready to calculate our loss, which is really easy. So we have our model up here, which we have already initialized. So we just take that.

And all we do is pass our inputs from here into our model is keyword arguments. So that's what these two symbols are for. And then we also pass labels to the labels parameter. Okay. And that will output a couple of tensors for us. So we can execute that. And let's have a look what we have.

So you see that we get these two tensors, we have the logits, and we also have the loss tensor. So let's have a look at the logits. And we should be able to recognize this from early run where we saw those two nodes, and we had the two values on for the index zero for is next and index one for is not next.

So let's have a look. You see here that we get both of those. So this is our activation for is the next sentence. This is our activation for is not the next sentence. And if we were to take the argmax of those outputs logits, we get zero, which means it is the next sentence.

Okay. And we also have the loss. And this loss tensor, that will only be output if we pass our labels here. Otherwise, we just get a logits tensor. So when we're training, obviously, we need labels so that we can calculate the loss. And if we just have a look at that, we see it's just a loss value, which is very small because the model is predicting a zero and the label that we've provided is also a zero.

So the loss is pretty good there. So that is how NSP works. Obviously, it's slightly different if you're actually training your model. And I am going to cover that in the next video. So I'll leave a link to that in the description. But for now, that's it for this.

So thank you very much for watching, and I'll see you again in the next one.