back to indexTraining BERT #3 - Next Sentence Prediction (NSP)
Chapters
0:0
7:23 Tokenization
10:32 Create a Labels Tensor
11:20 Calculate Our Loss
00:00:00.000 |
Hi, welcome to the video. Here we're going to have a look at using next sentence prediction or NSP 00:00:06.720 |
for fine-tuning our BERT models. Now a few of the previous videos we covered mass language modeling 00:00:14.320 |
and how we use mass language modeling to fine-tune our models. NSP is like the other half 00:00:20.880 |
of fine-tuning for BERT. So both of those techniques during the actual training of BERT, 00:00:28.480 |
so when Google train BERT initially, they use both of these methods. And whereas MLM is identifying 00:00:37.840 |
or almost training on the relationships between words, next sentence prediction is training on 00:00:44.160 |
more long-term relationships between sentences rather than words. And in the original BERT paper 00:00:54.000 |
it was found that without NSP, because they tried training BERT without NSP as well, 00:00:58.880 |
BERT performed worse on every single metric. So it is pretty important and obviously if we take 00:01:06.480 |
this approach, we take mass language modeling and NSP and apply both those to training our models, 00:01:13.360 |
fine-tuning our models, we're going to get better results than if we just use MLM. So what is NSP? 00:01:21.040 |
NSP consists of giving BERT two sentences, sentence A and sentence B, 00:01:26.640 |
and saying, "Hey BERT, does sentence B come after sentence A?" And then BERT will say, "Okay, 00:01:32.960 |
sentence B is the next sentence after sentence A, or it is not the next sentence after sentence A." 00:01:41.520 |
So if we took these three sentences that are on the screen, we have one, two, and three, right? 00:01:51.200 |
One and two, if you ask BERT, "Does sentence two come after sentence one?" Then we'd kind of want 00:01:58.960 |
BERT to say no, right? Because clearly they're talking about completely different topics, 00:02:04.720 |
and the type of language and everything in there just doesn't really match up. But then if we have 00:02:10.880 |
a look at sentence three and sentence one, they do match up. So sentence three is quite possibly 00:02:19.600 |
the follow-on sentence after sentence one. So in that case, we would expect BERT to say, 00:02:28.160 |
"This is the next sentence." So let's have a look at how NSP looks within BERT itself. 00:02:36.480 |
So here we have the core BERT model, and during fine-tuning or pre-training, we add this other 00:02:47.440 |
head on top of BERT. So this is the BERT for pre-training head. 00:02:51.120 |
And the BERT for pre-training head contains two different heads inside it. 00:02:57.360 |
And that is our NSP head and our mass language modeling head. 00:03:04.480 |
Now, we just want to focus on the NSP head for now. And as well, we don't need to fine-tune or 00:03:15.360 |
train our models with both of these heads. We can actually do it one by one. We could use mass 00:03:20.480 |
language modeling only, or we could use NSP only. But the full approach to pre-training BERT is 00:03:27.680 |
using both. So if we have a look inside our NSP head, we'll find that we have a feed-forward 00:03:35.440 |
neural network, and that will output two different values. Now, these two values are our "is not the 00:03:45.920 |
next sequence" there, and our "is the next sequence," which is there. Okay, so value zero 00:03:55.760 |
is the next sentence. Value one is not the next sentence. Now, we have the final outputs from our 00:04:07.440 |
final encoder in BERT at the bottom here. And we don't actually use all of these activations. We 00:04:16.160 |
only use the CLS token activation, which is over at the left here. So this here is our CLS token. 00:04:26.080 |
Okay, and when I say this is our CLS token, I mean more that this is not our CLS token. The 00:04:34.640 |
CLS token is down here. So we input the CLS token, and this output is the subsequent output after 00:04:46.240 |
being processed by 12 or so encoders within BERT itself. So this is the output representation of 00:04:57.280 |
that CLS token. Now, the activations from that get fed into our feed-forward neural network, 00:05:04.240 |
and the dimensionality that we have here is 768 for that single token. This is in the BERT 00:05:14.880 |
base model, by the way. And that gets translated into our dimensionality here, which is just the 00:05:23.200 |
two outputs. So that's essentially how NSP works. Once we have our two outputs here, we just take 00:05:35.920 |
the argmax of both of those. So we take both over here, and we just take an argmax function of that, 00:05:43.120 |
and that will output us either 0 or 1, where 0 is the isNext class, and 1 is the notNext class. 00:05:56.320 |
And that's how NSP works. So let's dive into the code and see how all this works in Python. 00:06:08.960 |
Okay, so we're going to be using HuggingFace's transformers and PyTorch. So we'll import both 00:06:14.640 |
of those. And from transformers, we just need the BERT tokenizer class and the BERT for next 00:06:24.000 |
sentence prediction class. And BERT for next sentence prediction. 00:06:36.960 |
Then we also want to import Torch. And we're going to use two sentences here. So both of these 00:06:49.200 |
are from the Wikipedia page on the American Civil War. And these are both consecutive sentences. So 00:06:57.120 |
going back to what we looked at before, we would be hoping that BERT would output a 0 label for 00:07:03.520 |
both of these, because sentence B is the next sentence after sentence A. This one being sentence 00:07:10.720 |
B, this one being sentence A. So execute that. And we now have three different steps that we need to 00:07:21.360 |
take. And that is tokenization, create a classification label, so the 0 or the 1, so that 00:07:29.440 |
we can train the model. And then from that, we calculate the loss. So the first step there is 00:07:36.240 |
tokenization. So we tokenize. It's pretty easy. All we do is inputs, tokenizer, and then we pass 00:07:44.320 |
text and text2. And we are using PyTorch here. So I want to return a PyTorch tensor. 00:07:56.400 |
And make sure that's PT. Now we need to also initialize those. So 00:08:19.600 |
And we'll just use BERT base and case for now. Obviously, you can use another BERT model if you 00:08:27.280 |
want. And I'm just going to copy that and initialize our model as well. 00:08:34.080 |
OK, now rerun that. And we'll get this warning. That's because we're using these models that are 00:08:45.600 |
used for training or for fine-tuning. So it's just telling us that we shouldn't really use 00:08:50.160 |
this for inference. You need to train it first. And that's fine, because that's our intention. 00:08:55.920 |
Now from these inputs, we'll get a few different tensors. So we have input IDs, token type IDs, 00:09:05.040 |
and attention mask. Now for next sentence prediction, we do need all of these. 00:09:12.000 |
So this is a little bit different to masked language modeling. With masked language modeling, 00:09:15.680 |
we don't actually need token type IDs. But for next sentence prediction, we do. 00:09:20.640 |
So let's have a look at what we have inside these. So input IDs is just our tokenized text. 00:09:31.600 |
And you see that we pass these two sentences here. And they're actually both within the same 00:09:38.160 |
sentence or the same tensor here, input IDs. And they're separated by this 102 in the middle, 00:09:45.360 |
which is a separated token. So before that, all these tokens, that is our text variable 00:09:52.320 |
or sentence A. And then afterwards, we have our text 2 variable, which is sentence B. 00:09:57.360 |
And we can see this mirrored in the token type IDs tensor as well. So 00:10:04.560 |
all the way along here up to here, that's our sentence A. So we have zeros for sentence A. 00:10:10.640 |
And then following that, we have ones representing sentence B. And then we have our attention mask, 00:10:19.200 |
which is just ones because the attention mask is a one where it's a real token and a zero 00:10:23.600 |
where we have padding token. So we don't need to really worry about that tensor at all. 00:10:29.680 |
Now, the next step here is that we need to create a labels tensor. So to do that, we just 00:10:38.160 |
write labels. And we just need to make sure that when we do this, we use a long tensor. 00:10:45.120 |
Okay, so we use a long tensor. And in here, we need to pass a list containing a single 00:10:55.120 |
value, which is either our zero for is the next sentence, or one for is not the next sentence. 00:11:00.640 |
In our case, our two sentences are supposed to be together. So we will pass a zero in here. 00:11:07.120 |
And run that. And if we just have a look at what we get from there, 00:11:14.000 |
we see that we get this integer tensor. So now we're ready to calculate our loss, 00:11:22.000 |
which is really easy. So we have our model up here, which we have already initialized. 00:11:26.720 |
So we just take that. And all we do is pass our inputs from here 00:11:32.880 |
into our model is keyword arguments. So that's what these two symbols are for. 00:11:40.000 |
And then we also pass labels to the labels parameter. Okay. And that will output a couple 00:11:50.160 |
of tensors for us. So we can execute that. And let's have a look what we have. 00:11:54.800 |
So you see that we get these two tensors, we have the logits, and we also have the loss tensor. So 00:12:03.440 |
let's have a look at the logits. And we should be able to recognize this from early run where 00:12:08.720 |
we saw those two nodes, and we had the two values on for the index zero for is next and index one 00:12:15.600 |
for is not next. So let's have a look. You see here that we get both of those. So this is our 00:12:24.720 |
activation for is the next sentence. This is our activation for is not the next sentence. 00:12:33.600 |
outputs logits, we get zero, which means it is the next sentence. Okay. And we also have the 00:12:44.560 |
loss. And this loss tensor, that will only be output if we pass our labels here. Otherwise, 00:12:52.240 |
we just get a logits tensor. So when we're training, obviously, we need labels so that 00:12:58.480 |
we can calculate the loss. And if we just have a look at that, we see it's just a loss value, 00:13:06.640 |
which is very small because the model is predicting a zero and the label that we've 00:13:15.040 |
provided is also a zero. So the loss is pretty good there. So that is how NSP works. Obviously, 00:13:23.680 |
it's slightly different if you're actually training your model. And I am going to cover 00:13:30.480 |
that in the next video. So I'll leave a link to that in the description. But for now, that's it 00:13:38.000 |
for this. So thank you very much for watching, and I'll see you again in the next one.