back to index

Training BERT #3 - Next Sentence Prediction (NSP)


Chapters

0:0
7:23 Tokenization
10:32 Create a Labels Tensor
11:20 Calculate Our Loss

Whisper Transcript | Transcript Only Page

00:00:00.000 | Hi, welcome to the video. Here we're going to have a look at using next sentence prediction or NSP
00:00:06.720 | for fine-tuning our BERT models. Now a few of the previous videos we covered mass language modeling
00:00:14.320 | and how we use mass language modeling to fine-tune our models. NSP is like the other half
00:00:20.880 | of fine-tuning for BERT. So both of those techniques during the actual training of BERT,
00:00:28.480 | so when Google train BERT initially, they use both of these methods. And whereas MLM is identifying
00:00:37.840 | or almost training on the relationships between words, next sentence prediction is training on
00:00:44.160 | more long-term relationships between sentences rather than words. And in the original BERT paper
00:00:54.000 | it was found that without NSP, because they tried training BERT without NSP as well,
00:00:58.880 | BERT performed worse on every single metric. So it is pretty important and obviously if we take
00:01:06.480 | this approach, we take mass language modeling and NSP and apply both those to training our models,
00:01:13.360 | fine-tuning our models, we're going to get better results than if we just use MLM. So what is NSP?
00:01:21.040 | NSP consists of giving BERT two sentences, sentence A and sentence B,
00:01:26.640 | and saying, "Hey BERT, does sentence B come after sentence A?" And then BERT will say, "Okay,
00:01:32.960 | sentence B is the next sentence after sentence A, or it is not the next sentence after sentence A."
00:01:41.520 | So if we took these three sentences that are on the screen, we have one, two, and three, right?
00:01:51.200 | One and two, if you ask BERT, "Does sentence two come after sentence one?" Then we'd kind of want
00:01:58.960 | BERT to say no, right? Because clearly they're talking about completely different topics,
00:02:04.720 | and the type of language and everything in there just doesn't really match up. But then if we have
00:02:10.880 | a look at sentence three and sentence one, they do match up. So sentence three is quite possibly
00:02:19.600 | the follow-on sentence after sentence one. So in that case, we would expect BERT to say,
00:02:28.160 | "This is the next sentence." So let's have a look at how NSP looks within BERT itself.
00:02:36.480 | So here we have the core BERT model, and during fine-tuning or pre-training, we add this other
00:02:47.440 | head on top of BERT. So this is the BERT for pre-training head.
00:02:51.120 | And the BERT for pre-training head contains two different heads inside it.
00:02:57.360 | And that is our NSP head and our mass language modeling head.
00:03:04.480 | Now, we just want to focus on the NSP head for now. And as well, we don't need to fine-tune or
00:03:15.360 | train our models with both of these heads. We can actually do it one by one. We could use mass
00:03:20.480 | language modeling only, or we could use NSP only. But the full approach to pre-training BERT is
00:03:27.680 | using both. So if we have a look inside our NSP head, we'll find that we have a feed-forward
00:03:35.440 | neural network, and that will output two different values. Now, these two values are our "is not the
00:03:45.920 | next sequence" there, and our "is the next sequence," which is there. Okay, so value zero
00:03:55.760 | is the next sentence. Value one is not the next sentence. Now, we have the final outputs from our
00:04:07.440 | final encoder in BERT at the bottom here. And we don't actually use all of these activations. We
00:04:16.160 | only use the CLS token activation, which is over at the left here. So this here is our CLS token.
00:04:26.080 | Okay, and when I say this is our CLS token, I mean more that this is not our CLS token. The
00:04:34.640 | CLS token is down here. So we input the CLS token, and this output is the subsequent output after
00:04:46.240 | being processed by 12 or so encoders within BERT itself. So this is the output representation of
00:04:57.280 | that CLS token. Now, the activations from that get fed into our feed-forward neural network,
00:05:04.240 | and the dimensionality that we have here is 768 for that single token. This is in the BERT
00:05:14.880 | base model, by the way. And that gets translated into our dimensionality here, which is just the
00:05:23.200 | two outputs. So that's essentially how NSP works. Once we have our two outputs here, we just take
00:05:35.920 | the argmax of both of those. So we take both over here, and we just take an argmax function of that,
00:05:43.120 | and that will output us either 0 or 1, where 0 is the isNext class, and 1 is the notNext class.
00:05:56.320 | And that's how NSP works. So let's dive into the code and see how all this works in Python.
00:06:08.960 | Okay, so we're going to be using HuggingFace's transformers and PyTorch. So we'll import both
00:06:14.640 | of those. And from transformers, we just need the BERT tokenizer class and the BERT for next
00:06:24.000 | sentence prediction class. And BERT for next sentence prediction.
00:06:36.960 | Then we also want to import Torch. And we're going to use two sentences here. So both of these
00:06:49.200 | are from the Wikipedia page on the American Civil War. And these are both consecutive sentences. So
00:06:57.120 | going back to what we looked at before, we would be hoping that BERT would output a 0 label for
00:07:03.520 | both of these, because sentence B is the next sentence after sentence A. This one being sentence
00:07:10.720 | B, this one being sentence A. So execute that. And we now have three different steps that we need to
00:07:21.360 | take. And that is tokenization, create a classification label, so the 0 or the 1, so that
00:07:29.440 | we can train the model. And then from that, we calculate the loss. So the first step there is
00:07:36.240 | tokenization. So we tokenize. It's pretty easy. All we do is inputs, tokenizer, and then we pass
00:07:44.320 | text and text2. And we are using PyTorch here. So I want to return a PyTorch tensor.
00:07:56.400 | And make sure that's PT. Now we need to also initialize those. So
00:08:06.560 | tokenizer
00:08:10.000 | equals BERT tokenizer from pre-trained.
00:08:19.600 | And we'll just use BERT base and case for now. Obviously, you can use another BERT model if you
00:08:27.280 | want. And I'm just going to copy that and initialize our model as well.
00:08:34.080 | OK, now rerun that. And we'll get this warning. That's because we're using these models that are
00:08:45.600 | used for training or for fine-tuning. So it's just telling us that we shouldn't really use
00:08:50.160 | this for inference. You need to train it first. And that's fine, because that's our intention.
00:08:55.920 | Now from these inputs, we'll get a few different tensors. So we have input IDs, token type IDs,
00:09:05.040 | and attention mask. Now for next sentence prediction, we do need all of these.
00:09:12.000 | So this is a little bit different to masked language modeling. With masked language modeling,
00:09:15.680 | we don't actually need token type IDs. But for next sentence prediction, we do.
00:09:20.640 | So let's have a look at what we have inside these. So input IDs is just our tokenized text.
00:09:31.600 | And you see that we pass these two sentences here. And they're actually both within the same
00:09:38.160 | sentence or the same tensor here, input IDs. And they're separated by this 102 in the middle,
00:09:45.360 | which is a separated token. So before that, all these tokens, that is our text variable
00:09:52.320 | or sentence A. And then afterwards, we have our text 2 variable, which is sentence B.
00:09:57.360 | And we can see this mirrored in the token type IDs tensor as well. So
00:10:04.560 | all the way along here up to here, that's our sentence A. So we have zeros for sentence A.
00:10:10.640 | And then following that, we have ones representing sentence B. And then we have our attention mask,
00:10:19.200 | which is just ones because the attention mask is a one where it's a real token and a zero
00:10:23.600 | where we have padding token. So we don't need to really worry about that tensor at all.
00:10:29.680 | Now, the next step here is that we need to create a labels tensor. So to do that, we just
00:10:38.160 | write labels. And we just need to make sure that when we do this, we use a long tensor.
00:10:45.120 | Okay, so we use a long tensor. And in here, we need to pass a list containing a single
00:10:55.120 | value, which is either our zero for is the next sentence, or one for is not the next sentence.
00:11:00.640 | In our case, our two sentences are supposed to be together. So we will pass a zero in here.
00:11:07.120 | And run that. And if we just have a look at what we get from there,
00:11:14.000 | we see that we get this integer tensor. So now we're ready to calculate our loss,
00:11:22.000 | which is really easy. So we have our model up here, which we have already initialized.
00:11:26.720 | So we just take that. And all we do is pass our inputs from here
00:11:32.880 | into our model is keyword arguments. So that's what these two symbols are for.
00:11:40.000 | And then we also pass labels to the labels parameter. Okay. And that will output a couple
00:11:50.160 | of tensors for us. So we can execute that. And let's have a look what we have.
00:11:54.800 | So you see that we get these two tensors, we have the logits, and we also have the loss tensor. So
00:12:03.440 | let's have a look at the logits. And we should be able to recognize this from early run where
00:12:08.720 | we saw those two nodes, and we had the two values on for the index zero for is next and index one
00:12:15.600 | for is not next. So let's have a look. You see here that we get both of those. So this is our
00:12:24.720 | activation for is the next sentence. This is our activation for is not the next sentence.
00:12:30.080 | And if we were to take the argmax of those
00:12:33.600 | outputs logits, we get zero, which means it is the next sentence. Okay. And we also have the
00:12:44.560 | loss. And this loss tensor, that will only be output if we pass our labels here. Otherwise,
00:12:52.240 | we just get a logits tensor. So when we're training, obviously, we need labels so that
00:12:58.480 | we can calculate the loss. And if we just have a look at that, we see it's just a loss value,
00:13:06.640 | which is very small because the model is predicting a zero and the label that we've
00:13:15.040 | provided is also a zero. So the loss is pretty good there. So that is how NSP works. Obviously,
00:13:23.680 | it's slightly different if you're actually training your model. And I am going to cover
00:13:30.480 | that in the next video. So I'll leave a link to that in the description. But for now, that's it
00:13:38.000 | for this. So thank you very much for watching, and I'll see you again in the next one.