back to indexTraining BERT #5 - Training With BertForPretraining

Chapters
0:0 Introduction
1:7 Import Data
3:22 Power Data
5:14 Training Data
10:45 Mask Data
00:00:00.000 | 
Hi, welcome to the video. Here we're going to have a look at how we can pre-train BERT. 00:00:06.800 | 
So what I mean by pre-train is fine-tune BERT using the same approaches that are used to actually 00:00:14.560 | 
pre-train BERT itself. So we would use these when we want to teach BERT to better understand 00:00:22.400 | 
the style of language in our specific use cases. So we'll jump straight into it but what we're 00:00:31.920 | 
going to see is essentially two different methods applied together. So when we're pre-training we're 00:00:39.840 | 
using something called mass language modeling or MLM and also next sentence prediction or NSP. 00:00:47.120 | 
Now in a few previous videos I've covered all of these so if you do want to go into a little more 00:00:53.280 | 
depth then I would definitely recommend having a look at those. But in this video we're just 00:00:58.640 | 
going to go straight into actually training a BERT model using both of those methods using 00:01:04.640 | 
the pre-training class. So we need first to import everything that we need. So I'm going to import 00:01:12.880 | 
requests because I'm going to use request download data we're using which is from here. You'll find a 00:01:18.320 | 
link in the description for that. And we also need to import our tokenizer and model classes 00:01:28.080 | 
from transformers. So from transformers we're going to import BERT tokenizer 00:01:42.720 | 
Now like I said before this BERT for pre-training class contains both 00:01:47.040 | 
an MLM head and an NSP head. So once we have that we also need to import torch as well so let me 00:01:56.320 | 
import torch. Once we have that we can initialize our tokenizer and model. So we initialize our 00:02:06.080 | 
tokenizer like this so BERT tokenizer and it's from pre-train and we're going to be using the 00:02:14.320 | 
BERT base uncased model. Obviously you can use whichever BERT model you you'd like. 00:02:22.000 | 
And for our model we have the BERT for pre-training class. So that's our tokenizer model now let's get 00:02:34.400 | 
our data. Don't need to worry about that warning it's just telling us that we need to train it 00:02:40.240 | 
basically if we want to use it for inference predictions. So we get our data we're going to 00:02:49.360 | 
pull it from here so let me copy that and it's just request.get and paste that in there and we 00:03:00.640 | 
should see a 200 code that's good. And so we just extract the data using the text attribute. 00:03:07.440 | 
So text is that we also need to split it because it's a set of paragraphs that are split by a new 00:03:15.600 | 
line character and we can see those in here. Now we need to pair our data both for NSP and MLM so 00:03:26.160 | 
we'll go with NSP first and to do that we need to create a set of random sentences. So sentence A 00:03:34.080 | 
and B where the sentence B is not related to sentence A. We need roughly 50 percent of those 00:03:42.320 | 
and then the other 50 percent we want it to be sentence A is followed by sentence B so they are 00:03:49.520 | 
more coherent. So we're basically teaching BERT to distinguish between coherence and 00:03:55.360 | 
non-coherence between sentences so like long-term dependencies. 00:04:00.320 | 
And we just want to be aware that within our text so we have this one paragraph that has 00:04:09.520 | 
multiple sentences so if we split by this we have those. So we need to create essentially a list of 00:04:18.960 | 
all of the different sentences that we have that we can just pull from when we're creating our 00:04:23.200 | 
training data for NSP. Now to do that we're going to use this comprehension here and what we do is 00:04:31.360 | 
write sentence so for each sentence for each paragraph in the text so this variable 00:04:42.560 | 
for sentence in para.split so this is where we're getting our sentence variable from 00:04:48.880 | 
and we just want to be aware of if we have a look at this one we see we get this this empty 00:04:57.120 | 
sentence we get that for all of our paragraphs so we just want to not include those so we say if 00:05:02.080 | 
sentence is not equal to that empty sentence and we're also going to need to get the length of that 00:05:12.000 | 
bag for later as well and now what we do is create our NSP training data so we want that 50/50 split 00:05:20.720 | 
so we're going to use the random library to create that 50/50 randomness we want to initialize a 00:05:30.640 | 
list of sentence A's a list of sentence B's and also a list of labels and then what we do is we're 00:05:41.440 | 
going to loop through each paragraph in our text so for paragraph in text we want to extract each 00:05:50.560 | 
sentence from the paragraph so we're going to use it similar to what we've done here so write 00:05:55.680 | 
sentences and this is going to be a list of all the sentences within each paragraph so sentence 00:06:01.360 | 
for sentence in paragraph dot split by a period character and we also want to make sure we're not 00:06:11.840 | 
including those empty ones so if sentence is not equal to empty then once we're there what we want 00:06:22.800 | 
to do is want to get the number of sentences within each sentence or sentences variable so 00:06:32.560 | 
just get length and the reason we do that is because we want to check that a couple of times 00:06:37.680 | 
in the next few lines of code and first time we check that is now so we check that the number of 00:06:44.320 | 
sentences is greater than one now this because we're concatenating two sentences to create our 00:06:52.000 | 
training data we don't want to get just one sentence we need it where we have for example 00:06:57.280 | 
in this one where multiple sentences so that we can select like this sentence followed by this 00:07:02.320 | 
sentence we can't do that with these because there's no guarantee that this paragraph here 00:07:07.680 | 
is going to be talking about the same topic as this paragraph here so we just avoid that and 00:07:12.640 | 
in here first thing we want to do is set our start sentence so this is where sentence a is going to 00:07:19.280 | 
come from and we're going to randomly select say for this example we want to randomly select 00:07:25.840 | 
any of the first one two three sentences okay we'd want to select any of these three 00:07:34.000 | 
but not this one because if this sentence a we don't have a sentence b which follows it to extract 00:07:42.240 | 
so we write random rand int zero up to the length of num sentences 00:07:49.200 | 
minus two now we can now get our sentence a which is append and we just write sentences 00:08:00.000 | 
start and then for our sentence b 50% we want to select a random one from bag 00:08:09.040 | 
up here 50% of the time we want to select the genuine next sentence so say if random 00:08:15.280 | 
dot random so this will select a random float between zero and one is greater than 0.5 00:08:24.480 | 
is going to be we'll make this our coherent version so sentences 00:08:35.440 | 
start plus one and that means our label will have to be zero because that means 00:08:44.320 | 
that these two sentences are coherent sentence b does follow sentence a 00:08:48.800 | 
otherwise we select a random sentence for sentence b so do append and here we would write bag 00:09:00.880 | 
and we need to need to select a random one so we do random same as we did earlier on for the start 00:09:06.240 | 
we do random rand int from zero to the length of the bag size minus one so now we also need 00:09:16.960 | 
to do the label which is going to be one in this case we can execute that now that will work i go 00:09:24.640 | 
a little more into depth on this in the previous nsp video so i'll leave a link to that in the 00:09:32.000 | 
description if you want to go through it and now what we can do is tokenize our data 00:09:37.040 | 
so to do that we just write inputs and we use a tokenizer so this is just normal you know 00:09:43.840 | 
hugging face transformers and we just write sentence a and sentence b so hugging face 00:09:53.440 | 
transformers will will know what we want to do that would deal with formatting for us which is 00:09:58.000 | 
pretty useful we want to return pytorch tensors so return tensors equals pt and we need to set 00:10:11.200 | 
everything to a max length of 512 tokens so max length equals 512 the truncation needs to be set 00:10:22.400 | 
to true and we also need to set padding equal to max length okay 00:10:29.280 | 
so that creates three different tensors for us 00:10:36.560 | 
input ids token type ids and attention mask now for the pre-train model we need two more tensors 00:10:46.560 | 
we need our next sentence label tensor so to create that we write inputs next sentence 00:10:59.840 | 
containing our labels which we created before in the correct dimensionality so that's why we're 00:11:10.080 | 
using the the list here and the transpose and we can have a look at what that creates as well 00:11:17.120 | 
so let's have a look at the first 10 we get that okay and now what we want to do is create our 00:11:25.440 | 
mask data so we need the labels for our mask first so when we do this what we'll do is we're going to 00:11:34.640 | 
clone the input ids tensor we're going to use that clone for the labels tensor and then we're going 00:11:40.800 | 
to go back to our input ids and mask around 15 of the tokens in that tensor so let's create that 00:11:47.360 | 
labels tensor it's going to be equal to inputs input ids detach and clone okay so now we'll see 00:12:03.360 | 
in here we have all of the tensors we need but we still need to mask around 15 of these before 00:12:09.120 | 
moving on to training our model and to do that we'll use we'll create a random array 00:12:14.720 | 
using the torch rand that needs to be in the same shape as our input ids 00:12:20.560 | 
and that will just create a big tensor between values of zero up to one and what we want to do 00:12:29.840 | 
is mask around 15 of those so we will write something like this okay and that will give us 00:12:37.920 | 
our mask here but we also don't want to mask special tokens which we are doing here we're 00:12:43.840 | 
masking our classification tokens and we're also masking padding tokens up here so we need to add 00:12:49.680 | 
a little bit more logic to that so let me just add this to a variable so we add that logic 00:13:00.720 | 
and input ids is not equal to 101 which is our cls token which is what we 00:13:12.800 | 
we get down here so we can actually see the impact see we get faults now 00:13:17.360 | 
and we also want to do the same file separator tokens 00:13:25.040 | 
which is 102 we can't see any of those and our padding tokens we use zero 00:13:30.560 | 
so you see these are all that will go false now like so 00:13:36.560 | 
so that's our masking array and now what we want to do is loop through all of these extract 00:13:47.120 | 
the points at which they are not false so where we have the mask and use those indice values 00:13:56.400 | 
to mask our actual input ids up here to do that we go for i in range inputs input ids dot shape 00:14:08.880 | 
zero this is like iterating through each row and what we do here is we get selection 00:14:18.720 | 
so these are the indices where we have true values from the mask array 00:14:22.880 | 
and we do that using torch flatten mask array 00:14:32.320 | 
at the given index where they are non-zero and we want to create a list from that 00:14:40.160 | 
okay so we have that um oh and so let me show you what the selection looks like quickly 00:14:47.920 | 
so it's just a selection of indices to mask and we want to apply that to our inputs 00:14:58.160 | 
input ids so at the current index and we select those specific items and we set them equal to 103 00:15:09.760 | 
which is the masking token id okay so that's our masking and now what we need to do is we need to 00:15:20.160 | 
take all of our data here and load it into a pytorch data loader and to do that 00:15:27.200 | 
we need to reformat our data into a pytorch data set object and we do that here so 00:15:34.160 | 
main thing to note is we pass our data into this initialization that assigns them to this self 00:15:42.080 | 
encodings attribute and then here we say okay given a certain index we want to extract the 00:15:49.760 | 
tensors in a dictionary format for that index and then here we're just passing 00:15:57.040 | 
lengths to how many uh how many tensors or how many samples we have in the full data set 00:16:02.720 | 
so run that we initialize our data set using that class so right data set equals 00:16:12.320 | 
meditations data set pass our data in there which is inputs and then with that we can create our 00:16:19.600 | 
data loader like this so torch utils data data loader and we have data set 00:16:29.280 | 
okay so that's ready now we need to set up our training loop so first thing we need to do is 00:16:38.400 | 
check if we are on gpu or not if we are we use it and we do that like so so device equals torch 00:16:45.040 | 
device cuda if torch cuda is available else torch device cpu so that's saying use the gpu if we have 00:16:52.000 | 
the cuda enabled gpu otherwise use cpu and then what we want to do is move our model over to that 00:17:01.040 | 
device and we also want to activate the training mode of our model 00:17:09.040 | 
and then we need to initialize our optimizer i'm going to be using adam with weighted decay 00:17:17.040 | 
so from transformers import adam w and initialize it like this so optim equals adam w 00:17:30.480 | 
we pass our model parameters to that and we also pass a learning rate so learning rate 00:17:36.960 | 
is going to be 5e to the minus 5 okay and now we can create our training loop so 00:17:45.040 | 
you're going to use tqdm to create the the progress bar and we're going to go through 00:17:53.280 | 
two epochs so for epoch in range 2 we initialize our loop by wrapping it within tqdm and in here 00:18:04.480 | 
we have our data loader and we set leave equal to true so that we can see that progress bar 00:18:10.480 | 
and then we loop through each batch within that loop 00:18:19.440 | 
um oh up here so i didn't actually set the batches my mistake so up here we want to set 00:18:25.440 | 
where we initialize the data loader i'm going to set batch batch size 00:18:29.680 | 
equal to 16 and also shuffle the data set as well 00:18:34.720 | 
okay so for batch in loop here we want to initialize the gradients on our optimizer 00:18:47.920 | 
and then we need to load in each of our tensors which there are quite a few of them so we have 00:18:54.320 | 
input keys we need to load in each one of these so input ids equals batch we access this like a 00:19:06.000 | 
dictionary so input ids we also want to move each one of those tensors that we're using to our device 00:19:23.120 | 
and we have tension mask and next sentence labels and also labels 00:19:41.040 | 
okay and now we can actually process that through our model 00:19:47.200 | 
so in here we just need to pass all of these tensors that we have so input ids 00:20:17.200 | 
okay so there's quite a lot going into our model and now what we want to do is extract the loss 00:20:27.200 | 
from that then we calculate loss for every parameter in our model and then using that 00:20:35.280 | 
we can update our gradients using our optimizer and then what we want to do is print the relevant 00:20:42.720 | 
info to our progress bar that we set up using tqdm and loop so loop with set description 00:20:51.840 | 
and here I was going to put the epoch info so the epoch we're currently on 00:21:05.440 | 
which will contain the loss information so loss.item okay we can run that 00:21:15.600 | 
and you see that our model is now training so we're now training a model using both our 00:21:23.760 | 
sign language modeling and next sentence prediction and we haven't needed to take any 00:21:29.600 | 
structured data we've just taken a book and pulled all data and formatted it in the correct 00:21:35.520 | 
way for us to actually train a better model which I think is really cool so that's it 00:21:40.720 | 
for this video I hope it's been useful and I'll see you in the next one.