back to index

Training BERT #1 - Masked-Language Modeling (MLM)


Chapters

0:0 Introduction
2:15 Walkthrough
4:22 Tutorial

Whisper Transcript | Transcript Only Page

00:00:00.000 | Okay, so in this video, we're going to have a look at what I think is the more interesting
00:00:05.400 | side of transformers, which is how we actually train those. So typically with transforming,
00:00:12.160 | what we do is we download a pre-trained model from Hugging Face. And then at that point,
00:00:18.400 | we can either use a pre-trained model as is, which in a lot of cases, it will be good enough
00:00:24.120 | to actually do that. But then at other times, we might want to actually fine-tune the model.
00:00:32.560 | And that is what I'll be showing you how to do here. So core of BERT, there are two different
00:00:40.760 | training or fine-tuning approaches that we can use. And we can even use both of those
00:00:45.920 | together. But for this video, what we're going to have a look at is how to use a mass language
00:00:51.920 | modeling, which is called MLM. And MLM is really the, probably the most important of
00:01:00.000 | those two core training approaches. The other one being next sentence prediction. So what
00:01:08.760 | MLM is, is we essentially give BERT a input sequence. So like this, so this would be our
00:01:19.440 | input sequence. And we ask BERT to predict the same input sequence as the output. And
00:01:27.720 | BERT will optimize the weights within its encoder layers in order to produce this output.
00:01:36.080 | Now obviously that's pretty easy. So what we do is we mask some, some random tokens
00:01:45.200 | within the input. So here we might mask one. And what we do is replace that with another
00:01:54.000 | token, which is a special token called a mask token, which looks like that. And when we're
00:02:02.400 | doing MLM, we would typically mask around 15% of the input tokens. So if we take a look
00:02:15.840 | at how that looks, so this might look a little complex, but it's pretty straightforward.
00:02:20.960 | So down here, we have our input from the previous slide. We process that through our tokenizer
00:02:29.600 | like we normally would with transformers. And then in the middle here, I haven't drawn
00:02:34.760 | it, but in the middle, there's a masking function. And that masking function will mask around
00:02:42.320 | 15% of the tokens in the input IDs tensor. So here we have a mask token and they will
00:02:51.720 | then get processed by BERT in the middle here. And BERT will output a set of vectors, which
00:02:59.580 | all have the length 768. Usually there's, there's different BERT models. They have different
00:03:07.480 | lengths. We'll go with the 768 here, and then we pass them through a feed forward network
00:03:14.360 | and that will output our, our output logits up here. And each one of those is of the size
00:03:25.440 | equal to the vocab size. And with this model, I think the vocab size is something around
00:03:32.040 | I think three or 30,500, something like that. And then from there to get the predicted token
00:03:41.260 | for each one of those logits, we apply a softmax function to get a probability distribution.
00:03:49.480 | And then we apply a argmax function, which is what you can see here. So this is just
00:03:56.280 | an example of one of those logits over here. We have the softmax, we get the probability
00:04:03.500 | distribution, and then we apply our argmax to get our final token ID, which we can then
00:04:10.500 | map or we can then decode using our tokenizer to get an actual word in English. So that's
00:04:18.880 | how it works. Let's have a look at how we actually do that in code. Okay, so first we'll
00:04:25.200 | need to import everything we need. So we're using transformers here, where we're using
00:04:29.960 | the BERT tokenizer and BERT format LM classes. And then we'll also be importing Torch as
00:04:37.600 | well. So from transformers, import our tokenizer, and also our BERT for mass LM, which is MLM,
00:04:54.240 | the mass language modeling. And then we also want to import Torch as well. Okay. And then
00:05:06.280 | what I want to do here is initialize our two models, well our tokenizer and model. And
00:05:13.960 | I do that just as we normally would with Hugging Face transformers. So we do BERT tokenizer
00:05:23.400 | from pre-trained. And here we have BERT base on case. And then we also want our model,
00:05:39.000 | which is BERT for mass LM. And this will also be from pre-trained. Again, using the same
00:05:49.560 | model, so BERT base on case. Okay. So that's our tokenizer and model. And I'm also going
00:05:58.800 | to use this example text here. So we see here, so this should be election, this mask, and
00:06:09.840 | this one here should be attacked. Okay. Now, execute that. I've made a typo here. Okay.
00:06:30.840 | And now what we want to do is actually tokenize that chunk of text. So to do that, we would
00:06:39.880 | write inputs. We have our tokenizer, and all we do is pass our text in there. We're using
00:06:45.840 | PyTorch here, so we want to return tensors, PT. Okay. And let's have a look at what tensors
00:06:54.840 | we return from that. So you see we have our input IDs, token type IDs, and attention mask.
00:07:02.200 | Now we don't need to worry about token type IDs whatsoever for MLM. And attention mask,
00:07:12.240 | MLM does use that, but I'm not going to go into any details. So all we want to focus
00:07:19.440 | on this video is input IDs. So let's have a look at what we have there. So there's a
00:07:28.200 | few things that I want to point out. First, we have our special tokens. So we have the
00:07:33.880 | CLS or classified token here. We have the separated token, SCP. And we also have our
00:07:41.080 | mass tokens, one here and one here. And everything in between are actual real tokens from our
00:07:52.320 | text. So what we have now, we have our inputs. And what we do is use these inputs initially
00:08:03.800 | to create our labels. But what I've done here is already amassed our inputs. So what I'm
00:08:11.880 | going to do is just actually replace these with the actual words. So this is election.
00:08:21.140 | And this one is attacked. So just rerun that and that. Okay. And now what we can do with
00:08:33.860 | that is actually create our target labels. So the target labels needs to be contained
00:08:41.700 | within a tensor called labels. Create like that. And it just needs to be a copy of this
00:08:51.700 | input IDs tensor. And to create a copy of that, we write detach. And then we clone it.
00:09:05.300 | Okay. So that creates our copy, which is not going to be connected to our input IDs. And
00:09:15.980 | now if we just have a look at our inputs, we can see input IDs at the top, and we have
00:09:21.340 | labels at the bottom. They're just copies. Okay. Now what we want to do is mask a random
00:09:29.620 | number of input IDs or tokens within the input IDs tensor, but not the labels tensor. Now
00:09:39.660 | to do that, what we can do is use the PyTorch random function. And using that, what we'll
00:09:49.380 | do is create a random array of floats that have equal dimensions to input IDs tensor.
00:09:58.100 | So all we do is we pass input IDs dot shape into there. And if we can check the shape
00:10:07.600 | of it afterwards, we get this one by 62, which is equal to this here. And we can have a look
00:10:16.820 | at what we have there. It's just a set of floats between zero and one. Now, if we want
00:10:25.180 | to select a random 15% of those, what we do is we'll create a new array, mask array, and
00:10:36.020 | this will be equal to rand where rand is greater than or less than 0.15. Okay. And this will
00:10:51.540 | select 15% of those. And let me show you what that looks like. So this will create a Boolean
00:10:58.740 | array and he'll say all of these faults and then these true values are, that's where we'll
00:11:08.140 | put our mask tokens later on. Now there's one here and this one is covering our separator
00:11:16.420 | token. Now we don't want to mask our separator or classifier token. We don't want to mask
00:11:21.340 | any special tokens. So what we can do is add an extra little bit of logic there, which
00:11:30.860 | will like this. So we do inputs, input IDs, and we say not equal to one zero one, which
00:11:43.220 | is our classifier token. And let's just have a look at what this looks like and see that
00:11:48.020 | now we get true for everything, except for my classifier token. And we multiply this
00:11:54.980 | by the same rule, but for our separator token. So now you see that we have faults here and
00:12:02.180 | faults here. Now all we need to do is add this to our mask array logic up here. And
00:12:11.180 | we also put brackets around this and this will make sure that these two are always faults
00:12:19.300 | no matter what. Okay. Now what I want to do is actually get the index positions of each
00:12:30.220 | one of these true values and do that. We write torch flatten. So this is going to just flatten
00:12:37.980 | the tensor that we will get out from this next bit of code and maybe it would make sense.
00:12:45.940 | Okay. Let's start from the first part of the code. So we're going to go mask array here.
00:12:55.260 | That gets us our mask array. We want to say non zero. And that will get us a vector of
00:13:07.700 | indices where we have true values or non zero values. And what we want to do is convert
00:13:14.340 | that into a list like that. But you see that we have a list within a list. So this is where
00:13:24.420 | the torch flatten comes in. So we add another bracket around this and we do torch flatten.
00:13:35.940 | And then we convert it to a list. And that gives us a list of indices where we have these
00:13:42.380 | true values. So that's our selection. And now what we want to do is use that selection
00:13:57.780 | to select a certain number or select those specific indices within our input ID tensor.
00:14:12.540 | So we want to select the first part of that. So the zero index followed by selection. And
00:14:21.820 | we set those equal to one zero three. And then let's have a look and see if that works.
00:14:28.220 | So one zero three is our mask token, by the way. And you can see here now we have those
00:14:36.900 | mask tokens in those positions. So we just masked random, roughly 15% of those tokens.
00:14:46.140 | And then from there, we can pass all of this into our model and the model will calculate
00:14:51.500 | out loss and the logits that we saw before. So we do that as we normally would when we're
00:14:58.220 | using HuggingFace and Torch. So we have models, we pass our inputs as keyword arguments. So
00:15:09.540 | look at what output is given us. And we'll see we have these two tensors, we have loss
00:15:15.380 | and we have logits. Now, let's have a look at what that loss looks like. Okay, so we
00:15:23.300 | get this value here. So that is our loss. And of course, with that loss, we can actually
00:15:29.440 | optimize our model. Okay, so that's how mass language modeling works. Now, when we're actually
00:15:41.740 | training a model using mass language modeling, obviously, the code is slightly different.
00:15:48.660 | But there's also a reasonable amount of depth that we need to go into for that. So I'm not
00:15:54.540 | going to include in this video, but I am going to do a video on that, actually training a
00:16:01.740 | model using mass language modeling pretty soon. And I'll leave a link to that in the
00:16:07.660 | description because I know some of you probably want to watch that to understand how to actually
00:16:12.380 | train your own models using this. But that's it for this video. I hope it's been useful.
00:16:20.260 | And I will see you again in the next one.