back to index

Training BERT #5 - Training With BertForPretraining


Chapters

0:0 Introduction
1:7 Import Data
3:22 Power Data
5:14 Training Data
10:45 Mask Data

Whisper Transcript | Transcript Only Page

00:00:00.000 | Hi, welcome to the video. Here we're going to have a look at how we can pre-train BERT.
00:00:06.800 | So what I mean by pre-train is fine-tune BERT using the same approaches that are used to actually
00:00:14.560 | pre-train BERT itself. So we would use these when we want to teach BERT to better understand
00:00:22.400 | the style of language in our specific use cases. So we'll jump straight into it but what we're
00:00:31.920 | going to see is essentially two different methods applied together. So when we're pre-training we're
00:00:39.840 | using something called mass language modeling or MLM and also next sentence prediction or NSP.
00:00:47.120 | Now in a few previous videos I've covered all of these so if you do want to go into a little more
00:00:53.280 | depth then I would definitely recommend having a look at those. But in this video we're just
00:00:58.640 | going to go straight into actually training a BERT model using both of those methods using
00:01:04.640 | the pre-training class. So we need first to import everything that we need. So I'm going to import
00:01:12.880 | requests because I'm going to use request download data we're using which is from here. You'll find a
00:01:18.320 | link in the description for that. And we also need to import our tokenizer and model classes
00:01:28.080 | from transformers. So from transformers we're going to import BERT tokenizer
00:01:33.680 | and also BERT for pre-training.
00:01:42.720 | Now like I said before this BERT for pre-training class contains both
00:01:47.040 | an MLM head and an NSP head. So once we have that we also need to import torch as well so let me
00:01:56.320 | import torch. Once we have that we can initialize our tokenizer and model. So we initialize our
00:02:06.080 | tokenizer like this so BERT tokenizer and it's from pre-train and we're going to be using the
00:02:14.320 | BERT base uncased model. Obviously you can use whichever BERT model you you'd like.
00:02:22.000 | And for our model we have the BERT for pre-training class. So that's our tokenizer model now let's get
00:02:34.400 | our data. Don't need to worry about that warning it's just telling us that we need to train it
00:02:40.240 | basically if we want to use it for inference predictions. So we get our data we're going to
00:02:49.360 | pull it from here so let me copy that and it's just request.get and paste that in there and we
00:03:00.640 | should see a 200 code that's good. And so we just extract the data using the text attribute.
00:03:07.440 | So text is that we also need to split it because it's a set of paragraphs that are split by a new
00:03:15.600 | line character and we can see those in here. Now we need to pair our data both for NSP and MLM so
00:03:26.160 | we'll go with NSP first and to do that we need to create a set of random sentences. So sentence A
00:03:34.080 | and B where the sentence B is not related to sentence A. We need roughly 50 percent of those
00:03:42.320 | and then the other 50 percent we want it to be sentence A is followed by sentence B so they are
00:03:49.520 | more coherent. So we're basically teaching BERT to distinguish between coherence and
00:03:55.360 | non-coherence between sentences so like long-term dependencies.
00:04:00.320 | And we just want to be aware that within our text so we have this one paragraph that has
00:04:09.520 | multiple sentences so if we split by this we have those. So we need to create essentially a list of
00:04:18.960 | all of the different sentences that we have that we can just pull from when we're creating our
00:04:23.200 | training data for NSP. Now to do that we're going to use this comprehension here and what we do is
00:04:31.360 | write sentence so for each sentence for each paragraph in the text so this variable
00:04:42.560 | for sentence in para.split so this is where we're getting our sentence variable from
00:04:48.880 | and we just want to be aware of if we have a look at this one we see we get this this empty
00:04:57.120 | sentence we get that for all of our paragraphs so we just want to not include those so we say if
00:05:02.080 | sentence is not equal to that empty sentence and we're also going to need to get the length of that
00:05:12.000 | bag for later as well and now what we do is create our NSP training data so we want that 50/50 split
00:05:20.720 | so we're going to use the random library to create that 50/50 randomness we want to initialize a
00:05:30.640 | list of sentence A's a list of sentence B's and also a list of labels and then what we do is we're
00:05:41.440 | going to loop through each paragraph in our text so for paragraph in text we want to extract each
00:05:50.560 | sentence from the paragraph so we're going to use it similar to what we've done here so write
00:05:55.680 | sentences and this is going to be a list of all the sentences within each paragraph so sentence
00:06:01.360 | for sentence in paragraph dot split by a period character and we also want to make sure we're not
00:06:11.840 | including those empty ones so if sentence is not equal to empty then once we're there what we want
00:06:22.800 | to do is want to get the number of sentences within each sentence or sentences variable so
00:06:32.560 | just get length and the reason we do that is because we want to check that a couple of times
00:06:37.680 | in the next few lines of code and first time we check that is now so we check that the number of
00:06:44.320 | sentences is greater than one now this because we're concatenating two sentences to create our
00:06:52.000 | training data we don't want to get just one sentence we need it where we have for example
00:06:57.280 | in this one where multiple sentences so that we can select like this sentence followed by this
00:07:02.320 | sentence we can't do that with these because there's no guarantee that this paragraph here
00:07:07.680 | is going to be talking about the same topic as this paragraph here so we just avoid that and
00:07:12.640 | in here first thing we want to do is set our start sentence so this is where sentence a is going to
00:07:19.280 | come from and we're going to randomly select say for this example we want to randomly select
00:07:25.840 | any of the first one two three sentences okay we'd want to select any of these three
00:07:34.000 | but not this one because if this sentence a we don't have a sentence b which follows it to extract
00:07:42.240 | so we write random rand int zero up to the length of num sentences
00:07:49.200 | minus two now we can now get our sentence a which is append and we just write sentences
00:08:00.000 | start and then for our sentence b 50% we want to select a random one from bag
00:08:09.040 | up here 50% of the time we want to select the genuine next sentence so say if random
00:08:15.280 | dot random so this will select a random float between zero and one is greater than 0.5
00:08:21.680 | and sentence b
00:08:24.480 | is going to be we'll make this our coherent version so sentences
00:08:35.440 | start plus one and that means our label will have to be zero because that means
00:08:44.320 | that these two sentences are coherent sentence b does follow sentence a
00:08:48.800 | otherwise we select a random sentence for sentence b so do append and here we would write bag
00:09:00.880 | and we need to need to select a random one so we do random same as we did earlier on for the start
00:09:06.240 | we do random rand int from zero to the length of the bag size minus one so now we also need
00:09:16.960 | to do the label which is going to be one in this case we can execute that now that will work i go
00:09:24.640 | a little more into depth on this in the previous nsp video so i'll leave a link to that in the
00:09:32.000 | description if you want to go through it and now what we can do is tokenize our data
00:09:37.040 | so to do that we just write inputs and we use a tokenizer so this is just normal you know
00:09:43.840 | hugging face transformers and we just write sentence a and sentence b so hugging face
00:09:53.440 | transformers will will know what we want to do that would deal with formatting for us which is
00:09:58.000 | pretty useful we want to return pytorch tensors so return tensors equals pt and we need to set
00:10:11.200 | everything to a max length of 512 tokens so max length equals 512 the truncation needs to be set
00:10:22.400 | to true and we also need to set padding equal to max length okay
00:10:29.280 | so that creates three different tensors for us
00:10:36.560 | input ids token type ids and attention mask now for the pre-train model we need two more tensors
00:10:46.560 | we need our next sentence label tensor so to create that we write inputs next sentence
00:10:54.560 | label and that needs to be a long tensor
00:10:59.840 | containing our labels which we created before in the correct dimensionality so that's why we're
00:11:10.080 | using the the list here and the transpose and we can have a look at what that creates as well
00:11:17.120 | so let's have a look at the first 10 we get that okay and now what we want to do is create our
00:11:25.440 | mask data so we need the labels for our mask first so when we do this what we'll do is we're going to
00:11:34.640 | clone the input ids tensor we're going to use that clone for the labels tensor and then we're going
00:11:40.800 | to go back to our input ids and mask around 15 of the tokens in that tensor so let's create that
00:11:47.360 | labels tensor it's going to be equal to inputs input ids detach and clone okay so now we'll see
00:12:03.360 | in here we have all of the tensors we need but we still need to mask around 15 of these before
00:12:09.120 | moving on to training our model and to do that we'll use we'll create a random array
00:12:14.720 | using the torch rand that needs to be in the same shape as our input ids
00:12:20.560 | and that will just create a big tensor between values of zero up to one and what we want to do
00:12:29.840 | is mask around 15 of those so we will write something like this okay and that will give us
00:12:37.920 | our mask here but we also don't want to mask special tokens which we are doing here we're
00:12:43.840 | masking our classification tokens and we're also masking padding tokens up here so we need to add
00:12:49.680 | a little bit more logic to that so let me just add this to a variable so we add that logic
00:13:00.000 | which says
00:13:00.720 | and input ids is not equal to 101 which is our cls token which is what we
00:13:12.800 | we get down here so we can actually see the impact see we get faults now
00:13:17.360 | and we also want to do the same file separator tokens
00:13:25.040 | which is 102 we can't see any of those and our padding tokens we use zero
00:13:30.560 | so you see these are all that will go false now like so
00:13:36.560 | so that's our masking array and now what we want to do is loop through all of these extract
00:13:47.120 | the points at which they are not false so where we have the mask and use those indice values
00:13:56.400 | to mask our actual input ids up here to do that we go for i in range inputs input ids dot shape
00:14:08.880 | zero this is like iterating through each row and what we do here is we get selection
00:14:18.720 | so these are the indices where we have true values from the mask array
00:14:22.880 | and we do that using torch flatten mask array
00:14:32.320 | at the given index where they are non-zero and we want to create a list from that
00:14:40.160 | okay so we have that um oh and so let me show you what the selection looks like quickly
00:14:47.920 | so it's just a selection of indices to mask and we want to apply that to our inputs
00:14:58.160 | input ids so at the current index and we select those specific items and we set them equal to 103
00:15:09.760 | which is the masking token id okay so that's our masking and now what we need to do is we need to
00:15:20.160 | take all of our data here and load it into a pytorch data loader and to do that
00:15:27.200 | we need to reformat our data into a pytorch data set object and we do that here so
00:15:34.160 | main thing to note is we pass our data into this initialization that assigns them to this self
00:15:42.080 | encodings attribute and then here we say okay given a certain index we want to extract the
00:15:49.760 | tensors in a dictionary format for that index and then here we're just passing
00:15:57.040 | lengths to how many uh how many tensors or how many samples we have in the full data set
00:16:02.720 | so run that we initialize our data set using that class so right data set equals
00:16:12.320 | meditations data set pass our data in there which is inputs and then with that we can create our
00:16:19.600 | data loader like this so torch utils data data loader and we have data set
00:16:29.280 | okay so that's ready now we need to set up our training loop so first thing we need to do is
00:16:38.400 | check if we are on gpu or not if we are we use it and we do that like so so device equals torch
00:16:45.040 | device cuda if torch cuda is available else torch device cpu so that's saying use the gpu if we have
00:16:52.000 | the cuda enabled gpu otherwise use cpu and then what we want to do is move our model over to that
00:17:01.040 | device and we also want to activate the training mode of our model
00:17:09.040 | and then we need to initialize our optimizer i'm going to be using adam with weighted decay
00:17:17.040 | so from transformers import adam w and initialize it like this so optim equals adam w
00:17:30.480 | we pass our model parameters to that and we also pass a learning rate so learning rate
00:17:36.960 | is going to be 5e to the minus 5 okay and now we can create our training loop so
00:17:45.040 | you're going to use tqdm to create the the progress bar and we're going to go through
00:17:53.280 | two epochs so for epoch in range 2 we initialize our loop by wrapping it within tqdm and in here
00:18:04.480 | we have our data loader and we set leave equal to true so that we can see that progress bar
00:18:10.480 | and then we loop through each batch within that loop
00:18:19.440 | um oh up here so i didn't actually set the batches my mistake so up here we want to set
00:18:25.440 | where we initialize the data loader i'm going to set batch batch size
00:18:29.680 | equal to 16 and also shuffle the data set as well
00:18:34.720 | okay so for batch in loop here we want to initialize the gradients on our optimizer
00:18:47.920 | and then we need to load in each of our tensors which there are quite a few of them so we have
00:18:54.320 | input keys we need to load in each one of these so input ids equals batch we access this like a
00:19:06.000 | dictionary so input ids we also want to move each one of those tensors that we're using to our device
00:19:16.400 | so we do that for each one of those
00:19:23.120 | and we have tension mask and next sentence labels and also labels
00:19:41.040 | okay and now we can actually process that through our model
00:19:47.200 | so in here we just need to pass all of these tensors that we have so input ids
00:19:53.440 | then we have token type ids just copy this
00:19:59.600 | attention mass
00:20:10.240 | next sentence label
00:20:11.840 | and labels
00:20:17.200 | okay so there's quite a lot going into our model and now what we want to do is extract the loss
00:20:27.200 | from that then we calculate loss for every parameter in our model and then using that
00:20:35.280 | we can update our gradients using our optimizer and then what we want to do is print the relevant
00:20:42.720 | info to our progress bar that we set up using tqdm and loop so loop with set description
00:20:51.840 | and here I was going to put the epoch info so the epoch we're currently on
00:21:02.880 | and then I also want to set the postfix
00:21:05.440 | which will contain the loss information so loss.item okay we can run that
00:21:15.600 | and you see that our model is now training so we're now training a model using both our
00:21:23.760 | sign language modeling and next sentence prediction and we haven't needed to take any
00:21:29.600 | structured data we've just taken a book and pulled all data and formatted it in the correct
00:21:35.520 | way for us to actually train a better model which I think is really cool so that's it
00:21:40.720 | for this video I hope it's been useful and I'll see you in the next one.