back to indexTraining BERT #1 - Masked-Language Modeling (MLM)
Chapters
0:0 Introduction
2:15 Walkthrough
4:22 Tutorial
00:00:00.000 |
Okay, so in this video, we're going to have a look at what I think is the more interesting 00:00:05.400 |
side of transformers, which is how we actually train those. So typically with transforming, 00:00:12.160 |
what we do is we download a pre-trained model from Hugging Face. And then at that point, 00:00:18.400 |
we can either use a pre-trained model as is, which in a lot of cases, it will be good enough 00:00:24.120 |
to actually do that. But then at other times, we might want to actually fine-tune the model. 00:00:32.560 |
And that is what I'll be showing you how to do here. So core of BERT, there are two different 00:00:40.760 |
training or fine-tuning approaches that we can use. And we can even use both of those 00:00:45.920 |
together. But for this video, what we're going to have a look at is how to use a mass language 00:00:51.920 |
modeling, which is called MLM. And MLM is really the, probably the most important of 00:01:00.000 |
those two core training approaches. The other one being next sentence prediction. So what 00:01:08.760 |
MLM is, is we essentially give BERT a input sequence. So like this, so this would be our 00:01:19.440 |
input sequence. And we ask BERT to predict the same input sequence as the output. And 00:01:27.720 |
BERT will optimize the weights within its encoder layers in order to produce this output. 00:01:36.080 |
Now obviously that's pretty easy. So what we do is we mask some, some random tokens 00:01:45.200 |
within the input. So here we might mask one. And what we do is replace that with another 00:01:54.000 |
token, which is a special token called a mask token, which looks like that. And when we're 00:02:02.400 |
doing MLM, we would typically mask around 15% of the input tokens. So if we take a look 00:02:15.840 |
at how that looks, so this might look a little complex, but it's pretty straightforward. 00:02:20.960 |
So down here, we have our input from the previous slide. We process that through our tokenizer 00:02:29.600 |
like we normally would with transformers. And then in the middle here, I haven't drawn 00:02:34.760 |
it, but in the middle, there's a masking function. And that masking function will mask around 00:02:42.320 |
15% of the tokens in the input IDs tensor. So here we have a mask token and they will 00:02:51.720 |
then get processed by BERT in the middle here. And BERT will output a set of vectors, which 00:02:59.580 |
all have the length 768. Usually there's, there's different BERT models. They have different 00:03:07.480 |
lengths. We'll go with the 768 here, and then we pass them through a feed forward network 00:03:14.360 |
and that will output our, our output logits up here. And each one of those is of the size 00:03:25.440 |
equal to the vocab size. And with this model, I think the vocab size is something around 00:03:32.040 |
I think three or 30,500, something like that. And then from there to get the predicted token 00:03:41.260 |
for each one of those logits, we apply a softmax function to get a probability distribution. 00:03:49.480 |
And then we apply a argmax function, which is what you can see here. So this is just 00:03:56.280 |
an example of one of those logits over here. We have the softmax, we get the probability 00:04:03.500 |
distribution, and then we apply our argmax to get our final token ID, which we can then 00:04:10.500 |
map or we can then decode using our tokenizer to get an actual word in English. So that's 00:04:18.880 |
how it works. Let's have a look at how we actually do that in code. Okay, so first we'll 00:04:25.200 |
need to import everything we need. So we're using transformers here, where we're using 00:04:29.960 |
the BERT tokenizer and BERT format LM classes. And then we'll also be importing Torch as 00:04:37.600 |
well. So from transformers, import our tokenizer, and also our BERT for mass LM, which is MLM, 00:04:54.240 |
the mass language modeling. And then we also want to import Torch as well. Okay. And then 00:05:06.280 |
what I want to do here is initialize our two models, well our tokenizer and model. And 00:05:13.960 |
I do that just as we normally would with Hugging Face transformers. So we do BERT tokenizer 00:05:23.400 |
from pre-trained. And here we have BERT base on case. And then we also want our model, 00:05:39.000 |
which is BERT for mass LM. And this will also be from pre-trained. Again, using the same 00:05:49.560 |
model, so BERT base on case. Okay. So that's our tokenizer and model. And I'm also going 00:05:58.800 |
to use this example text here. So we see here, so this should be election, this mask, and 00:06:09.840 |
this one here should be attacked. Okay. Now, execute that. I've made a typo here. Okay. 00:06:30.840 |
And now what we want to do is actually tokenize that chunk of text. So to do that, we would 00:06:39.880 |
write inputs. We have our tokenizer, and all we do is pass our text in there. We're using 00:06:45.840 |
PyTorch here, so we want to return tensors, PT. Okay. And let's have a look at what tensors 00:06:54.840 |
we return from that. So you see we have our input IDs, token type IDs, and attention mask. 00:07:02.200 |
Now we don't need to worry about token type IDs whatsoever for MLM. And attention mask, 00:07:12.240 |
MLM does use that, but I'm not going to go into any details. So all we want to focus 00:07:19.440 |
on this video is input IDs. So let's have a look at what we have there. So there's a 00:07:28.200 |
few things that I want to point out. First, we have our special tokens. So we have the 00:07:33.880 |
CLS or classified token here. We have the separated token, SCP. And we also have our 00:07:41.080 |
mass tokens, one here and one here. And everything in between are actual real tokens from our 00:07:52.320 |
text. So what we have now, we have our inputs. And what we do is use these inputs initially 00:08:03.800 |
to create our labels. But what I've done here is already amassed our inputs. So what I'm 00:08:11.880 |
going to do is just actually replace these with the actual words. So this is election. 00:08:21.140 |
And this one is attacked. So just rerun that and that. Okay. And now what we can do with 00:08:33.860 |
that is actually create our target labels. So the target labels needs to be contained 00:08:41.700 |
within a tensor called labels. Create like that. And it just needs to be a copy of this 00:08:51.700 |
input IDs tensor. And to create a copy of that, we write detach. And then we clone it. 00:09:05.300 |
Okay. So that creates our copy, which is not going to be connected to our input IDs. And 00:09:15.980 |
now if we just have a look at our inputs, we can see input IDs at the top, and we have 00:09:21.340 |
labels at the bottom. They're just copies. Okay. Now what we want to do is mask a random 00:09:29.620 |
number of input IDs or tokens within the input IDs tensor, but not the labels tensor. Now 00:09:39.660 |
to do that, what we can do is use the PyTorch random function. And using that, what we'll 00:09:49.380 |
do is create a random array of floats that have equal dimensions to input IDs tensor. 00:09:58.100 |
So all we do is we pass input IDs dot shape into there. And if we can check the shape 00:10:07.600 |
of it afterwards, we get this one by 62, which is equal to this here. And we can have a look 00:10:16.820 |
at what we have there. It's just a set of floats between zero and one. Now, if we want 00:10:25.180 |
to select a random 15% of those, what we do is we'll create a new array, mask array, and 00:10:36.020 |
this will be equal to rand where rand is greater than or less than 0.15. Okay. And this will 00:10:51.540 |
select 15% of those. And let me show you what that looks like. So this will create a Boolean 00:10:58.740 |
array and he'll say all of these faults and then these true values are, that's where we'll 00:11:08.140 |
put our mask tokens later on. Now there's one here and this one is covering our separator 00:11:16.420 |
token. Now we don't want to mask our separator or classifier token. We don't want to mask 00:11:21.340 |
any special tokens. So what we can do is add an extra little bit of logic there, which 00:11:30.860 |
will like this. So we do inputs, input IDs, and we say not equal to one zero one, which 00:11:43.220 |
is our classifier token. And let's just have a look at what this looks like and see that 00:11:48.020 |
now we get true for everything, except for my classifier token. And we multiply this 00:11:54.980 |
by the same rule, but for our separator token. So now you see that we have faults here and 00:12:02.180 |
faults here. Now all we need to do is add this to our mask array logic up here. And 00:12:11.180 |
we also put brackets around this and this will make sure that these two are always faults 00:12:19.300 |
no matter what. Okay. Now what I want to do is actually get the index positions of each 00:12:30.220 |
one of these true values and do that. We write torch flatten. So this is going to just flatten 00:12:37.980 |
the tensor that we will get out from this next bit of code and maybe it would make sense. 00:12:45.940 |
Okay. Let's start from the first part of the code. So we're going to go mask array here. 00:12:55.260 |
That gets us our mask array. We want to say non zero. And that will get us a vector of 00:13:07.700 |
indices where we have true values or non zero values. And what we want to do is convert 00:13:14.340 |
that into a list like that. But you see that we have a list within a list. So this is where 00:13:24.420 |
the torch flatten comes in. So we add another bracket around this and we do torch flatten. 00:13:35.940 |
And then we convert it to a list. And that gives us a list of indices where we have these 00:13:42.380 |
true values. So that's our selection. And now what we want to do is use that selection 00:13:57.780 |
to select a certain number or select those specific indices within our input ID tensor. 00:14:12.540 |
So we want to select the first part of that. So the zero index followed by selection. And 00:14:21.820 |
we set those equal to one zero three. And then let's have a look and see if that works. 00:14:28.220 |
So one zero three is our mask token, by the way. And you can see here now we have those 00:14:36.900 |
mask tokens in those positions. So we just masked random, roughly 15% of those tokens. 00:14:46.140 |
And then from there, we can pass all of this into our model and the model will calculate 00:14:51.500 |
out loss and the logits that we saw before. So we do that as we normally would when we're 00:14:58.220 |
using HuggingFace and Torch. So we have models, we pass our inputs as keyword arguments. So 00:15:09.540 |
look at what output is given us. And we'll see we have these two tensors, we have loss 00:15:15.380 |
and we have logits. Now, let's have a look at what that loss looks like. Okay, so we 00:15:23.300 |
get this value here. So that is our loss. And of course, with that loss, we can actually 00:15:29.440 |
optimize our model. Okay, so that's how mass language modeling works. Now, when we're actually 00:15:41.740 |
training a model using mass language modeling, obviously, the code is slightly different. 00:15:48.660 |
But there's also a reasonable amount of depth that we need to go into for that. So I'm not 00:15:54.540 |
going to include in this video, but I am going to do a video on that, actually training a 00:16:01.740 |
model using mass language modeling pretty soon. And I'll leave a link to that in the 00:16:07.660 |
description because I know some of you probably want to watch that to understand how to actually 00:16:12.380 |
train your own models using this. But that's it for this video. I hope it's been useful.