back to index

How to Build Custom Q&A Transformer Models in Python


Chapters

0:0
1:4 Download Our Data
1:9 Squad Data Set
6:40 Data Prep
28:53 Initialize Our Tokenizer
41:10 Create a Pi Torch Dataset Object
47:56 Initialize the Optimizer
48:33 Initializing the Model
50:3 Initialize Our Data Loader
51:22 Training Loop
58:24 While Loop
61:3 Create a Data Loader
65:15 Check for an Exact Match
66:31 Calculate the Accuracy
67:47 Accuracy
68:31 Exact Match Accuracy

Whisper Transcript | Transcript Only Page

00:00:00.000 | Hi and welcome to the video. Today we're going to go through how we can fine-tune a QnA transform
00:00:06.960 | model. So for those of you that don't know QnA just means question and answering and it's one
00:00:14.320 | of the biggest topics in NLP at the moment. There's a lot of models out there where you
00:00:20.720 | ask a question and it will give you an answer. And one of the biggest things that you need to know
00:00:30.400 | how to do when you are working with transformers, whether that's QnA or any of the other
00:00:35.280 | transformer-based solutions, is how to actually fine-tune those. So that's what we're going to be
00:00:43.440 | doing in this video. We're going to go through how we can fine-tune a QnA transformer model
00:00:49.360 | in Python. So I think it's really interesting and I think you will enjoy it a lot. So let's
00:00:58.320 | just go ahead and we can get started. Okay so first thing we need to do is actually
00:01:04.640 | download our data. So we're going to be using the squad data set which is the
00:01:11.360 | Sanford question answering data set which is essentially one of the better known QnA
00:01:17.760 | data sets out there that we can use to fine-tune our model. So let's first create a folder.
00:01:26.560 | I'm just going to use os
00:01:29.920 | and os make data. We'll just call it squad. Obviously call this and organize it as you want.
00:01:41.840 | This is what I will be doing. Now the URL that we are going to be downloading this from is this.
00:01:52.000 | Okay and there are actually two files here that we're going to be downloading
00:01:55.600 | but both will be coming from the same URL.
00:01:58.640 | So because we're making a request to a URL we're going to import requests.
00:02:04.480 | We can also use the wget library as well or if you're on Linux you can just use wget
00:02:12.240 | directly in the terminal. It's up to you what we're going to be using requests.
00:02:17.840 | Okay and to request our data we're going to be doing this. So it's just a get request.
00:02:30.160 | Use a f string and we have the URL that we've already defined.
00:02:38.880 | And then the training data that we'll be using is this file here.
00:02:44.080 | Okay requests.
00:02:52.800 | Okay and we can see that we've successfully pulled that data in there. Okay so like I said
00:03:03.120 | before there's actually two of these files that we want to extract. So what I'm going to do is
00:03:10.240 | just put this into a for loop which will go through both of them. Just copy and paste this across.
00:03:19.280 | Rename this file.
00:03:30.400 | And the other file is the same but instead of train we have dev.
00:03:34.560 | Okay so here we're making our request.
00:03:37.200 | And then the next thing we want to do after making our request is actually
00:03:44.480 | saving this file to our drive. Okay and we want to put that inside this squad folder here.
00:03:53.120 | So to do that we use open. And again we're going to use a f string here.
00:04:00.320 | And we want to put inside the squad folder here. And then here we are just going to put our file
00:04:11.040 | name which is file. Now we're writing this in binary because it's JSON so we put wb for our
00:04:19.360 | flags here. And then within this namespace we are going to run through the file and download it in
00:04:28.480 | chunks. So we do for chunk and then we iterate through the response like this. Let's use a chunk
00:04:42.160 | size of four. And then we just want to write to the file like that. So that will download both
00:04:54.640 | files. Just add the colon there. So that will download both files. We should be able to see
00:05:02.080 | them here now. So in here we have data, we have essentially a lot of different topics. So the
00:05:08.560 | first one is Beyonce. And then in here we will see, if we just come to here, we get a context.
00:05:16.240 | But alongside this context we also have QAS which is question and answers. And each one of these
00:05:25.040 | contains a question and answer pair. So we have this question, when did Beyonce start becoming
00:05:33.200 | popular? So this answer is actually within this context. And what we want our model to do is
00:05:39.040 | extract the answer from that context by telling us the start and end token of the answer within
00:05:46.240 | that context. So we go zero and it is in the late 1990s. And we have answer start 269. So
00:05:56.080 | that means that at character 269, we get I. So if we go through here, we can find it here.
00:06:05.440 | Okay, so this is the extract. And that's what we will be aiming for our model to actually extract.
00:06:12.400 | But there will be a start point and also the end point as well, which is not included in here,
00:06:18.080 | but we will add that manually quite soon. So that's our data. And then we'll also be testing
00:06:25.600 | on the dev data as well, which is exactly the same. Okay, so we move on to the data prep.
00:06:42.080 | So now we have our files here, we're going to want to read them in. So we're going to use
00:06:49.280 | the JSON library for that. And like we saw before, there's quite a complex structure in these
00:06:57.440 | JSONs, there's a lot of different layers. So we need to use a few for loops to fill through each
00:07:03.360 | of these and extract what we want, which is the context, questions and answers, all corresponding
00:07:11.280 | to each other. So in the end, we're going to have lists of strings, which is going to be all of
00:07:15.680 | these. And in the case of the answers, we will also have the starting position. So it will be a
00:07:20.080 | list of dictionaries, where one value is a text and one value is the starting position. So to do
00:07:29.600 | that, we're going to define a function called read squad. We'll define our path here as well.
00:07:41.600 | And the first thing we need to do is actually open the JSON file. So we do with open path.
00:07:49.040 | And again, we are using a binary file. So we're going to have B
00:07:54.320 | as a flag. But we're instead of writing, we are reading so use our here. So our B.
00:08:10.400 | And just do JSON load F here. So now we have our dictionary within this squad dict here. So
00:08:19.120 | maybe whilst we're just building this function up, it's probably more useful to
00:08:23.680 | do it here. So you can see what we're actually doing. So let's copy that across.
00:08:33.360 | And then we'll fill this out afterwards.
00:08:34.880 | Of course, we do actually need to include the path. So let's take this.
00:08:56.240 | And then we can see what's inside here.
00:08:57.760 | Maybe we can load just a few rather than all of them.
00:09:09.280 | Or we can investigate it like this. Okay, so we have the version and data, which we can actually
00:09:21.760 | see over here, version and data. So we want to access the data. Now within data is we have a
00:09:28.880 | list of all these different items, which is what I was trying to do before. So we go into data.
00:09:35.120 | And just take a few of those. Okay, and then we get our different sections.
00:09:47.600 | For the first one, let's just take zero, which is Beyonce. And then we have all of these.
00:09:55.120 | So we're going to want to loop through each one of these, because we have this one next,
00:10:03.440 | and we're going to keep needing to just run through all of these.
00:10:10.560 | So to do that, we want to do for group
00:10:15.920 | in squad dict. And remember, we need to include the data here.
00:10:24.240 | Let's just see our say group title. So we can see a few of those. Okay,
00:10:34.320 | then go through each one of those. So the second part of that are these,
00:10:40.800 | these paragraphs. And when the paragraphs we have each one of our questions. So let's first
00:10:52.240 | go with paragraphs. And we'll do the chop in here.
00:10:56.800 | Sorry, it's a list. There we go.
00:11:05.840 | And the first thing we need to extract is the easiest one, which is our context.
00:11:17.680 | However, that is also within a list. So now if we access the context, we get this.
00:11:26.640 | So we're essentially going to need to jump through or loop through each one of these here.
00:11:35.280 | Now we're gonna need to access the paragraphs and loop through each one of those. And then here,
00:11:40.880 | we're going to access the context. So let's write that. So we already have one group here. So let's
00:11:49.600 | just stick with that. And we're going to run through the passage in the paragraphs. So already
00:12:01.440 | here, we're going through the for loop on this index. And now we're going to go through a loop
00:12:07.360 | on this index. Let's keep that. So that means that we will be able to print the passage
00:12:23.120 | context. And there we go. So here we have all of our context. So that's
00:12:32.320 | one of our three items that we need to extract. Okay, so that's great. Let's put that all together.
00:12:39.600 | So we're going to take this, put it here. And then we have our context.
00:12:48.560 | Okay, that's great. Obviously, for each context, we have a few different questions and answers.
00:12:58.800 | So we need to get those as well. Now, that requires us to go through another for loop.
00:13:06.000 | So let's go this passage, we need to go into the QAS key and loop through this list of
00:13:14.320 | question and answers. So we have this, and then we have our list. So another layer in our for loop
00:13:25.680 | will be for question answer in that passage QAS. And then let's take a look at what we have there.
00:13:37.120 | Okay, great. So we have plausible answers, question and answers. So what we want in here
00:13:47.280 | is the question and answers. So question is our first one.
00:13:52.080 | Perfect. So we have the questions now.
00:14:15.760 | And then after we have extracted the question, we can move on to our answers.
00:14:20.480 | As we see here, the answers comes as another list. Now each one of these lists all just have
00:14:28.080 | one actual answer in there, which is completely fine. So we can access that in two ways. We can
00:14:33.680 | either loop through or we can access the zero value of that array. Either way, it doesn't matter.
00:14:44.560 | So all we need to do here is loop through those answers.
00:14:49.120 | Or if we want, just go with QA answers zero.
00:14:56.880 | So in most cases, this should be completely fine.
00:15:05.440 | As we can see here, most of these question and then they have the answers dictionary.
00:15:14.080 | Which is fine. However, some of these are slightly different.
00:15:20.160 | So if we scroll right down to the end here,
00:15:25.440 | say, okay, we have this, which is talking about physics.
00:15:32.320 | And then rather than having our answers array, we have these plausible answers,
00:15:41.520 | which is obviously slightly different. And this is the case for a couple of those.
00:15:47.360 | So from what I've seen, the states that the best way to deal with this
00:15:52.480 | is simply to have a check. If there is a plausible answers key within the dictionary,
00:16:00.160 | we will include that as the answer rather than the actual answers dictionary.
00:16:06.800 | So to do that, all we need to do is check if QA
00:16:13.520 | keys contains plausible answers. If it does, we use that. Otherwise, we use answers.
00:16:30.960 | Okay. Then we use this one. Otherwise,
00:16:42.000 | we will use answers. So let's just add all of that into our for loop here.
00:16:57.760 | So we have our context, and then we want to loop through the question answers.
00:17:02.320 | And this is where we get our question.
00:17:13.120 | Then once we're here, we need to do something slightly different, which is this plausible
00:17:27.360 | answers. Okay. And then we use this access variable in order to define what we're going
00:17:36.960 | to loop through next. So here we go for answers, answer, sorry, in QA
00:17:46.960 | access, because this will switch to implausible answers or answers.
00:17:54.160 | And then within this for loop, this is where we can begin adding this context, question,
00:18:01.600 | and answer to a list of questions, context, and answers that we still need to define up here.
00:18:09.520 | So each one of these is just going to be an empty list.
00:18:22.160 | And then all we do, copy this across,
00:18:24.320 | and we just append everything that we've extracted in this loop and the context,
00:18:35.840 | question, and answer.
00:18:50.640 | And that should work.
00:18:51.600 | So now let's take a look at a few about context. Okay. We can see we have this, and because we
00:19:05.680 | have multiple question answers for each context, the context does repeat over and over again.
00:19:10.800 | But then we should see something slightly different when we go with answers.
00:19:20.080 | And questions. Okay. So that's great. We have our data in a reasonable format now,
00:19:28.400 | but we want to do this for both the training set and the validation set. So what we're going to do
00:19:36.560 | is just going to put this into a function like we were going to do before, which is this read squad.
00:19:47.120 | So here, we're going to read in our data, and then we run through it and transform it into
00:19:59.520 | our three lists. And all we need to do now is actually return those three lists and answers.
00:20:13.600 | So now what we can do is execute this function for both our training and validation sets.
00:20:20.400 | So we're going to train context, questions, and answers.
00:20:39.120 | Okay. So that is one of them, and we can just copy that.
00:20:53.840 | And we just want this to be our validation set.
00:21:01.280 | Like so.
00:21:04.640 | Okay. So that's great. We now have the training context and the
00:21:14.400 | validation context, which we can see right here.
00:21:23.840 | So here, let's hope that there is a slight difference in what we see between both.
00:21:29.360 | Okay. Great. That's what we would expect. Okay. So now we have our data almost in the right format.
00:21:46.320 | We just need to add the ending position. So we already have the start position. If we take a look
00:21:54.320 | in our train answers. Okay. We have the answer start, but we also need the answer end, and that's
00:22:03.840 | not included within the data. So what we need to do here is actually put that into a function
00:22:10.560 | and that's not included within the data. So what we need to do here is actually define a function
00:22:17.760 | that will go through each one of our answers and context and figure out where that ending
00:22:25.040 | character actually is. And of course we could just say, okay, it's the length of the text.
00:22:32.080 | We add that onto the answer start and we have our answer end. However, that unfortunately won't work
00:22:40.080 | because some of the answer starts are actually incorrect and they're usually off by one or two
00:22:47.040 | characters. So we actually need to go through and one, fix that and two, add our end indices.
00:22:55.040 | So to do that, we're just going to define a new function,
00:23:02.720 | which is going to be add end index. And here we will have our answers and the context,
00:23:13.440 | and then we're going to just feed these in. So first thing we do is loop through each
00:23:19.360 | answer and context pair. And then we extract something which is called the gold text,
00:23:39.200 | which is essentially the answer that we are looking for. It's called the golden text or gold text.
00:23:47.760 | So simply our answer and within that, the text. So we are pulling this out here.
00:23:58.160 | So we should already know the starting index. So what we do here is simply pull that out as well.
00:24:14.800 | And then the end index ideally will be the start plus the length of the gold text.
00:24:26.640 | However, that's not always the case because like I said before, they can be off by one or two
00:24:34.560 | characters. So we need to add in some logic just to deal with that. So in our first case, let's
00:24:43.040 | assume that the characters are not off. So if context start to end equals the gold text,
00:25:03.680 | this means everything is good and we don't need to worry about it. So we can modify the
00:25:12.000 | original dictionary and we can add answer end into there. And we made that equal to our end index.
00:25:22.720 | However, if that's not the case, that means we have a problem. It's one of those
00:25:29.200 | dodgy question answer pairs. And so this time what we can do is we'll add a out statement.
00:25:40.000 | So we're just going to go through when the position is off by one or two characters,
00:25:44.720 | because it is not off by any more than that in the squad dataset.
00:25:48.400 | Loop through each of those and we'll say, okay, if the context,
00:25:54.320 | and then in here we need to add the start index and this again. So let's just copy and paste
00:26:02.640 | side cross be easier. But this time we're checking to see if it is off by one or two characters.
00:26:10.000 | So just do minus N and it's always minus N. It isn't shifted. It's always shifted to the left
00:26:17.040 | rather than shifted to the right. So this is fine. So in this case, the answer is off by N tokens.
00:26:28.640 | And so we need to update our answer, start value, and also add our answer end value.
00:26:45.840 | So start index minus N and we also have the end.
00:26:59.040 | So that's great. We can take that and we can apply it to our train and validation sets.
00:27:10.240 | So all we do here is call the function and we just see train answers and train context.
00:27:22.000 | And of course we can just copy this and do the same for our validation set.
00:27:38.960 | Okay. Perfect. So now if we have a quick look, we should be able to see that we have
00:27:45.280 | a few of these ending points as well.
00:27:51.440 | Okay. So I think that looks pretty good.
00:27:57.680 | And that means we can move on to actually encoding our text.
00:28:02.640 | To tokenize or encode our text, this is where we bring in a tokenizer. So we need to import
00:28:18.960 | the transformers library for this. And from transformers, we are going to import the distilbert.
00:28:26.640 | So distilbert is a smaller version of BERT, which is just going to run a bit quicker,
00:28:32.960 | but it will take a very long time. And we're going to import the fast version of this tokenizer
00:28:40.560 | because this allows us to more easily adjust our character and then start locations to token
00:28:49.040 | end and start locations later on. So first we need to actually initialize our tokenizer.
00:28:55.520 | Which is super easy. All we're doing is loading it from a pre-trained model.
00:29:16.400 | Okay. And then all we do to create our encodings is to call the tokenizer.
00:29:32.160 | So we'll do the training set first,
00:29:34.640 | which is called tokenizer. And in here, we include our training context.
00:29:43.600 | And the training questions. So what this will do
00:29:52.080 | is actually merge these two strings together. So what we will have is our context,
00:30:03.200 | and then there will be a separator token followed by the question. And this will be fed
00:30:08.240 | into distilbert during training. I just want to add padding there as well.
00:30:16.720 | And then we'll copy this and do the same for our validation set.
00:30:37.120 | Okay. And this will convert our data into encoding objects.
00:30:41.520 | So what we can do here is print out different parts that we have within our encodings.
00:30:52.720 | So in here, so you have the input IDs. So let's access that.
00:31:02.800 | And you'll find in here, we have a big list of all of our samples. So check that we have 130K.
00:31:11.680 | And let's open one of those. Okay. And we have these token IDs, and this is what Bert will be
00:31:20.720 | reading. Now, if we want to have a look at what this actually is in sort of human readable
00:31:27.920 | language, we can use the tokenizer to just decode it for us.
00:31:32.640 | Okay. And this is what we're feeding in. So we have a couple of these special tokens. This just
00:31:45.040 | means it's the start of a sequence. And in here, we have a process form of our original context.
00:31:55.120 | Now, you'll find that the context actually ends here. And like I said before, we have this
00:32:00.400 | separator token. And then after that, we have our actual question. And this is what is being fed
00:32:08.720 | into Bert, but obviously the token ID version. So it's just good to be aware of what is actually
00:32:15.920 | being fed in and what we're actually using here. But this is a format that Bert is expecting.
00:32:22.400 | And then after that, we have another separator token followed by all of our padding tokens,
00:32:26.560 | because Bert is going to be expecting 512 tokens to be fed in for every one sample.
00:32:34.800 | So we just need to fill that space essentially. So that's all that is doing.
00:32:39.200 | So let's remove those and we can continue. So the next thing we need to add to our encodings
00:32:50.640 | is the start and end positions, because at the moment, we just don't have them in there.
00:32:57.200 | So to do that, we need to add an additional bit of logic. We use this character to token method.
00:33:06.960 | So if we just take out one of these,
00:33:15.360 | let's take the first one. OK, we have this.
00:33:19.360 | And what we can do is actually modify this to use the character to token method,
00:33:29.440 | remove the input IDs, because we just need to pass it the index of whichever encoding
00:33:36.640 | we are wanting to modify or get the start and end position of. And in here, all we're doing is
00:33:45.040 | converting from the character that we have found a position for to the token that we want to find
00:33:51.280 | a position for. And what we need to add is train answers. We have our position again, because the
00:34:01.360 | answers and encodings, the context in question needs to match up to the answer, of course,
00:34:07.120 | that we're asking about. And we do answers start. So here, we're just feeding in the position of the
00:34:16.640 | character. And this is answer. OK. So feeding in position of the character,
00:34:23.920 | and we're expecting to return the position of the token, which is position 64.
00:34:31.760 | So all we need to do now is do this for both of those. So for the start position and end position.
00:34:38.880 | So here we should get a different value.
00:34:46.640 | OK, but this is one of the limitations of this. Sometimes this is going to return nothing. As
00:34:57.680 | you can see, it's not returning anything here. And that is because sometimes it is actually
00:35:03.760 | returning the space. And when it looks at the space and the tokenizer see that, and they say,
00:35:10.880 | OK, that's nothing. We're not concerned about spaces. And it returns this non-value that you
00:35:15.520 | can see here. So this is something that we need to consider and build in some added logic for.
00:35:25.200 | So to do that, again, we're going to use a function to contain all of this.
00:35:30.880 | And call it addTokenPositions.
00:35:36.880 | Yeah, we'll have our encodings and our answers. And then we just modify this code. So we have the
00:35:47.200 | encodings. We have the answers. And because we're collecting all of the token positions,
00:35:56.800 | we also need to initialize a list to contain those. So we do startPositions, emptyList, and endPositions.
00:36:12.320 | And now we just want to loop through every single answer and encoding that we have.
00:36:17.600 | Like so. And here we have our startPositions. So we need to append that to our startPositionsList.
00:36:38.800 | Then we just do the same for our endPositions, which is here.
00:36:43.040 | Now, here we can deal with this problem that we had. So if we find that the endPositions,
00:36:58.400 | the most recent one, so the negative one index, is none, that means it wasn't found. And it means
00:37:07.040 | there is a space. So what we do is we change it to instead use the -1 version. And all this needs
00:37:16.000 | to do is update the endPositions here. OK, that's great. But in some cases, this also happens with
00:37:26.080 | the startPosition. But that is for a different reason. The reason it will occasionally happen
00:37:31.680 | with startPosition is when the passage of data that we're adding in here, so you saw before we
00:37:39.040 | had the context, a separated token, and then the question, sometimes the context passage is
00:37:46.160 | truncated in order to fit in the question. So some of it will be cut off. And in that case,
00:37:54.400 | we do have a bit of a problem. But we still need to just allow our code to run without any problems.
00:38:01.520 | So what we do is we just modify the startPositions again,
00:38:08.160 | just like we did with the endPositions. Obviously, only if it's none. And we just set it
00:38:16.320 | to be equal to the maximum length that has been defined by the tokenizer.
00:38:26.240 | And it's as simple as that. Now, the only final thing we need to do, which is because we're using
00:38:33.760 | the encodings, is actually update those encodings to include this data. Because as of yet, we haven't
00:38:41.280 | added that back in. So to do that, we can use this quite handy update method.
00:38:50.640 | And just add in our data as a dictionary. So we have the startPositions,
00:38:56.240 | startPositions. And then we also have our endPositions.
00:39:06.240 | And then, again, we just need to apply this to our training.
00:39:18.400 | We just need to apply this to our training and validation sets. And let's just modify that.
00:39:26.080 | Let's add the train encodings here and the train answers.
00:39:47.520 | And do that again, the validation set.
00:39:54.480 | So now let's take a look at our encodings.
00:40:01.680 | And here we can see, great, now we have those startPositions and endPositions.
00:40:09.600 | We can even have a quick look at what they look like.
00:40:16.000 | And what we've done is actually not included the index here. So we're just taking it for
00:40:23.440 | the very first item every single time. So let's just update that.
00:40:28.160 | So obviously, that won't get us very far.
00:40:36.320 | And just update that as well.
00:40:39.280 | And now this should look a little bit better. So it's lucky we checked.
00:40:47.280 | Okay, so our data, our training, and our validation sets are now up to date.
00:41:02.560 | Okay, so our data at the moment is in the right format. We just need to use it to create a
00:41:12.320 | PyTorch dataset object. So to do that, obviously, we need to import PyTorch.
00:41:22.960 | And we define that dataset using a class. And just pass in the TorchUtilsDataDataset.
00:41:35.520 | We need to initialize that like so.
00:41:47.840 | And this is coming from the Houdinface Transformers documentation. Don't take credit for this.
00:42:15.280 | And we essentially need to do this so that we can load in our data using the
00:42:19.680 | PyTorch data loader later on, which makes things incredibly easy.
00:42:44.160 | And then we just have one more function here, one method.
00:43:04.560 | Okay, and we just need to return and also this as well. That should be okay.
00:43:12.640 | So we apply this to our datasets to create dataset objects.
00:43:28.400 | We have our encodings and then the same again for the validation set.
00:43:33.280 | Okay, so that is our data almost fully prepared. All we do now is load it into a data loader
00:43:47.760 | object. But this is everything on the data side done, which is great because I know it does
00:43:55.520 | take some time and I know it's not the most interesting part of it, but it's just something
00:44:00.160 | that we need to do and need to understand what we're doing as well. So now we get to
00:44:09.760 | the more interesting bit. So we'll just add the imports in here.
00:44:22.560 | So we need our data loader.
00:44:26.800 | We're going to import the atom optimizer with weighted decay, which is pretty commonly used
00:44:41.680 | for transformer models when you are fine tuning. Because transformer models are generally very
00:44:48.880 | large models and they can overfit very easily. So this atom optimizer with weighted decay
00:44:56.960 | essentially just reduces the chances of that happening,
00:45:01.520 | which is supposed to be very useful and quite important. So obviously we're going to use that.
00:45:14.160 | And then final bit is TQDM.
00:45:18.160 | So TQDM is a progress bar that we are going to be using so that we can actually see
00:45:30.400 | the progress of our training. Otherwise, we're just going to be sat there for
00:45:35.440 | probably quite a long time, not knowing what is actually happening. And trust me,
00:45:40.800 | it won't take long before you start questioning whether anything is happening, because it takes
00:45:45.200 | a long time to train these models. So they are our imports and I'm being stupid again here.
00:45:54.400 | That's from, did that twice. Okay, so that's all good. So now we just need to do a few
00:46:01.520 | little bits for the setup. So we need to tell PyTorch whether we're using CPU or GPU.
00:46:10.400 | In my case, it will be a GPU. If you're using CPU, this is going to take you a very long time
00:46:16.720 | to train. And it's still going to take you a long time on GPU. So just be aware of that.
00:46:22.640 | But what we're going to do here is say device.
00:46:32.960 | It's CUDA, if CUDA is available.
00:46:35.680 | Otherwise, we are going to use the CPU. And good luck if that is what you're doing.
00:46:52.080 | So once we've defined the device, we want to move our model over to it.
00:47:02.720 | So we just model.toDevice. So this .to method is essentially a way of transferring data between
00:47:11.200 | different hardware components, so your CPU or GPU. It's quite useful. And then we want to
00:47:17.840 | activate our model for training. So there's two things we have here. So we have .train and eval.
00:47:27.440 | So when we're in train mode, there's a lot of different layers and different parts of your
00:47:32.880 | model that will behave differently depending on whether you are using the model for training
00:47:38.320 | or you're using it for inference, which is predictions. So we just need to make sure our
00:47:44.000 | model is in the right mode for whatever we're doing. And later on, we'll switch it to eval
00:47:50.160 | to make some predictions. So that's almost everything. So we just need to initialize
00:47:57.520 | the optimizer. And here, we're using the weighted decay atom optimizer.
00:48:03.360 | We need to pass in our model parameters and also give it a learning rate. And we're just going to
00:48:15.840 | use this value here. All of these are the recommended parameters for what we are doing here.
00:48:23.840 | So the one thing that I have somehow missed is defining the, actually initializing the model.
00:48:35.280 | So let's just add that in. And all we're doing here is loading, again, a pre-trained one.
00:48:42.560 | So like we did before when we were loading the Transformers tokenizer.
00:48:47.280 | This time, it's for question answering. So this distilbert of question answering is a distilbert
00:49:03.360 | model with a question and answering head added onto the end of it. So essentially with Transformers,
00:49:10.720 | you have all these different heads that you add on. And they will do different things depending on
00:49:16.000 | what head it has on there. So let's initialize that from pre-trained.
00:49:26.720 | And we're using the same one we use up here, which is distilbert base uncased.
00:49:37.680 | And sometimes you will need to download that. Fortunately, I don't need to as I've already
00:49:43.600 | done that. But this can also take a little bit of time, not too long though. And you get a nice
00:49:48.160 | progress bar, hopefully, as well. Okay, so now that is all settled, we can initialize our data
00:49:57.120 | loader. So all we're doing here is using the PyTorch data loader object. And we just pass in
00:50:15.520 | our training data set, the batch size. So how many we want to train on at once in parallel before
00:50:23.120 | updating the model weights, which will be 16. And we also would like to shuffle the data because we
00:50:31.200 | don't want to train the model on a single batch and it just learned about Beyonce. And then the
00:50:35.680 | next one is learning about Chopin. And it will keep switching from batch to batch. So let's
00:50:42.160 | move on. It's learning about Chopin. And it will keep switching between those. But never
00:50:48.000 | within a single batch, having a good mix of different things to learn about.
00:50:52.640 | So it is data sets. Seems a bit of a weird name to me, so I'm just going to change it.
00:51:03.760 | And I also can't spell. There we go. And that is everything we can actually begin
00:51:22.000 | our training loop. So we're going to go for three pups.
00:51:33.200 | And what we want to start with here is a loop object. So we do this mainly because we're using
00:51:42.560 | TQDM as a progress bar. Otherwise, we wouldn't need to do this. There'd be no point in doing it.
00:51:49.680 | And all this is doing is kind of like pre-initializing our loop that we are going
00:51:56.960 | to go through. So we're going to obviously loop through every batch within the train loader. So we
00:52:02.880 | just add that in here. And then there's this other parameter, which I don't know if we...
00:52:10.160 | So let's leave it. But essentially, you can add leave equals true in order to leave your progress
00:52:17.520 | bar in the same place with every epoch. Whereas at the moment, with every epoch, what it will do is
00:52:23.600 | create a new progress bar. We are going to create a new progress bar. But if you don't want to do
00:52:28.880 | that and you want it to just stay in the same place, you add leave equals true into this function
00:52:36.000 | here. So after that, we need to go through each batch within our loop. And the first thing that
00:52:46.320 | we need to do is set all of our calculated gradients to zero. So with every iteration
00:52:56.320 | that we go through here or every batch, at the end of it, we are going to calculate gradients,
00:53:00.720 | which tells the model in which direction to change the weights within the model.
00:53:07.200 | And obviously, when we go into the next iteration, we don't want those gradients to still be there.
00:53:14.320 | So all we're doing here is re-initializing those gradients at the start of every
00:53:19.920 | loop. So we have a fresh set of gradients to work with every time.
00:53:25.200 | And here, we just want to pull in our data.
00:53:28.320 | So this is everything that is relevant that we're going to be feeding into the training process.
00:53:34.880 | So we have everything within our batch. And then in here, we have
00:53:42.400 | all of our different items. So we can actually see-- go here.
00:53:50.960 | We want to add in all of these. And we also want to move them across to the
00:54:04.000 | GPU, in my case, or whatever device you are working on.
00:54:12.080 | And we'll do that for the attention mask start positions and end positions.
00:54:37.120 | So these start and end positions are essentially the labels,
00:54:41.200 | they're the targets that we want our model to optimize for.
00:54:44.400 | And the input IDs and attention masks are the inputs.
00:55:06.880 | So now we have those defined. We just need to feed them into our model for training.
00:55:11.600 | And we will output the results of that training batch to the outputs variable.
00:55:19.280 | So our model, input IDs, need the attention mask.
00:55:36.560 | And we also want our start positions and end positions.
00:55:42.880 | Now, from our training batch, we want to extract the loss.
00:56:02.480 | And then we want to calculate loss for every parameter.
00:56:06.160 | And this is for our gradient update.
00:56:13.120 | And then we use the step method here to actually update those gradient updates.
00:56:26.080 | And then we use the step method here to actually update those gradients.
00:56:30.400 | And then this final little bit here is purely for us to see, this is our progress bar.
00:56:38.640 | So we call it a loop. We set the description, which is going to be our epoch.
00:56:44.480 | And then it would probably be quite useful to also see the loss in there as well.
00:56:48.720 | We will set that as a postfix. So it will appear after the progress bar.
00:57:06.960 | Okay, and that should be everything. Okay, so that looks pretty good. We have our model training.
00:57:14.160 | And as I said, this will take a little bit of time. So I will let that run.
00:57:21.280 | And then we will go back to the model and we will run the training batch.
00:57:31.760 | So I will let that run.
00:57:33.920 | Okay, so we have this non-type error here. And this is because within our mPositions,
00:57:48.240 | we will normally expect integers, but we're also getting some non-values because the
00:57:52.880 | code that we used earlier, where we're checking if mPosition is non, essentially wasn't good enough.
00:57:59.760 | So as a fix for that, we'll just go back and we'll add like a while loop, which will keep
00:58:06.240 | checking if it's non. And every time it is non, reduce the value that we are seeing by one.
00:58:12.480 | So go back up here, and this is where the problem is coming from.
00:58:21.040 | So we're just going to change this to be a while loop.
00:58:29.520 | And just initialize essentially a counter here.
00:58:32.800 | And we'll use this as our go back value. And every time the mPosition is still non,
00:58:45.600 | we'll just add one to go back. And this should work.
00:58:51.680 | So we need to remember to rerun anything we need to rerun. Yeah.
00:59:18.720 | Okay. And that looks like it solved the issue. So great. We can just leave that
00:59:25.920 | training for a little while and I will see you when it's done.
00:59:29.920 | Okay. So the model's finished and we'll go ahead and just save it.
00:59:38.240 | So obviously we'll need to do that whenever actually doing this on any other projects.
00:59:47.120 | So I'm just going to call it the Silbert custom.
00:59:49.120 | And it's super easy to save. We just do save pre-trained and the model path.
01:00:00.160 | Now, as well as this, we might also want to save the tokenizer so we have everything in one place.
01:00:10.080 | So to do that, we also just use tokenizer and save pre-trained again.
01:00:16.160 | Okay. So if we go back into our folder here, see models, and we have this Silbert custom. And then
01:00:29.680 | in here we have all of the files we need to build our PyTorch model. It's a little bit different if
01:00:37.760 | we're using TensorFlow, but the actual saving process is practically the same. So now we've
01:00:45.680 | finished training. We want to switch it out of the training mode. So we use a model eval.
01:00:51.840 | And we just get all this information about our model as well. We don't actually need any of that.
01:01:00.480 | And just like before, we want to create a data loader. So for that, I'm just going to call it
01:01:08.240 | ValLoader. And it's exactly the same code as before. In fact, it's probably better if we just
01:01:13.760 | copy and paste some of this. At least the loop. So what we're going to do here is
01:01:24.960 | take the same loop and apply it as a validation run with our validation data.
01:01:31.920 | Just paste up there. We'll initialize this data loader. This time, of course, with the validation
01:01:37.600 | set. We'll start with the same batch size.
01:01:47.440 | Now, this time, we do want to keep a log of accuracy. So we will keep that there. And we
01:01:55.680 | also don't need to run multiple epochs because we're not training this time. We're just running
01:02:00.640 | through all of the batches within our loop of validation data. So this is now a validation
01:02:07.760 | loader. And we just loop through each of those batches. So we don't need to do anything with
01:02:13.920 | the gradients here. And because we're not doing anything with the gradients, we actually
01:02:21.360 | have this in to stop PyTorch from calculating any gradients. Because this will obviously
01:02:29.440 | save us a bit of time when we're processing all of this. And we put those in there. The
01:02:37.760 | outputs, we do still want this. But of course, we don't need to be putting in the start and
01:02:44.560 | end positions. So we can remove those. And this time, we want to pull out the start prediction
01:02:53.360 | and end prediction. So if we have a look at what our outputs look like before, you see
01:03:02.000 | we have this model output. And within here, we have a few different tensors which each
01:03:09.680 | have an accessible name. So the ones that we care about are start logits. And that will
01:03:21.520 | give us the logits for our start position, which is essentially like a set of predictions
01:03:30.880 | where the highest value within that vector represents the token ID. So we can do that
01:03:39.200 | for both. And you'll see we get these tensors. Now, we only want the largest value in each
01:03:50.640 | one of these vectors here, because that will give us the input ID. So to get that, we use
01:03:57.800 | the argmax function. And if we just use it by itself, that will give us the maximum index
01:04:09.480 | within the whole thing. But we don't want that. We want one for every single vector
01:04:15.000 | or every row. And to do that, we just set dim equal to 1. And there you go. We get a
01:04:23.480 | full batch of outputs. So these are our starting positions. And then we also want to do the
01:04:30.760 | same for our ending positions. So we just change start to end. So it's pretty easy.
01:04:43.500 | Now obviously, we want to be doing this within our loop, because this is only doing one batch.
01:04:50.860 | And we need to do this for every single batch. So we're just going to assign them to a variable.
01:05:09.260 | And there we have our predictions. And all we need to do now is check for an exact match.
01:05:17.680 | So what I mean by exact match is we want to see whether the start positions here, which
01:05:24.320 | we can rename to the true values, whether these are equal to the predicted values down
01:05:33.260 | here. And to calculate that, so let me just run this so we have one batch. That shouldn't
01:05:46.060 | take too long to process, and we can just write the code out. So to check this, we just
01:05:51.700 | use the double equal syntax here. And this will just check for matches between two arrays.
01:06:03.460 | So we have the startPredictions and the startTrueValues. So we'll check for those.
01:06:22.340 | So if we just have a look at what we have here, we get this array of true or false.
01:06:28.540 | So these ones don't look particularly good. But that's fine. We just want to calculate
01:06:33.020 | the accuracy here. So let's take the sum, and we also want to divide that by the length.
01:06:52.340 | Okay so that will give us our accuracy within the tensor. And we just take it out using
01:06:57.660 | the item method. But we also just need to include brackets around this, because at the
01:07:03.540 | moment we're trying to take item of the length value. Okay and then that gives us our very
01:07:10.660 | poor accuracy on this final batch. So we can take that, and within here we want to append
01:07:18.060 | that to our accuracy list. And then we also want to do that for the endPrediction as well.
01:07:45.700 | And we'll just let that run through, and then we can calculate our accuracy from the end
01:07:50.580 | of that. And then we can have a quick look at our accuracy here. And we can see fortunately
01:07:59.020 | it's not as bad as it first seemed. So we're getting a lot of 93, 100%, 81%. That's generally
01:08:09.300 | pretty good. So of course if we want to get the overall accuracy, all we do is sum that
01:08:22.420 | and divide by the length. And we get 63.6% for an exact match accuracy. So what I mean
01:08:35.180 | by exact match is, say if we take a look at a few of these that do not match. So we have
01:08:43.900 | a 75% match on the fourth batch. Although that won't be particularly useful, because
01:08:52.940 | we can't see that batch right now. So let's just take the last batch, because we have
01:08:56.820 | these values here. Now if we look at what startTrue is, we get these values. Then if
01:09:06.600 | we look at startPred, we get this. So none of these match, but a couple of them do get
01:09:18.140 | pretty close. So these final four, all of these count as 0% on the exact match. But
01:09:26.560 | in reality, if you look at what we predicted, for every single one of them, it's predicting
01:09:32.220 | just one token before. So it's getting quite close, but it's not an exact match. So it
01:09:36.380 | scores 0. So when you consider that with our 63.6% accuracy here, that means that this
01:09:45.940 | model is actually probably doing pretty well. It's not perfect, of course, but it's doing
01:09:50.860 | pretty well. So overall, that's everything for this video. We've gone all the way through
01:09:58.540 | this. If you do want the code for this, I'm going to make sure I keep a link to it in
01:10:03.580 | the description. So check that out if you just want to copy this across. But for now,
01:10:10.180 | that's everything. So thank you very much for watching, and I will see you again soon.