back to indexHow to Build Custom Q&A Transformer Models in Python
Chapters
0:0
1:4 Download Our Data
1:9 Squad Data Set
6:40 Data Prep
28:53 Initialize Our Tokenizer
41:10 Create a Pi Torch Dataset Object
47:56 Initialize the Optimizer
48:33 Initializing the Model
50:3 Initialize Our Data Loader
51:22 Training Loop
58:24 While Loop
61:3 Create a Data Loader
65:15 Check for an Exact Match
66:31 Calculate the Accuracy
67:47 Accuracy
68:31 Exact Match Accuracy
00:00:00.000 |
Hi and welcome to the video. Today we're going to go through how we can fine-tune a QnA transform 00:00:06.960 |
model. So for those of you that don't know QnA just means question and answering and it's one 00:00:14.320 |
of the biggest topics in NLP at the moment. There's a lot of models out there where you 00:00:20.720 |
ask a question and it will give you an answer. And one of the biggest things that you need to know 00:00:30.400 |
how to do when you are working with transformers, whether that's QnA or any of the other 00:00:35.280 |
transformer-based solutions, is how to actually fine-tune those. So that's what we're going to be 00:00:43.440 |
doing in this video. We're going to go through how we can fine-tune a QnA transformer model 00:00:49.360 |
in Python. So I think it's really interesting and I think you will enjoy it a lot. So let's 00:00:58.320 |
just go ahead and we can get started. Okay so first thing we need to do is actually 00:01:04.640 |
download our data. So we're going to be using the squad data set which is the 00:01:11.360 |
Sanford question answering data set which is essentially one of the better known QnA 00:01:17.760 |
data sets out there that we can use to fine-tune our model. So let's first create a folder. 00:01:29.920 |
and os make data. We'll just call it squad. Obviously call this and organize it as you want. 00:01:41.840 |
This is what I will be doing. Now the URL that we are going to be downloading this from is this. 00:01:52.000 |
Okay and there are actually two files here that we're going to be downloading 00:01:58.640 |
So because we're making a request to a URL we're going to import requests. 00:02:04.480 |
We can also use the wget library as well or if you're on Linux you can just use wget 00:02:12.240 |
directly in the terminal. It's up to you what we're going to be using requests. 00:02:17.840 |
Okay and to request our data we're going to be doing this. So it's just a get request. 00:02:30.160 |
Use a f string and we have the URL that we've already defined. 00:02:38.880 |
And then the training data that we'll be using is this file here. 00:02:52.800 |
Okay and we can see that we've successfully pulled that data in there. Okay so like I said 00:03:03.120 |
before there's actually two of these files that we want to extract. So what I'm going to do is 00:03:10.240 |
just put this into a for loop which will go through both of them. Just copy and paste this across. 00:03:30.400 |
And the other file is the same but instead of train we have dev. 00:03:37.200 |
And then the next thing we want to do after making our request is actually 00:03:44.480 |
saving this file to our drive. Okay and we want to put that inside this squad folder here. 00:03:53.120 |
So to do that we use open. And again we're going to use a f string here. 00:04:00.320 |
And we want to put inside the squad folder here. And then here we are just going to put our file 00:04:11.040 |
name which is file. Now we're writing this in binary because it's JSON so we put wb for our 00:04:19.360 |
flags here. And then within this namespace we are going to run through the file and download it in 00:04:28.480 |
chunks. So we do for chunk and then we iterate through the response like this. Let's use a chunk 00:04:42.160 |
size of four. And then we just want to write to the file like that. So that will download both 00:04:54.640 |
files. Just add the colon there. So that will download both files. We should be able to see 00:05:02.080 |
them here now. So in here we have data, we have essentially a lot of different topics. So the 00:05:08.560 |
first one is Beyonce. And then in here we will see, if we just come to here, we get a context. 00:05:16.240 |
But alongside this context we also have QAS which is question and answers. And each one of these 00:05:25.040 |
contains a question and answer pair. So we have this question, when did Beyonce start becoming 00:05:33.200 |
popular? So this answer is actually within this context. And what we want our model to do is 00:05:39.040 |
extract the answer from that context by telling us the start and end token of the answer within 00:05:46.240 |
that context. So we go zero and it is in the late 1990s. And we have answer start 269. So 00:05:56.080 |
that means that at character 269, we get I. So if we go through here, we can find it here. 00:06:05.440 |
Okay, so this is the extract. And that's what we will be aiming for our model to actually extract. 00:06:12.400 |
But there will be a start point and also the end point as well, which is not included in here, 00:06:18.080 |
but we will add that manually quite soon. So that's our data. And then we'll also be testing 00:06:25.600 |
on the dev data as well, which is exactly the same. Okay, so we move on to the data prep. 00:06:42.080 |
So now we have our files here, we're going to want to read them in. So we're going to use 00:06:49.280 |
the JSON library for that. And like we saw before, there's quite a complex structure in these 00:06:57.440 |
JSONs, there's a lot of different layers. So we need to use a few for loops to fill through each 00:07:03.360 |
of these and extract what we want, which is the context, questions and answers, all corresponding 00:07:11.280 |
to each other. So in the end, we're going to have lists of strings, which is going to be all of 00:07:15.680 |
these. And in the case of the answers, we will also have the starting position. So it will be a 00:07:20.080 |
list of dictionaries, where one value is a text and one value is the starting position. So to do 00:07:29.600 |
that, we're going to define a function called read squad. We'll define our path here as well. 00:07:41.600 |
And the first thing we need to do is actually open the JSON file. So we do with open path. 00:07:49.040 |
And again, we are using a binary file. So we're going to have B 00:07:54.320 |
as a flag. But we're instead of writing, we are reading so use our here. So our B. 00:08:10.400 |
And just do JSON load F here. So now we have our dictionary within this squad dict here. So 00:08:19.120 |
maybe whilst we're just building this function up, it's probably more useful to 00:08:23.680 |
do it here. So you can see what we're actually doing. So let's copy that across. 00:08:34.880 |
Of course, we do actually need to include the path. So let's take this. 00:08:57.760 |
Maybe we can load just a few rather than all of them. 00:09:09.280 |
Or we can investigate it like this. Okay, so we have the version and data, which we can actually 00:09:21.760 |
see over here, version and data. So we want to access the data. Now within data is we have a 00:09:28.880 |
list of all these different items, which is what I was trying to do before. So we go into data. 00:09:35.120 |
And just take a few of those. Okay, and then we get our different sections. 00:09:47.600 |
For the first one, let's just take zero, which is Beyonce. And then we have all of these. 00:09:55.120 |
So we're going to want to loop through each one of these, because we have this one next, 00:10:03.440 |
and we're going to keep needing to just run through all of these. 00:10:15.920 |
in squad dict. And remember, we need to include the data here. 00:10:24.240 |
Let's just see our say group title. So we can see a few of those. Okay, 00:10:34.320 |
then go through each one of those. So the second part of that are these, 00:10:40.800 |
these paragraphs. And when the paragraphs we have each one of our questions. So let's first 00:10:52.240 |
go with paragraphs. And we'll do the chop in here. 00:11:05.840 |
And the first thing we need to extract is the easiest one, which is our context. 00:11:17.680 |
However, that is also within a list. So now if we access the context, we get this. 00:11:26.640 |
So we're essentially going to need to jump through or loop through each one of these here. 00:11:35.280 |
Now we're gonna need to access the paragraphs and loop through each one of those. And then here, 00:11:40.880 |
we're going to access the context. So let's write that. So we already have one group here. So let's 00:11:49.600 |
just stick with that. And we're going to run through the passage in the paragraphs. So already 00:12:01.440 |
here, we're going through the for loop on this index. And now we're going to go through a loop 00:12:07.360 |
on this index. Let's keep that. So that means that we will be able to print the passage 00:12:23.120 |
context. And there we go. So here we have all of our context. So that's 00:12:32.320 |
one of our three items that we need to extract. Okay, so that's great. Let's put that all together. 00:12:39.600 |
So we're going to take this, put it here. And then we have our context. 00:12:48.560 |
Okay, that's great. Obviously, for each context, we have a few different questions and answers. 00:12:58.800 |
So we need to get those as well. Now, that requires us to go through another for loop. 00:13:06.000 |
So let's go this passage, we need to go into the QAS key and loop through this list of 00:13:14.320 |
question and answers. So we have this, and then we have our list. So another layer in our for loop 00:13:25.680 |
will be for question answer in that passage QAS. And then let's take a look at what we have there. 00:13:37.120 |
Okay, great. So we have plausible answers, question and answers. So what we want in here 00:13:47.280 |
is the question and answers. So question is our first one. 00:14:15.760 |
And then after we have extracted the question, we can move on to our answers. 00:14:20.480 |
As we see here, the answers comes as another list. Now each one of these lists all just have 00:14:28.080 |
one actual answer in there, which is completely fine. So we can access that in two ways. We can 00:14:33.680 |
either loop through or we can access the zero value of that array. Either way, it doesn't matter. 00:14:44.560 |
So all we need to do here is loop through those answers. 00:14:56.880 |
So in most cases, this should be completely fine. 00:15:05.440 |
As we can see here, most of these question and then they have the answers dictionary. 00:15:14.080 |
Which is fine. However, some of these are slightly different. 00:15:25.440 |
say, okay, we have this, which is talking about physics. 00:15:32.320 |
And then rather than having our answers array, we have these plausible answers, 00:15:41.520 |
which is obviously slightly different. And this is the case for a couple of those. 00:15:47.360 |
So from what I've seen, the states that the best way to deal with this 00:15:52.480 |
is simply to have a check. If there is a plausible answers key within the dictionary, 00:16:00.160 |
we will include that as the answer rather than the actual answers dictionary. 00:16:06.800 |
So to do that, all we need to do is check if QA 00:16:13.520 |
keys contains plausible answers. If it does, we use that. Otherwise, we use answers. 00:16:42.000 |
we will use answers. So let's just add all of that into our for loop here. 00:16:57.760 |
So we have our context, and then we want to loop through the question answers. 00:17:13.120 |
Then once we're here, we need to do something slightly different, which is this plausible 00:17:27.360 |
answers. Okay. And then we use this access variable in order to define what we're going 00:17:36.960 |
to loop through next. So here we go for answers, answer, sorry, in QA 00:17:46.960 |
access, because this will switch to implausible answers or answers. 00:17:54.160 |
And then within this for loop, this is where we can begin adding this context, question, 00:18:01.600 |
and answer to a list of questions, context, and answers that we still need to define up here. 00:18:09.520 |
So each one of these is just going to be an empty list. 00:18:24.320 |
and we just append everything that we've extracted in this loop and the context, 00:18:51.600 |
So now let's take a look at a few about context. Okay. We can see we have this, and because we 00:19:05.680 |
have multiple question answers for each context, the context does repeat over and over again. 00:19:10.800 |
But then we should see something slightly different when we go with answers. 00:19:20.080 |
And questions. Okay. So that's great. We have our data in a reasonable format now, 00:19:28.400 |
but we want to do this for both the training set and the validation set. So what we're going to do 00:19:36.560 |
is just going to put this into a function like we were going to do before, which is this read squad. 00:19:47.120 |
So here, we're going to read in our data, and then we run through it and transform it into 00:19:59.520 |
our three lists. And all we need to do now is actually return those three lists and answers. 00:20:13.600 |
So now what we can do is execute this function for both our training and validation sets. 00:20:20.400 |
So we're going to train context, questions, and answers. 00:20:39.120 |
Okay. So that is one of them, and we can just copy that. 00:20:53.840 |
And we just want this to be our validation set. 00:21:04.640 |
Okay. So that's great. We now have the training context and the 00:21:14.400 |
validation context, which we can see right here. 00:21:23.840 |
So here, let's hope that there is a slight difference in what we see between both. 00:21:29.360 |
Okay. Great. That's what we would expect. Okay. So now we have our data almost in the right format. 00:21:46.320 |
We just need to add the ending position. So we already have the start position. If we take a look 00:21:54.320 |
in our train answers. Okay. We have the answer start, but we also need the answer end, and that's 00:22:03.840 |
not included within the data. So what we need to do here is actually put that into a function 00:22:10.560 |
and that's not included within the data. So what we need to do here is actually define a function 00:22:17.760 |
that will go through each one of our answers and context and figure out where that ending 00:22:25.040 |
character actually is. And of course we could just say, okay, it's the length of the text. 00:22:32.080 |
We add that onto the answer start and we have our answer end. However, that unfortunately won't work 00:22:40.080 |
because some of the answer starts are actually incorrect and they're usually off by one or two 00:22:47.040 |
characters. So we actually need to go through and one, fix that and two, add our end indices. 00:22:55.040 |
So to do that, we're just going to define a new function, 00:23:02.720 |
which is going to be add end index. And here we will have our answers and the context, 00:23:13.440 |
and then we're going to just feed these in. So first thing we do is loop through each 00:23:19.360 |
answer and context pair. And then we extract something which is called the gold text, 00:23:39.200 |
which is essentially the answer that we are looking for. It's called the golden text or gold text. 00:23:47.760 |
So simply our answer and within that, the text. So we are pulling this out here. 00:23:58.160 |
So we should already know the starting index. So what we do here is simply pull that out as well. 00:24:14.800 |
And then the end index ideally will be the start plus the length of the gold text. 00:24:26.640 |
However, that's not always the case because like I said before, they can be off by one or two 00:24:34.560 |
characters. So we need to add in some logic just to deal with that. So in our first case, let's 00:24:43.040 |
assume that the characters are not off. So if context start to end equals the gold text, 00:25:03.680 |
this means everything is good and we don't need to worry about it. So we can modify the 00:25:12.000 |
original dictionary and we can add answer end into there. And we made that equal to our end index. 00:25:22.720 |
However, if that's not the case, that means we have a problem. It's one of those 00:25:29.200 |
dodgy question answer pairs. And so this time what we can do is we'll add a out statement. 00:25:40.000 |
So we're just going to go through when the position is off by one or two characters, 00:25:44.720 |
because it is not off by any more than that in the squad dataset. 00:25:48.400 |
Loop through each of those and we'll say, okay, if the context, 00:25:54.320 |
and then in here we need to add the start index and this again. So let's just copy and paste 00:26:02.640 |
side cross be easier. But this time we're checking to see if it is off by one or two characters. 00:26:10.000 |
So just do minus N and it's always minus N. It isn't shifted. It's always shifted to the left 00:26:17.040 |
rather than shifted to the right. So this is fine. So in this case, the answer is off by N tokens. 00:26:28.640 |
And so we need to update our answer, start value, and also add our answer end value. 00:26:45.840 |
So start index minus N and we also have the end. 00:26:59.040 |
So that's great. We can take that and we can apply it to our train and validation sets. 00:27:10.240 |
So all we do here is call the function and we just see train answers and train context. 00:27:22.000 |
And of course we can just copy this and do the same for our validation set. 00:27:38.960 |
Okay. Perfect. So now if we have a quick look, we should be able to see that we have 00:27:57.680 |
And that means we can move on to actually encoding our text. 00:28:02.640 |
To tokenize or encode our text, this is where we bring in a tokenizer. So we need to import 00:28:18.960 |
the transformers library for this. And from transformers, we are going to import the distilbert. 00:28:26.640 |
So distilbert is a smaller version of BERT, which is just going to run a bit quicker, 00:28:32.960 |
but it will take a very long time. And we're going to import the fast version of this tokenizer 00:28:40.560 |
because this allows us to more easily adjust our character and then start locations to token 00:28:49.040 |
end and start locations later on. So first we need to actually initialize our tokenizer. 00:28:55.520 |
Which is super easy. All we're doing is loading it from a pre-trained model. 00:29:16.400 |
Okay. And then all we do to create our encodings is to call the tokenizer. 00:29:34.640 |
which is called tokenizer. And in here, we include our training context. 00:29:43.600 |
And the training questions. So what this will do 00:29:52.080 |
is actually merge these two strings together. So what we will have is our context, 00:30:03.200 |
and then there will be a separator token followed by the question. And this will be fed 00:30:08.240 |
into distilbert during training. I just want to add padding there as well. 00:30:16.720 |
And then we'll copy this and do the same for our validation set. 00:30:37.120 |
Okay. And this will convert our data into encoding objects. 00:30:41.520 |
So what we can do here is print out different parts that we have within our encodings. 00:30:52.720 |
So in here, so you have the input IDs. So let's access that. 00:31:02.800 |
And you'll find in here, we have a big list of all of our samples. So check that we have 130K. 00:31:11.680 |
And let's open one of those. Okay. And we have these token IDs, and this is what Bert will be 00:31:20.720 |
reading. Now, if we want to have a look at what this actually is in sort of human readable 00:31:27.920 |
language, we can use the tokenizer to just decode it for us. 00:31:32.640 |
Okay. And this is what we're feeding in. So we have a couple of these special tokens. This just 00:31:45.040 |
means it's the start of a sequence. And in here, we have a process form of our original context. 00:31:55.120 |
Now, you'll find that the context actually ends here. And like I said before, we have this 00:32:00.400 |
separator token. And then after that, we have our actual question. And this is what is being fed 00:32:08.720 |
into Bert, but obviously the token ID version. So it's just good to be aware of what is actually 00:32:15.920 |
being fed in and what we're actually using here. But this is a format that Bert is expecting. 00:32:22.400 |
And then after that, we have another separator token followed by all of our padding tokens, 00:32:26.560 |
because Bert is going to be expecting 512 tokens to be fed in for every one sample. 00:32:34.800 |
So we just need to fill that space essentially. So that's all that is doing. 00:32:39.200 |
So let's remove those and we can continue. So the next thing we need to add to our encodings 00:32:50.640 |
is the start and end positions, because at the moment, we just don't have them in there. 00:32:57.200 |
So to do that, we need to add an additional bit of logic. We use this character to token method. 00:33:19.360 |
And what we can do is actually modify this to use the character to token method, 00:33:29.440 |
remove the input IDs, because we just need to pass it the index of whichever encoding 00:33:36.640 |
we are wanting to modify or get the start and end position of. And in here, all we're doing is 00:33:45.040 |
converting from the character that we have found a position for to the token that we want to find 00:33:51.280 |
a position for. And what we need to add is train answers. We have our position again, because the 00:34:01.360 |
answers and encodings, the context in question needs to match up to the answer, of course, 00:34:07.120 |
that we're asking about. And we do answers start. So here, we're just feeding in the position of the 00:34:16.640 |
character. And this is answer. OK. So feeding in position of the character, 00:34:23.920 |
and we're expecting to return the position of the token, which is position 64. 00:34:31.760 |
So all we need to do now is do this for both of those. So for the start position and end position. 00:34:46.640 |
OK, but this is one of the limitations of this. Sometimes this is going to return nothing. As 00:34:57.680 |
you can see, it's not returning anything here. And that is because sometimes it is actually 00:35:03.760 |
returning the space. And when it looks at the space and the tokenizer see that, and they say, 00:35:10.880 |
OK, that's nothing. We're not concerned about spaces. And it returns this non-value that you 00:35:15.520 |
can see here. So this is something that we need to consider and build in some added logic for. 00:35:25.200 |
So to do that, again, we're going to use a function to contain all of this. 00:35:36.880 |
Yeah, we'll have our encodings and our answers. And then we just modify this code. So we have the 00:35:47.200 |
encodings. We have the answers. And because we're collecting all of the token positions, 00:35:56.800 |
we also need to initialize a list to contain those. So we do startPositions, emptyList, and endPositions. 00:36:12.320 |
And now we just want to loop through every single answer and encoding that we have. 00:36:17.600 |
Like so. And here we have our startPositions. So we need to append that to our startPositionsList. 00:36:38.800 |
Then we just do the same for our endPositions, which is here. 00:36:43.040 |
Now, here we can deal with this problem that we had. So if we find that the endPositions, 00:36:58.400 |
the most recent one, so the negative one index, is none, that means it wasn't found. And it means 00:37:07.040 |
there is a space. So what we do is we change it to instead use the -1 version. And all this needs 00:37:16.000 |
to do is update the endPositions here. OK, that's great. But in some cases, this also happens with 00:37:26.080 |
the startPosition. But that is for a different reason. The reason it will occasionally happen 00:37:31.680 |
with startPosition is when the passage of data that we're adding in here, so you saw before we 00:37:39.040 |
had the context, a separated token, and then the question, sometimes the context passage is 00:37:46.160 |
truncated in order to fit in the question. So some of it will be cut off. And in that case, 00:37:54.400 |
we do have a bit of a problem. But we still need to just allow our code to run without any problems. 00:38:01.520 |
So what we do is we just modify the startPositions again, 00:38:08.160 |
just like we did with the endPositions. Obviously, only if it's none. And we just set it 00:38:16.320 |
to be equal to the maximum length that has been defined by the tokenizer. 00:38:26.240 |
And it's as simple as that. Now, the only final thing we need to do, which is because we're using 00:38:33.760 |
the encodings, is actually update those encodings to include this data. Because as of yet, we haven't 00:38:41.280 |
added that back in. So to do that, we can use this quite handy update method. 00:38:50.640 |
And just add in our data as a dictionary. So we have the startPositions, 00:38:56.240 |
startPositions. And then we also have our endPositions. 00:39:06.240 |
And then, again, we just need to apply this to our training. 00:39:18.400 |
We just need to apply this to our training and validation sets. And let's just modify that. 00:39:26.080 |
Let's add the train encodings here and the train answers. 00:40:01.680 |
And here we can see, great, now we have those startPositions and endPositions. 00:40:09.600 |
We can even have a quick look at what they look like. 00:40:16.000 |
And what we've done is actually not included the index here. So we're just taking it for 00:40:23.440 |
the very first item every single time. So let's just update that. 00:40:39.280 |
And now this should look a little bit better. So it's lucky we checked. 00:40:47.280 |
Okay, so our data, our training, and our validation sets are now up to date. 00:41:02.560 |
Okay, so our data at the moment is in the right format. We just need to use it to create a 00:41:12.320 |
PyTorch dataset object. So to do that, obviously, we need to import PyTorch. 00:41:22.960 |
And we define that dataset using a class. And just pass in the TorchUtilsDataDataset. 00:41:47.840 |
And this is coming from the Houdinface Transformers documentation. Don't take credit for this. 00:42:15.280 |
And we essentially need to do this so that we can load in our data using the 00:42:19.680 |
PyTorch data loader later on, which makes things incredibly easy. 00:42:44.160 |
And then we just have one more function here, one method. 00:43:04.560 |
Okay, and we just need to return and also this as well. That should be okay. 00:43:12.640 |
So we apply this to our datasets to create dataset objects. 00:43:28.400 |
We have our encodings and then the same again for the validation set. 00:43:33.280 |
Okay, so that is our data almost fully prepared. All we do now is load it into a data loader 00:43:47.760 |
object. But this is everything on the data side done, which is great because I know it does 00:43:55.520 |
take some time and I know it's not the most interesting part of it, but it's just something 00:44:00.160 |
that we need to do and need to understand what we're doing as well. So now we get to 00:44:09.760 |
the more interesting bit. So we'll just add the imports in here. 00:44:26.800 |
We're going to import the atom optimizer with weighted decay, which is pretty commonly used 00:44:41.680 |
for transformer models when you are fine tuning. Because transformer models are generally very 00:44:48.880 |
large models and they can overfit very easily. So this atom optimizer with weighted decay 00:44:56.960 |
essentially just reduces the chances of that happening, 00:45:01.520 |
which is supposed to be very useful and quite important. So obviously we're going to use that. 00:45:18.160 |
So TQDM is a progress bar that we are going to be using so that we can actually see 00:45:30.400 |
the progress of our training. Otherwise, we're just going to be sat there for 00:45:35.440 |
probably quite a long time, not knowing what is actually happening. And trust me, 00:45:40.800 |
it won't take long before you start questioning whether anything is happening, because it takes 00:45:45.200 |
a long time to train these models. So they are our imports and I'm being stupid again here. 00:45:54.400 |
That's from, did that twice. Okay, so that's all good. So now we just need to do a few 00:46:01.520 |
little bits for the setup. So we need to tell PyTorch whether we're using CPU or GPU. 00:46:10.400 |
In my case, it will be a GPU. If you're using CPU, this is going to take you a very long time 00:46:16.720 |
to train. And it's still going to take you a long time on GPU. So just be aware of that. 00:46:22.640 |
But what we're going to do here is say device. 00:46:35.680 |
Otherwise, we are going to use the CPU. And good luck if that is what you're doing. 00:46:52.080 |
So once we've defined the device, we want to move our model over to it. 00:47:02.720 |
So we just model.toDevice. So this .to method is essentially a way of transferring data between 00:47:11.200 |
different hardware components, so your CPU or GPU. It's quite useful. And then we want to 00:47:17.840 |
activate our model for training. So there's two things we have here. So we have .train and eval. 00:47:27.440 |
So when we're in train mode, there's a lot of different layers and different parts of your 00:47:32.880 |
model that will behave differently depending on whether you are using the model for training 00:47:38.320 |
or you're using it for inference, which is predictions. So we just need to make sure our 00:47:44.000 |
model is in the right mode for whatever we're doing. And later on, we'll switch it to eval 00:47:50.160 |
to make some predictions. So that's almost everything. So we just need to initialize 00:47:57.520 |
the optimizer. And here, we're using the weighted decay atom optimizer. 00:48:03.360 |
We need to pass in our model parameters and also give it a learning rate. And we're just going to 00:48:15.840 |
use this value here. All of these are the recommended parameters for what we are doing here. 00:48:23.840 |
So the one thing that I have somehow missed is defining the, actually initializing the model. 00:48:35.280 |
So let's just add that in. And all we're doing here is loading, again, a pre-trained one. 00:48:42.560 |
So like we did before when we were loading the Transformers tokenizer. 00:48:47.280 |
This time, it's for question answering. So this distilbert of question answering is a distilbert 00:49:03.360 |
model with a question and answering head added onto the end of it. So essentially with Transformers, 00:49:10.720 |
you have all these different heads that you add on. And they will do different things depending on 00:49:16.000 |
what head it has on there. So let's initialize that from pre-trained. 00:49:26.720 |
And we're using the same one we use up here, which is distilbert base uncased. 00:49:37.680 |
And sometimes you will need to download that. Fortunately, I don't need to as I've already 00:49:43.600 |
done that. But this can also take a little bit of time, not too long though. And you get a nice 00:49:48.160 |
progress bar, hopefully, as well. Okay, so now that is all settled, we can initialize our data 00:49:57.120 |
loader. So all we're doing here is using the PyTorch data loader object. And we just pass in 00:50:15.520 |
our training data set, the batch size. So how many we want to train on at once in parallel before 00:50:23.120 |
updating the model weights, which will be 16. And we also would like to shuffle the data because we 00:50:31.200 |
don't want to train the model on a single batch and it just learned about Beyonce. And then the 00:50:35.680 |
next one is learning about Chopin. And it will keep switching from batch to batch. So let's 00:50:42.160 |
move on. It's learning about Chopin. And it will keep switching between those. But never 00:50:48.000 |
within a single batch, having a good mix of different things to learn about. 00:50:52.640 |
So it is data sets. Seems a bit of a weird name to me, so I'm just going to change it. 00:51:03.760 |
And I also can't spell. There we go. And that is everything we can actually begin 00:51:22.000 |
our training loop. So we're going to go for three pups. 00:51:33.200 |
And what we want to start with here is a loop object. So we do this mainly because we're using 00:51:42.560 |
TQDM as a progress bar. Otherwise, we wouldn't need to do this. There'd be no point in doing it. 00:51:49.680 |
And all this is doing is kind of like pre-initializing our loop that we are going 00:51:56.960 |
to go through. So we're going to obviously loop through every batch within the train loader. So we 00:52:02.880 |
just add that in here. And then there's this other parameter, which I don't know if we... 00:52:10.160 |
So let's leave it. But essentially, you can add leave equals true in order to leave your progress 00:52:17.520 |
bar in the same place with every epoch. Whereas at the moment, with every epoch, what it will do is 00:52:23.600 |
create a new progress bar. We are going to create a new progress bar. But if you don't want to do 00:52:28.880 |
that and you want it to just stay in the same place, you add leave equals true into this function 00:52:36.000 |
here. So after that, we need to go through each batch within our loop. And the first thing that 00:52:46.320 |
we need to do is set all of our calculated gradients to zero. So with every iteration 00:52:56.320 |
that we go through here or every batch, at the end of it, we are going to calculate gradients, 00:53:00.720 |
which tells the model in which direction to change the weights within the model. 00:53:07.200 |
And obviously, when we go into the next iteration, we don't want those gradients to still be there. 00:53:14.320 |
So all we're doing here is re-initializing those gradients at the start of every 00:53:19.920 |
loop. So we have a fresh set of gradients to work with every time. 00:53:28.320 |
So this is everything that is relevant that we're going to be feeding into the training process. 00:53:34.880 |
So we have everything within our batch. And then in here, we have 00:53:42.400 |
all of our different items. So we can actually see-- go here. 00:53:50.960 |
We want to add in all of these. And we also want to move them across to the 00:54:04.000 |
GPU, in my case, or whatever device you are working on. 00:54:12.080 |
And we'll do that for the attention mask start positions and end positions. 00:54:37.120 |
So these start and end positions are essentially the labels, 00:54:41.200 |
they're the targets that we want our model to optimize for. 00:54:44.400 |
And the input IDs and attention masks are the inputs. 00:55:06.880 |
So now we have those defined. We just need to feed them into our model for training. 00:55:11.600 |
And we will output the results of that training batch to the outputs variable. 00:55:19.280 |
So our model, input IDs, need the attention mask. 00:55:36.560 |
And we also want our start positions and end positions. 00:55:42.880 |
Now, from our training batch, we want to extract the loss. 00:56:02.480 |
And then we want to calculate loss for every parameter. 00:56:13.120 |
And then we use the step method here to actually update those gradient updates. 00:56:26.080 |
And then we use the step method here to actually update those gradients. 00:56:30.400 |
And then this final little bit here is purely for us to see, this is our progress bar. 00:56:38.640 |
So we call it a loop. We set the description, which is going to be our epoch. 00:56:44.480 |
And then it would probably be quite useful to also see the loss in there as well. 00:56:48.720 |
We will set that as a postfix. So it will appear after the progress bar. 00:57:06.960 |
Okay, and that should be everything. Okay, so that looks pretty good. We have our model training. 00:57:14.160 |
And as I said, this will take a little bit of time. So I will let that run. 00:57:21.280 |
And then we will go back to the model and we will run the training batch. 00:57:33.920 |
Okay, so we have this non-type error here. And this is because within our mPositions, 00:57:48.240 |
we will normally expect integers, but we're also getting some non-values because the 00:57:52.880 |
code that we used earlier, where we're checking if mPosition is non, essentially wasn't good enough. 00:57:59.760 |
So as a fix for that, we'll just go back and we'll add like a while loop, which will keep 00:58:06.240 |
checking if it's non. And every time it is non, reduce the value that we are seeing by one. 00:58:12.480 |
So go back up here, and this is where the problem is coming from. 00:58:21.040 |
So we're just going to change this to be a while loop. 00:58:29.520 |
And just initialize essentially a counter here. 00:58:32.800 |
And we'll use this as our go back value. And every time the mPosition is still non, 00:58:45.600 |
we'll just add one to go back. And this should work. 00:58:51.680 |
So we need to remember to rerun anything we need to rerun. Yeah. 00:59:18.720 |
Okay. And that looks like it solved the issue. So great. We can just leave that 00:59:25.920 |
training for a little while and I will see you when it's done. 00:59:29.920 |
Okay. So the model's finished and we'll go ahead and just save it. 00:59:38.240 |
So obviously we'll need to do that whenever actually doing this on any other projects. 00:59:47.120 |
So I'm just going to call it the Silbert custom. 00:59:49.120 |
And it's super easy to save. We just do save pre-trained and the model path. 01:00:00.160 |
Now, as well as this, we might also want to save the tokenizer so we have everything in one place. 01:00:10.080 |
So to do that, we also just use tokenizer and save pre-trained again. 01:00:16.160 |
Okay. So if we go back into our folder here, see models, and we have this Silbert custom. And then 01:00:29.680 |
in here we have all of the files we need to build our PyTorch model. It's a little bit different if 01:00:37.760 |
we're using TensorFlow, but the actual saving process is practically the same. So now we've 01:00:45.680 |
finished training. We want to switch it out of the training mode. So we use a model eval. 01:00:51.840 |
And we just get all this information about our model as well. We don't actually need any of that. 01:01:00.480 |
And just like before, we want to create a data loader. So for that, I'm just going to call it 01:01:08.240 |
ValLoader. And it's exactly the same code as before. In fact, it's probably better if we just 01:01:13.760 |
copy and paste some of this. At least the loop. So what we're going to do here is 01:01:24.960 |
take the same loop and apply it as a validation run with our validation data. 01:01:31.920 |
Just paste up there. We'll initialize this data loader. This time, of course, with the validation 01:01:47.440 |
Now, this time, we do want to keep a log of accuracy. So we will keep that there. And we 01:01:55.680 |
also don't need to run multiple epochs because we're not training this time. We're just running 01:02:00.640 |
through all of the batches within our loop of validation data. So this is now a validation 01:02:07.760 |
loader. And we just loop through each of those batches. So we don't need to do anything with 01:02:13.920 |
the gradients here. And because we're not doing anything with the gradients, we actually 01:02:21.360 |
have this in to stop PyTorch from calculating any gradients. Because this will obviously 01:02:29.440 |
save us a bit of time when we're processing all of this. And we put those in there. The 01:02:37.760 |
outputs, we do still want this. But of course, we don't need to be putting in the start and 01:02:44.560 |
end positions. So we can remove those. And this time, we want to pull out the start prediction 01:02:53.360 |
and end prediction. So if we have a look at what our outputs look like before, you see 01:03:02.000 |
we have this model output. And within here, we have a few different tensors which each 01:03:09.680 |
have an accessible name. So the ones that we care about are start logits. And that will 01:03:21.520 |
give us the logits for our start position, which is essentially like a set of predictions 01:03:30.880 |
where the highest value within that vector represents the token ID. So we can do that 01:03:39.200 |
for both. And you'll see we get these tensors. Now, we only want the largest value in each 01:03:50.640 |
one of these vectors here, because that will give us the input ID. So to get that, we use 01:03:57.800 |
the argmax function. And if we just use it by itself, that will give us the maximum index 01:04:09.480 |
within the whole thing. But we don't want that. We want one for every single vector 01:04:15.000 |
or every row. And to do that, we just set dim equal to 1. And there you go. We get a 01:04:23.480 |
full batch of outputs. So these are our starting positions. And then we also want to do the 01:04:30.760 |
same for our ending positions. So we just change start to end. So it's pretty easy. 01:04:43.500 |
Now obviously, we want to be doing this within our loop, because this is only doing one batch. 01:04:50.860 |
And we need to do this for every single batch. So we're just going to assign them to a variable. 01:05:09.260 |
And there we have our predictions. And all we need to do now is check for an exact match. 01:05:17.680 |
So what I mean by exact match is we want to see whether the start positions here, which 01:05:24.320 |
we can rename to the true values, whether these are equal to the predicted values down 01:05:33.260 |
here. And to calculate that, so let me just run this so we have one batch. That shouldn't 01:05:46.060 |
take too long to process, and we can just write the code out. So to check this, we just 01:05:51.700 |
use the double equal syntax here. And this will just check for matches between two arrays. 01:06:03.460 |
So we have the startPredictions and the startTrueValues. So we'll check for those. 01:06:22.340 |
So if we just have a look at what we have here, we get this array of true or false. 01:06:28.540 |
So these ones don't look particularly good. But that's fine. We just want to calculate 01:06:33.020 |
the accuracy here. So let's take the sum, and we also want to divide that by the length. 01:06:52.340 |
Okay so that will give us our accuracy within the tensor. And we just take it out using 01:06:57.660 |
the item method. But we also just need to include brackets around this, because at the 01:07:03.540 |
moment we're trying to take item of the length value. Okay and then that gives us our very 01:07:10.660 |
poor accuracy on this final batch. So we can take that, and within here we want to append 01:07:18.060 |
that to our accuracy list. And then we also want to do that for the endPrediction as well. 01:07:45.700 |
And we'll just let that run through, and then we can calculate our accuracy from the end 01:07:50.580 |
of that. And then we can have a quick look at our accuracy here. And we can see fortunately 01:07:59.020 |
it's not as bad as it first seemed. So we're getting a lot of 93, 100%, 81%. That's generally 01:08:09.300 |
pretty good. So of course if we want to get the overall accuracy, all we do is sum that 01:08:22.420 |
and divide by the length. And we get 63.6% for an exact match accuracy. So what I mean 01:08:35.180 |
by exact match is, say if we take a look at a few of these that do not match. So we have 01:08:43.900 |
a 75% match on the fourth batch. Although that won't be particularly useful, because 01:08:52.940 |
we can't see that batch right now. So let's just take the last batch, because we have 01:08:56.820 |
these values here. Now if we look at what startTrue is, we get these values. Then if 01:09:06.600 |
we look at startPred, we get this. So none of these match, but a couple of them do get 01:09:18.140 |
pretty close. So these final four, all of these count as 0% on the exact match. But 01:09:26.560 |
in reality, if you look at what we predicted, for every single one of them, it's predicting 01:09:32.220 |
just one token before. So it's getting quite close, but it's not an exact match. So it 01:09:36.380 |
scores 0. So when you consider that with our 63.6% accuracy here, that means that this 01:09:45.940 |
model is actually probably doing pretty well. It's not perfect, of course, but it's doing 01:09:50.860 |
pretty well. So overall, that's everything for this video. We've gone all the way through 01:09:58.540 |
this. If you do want the code for this, I'm going to make sure I keep a link to it in 01:10:03.580 |
the description. So check that out if you just want to copy this across. But for now, 01:10:10.180 |
that's everything. So thank you very much for watching, and I will see you again soon.