back to indexLesson 13: Deep Learning Foundations to Stable Diffusion
Chapters
0:0 Introduction
2:54 Linear models & rectified lines (ReLU) diagram
10:15 Multi Layer Perceptron (MLP) from scratch
18:15 Loss function from scratch - Mean Squared Error (MSE)
23:14 Gradients and backpropagation diagram
31:30 Matrix calculus resources
33:27 Gradients and backpropagation code
38:15 Chain rule visualized + how it applies
49:8 Using Python’s built in debugger
60:47 Refactoring the code
00:00:00.000 |
Hi everybody, and welcome to lesson 13, where we're going to start talking about back propagation. 00:00:10.120 |
Before we do, I'll just mention that there was some great success amongst the folks in 00:00:13.880 |
the class during the week on working with flexing their tensor manipulation muscles. 00:00:26.920 |
So far the fastest main shift algorithm, which has a similar accuracy to the one I displayed, 00:00:36.160 |
is one that actually randomly chooses data points as subset. 00:00:42.000 |
And I actually think that's a great approach. 00:00:43.520 |
Very often random sampling and random projections are two excellent ways of speeding up algorithms. 00:00:55.060 |
So it'd be interesting to see if anybody during the rest of the course comes up with anything 00:01:04.920 |
Also been seeing some good Einstein summation examples and implementations and continuing 00:01:10.440 |
to see lots of good diff edit implementations. 00:01:16.480 |
So congratulations to all the students and I hope those of you following along the videos 00:01:20.840 |
in the MOOC will be working on the same homework as well and sharing your results on the fast 00:01:31.520 |
So now we're going to take a look at notebook number three in the normal repo, course 22p1 00:01:50.560 |
And we're going to be looking at the forward and backward passes of a simple multi-layer 00:02:02.080 |
The initial stuff up here is just importing things and just settings and stuff that just 00:02:08.560 |
copying and pasting some stuff from previous notebooks around paths and parameters and 00:02:16.520 |
So we'll often be kind of copying and pasting stuff from one notebook to another's kind 00:02:24.720 |
And I'm also loading in our data for MNIST as tensors. 00:02:33.480 |
So we, to start with, need to create the basic architecture of our neural network. 00:02:42.680 |
And I did mention at the start of the course that we will briefly review everything that 00:02:48.840 |
So we should briefly review what basic neural networks are and why they are what they are. 00:02:55.840 |
So to start with, let's consider a linear model. 00:03:17.440 |
So let's start by considering a linear model of, well, let's take the most simple example 00:03:29.040 |
possible, which is we're going to pick a single pixel from our MNIST pictures. 00:03:36.720 |
And so that will be our X. And for our Y values, then we'll have some loss function of how 00:03:51.120 |
good is this model, sorry, not some loss function. 00:04:00.200 |
For our Y value, we're going to be looking at how likely is it that this is, say, the 00:04:07.280 |
number three based on the value of this one pixel. 00:04:12.880 |
So the pixel, its value will be X and the probability of being the number three, we'll 00:04:20.280 |
call Y. And if we just have a linear model, then it's going to look like this. 00:04:29.840 |
And so in this case, it's saying that the brighter this pixel is, the more likely it 00:04:44.080 |
The first one, obviously, is that as a linear model, it's very limiting because maybe we 00:04:51.360 |
actually are trying to draw something that looks more like this. 00:05:02.120 |
Well, there's actually a neat trick we can use to do that. 00:05:06.920 |
What we could do is, well, let's first talk about something we can't do. 00:05:14.440 |
Something we can't do is to add a bunch of additional lines. 00:05:20.000 |
So consider what happens if we say, OK, well, let's add a few different lines. 00:05:36.320 |
Well, the answer is, of course, that the sum of the two lines will itself be a line. 00:05:40.280 |
So it's not going to help us at all match the actual curve that we want. 00:05:48.240 |
Instead, we could create a line like this that actually we could create this line. 00:06:13.240 |
And now consider what happens if we add this original line with this new-- well, it's not 00:06:21.000 |
So what we would get is this-- everything to the left of this point is going to not be 00:06:34.800 |
changed if I add these two lines together, because this is zero all the way. 00:06:39.520 |
And everything to the right of it is going to be reduced. 00:06:44.280 |
So we might end up with instead-- so this would all disappear here. 00:06:52.520 |
And instead, we would end up with something like this. 00:07:03.340 |
We could add an additional line that looks a bit like that. 00:07:09.280 |
So it would go-- but this time, it could go even further out here, and it could be something 00:07:23.280 |
Well, again, at the point underneath here, it's always zero, so it won't do anything 00:07:30.440 |
But after that, it's going to make it even more negatively sloped. 00:07:36.360 |
And if you can see, using this approach, we could add up lots of these rectified lines, 00:07:46.200 |
And we could create any shape we want with enough of them. 00:07:50.680 |
And these lines are very easy to create, because actually, all we need to do is to create just 00:07:58.840 |
a regular line, which we can move up, down, left, right, change its angle, whatever. 00:08:13.180 |
And then just say, if it's greater than zero, truncate it to zero. 00:08:18.080 |
Or we could do the opposite for a line going the opposite direction. 00:08:20.480 |
If it's less than zero, we could say, truncate it to zero. 00:08:24.520 |
And that would get rid of, as we want, this whole section here, and make it flat. 00:08:38.560 |
And so we can sum up a bunch of these together to basically match any arbitrary curve. 00:08:52.840 |
Well, the other thing we should mention, of course, is that we're going to have not just 00:08:56.960 |
one pixel, but we're going to have lots of pixels. 00:09:02.680 |
So to start with the kind of most slightly-- the only slightly less simple approach, we 00:09:18.160 |
could have something where we've got pixel number one and pixel number two. 00:09:22.800 |
We're looking at two different pixels to see how likely they are to be the number three. 00:09:28.880 |
And so that would allow us to draw more complex shapes that have some kind of surface between 00:09:49.040 |
OK, and then we can do exactly the same thing is to create these surfaces, we can add up 00:10:01.400 |
But now they're going to be kind of rectified planes. 00:10:08.360 |
We're going to be adding together a bunch of lines, each one of which is truncated at 00:10:16.680 |
And so to do that, we'll start out by just defining a few variables. 00:10:22.520 |
So n is the number of training examples, m is the number of pixels, c is the number of 00:10:36.480 |
possible values of our digits, and so here they are, 50,000 samples, 784 pixels, and 00:10:46.320 |
OK, so what we do is we basically decide ahead of time how many of these line segment thingies 00:10:59.640 |
And so the number that we create in a layer is called the number of hidden nodes or activations. 00:11:07.100 |
So let's just arbitrarily decide on creating 50 of those. 00:11:11.520 |
So in order to create lots of lines, which we're going to truncate at zero, we can do 00:11:21.520 |
So with a matrix multiplication, we're going to have something where we've got 50,000 rows 00:11:50.600 |
And we're going to multiply that by something with 784 rows and 10 columns. 00:12:07.600 |
Well, that's because if we take this very first line of this first vector here, row one, 00:12:13.560 |
have 784 values, they're the pixel values of the first image. 00:12:20.280 |
And so each of those 784 values will be multiplied by each of these 784 values in the first column, 00:12:31.480 |
And that's going to give us a number in our output. 00:12:35.720 |
So our output is going to be 50,000 images by 10. 00:12:50.640 |
And so that result, we'll multiply those together and we'll add them up. 00:12:54.520 |
And that result is going to end up over here in this first cell. 00:12:59.520 |
And so each of these columns is going to eventually represent, if this is a linear model, in this 00:13:07.120 |
case, this is just the example of doing a linear model, each of these cells is going 00:13:12.720 |
So this first column will be the probability of being a zero. 00:13:15.480 |
And the second column will be the probability of one. 00:13:17.640 |
The third column will be the probability of being a two and so forth. 00:13:21.280 |
So that's why we're going to have these 10 columns, each one allowing us to weight the 00:13:28.040 |
Now, of course, we're going to do something a bit more tricky than that, which is actually 00:13:31.760 |
we're going to have a 784 by 50 input going into a 784 by 50 output to create the 50 hidden 00:13:41.280 |
Then we're going to truncate those at zero and then multiply that by a 50 by 10 to create 00:13:51.320 |
So the way SGD works is we start with just this is our weight matrix here. 00:14:06.160 |
The way it works is that this weight matrix is initially filled with random values. 00:14:13.000 |
Also, of course, this contains our pixel values, this contains the results. 00:14:26.100 |
It's going to have, as we discussed, 50,000 by 50 random values. 00:14:42.060 |
So we call those the biases, the things we add. 00:14:48.560 |
So we'll need one for each output, so 50 of those. 00:14:52.720 |
And then as we just mentioned, layer two will be a matrix that goes from 50 hidden. 00:15:00.920 |
And now I'm going to do something totally cheating to simplify some of the calculations 00:15:11.400 |
That's because I'm not going to use cross entropy just yet. 00:15:18.380 |
So actually, I'm going to create one output, which will literally just be what number do 00:15:40.880 |
And so then we're going to compare those to the actual-- so these will be our y-predictors. 00:15:45.480 |
We normally use a little hat for that, and we're going to compare that to our actuals. 00:15:52.680 |
And yeah, in this very hacky approach, let's say we predict over here the number 9, and 00:15:59.880 |
And we'll compare those together using MSE, which will be a stupid way to do it, because 00:16:07.120 |
it's saying that 9 is further away from being 2 than 2-- 9 is further away from 2 than it 00:16:12.640 |
is from 4 in terms of how correct it is, which is not what we want at all. 00:16:17.480 |
But this is what we're going to do just to simplify our starting point. 00:16:20.560 |
So that's why we're going to have a single output for this weight matrix and a single 00:16:28.720 |
So a linear-- let's create a function for putting x through a linear layer with these 00:16:43.560 |
So if we multiply our x-- oh, we're doing x valid this time. 00:16:48.680 |
So just to clarify, x valid is 10,000 by 784. 00:16:56.520 |
So if we put x valid through our weights and biases with a linear layer, we end up with 00:17:01.680 |
a 10,000 by 50, so 10,050 long hidden activations. 00:17:09.000 |
They're not quite ready yet, because we have to put them through ReLU. 00:17:18.720 |
And so here's what it looks like when we go through the linear layer and then the ReLU. 00:17:23.420 |
And you can see here's a tensor with a bunch of things, some of which are 0 or they're 00:17:28.480 |
So that's the result of this matrix multiplication. 00:17:32.200 |
OK, so to create our basic MLP multi-layer perceptron from scratch, we will take our 00:17:45.440 |
We will create our first layer's output with a linear. 00:17:50.240 |
And then that will go through the second linear. 00:17:51.820 |
So the first one uses the w1b one, these ones. 00:18:05.560 |
And as we hoped, when we pass in the validation set, we get back 10,000 digits, so 10,000 00:18:14.840 |
OK, so let's use our ridiculous loss function of MSC. 00:18:26.800 |
And our x valid-- sorry, our y valid is just a vector. 00:18:32.160 |
Now what's going to happen if I do res minus y valid? 00:18:39.240 |
So before you continue in the video, have a think about that. 00:18:43.400 |
What's going to happen if I do res minus y valid by thinking about the NumPy broadcasting 00:18:59.120 |
We've ended up with a 10,000 by 10,000 matrix. 00:19:06.520 |
Now we would expect an MSC to just contain 1,000 points. 00:19:17.800 |
The reason it happened is because we have to start out at the last dimension and go 00:19:28.680 |
And we compare the 10,000 to the 1 and say, are they compatible? 00:19:35.240 |
And the answer is-- that's right, Alexei in the chat's got it right-- broadcasting rules. 00:19:39.620 |
So the answer is that this 1 will be broadcast over these 10,000. 00:19:44.120 |
So this pair here will give us 10,000 outputs. 00:19:58.880 |
Now, if you remember the rules, it inserts a unit axis for us. 00:20:04.640 |
So that means each of the 10,000 outputs from here will end up being broadcast across the 00:20:14.080 |
So that means that will end up-- for each of those 10,000, we'll have another 10,000. 00:20:17.560 |
So we'll end up with a 10,000 by 10,000 output. 00:20:26.840 |
Well, what we really would want would we want this to be 10,000 comma 1 here. 00:20:35.000 |
If that was 10,000 comma 1, then we'd compare these two right to left. 00:20:43.720 |
And there's nothing to broadcast because they're the same. 00:20:46.160 |
And then we'll go to the next one, 10,000 to 10,000. 00:20:55.680 |
Or alternatively, we could remove this dimension. 00:21:01.960 |
We're then going to add right to left, compatible 10,000. 00:21:16.760 |
So in this case, I got rid of the trailing comma 1. 00:21:21.720 |
One is just to say, OK, grab every row and the zeroth column of res. 00:21:27.720 |
And that's going to turn it from a 10,000 by 1 into a 10,000. 00:21:35.280 |
Now, dot squeeze removes all trailing unit vectors and possibly also prefix unit vectors. 00:21:48.560 |
So let's say res none comma colon comma none. 00:22:07.440 |
OK, so if I go Q dot squeeze dot shape, OK, so all the unit vectors get removed. 00:22:24.840 |
Sorry, all the unit dimensions get removed, I should say. 00:22:29.000 |
OK, so now that we've got a way to remove that axis that we didn't want, we can use it. 00:22:35.800 |
And if we do that subtraction, now we get 10,000 just like we wanted. 00:22:40.720 |
So now let's get our training and validation wise. 00:22:45.920 |
We'll turn them into floats because we're using MSE. 00:22:50.200 |
So let's calculate our predictions for the training set, which is 50,000 by 1. 00:22:55.880 |
And so if we create an MSE function that just does what we just said we wanted. 00:23:00.160 |
So it does the subtraction and then squares it and then takes the mean, that's MSE. 00:23:08.080 |
So there we go, we now have a loss function being applied to our training set. 00:23:22.480 |
So as we briefly discussed last time, gradients are slopes. 00:23:33.760 |
And in fact, maybe it would even be easier to look at last time. 00:23:41.480 |
So this was last time's notebook. And so we saw how the gradient at this point is the slope here. 00:24:03.280 |
And so it's the, as we discussed, rise over run. 00:24:10.000 |
Now, so that means as we increase, in this case, time by one, the distance increases by how much? 00:24:35.160 |
The reason it's interesting is because let's consider our neural network. 00:24:42.960 |
Our neural network is some function that takes two things, two groups of things. 00:24:49.360 |
It contains a matrix of our inputs and it contains our weight matrix. 00:25:01.320 |
And we want to and let's assume we're also putting it through a loss function. 00:25:11.040 |
So let's say, well, I mean, I guess we can be explicit about that. 00:25:14.800 |
So we could say we then take the result of that and we put it through some loss function. 00:25:20.480 |
So these are the predictions and we compare it to our actual dependent variable. 00:25:53.280 |
So if we can get the derivative of the loss with respect to, let's say, one particular 00:26:11.680 |
Well, it's saying as I increase the weight by a little bit, what happens to the loss? 00:26:21.320 |
And if it says, oh, well, that would make the loss go down, then obviously I want to 00:26:29.300 |
And if it says, oh, it makes the loss go up, then obviously I want to do the opposite. 00:26:34.320 |
So the derivative of the loss with respect to the weights, each one of those tells us 00:26:43.700 |
And so to remind you, we then change each weight by that derivative of times a little 00:26:49.760 |
bit and subtract it from the original weights. 00:26:55.920 |
And we do that a bunch of times and that's called SGD. 00:27:05.280 |
Now there's something interesting going on here, which is that in this case, there's 00:27:16.340 |
And so the derivative is a single number at any point. 00:27:20.320 |
It's the speed in this case, the vehicle's going. 00:27:24.560 |
But consider a more complex function like say this one. 00:27:35.160 |
Now in this case, there's one output, but there's two inputs. 00:27:39.700 |
And so if we want to take the derivative of this function, then we actually need to say, 00:27:46.120 |
well, what happens if we increase X by a little bit? 00:27:49.080 |
And also what happens if we increase Y by a little bit? 00:27:55.680 |
And so in that case, the derivative is actually going to contain two numbers, right? 00:28:02.080 |
It's going to contain the derivative of Z with respect to Y. 00:28:08.640 |
And it's going to contain the derivative of Z with respect to X. 00:28:12.520 |
What happens if we change each of these two numbers? 00:28:14.640 |
So for example, these could be, as we discussed, two different weights in our neural network 00:28:33.840 |
So we don't normally write them all like that. 00:28:36.160 |
We would just say, use this little squiggly symbol to say the derivative of the loss across 00:28:43.160 |
all of them with respect to all of the weights. 00:28:48.240 |
OK, and that's just saying that there's a whole bunch of them. 00:28:56.560 |
OK, so it gets more complicated still, though, because think about what happens if, for example, 00:29:08.880 |
you're in the first layer where we've got a weight matrix that's going to end up giving 00:29:15.400 |
So for every image, we're going to have 784 inputs to our function, and we're going to 00:29:26.520 |
And so in that case, I can't even draw it, right? 00:29:31.880 |
Because like for every-- even if I had two inputs and two outputs, then as I increase 00:29:36.620 |
my first input, I'd actually need to say, how does that change both of the two outputs? 00:29:44.260 |
And as I change my second input, how does that change both of my two outputs? 00:29:50.320 |
So for the full thing, you actually are going to end up with a matrix of derivatives. 00:29:57.640 |
It basically says, for every input that you change by a little bit, how much does it change 00:30:11.980 |
So that's what we're going to be doing, is we're going to be calculating these derivatives, 00:30:18.480 |
but rather than being single numbers, they're going to actually contain matrices with a 00:30:24.480 |
row for every input and a column for every output. 00:30:28.480 |
And a single cell in that matrix will tell us, as I change this input by a little bit, 00:30:39.600 |
Now eventually, we will end up with a single number for every input. 00:30:49.280 |
And that's because our loss in the end is going to be a single number. 00:30:53.800 |
And this is like a requirement that you'll find when you try to use SGD, is that your 00:31:01.880 |
And so we generally get it by either doing the sum or a mean or something like that. 00:31:09.480 |
But as you'll see on the way there, we're going to have to be dealing with these matrix 00:31:18.600 |
So I just want to mention, as I might have said before, I can't even remember. 00:31:32.420 |
There is this paper that Terrence Parr and I wrote a while ago, which goes through all 00:31:42.400 |
And it basically assumes that you only know high school calculus, and if you don't check 00:31:50.400 |
out Khan Academy, but then it describes matrix calculus in those terms. 00:31:55.680 |
So it's going to explain to you exactly, and it works through lots and lots of examples. 00:32:04.440 |
So for example, as it mentions here, when you have this matrix of derivatives, we call 00:32:17.320 |
So there's all these words, it doesn't matter too much if you know them or not, but it's 00:32:22.540 |
convenient to be able to talk about the matrix of all of the derivatives if somebody just 00:32:31.920 |
It's a bit easier than saying the matrix of all of the derivatives, where all of the rows 00:32:36.520 |
are things that are all the inputs and all the columns are the outputs. 00:32:43.160 |
So yeah, if you want to really understand, get to a point where papers are easier to 00:32:50.040 |
read in particular, it's quite useful to know this notation and definitions of words. 00:32:59.600 |
You can certainly get away without it, it's just something to consider. 00:33:06.360 |
Okay, so we need to be able to calculate derivatives, at least of a single variable. 00:33:14.480 |
And I am not going to worry too much about that, a, because that is something you do 00:33:19.100 |
in high school math, and b, because your computer can do it for you. 00:33:25.360 |
And so you can do it symbolically, using something called SYMPY, which is really great. 00:33:31.480 |
So if you create two symbols called x and y, you can say please differentiate x squared 00:33:41.240 |
with respect to x, and if you do that, SYMPY will tell you the answer is 2x. 00:33:48.240 |
If you say differentiate 3x squared plus 9 with respect to x, SYMPY will tell you that 00:33:58.480 |
And a lot of you probably will have used Wolfram Alpha, that does something very similar. 00:34:05.760 |
I kind of quite like this because I can quickly do it inside my notebook and include it in 00:34:14.720 |
So basically, yeah, you can quickly calculate derivatives on a computer. 00:34:25.280 |
Having said that, I do want to talk about why the derivative of 3x squared plus 9 equals 00:34:31.680 |
6x, because that is going to be very important. 00:34:46.440 |
So we're going to start with the information that the derivative of a to the b with respect 00:35:10.260 |
So for example, the derivative of x squared with respect to x equals 2x. 00:35:17.140 |
So that's just something I'm hoping you'll remember from high school or refresh your 00:35:26.080 |
So what we could now do is we could rewrite this derivative as 3u plus 9. 00:35:50.920 |
The derivative of two things being added together is simply the sum of their derivatives. 00:36:03.000 |
Sorry, ba to the power of b minus 1 is what it should be, which would be 2x to the power 00:36:19.480 |
So we get the derivative of 3u is actually just-- well, it's going to be the derivative 00:36:32.480 |
Now the derivative of any constant with respect to a variable is 0. 00:36:36.800 |
Because if I change something, an input, it doesn't change the constant. 00:36:44.480 |
And so we're going to end up with dy/du equals something plus 0. 00:36:52.280 |
And the derivative of 3u with respect to u is just 3 because it's just a line. 00:37:02.880 |
Well, the cool thing is that dy/dx is actually just equal to dy/du du/dx. 00:37:19.080 |
But for now then, let's recognize we've got dy-- sorry, du/dx. 00:37:26.840 |
So we can now multiply these two bits together. 00:37:31.140 |
And we will end up with 2x times 3, which is 6x, which is what Simpai told us. 00:37:40.840 |
OK, this is something we need to know really well. 00:37:50.080 |
So to understand it intuitively, we're going to take a look at an interactive animation. 00:38:00.900 |
So I found this nice interactive animation on this page here, webspace.ship.edu/msreadow. 00:38:14.800 |
OK, and the idea here is that we've got a wheel spinning around. 00:38:21.040 |
And each time it spins around, this is x going up. 00:38:25.800 |
OK, so at the moment, there's some change in x dx over a period of time. 00:38:34.480 |
All right, now, this wheel is eight times bigger than this wheel. 00:38:43.440 |
So each time this goes around once, if we connect the two together, this wheel would 00:38:49.560 |
be going around four times faster because the difference between-- the multiple between 00:39:03.360 |
So now that this wheel has got twice as big a circumference as the u wheel, each time this 00:39:09.620 |
goes around once, this is going around two times. 00:39:15.920 |
So the change in u, each time x goes around once, the change in u will be two. 00:39:31.520 |
Now we could make this interesting by connecting this wheel to this wheel. 00:39:36.920 |
Now this wheel is twice as small as this wheel. 00:39:44.020 |
So now we can see that, again, each time this spins around once, this spins around twice 00:39:50.600 |
because this has twice the circumference of this. 00:39:57.560 |
Now that means every time this goes around once, this goes around twice. 00:40:02.440 |
Every time this one goes around once, this gun goes around twice. 00:40:05.760 |
So therefore, every time this one goes around once, this one goes around four times. 00:40:16.560 |
So you can see here how the two-- well, how the du dx has to be multiplied with the dy 00:40:28.440 |
So this is what's going on in the chain rule. 00:40:31.760 |
And this is what you want to be thinking about is this idea that you've got one function 00:40:41.480 |
And so you have to multiply the two impacts to get the impact of the x wheel on the y 00:40:51.600 |
I find this-- personally, I find this intuition quite useful. 00:41:00.240 |
Well, the reason we care about this is because we want to calculate the gradient of our MSE 00:41:19.520 |
And so our inputs are going through a linear, they're going through a ReLU, they're going 00:41:24.600 |
through another linear, and then they're going through an MSE. 00:41:30.000 |
And so we're going to have to combine those all together. 00:41:36.120 |
So if our steps are that loss function is-- so we've got the loss function, which is some 00:41:50.280 |
function of the predictions and the actuals, and then we've got the second layer is a function 00:42:08.760 |
of-- actually, let's call this the output of the second layer. 00:42:17.280 |
It's slightly weird notation, but hopefully it's not too bad-- is going to be a function 00:42:30.960 |
And the ReLU activations are a function of the first layer. 00:42:36.640 |
And the first layer is a function of the inputs. 00:42:39.640 |
Oh, and of course, this also has weights and biases. 00:42:46.520 |
So we're basically going to have to calculate the derivative of that. 00:42:54.640 |
But then remember that this is itself a function. 00:42:57.640 |
So then we'll need to multiply that derivative by the derivative of that. 00:43:01.880 |
But that's also a function, so we have to multiply that derivative by this. 00:43:05.800 |
But that's also a function, so we have to multiply that derivative by this. 00:43:13.600 |
And we're going to take its derivative, and then we're going to gradually keep multiplying 00:43:25.000 |
So backpropagation sounds pretty fancy, but it's actually just using the chain rule-- 00:43:31.640 |
gosh, I didn't spell that very well-- prop-gation-- it's just using the chain rule. 00:43:39.200 |
And as you'll see, it's also just taking advantage of a computational trick of memorizing some 00:43:46.000 |
And in our chat, Siva made a very good point about understanding nonlinear functions in 00:43:54.440 |
this case, which is just to consider that the wheels could be growing and shrinking 00:44:01.400 |
But you're still going to have this same compound effect, which I really like that. 00:44:12.800 |
There's also a question in the chat about why is this colon, comma, zero being placed 00:44:18.020 |
in the function, given that we can do it outside the function? 00:44:21.520 |
Well, the point is we want an MSE function that will apply to any output. 00:44:28.800 |
So we haven't actually modified preds or anything like that, or Y_train. 00:44:39.320 |
So we want this to be able to apply to anything without us having to pre-process it. 00:44:55.800 |
So here's going to do a forward pass and a backward pass. 00:44:58.200 |
So the forward pass is where we calculate the loss. 00:45:15.420 |
So the loss is going to be the output of our neural net minus our target squared, then take 00:45:29.200 |
And then our output is going to be the output of the second linear layer. 00:45:37.280 |
The second linear layer's input will be the value. 00:45:41.840 |
So we're going to take our input, put it through a linear layer, put that through a value, put 00:45:45.480 |
that through a linear layer, and calculate the MSE. 00:45:50.280 |
OK, that bit hopefully is pretty straightforward. 00:45:57.780 |
So the backward pass, what I'm going to do-- and you'll see why in a moment-- is I'm going 00:46:08.000 |
So for example, the gradients of the loss with respect to its inputs in the layer itself. 00:46:19.560 |
I could call it anything I like, and I'm just going to call it .g. 00:46:23.200 |
So I'm going to create a new attribute called out.g, which is going to contain the gradients. 00:46:29.200 |
You don't have to do it this way, but as you'll see, it turns out pretty convenient. 00:46:34.000 |
So that's just going to be 2 times the difference, because we've got difference squared. 00:46:50.480 |
So we have to do the same thing here, divided by the input shape. 00:46:58.400 |
And now what we need to do is multiply by the gradients of the previous layer. 00:47:11.280 |
So the gradient of a linear layer, we're going to need to know the weights of the layer. 00:47:17.480 |
We're going to need to know the biases of the layer. 00:47:21.800 |
And then we're also going to know the input to the linear layer, because that's the thing 00:47:31.440 |
And then we're also going to need the output, because we have to multiply by the gradients 00:47:40.240 |
So again, we're going to store the gradients of our input. 00:47:46.660 |
So this would be the gradients of our output with respect to the input. 00:47:52.800 |
Because the weights, so a matrix multiplier is just a whole bunch of linear functions. 00:48:00.640 |
But you have to multiply it by the gradient of the outputs because of the chain rule. 00:48:06.100 |
And then the gradient of the outputs with respect to the weights is going to be the 00:48:22.440 |
The derivatives of the bias is very straightforward. 00:48:27.440 |
It's the gradients of the output added together because the bias is just a constant value. 00:48:37.040 |
So for the chain rule, we simply just use output times 1, which is output. 00:48:44.560 |
So for this one here, again, we have to do the same thing we've been doing before, which 00:48:48.920 |
is multiply by the output gradients because of the chain rule. 00:48:58.200 |
So every single one of those has to be multiplied by the outputs. 00:49:05.440 |
And so that's why we have to do an unsqueezed minus 1. 00:49:08.520 |
So what I'm going to do now is I'm going to show you how I would experiment with this 00:49:17.040 |
And I would encourage you to do the same thing. 00:49:19.440 |
It's a little harder to do this one cell by cell because we kind of want to put it all 00:49:25.880 |
So we need a way to explore the calculations interactively. 00:49:31.640 |
And the way we do that is by using the Python debugger. 00:49:36.160 |
Here is how you-- let me see a few ways to do this. 00:49:45.920 |
So if you say pdb.settrace in your code, then that tells the debugger to stop execution 00:49:58.660 |
So if I call forward and backward, you can see here it's stopped. 00:50:03.520 |
And the interactive Python debugger, ipdb, has popped up. 00:50:07.640 |
With an arrow pointing at the line of code, it's about to run. 00:50:11.840 |
And at this point, there's a whole range of things we can do to find out what they are. 00:50:18.360 |
Understanding how to use the Python debugger is one of the most powerful things I think 00:50:27.400 |
So one of the most useful things you can do is to print something. 00:50:34.200 |
But in a debugger, you want to be able to do things quickly. 00:50:36.000 |
So instead of typing print, I'll just type p. 00:50:39.160 |
So for example, let's take a look at the shape of the input. 00:50:53.780 |
So I've got a 50,000 by 50 input to the last layer. 00:50:57.600 |
These are the hidden activations coming into the last layer for every one of our images. 00:51:16.760 |
You don't have to use the p at all if your variable name is not the same as any of these 00:51:46.140 |
So the output of this is-- let's see if it makes sense. 00:51:52.200 |
We put a new axis on the end, unsqueezed minus 1 is the same as indexing it with dot dot 00:52:10.720 |
And then the outg dot unsqueezed we're putting in the first dimension. 00:52:14.440 |
So we're going to have 50,000 by 50 by 1 times 50,000 by 1 by 1. 00:52:21.280 |
And so we're only going to end up getting this broadcasting happening over these last 00:52:25.320 |
two dimensions, which is why we end up with 50,000 by 50 by 1. 00:52:30.600 |
And then with summing up-- this makes sense, right? 00:52:36.960 |
Each image is individually contributing to the derivative. 00:52:42.920 |
And so we want to add them all up to find their total impact, because remember the sum 00:52:46.680 |
of a bunch of-- the derivative of the sum of functions is the sum of the derivatives 00:52:54.680 |
Now this is one of these situations where if you see a times and a sum and an unsqueeze, 00:53:00.400 |
it's not a bad idea to think about Einstein summation notation. 00:53:09.280 |
So first of all, let's just see how we can do some more stuff in the debugger. 00:53:17.920 |
So press C for continue, and it keeps running until it comes back again to the same spot. 00:53:25.040 |
And the reason we've come to the same spot twice is because lin grad is called two times. 00:53:31.640 |
So we would expect that the second time, we're going to get a different bunch of inputs and 00:53:45.760 |
And so I can print out a tuple of the inputs and output gradient. 00:53:49.100 |
So now, yeah, so this is the first layer going into the second layer. 00:53:58.000 |
To find out what called this function, you just type w. 00:54:06.080 |
Oh, forward and backward was called-- see the arrow? 00:54:09.280 |
That called lin grad the second time, and now we're here in w.g equals. 00:54:15.600 |
If we want to find out what w.g ends up being equal to, I can press N to say, go to the 00:54:23.560 |
And so now we've moved from line five to nine six. 00:54:26.480 |
So the instruction point is now looking at line six. 00:54:28.640 |
So I could now print out, for example, w.g.shape. 00:54:37.800 |
One person on the chat has pointed out that you can use breakpoint instead of this import 00:54:46.240 |
Unfortunately, the breakpoint keyword doesn't currently work in Jupyter or in IPython. 00:54:56.840 |
That's why I'm doing it the old fashioned way. 00:54:58.800 |
So this way, maybe they'll fix the bug at some point. 00:55:09.120 |
But I would definitely suggest looking up a Python pdb tutorial to become very familiar 00:55:14.480 |
with this incredibly powerful tool because it really is so very handy. 00:55:21.960 |
So if I just press continue again, it keeps running all the way to the end and it's now 00:55:29.880 |
So when it's finished, we would find that there will now be, for example, a w1.g because 00:55:40.480 |
this is the gradients that it just calculated. 00:55:47.400 |
And there would also be a xtrain.g and so forth. 00:55:54.440 |
OK, so let's see if we can simplify this a little bit. 00:55:58.880 |
So I would be inclined to take these out and give them their own variable names just to 00:56:06.080 |
Would have been better if I'd actually done this before the debugging, so it'd be a bit 00:56:14.240 |
So let's set I and O equal to input and output dot g dot unsqueeze. 00:56:20.960 |
OK, so we'll get rid of our breakpoint and double check that we've got our gradients 00:56:42.560 |
And I guess before we run it, we should probably set those to zero. 00:56:50.920 |
What I would do here to try things out is I'd put my breakpoint there and then I would 00:57:01.280 |
And so I realize here that what we're actually doing is we're basically doing exactly the 00:57:23.280 |
Because I've just got this is being replicated and then I'm summing over that dimension, 00:57:31.680 |
because that's the multiplication that I'm doing. 00:57:33.280 |
So I'm basically multiplying the first dimension of each and then summing over that dimension. 00:57:52.640 |
And I've got zeros because I did x_train dot zero, that was silly. 00:58:14.920 |
So we've we've multiplied this repeating index. 00:58:17.740 |
So we were just multiplying the first dimensions together and then summing over them. 00:58:23.840 |
Now that's not quite the same thing as a matrix multiplication, but we could turn it into 00:58:28.120 |
the same thing as a matrix multiplication just by swapping i and j so that they're the 00:58:44.640 |
So that would become a matrix multiplication if we just use the transpose. 00:58:49.080 |
And in numpy, the transpose is the capital T attribute. 00:58:53.920 |
So here is exactly the same thing using a matrix multiply and a transpose. 00:59:08.840 |
So that tells us that now we've checked in our debugger that we can actually replace 00:59:39.800 |
OK, so hopefully that's convinced you that the debugger is a really handy thing for playing 00:59:47.120 |
around with numeric programming ideas or coding in general. 00:59:52.680 |
And so I think now is a good time to take a break. 01:00:07.800 |
So we've calculated our derivatives and we want to test them. 01:00:13.200 |
Luckily PyTorch already has derivatives implemented. 01:00:17.080 |
So I've got to totally cheat and use PyTorch to calculate the same derivatives. 01:00:24.160 |
So don't worry about how this works yet, because we're actually going to be doing all this 01:00:28.520 |
For now, I'm just going to run it all through PyTorch and check that their derivatives are 01:00:42.480 |
And obviously, it's clunky than what we do in PyTorch. 01:00:46.160 |
There's some really cool refactoring that we can do. 01:00:51.240 |
So what we're going to do is we're going to create a whole class for each of our functions, 01:00:55.760 |
for the value function and for the linear function. 01:01:00.920 |
So the way that we're going to do this is we're going to create a dunder call. 01:01:22.280 |
And we're just going to set that to print hello. 01:01:32.720 |
And then I call it as if it was a function, oops, missing the dunder bit here. 01:01:45.560 |
So in other words, you know, everything can be changed in Python. 01:01:54.320 |
And to do that, you simply define dunder call. 01:02:11.240 |
It just says it's just a little bit of syntax, sugary kind of stuff to say I want to be able 01:02:16.280 |
to treat it as if it's a function without any method at all. 01:02:26.080 |
But because it's got this special magic named under call, you don't have to write the dot 01:02:32.920 |
So here, if we create an instance of the relu class, we can treat it as a function. 01:02:39.280 |
And what it's going to do is it's going to take its input and do the relu on it. 01:02:44.400 |
But if you look back at the forward and backward, there's something very interesting about the 01:02:49.280 |
backward pass, which is that it has to know about, for example, this intermediate calculation 01:03:00.200 |
This intermediate calculation gets passed over here because of the chain rule, we're going 01:03:05.120 |
to need some of the intermediate calculations and not just the chain rule because of actually 01:03:12.920 |
So we need to actually store each of the layer intermediate calculations. 01:03:19.880 |
And so that's why relu doesn't just calculate and return the output, but it also stores 01:03:32.120 |
So that way, then, when we call backward, we know how to calculate that. 01:03:38.840 |
We set the inputs gradient because remember we stored the input, so we can do that. 01:03:45.200 |
And it's going to just be, oh, import greater than zero dot float. 01:03:49.960 |
So that's the definition of the derivative of a relu and then chain rule. 01:04:01.500 |
So that's how we can calculate the forward pass and the backward pass for relu. 01:04:07.640 |
And we're not going to have to then store all this intermediate stuff separately, it's 01:04:12.680 |
So we can do the same thing for a linear layer. 01:04:14.800 |
Now linear layer needs some additional state, weights and biases, relu doesn't, so there's 01:04:23.360 |
So when we create a linear layer, we have to say, what are its weights? 01:04:29.300 |
And then when we call it on the forward pass, just like before, we store the input. 01:04:36.320 |
And just like before, we calculate the output and store it and then return it. 01:04:44.960 |
And then for the backward pass, it's the same thing. 01:04:50.040 |
So the input gradients we calculate just like before, oh, dot t brackets is exactly the 01:04:57.080 |
same with a little t as big T is as a property. 01:05:01.440 |
So that's the same thing, that's just the transpose. 01:05:10.440 |
Again with the chain rule and the bias, just like we did it before. 01:05:14.120 |
And they're all being stored in the appropriate places. 01:05:19.680 |
We don't just calculate the MSE, but we also store it. 01:05:23.260 |
And we also, now the MSE needs two things, an input and a target. 01:05:30.020 |
So then in the backward pass, we can calculate its gradient of the input as being two times 01:05:44.560 |
So our model now, it's much easier to define. 01:05:49.480 |
We can just create a bunch of layers, linear w1b1, relu, linear w2b2. 01:05:57.680 |
And then we can store an instance of the MSE. 01:06:01.080 |
This is not calling MSE, it's creating an instance of the MSE class. 01:06:06.680 |
This is an instance of the relu class, so they're just being stored. 01:06:09.760 |
So then when we call the model, we pass it our inputs and our target. 01:06:16.720 |
We go through each layer, set x equal to the result of calling that layer, and then pass 01:06:25.080 |
So there's something kind of interesting here that you might've noticed, which is that we 01:06:45.620 |
Something interesting here is that we don't have two separate functions inside our model, 01:06:51.040 |
the loss function being applied to a separate neural net. 01:06:56.720 |
But we've actually integrated the loss function directly into the neural net, into the model. 01:07:02.320 |
See how the loss is being calculated inside the model? 01:07:06.520 |
Now that's neither better nor worse than having it separately. 01:07:10.440 |
And so generally, a lot of hugging face stuff does it this way. 01:07:13.120 |
They actually put the loss inside the forward. 01:07:17.100 |
That stuff in fast.ai and a lot of other libraries does it separately, which is the loss is a 01:07:23.800 |
And the model only returns the result of putting it through the layers. 01:07:27.000 |
So for this model, we're going to actually do the loss function inside the model. 01:07:34.760 |
So self.loss.backwards-- so self.loss is the MSE object. 01:07:42.400 |
And it's stored when it was called here, it was storing, remember, the inputs, the targets, 01:07:48.960 |
the outputs, so it can calculate the backward. 01:07:52.920 |
And then we go through each layer is in reverse, right? 01:07:55.960 |
This is back propagation, backwards reversed, calling backward on each one. 01:08:11.740 |
So now we can calculate the model, we can calculate the loss, we can call backward, 01:08:21.080 |
and then we can check that each of the gradients that we stored earlier are equal to each of 01:08:35.080 |
OK, so William's asked a very good question, that is, if you do put the loss inside here, 01:08:44.520 |
how on earth do you actually get predictions? 01:08:48.920 |
So generally, what happens is, in practice, hugging face models do something like this. 01:08:55.520 |
I'll say self.preds equals x, and then they'll say self.finalloss equals that, and then return 01:09:16.720 |
And that way-- I guess you don't even need that last bit. 01:09:20.120 |
Well, that's really the-- anyway, that is what they do, so I'll leave it there. 01:09:23.920 |
And so that way, you can kind of check, like, model.preds, for example. 01:09:33.520 |
Or alternatively, you can return not just the loss, but both as a dictionary, stuff 01:09:39.140 |
So there's a few different ways you could do it. 01:09:40.680 |
Actually, now I think about it, I think that's what they do, is they actually return both 01:09:45.480 |
as a dictionary, so it would be like return dictionary loss equals that, comma, preds equals 01:10:08.200 |
that, something like that, I guess, is what they would do. 01:10:12.360 |
Anyway, there's a few different ways to do it. 01:10:15.520 |
OK, so hopefully you can see that this is really making it nice and easy for us to do 01:10:24.000 |
our forward pass and our backward pass without all of this manual fiddling around. 01:10:31.760 |
Every class now can be totally, separately considered and can be combined however we 01:10:41.440 |
So you could try creating a bigger neural net if you want to. 01:10:47.080 |
So basically, as a rule of thumb, when you see repeated code, self.mp equals imp, self.mp 01:10:53.800 |
equals imp, self.ax equals return self.out, self.out equals return self.out. 01:11:01.740 |
And so what we can do is, a simple refactoring is to create a new class called module. 01:11:08.600 |
And module's going to do those things we just said. 01:11:13.800 |
And it's going to call something called self.forward in order to create our self.out, because remember, 01:11:19.800 |
that was one of the things we had again and again and again, self.out, self.out. 01:11:27.040 |
And so now, there's going to be a thing called forward, which actually, in this, it doesn't 01:11:33.700 |
do anything, because the whole purpose of this module is to be inherited. 01:11:37.800 |
When we call backward, it's going to call self.backward passing in self.out, because 01:11:43.480 |
notice, all of our backwards always wanted to get hold of self.out, self.out, self.out, 01:11:55.760 |
So let's pass that in, and pass in those arguments that we stored earlier. 01:12:01.280 |
And so star means take all of the arguments, regardless whether it's 0, 1, 2, or more, 01:12:09.800 |
And then that's what happens when it's inside the actual signature. 01:12:13.240 |
And then when you call a function using star, it says take this list and expand them into 01:12:18.160 |
separate arguments, calling backward with each one separately. 01:12:22.880 |
So now, for relu, look how much simpler it is. 01:12:29.420 |
So the old relu had to do all this storing stuff manually. 01:12:34.080 |
And it had all the self.stuff as well, but now we can get rid of all of that and just 01:12:39.040 |
implement forward, because that's the thing that's being called, and that's the thing 01:12:47.740 |
And so now the forward of relu just does the one thing we want, which also makes the code 01:12:53.920 |
Did over backward, it just does the one thing we want. 01:12:58.880 |
Now, we still have to multiply it, but I still have to do the chain rule manually. 01:13:03.880 |
But the same thing for linear, same thing for MSE. 01:13:09.520 |
And one thing to point out here is that there's often opportunities to manually speed things 01:13:19.920 |
up when you create custom autograd functions in PyTorch. 01:13:24.440 |
And here's an example, look, this calculation is being done twice, which seems like a waste, 01:13:33.280 |
So at the cost of some memory, we could instead store that calculation as diff. 01:13:47.560 |
Right, and I guess we'd have to store it for use later, so it'll need to be self.diff. 01:13:57.920 |
And at the cost of that memory, we could now remove this redundant calculation because 01:14:08.280 |
we've done it once before already and stored it and just use it directly. 01:14:18.960 |
And this is something that you can often do in neural nets. 01:14:22.080 |
So there's this compromise between storing things, the memory use of that, and then the 01:14:32.880 |
computational speed up of not having to recalculate it. 01:14:38.360 |
And so now we can call it in the same way, create our model, passing in all of those 01:14:43.360 |
So you can see with our model, so the model hasn't changed at this point, the definition 01:14:50.680 |
was up here, we just pass in the layers, sorry, not the layers, the weights for the layers. 01:15:03.640 |
Create the loss, call backward, and look, it's the same, hooray. 01:15:11.800 |
Okay, so thankfully PyTorch has written all this for us. 01:15:19.320 |
And remember, according to rules of our game, once we've reimplemented it, we're allowed 01:15:34.280 |
So if we want to create a linear layer, just like this one, rather than inheriting from 01:15:38.160 |
our module, we will inherit from that module. 01:15:45.200 |
So we create our, we can create our random numbers. 01:15:49.280 |
So in this case, rather than passing in the already randomized weights, we're actually 01:15:52.720 |
going to generate the random weights ourselves and the zeroed biases. 01:15:57.400 |
And then here's our linear layer, which you could also use Lin for that, of course, to 01:16:10.240 |
Because PyTorch already knows the derivatives of all of the functions in PyTorch, and it 01:16:23.000 |
It'll actually do that entirely for us, which is very cool. 01:16:31.320 |
So let's create a model that is a zn.module, otherwise it's exactly the same as before. 01:16:37.120 |
And now we're going to use PyTorch's MSE loss because we've already implemented ourselves. 01:16:42.200 |
It's very common to use torch.nn.functional as capital F. This is where lots of these 01:16:53.720 |
And so now you know why we need the colon, colon, none, because you saw the problem if 01:17:03.120 |
And remember, we stored our gradients in something called dot G. PyTorch stores them in something 01:17:09.520 |
called dot grad, but it's doing exactly the same thing. 01:17:22.520 |
So we've created a matrix multiplication from scratch. 01:17:29.940 |
We've created a complete backprop system of modules. 01:17:34.480 |
We can now calculate both the forward pass and the backward pass for linear layers and 01:17:40.440 |
values so we can create a multi-layer perceptron. 01:17:44.120 |
So we're now up to a point where we can train a model. 01:18:00.840 |
This cell's also the same as before, so we won't go through it. 01:18:03.920 |
Here's the same model that we had before, so we won't go through it. 01:18:10.840 |
OK, so the first thing we should do, I think, is to improve our loss function so it's not 01:18:21.180 |
So if you watched part one, you might recall that there are some Excel notebooks. 01:18:28.360 |
One of those Excel notebooks is entropy example. 01:18:37.120 |
So just to remind you, what we're doing now is which we're saying, OK, rather than outputting 01:18:50.400 |
a single number for each image, we're going to instead output 10 numbers for each image. 01:19:03.400 |
And so that's going to be a one hot encoded set of-- it'll be like 1, 0, 0, 0, et cetera. 01:19:15.080 |
And so then that's going to be-- well, actually, the outputs won't be 1, 0, 0. 01:19:19.280 |
They'll be basically probabilities, won't they? 01:19:32.940 |
So if it's the digit 0, for example, it might be 1, 0, 0, 0, 0, dot, dot, dot for all the 01:19:43.560 |
And so to see how good is it-- so in this case, it's really good. 01:19:47.560 |
It had a 0.99 probability prediction that it's 0. 01:19:51.360 |
And indeed, it is because this is the 100 encoded version. 01:19:55.840 |
And so the way we implement that is we don't even need to actually do the one hot encoding 01:20:06.760 |
We can actually just directly store the integer, but we can treat it as if it's one hot encoded. 01:20:12.080 |
So we can just store the actual target 0 as an integer. 01:20:19.440 |
So the way we do that is we say, for example, for a single output, oh, it could be, let's 01:20:34.840 |
What we do for Softmax is we go e to the power of each of those outputs. 01:20:46.160 |
So here's the e to the power of each of those outputs. 01:21:03.040 |
And then for the loss function, we then compare those Softmaxes to the one hot encoded version. 01:21:12.320 |
Then it's going to have a 1 for dog and 0 everywhere else. 01:21:18.280 |
And then Softmax, this is from this nice blog post here. 01:21:27.520 |
This is the calculation sum of the ones and zeros. 01:21:32.840 |
So each of the ones and zeros multiplied by the log of the probabilities. 01:21:37.920 |
So here is the log probability times the actuals. 01:21:44.520 |
And since the actuals are either 0 or 1, and only one of them is going to be a 1, we're 01:21:51.120 |
And so if we add them up, it's all 0 except for one of them. 01:21:59.800 |
So in this special case where the output's one hot encoded, then doing the one hot encoded 01:22:07.480 |
multiplied by the log Softmax is actually identical to simply saying, oh, dog is in 01:22:17.640 |
Let's just look it up directly and take its log Softmax. 01:22:31.600 |
So if you haven't seen that before, then yeah, go and watch the part one video where we went 01:22:45.440 |
It's a to the power of each output divided by the sum of them, or we can use sigma notation 01:22:54.880 |
And as you can see, Tupler Notebook lets us use LaTeX. 01:23:00.800 |
If you haven't used LaTeX before, it's actually surprisingly easy to learn. 01:23:05.320 |
You just put dollar signs around your equations like this and your equations backslash is 01:23:12.280 |
going to be kind of like your functions, if you like. 01:23:15.400 |
And curly parentheses, curly curlies are used to kind of for arguments. 01:23:22.040 |
So you can see here, here is e to the power of and then underscore is used for subscript. 01:23:27.240 |
So this is X subscript I and power of is used for superscripts. 01:23:39.760 |
So it's actually, yeah, learning LaTeX is easier than you might expect. 01:23:43.720 |
It can be quite convenient for writing these functions when you want to. 01:23:50.080 |
As we'll see in a moment, well, actually, as you've already seen, in cross entropy, 01:23:55.480 |
we don't really want Softmax, we want log of Softmax. 01:23:59.960 |
So log of Softmax is, here it is, so we've got x dot exp, so e to the x, divided by x 01:24:12.520 |
And we're going to sum up over the last dimension. 01:24:16.420 |
And then we actually want to keep that dimension so that when we do the divided by, we want 01:24:22.560 |
to be trailing unit axis for exactly the same reason we saw when we did our MSE loss function. 01:24:28.420 |
So if you sum with keep dim equals true, it leaves a unit axis in that last position. 01:24:35.720 |
So we don't have to put it back to avoid that horrible out of product issue. 01:24:40.520 |
So this is the equivalent of this and then dot log. 01:24:50.200 |
So there is the log of the Softmax with the predictions. 01:24:55.160 |
Now in terms of high school math that you may have forgotten, but you definitely are 01:25:01.160 |
going to want to know, a key piece that in that list of things is log and exponent rules. 01:25:12.680 |
So check out Khan Academy or similar if you've forgotten them. 01:25:17.700 |
But a quick reminder is, for example, the one we mentioned here, log of A over B equals 01:25:33.600 |
log of A minus log of B and equivalently log of A times B equals log of A plus log of B. 01:25:56.600 |
And these are very handy because for example, division can take a long time, multiply can 01:26:03.100 |
create really big numbers that have lots of floating point error. 01:26:07.600 |
Being able to replace these things with pluses and minuses is very handy indeed. 01:26:12.420 |
In fact, I used to give people an interview question 20 years ago, a company which I did 01:26:21.840 |
SQL actually only has a sum function for group by clauses. 01:26:28.960 |
And I used to ask people how you would deal with calculating a compound interest column 01:26:35.040 |
where the answer is basically that you have to say, because this compound interest is 01:26:41.320 |
So it has to be the sum of the log of the column and then e to the power of all that. 01:26:48.800 |
So it's like all kinds of little places that these things come in handy, but they come 01:27:02.200 |
So we're going to take advantage of that because we've got a divided by that's being logged. 01:27:10.280 |
And also rather handily, we're going to have therefore the log of exp.exp minus the log 01:27:24.840 |
So that is going to end up just being x minus. 01:27:29.220 |
So log softmax is just x minus all this logged and here it is all this logged. 01:27:45.640 |
Now there's another very cool trick, which is one of these things I figured out myself 01:27:53.760 |
and then discovered other people had known it for years. 01:27:57.800 |
So not my trick, but it's always nice to rediscover things. 01:28:07.560 |
This piece here, the log of this sum, right, this sum here, we've got x.exp.sum. 01:28:16.480 |
Now x could be some pretty big numbers and e to the power of that's going to be really 01:28:22.160 |
And e to the power of things creating really big numbers, well, really big numbers. 01:28:26.800 |
There's much less precision in your computer's floating point handling the further you get 01:28:35.120 |
So we don't want really big numbers, particularly because we're going to be taking derivatives. 01:28:39.800 |
And so if you're in an area that's not very precise as far as floating point math is concerned, 01:28:45.840 |
then the derivatives are going to be a disaster. 01:28:47.320 |
They might even be zero because you've got two numbers that the computer can't even recognize 01:28:53.720 |
So this is bad, but there's a nice trick we can do to make it a lot better. 01:28:59.360 |
What we can do is we can calculate the max of x, right, and we'll call that a. 01:29:07.000 |
And so then rather than doing the log of the sum of e to the xi, we're instead going to 01:29:24.440 |
define a as being the minimum, sorry, the maximum of all of our x values. 01:29:36.440 |
Now if we then subtract that from every number, that means none of the numbers are going to 01:29:44.800 |
be big by definition because we've subtracted it from all of them. 01:29:49.360 |
Now the problem is that's given us a different result, right? 01:29:53.640 |
But if you think about it, let's expand this sum. 01:29:57.200 |
It's e to the power of x1, if we don't include our minus a, plus e to the power of x2, plus 01:30:10.600 |
Okay, now we just subtracted a from our exponents, which has made, meant we're now wrong. 01:30:27.360 |
The bad news is that you've got more high school math to remember, which is exponent 01:30:32.640 |
rules. So x to the a plus b equals x to the a times x to the b. 01:30:44.440 |
And similarly, x to the a minus b equals x to the a divided by x to the b. 01:30:54.800 |
And to convince yourself that's true, consider, for example, 2 to the power of 2 plus 3. 01:31:03.480 |
Well, you've got 2 to the power of 2 is just 2 times 2, and 2 to the power of 2 plus 3, 01:31:10.360 |
well, it's 2 times 2 times, is 2 to the power of 5. 01:31:15.280 |
So you've got 2 to the power of 2, you've got 2 of them here, and you've got another 01:31:19.580 |
So we're just adding up the number to get the total index. 01:31:23.880 |
So we can take advantage of this here and say like, oh, well, this is equal to e to 01:31:28.600 |
the x1 over e to the a plus e to the x2 over e to the a plus e to the x3 over e to the 01:31:59.780 |
Because if we now multiply that all by e to the a, these would cancel out and we get the 01:32:08.060 |
So that means we simply have to multiply this by that, and this gives us exactly the same 01:32:19.160 |
That with, critically, this is no longer ever going to be a giant number. 01:32:24.200 |
So this might seem a bit weird, we're doing extra calculations. 01:32:27.080 |
It's not a simplification, it's a complexification, but it's one that's going to make it easier 01:32:35.460 |
So that's our trick, is rather than doing log of this sum, what we actually do is log 01:32:40.560 |
of e to the a times the sum of e to the x minus a. 01:32:46.640 |
And since we've got log of a product, that's just the sum of the logs, and log of e to 01:32:56.880 |
So this here is called the log sum exp trick. 01:33:09.960 |
Oops, people pointing out that I've made a mistake, thank you. 01:33:17.280 |
That, of course, should have been inside the log, you can't just go sticking it on the 01:33:24.280 |
outside like a crazy person, that's what I meant to say. 01:33:32.480 |
OK, so here is the log sum exp trick, I'll call it m instead of a, which is a bit silly, 01:33:40.000 |
But anyway, so we find the maximum on the last dimension, and then here is the m plus 01:33:50.160 |
OK, so that's just another way of doing that. 01:34:03.160 |
So now we can rewrite log softmax as x minus log sum exp. 01:34:10.080 |
And we're not going to use our version because pytorch already has one. 01:34:17.200 |
And if we check, here we go, here's our results. 01:34:26.520 |
And so then as we've discussed, the cross entropy loss is the sum of the outputs times 01:34:33.860 |
And as we discussed, our outputs are one hot encoded, or actually they're just the integers 01:34:41.360 |
So what we can do is we can, I guess I should make that more clear, actually there, just 01:35:00.720 |
So we can simply rewrite that as negative log of the target. 01:35:17.840 |
There's a lot of cool things you can do with array indexing in pytorch and numpy. 01:35:27.120 |
Here is the first three actual values in YTrain, they're 5, 0, and 4. 01:35:35.760 |
Now what we want to do is we want to find in our softmax predictions, we want to get 01:35:42.680 |
5, the fifth prediction in the zeroth row, the zeroth prediction in the first row, and 01:36:00.720 |
This is going to be what we add up for the first two rows of our loss function. 01:36:13.960 |
If we index using a two lists, we can put here 0, 1, 2. 01:36:21.480 |
And for the second list, we can put YTrain, 3, 5, 0, 4. 01:36:25.640 |
And this is actually going to return 0, 0, 1, 0, 0, 0, 0, 5, 1, 0, and 2, 4. 01:36:39.940 |
Which is, as you see, exactly the same thing. 01:36:44.840 |
So therefore, this is actually giving us what we need for the cross entropy loss. 01:36:55.500 |
So if we take range of our target's first dimension, or zero index dimension, which 01:37:03.680 |
is all this is, and the target, and then take the negative of that dot mean, that gives 01:37:11.080 |
us our cross entropy loss, which is pretty neat, in my opinion. 01:37:22.800 |
So PyTorch calls this negative log likelihood loss, but that's all it is. 01:37:32.800 |
And so if we take the negative log likelihood, and we pass that to that, the log soft max, 01:37:44.920 |
And this particular combination in PyTorch is called F dot cross entropy. 01:37:51.400 |
Yep, F dot cross entropy gives us exactly the same thing. 01:37:56.400 |
So we have now re-implemented the cross entropy loss. 01:38:00.080 |
And there's a lot of confusing things going on there, a lot. 01:38:06.320 |
And so this is one of those places where you should pause the video and go back and look 01:38:11.000 |
at each step and think not just like, what is it doing, but why is it doing it? 01:38:16.600 |
And also try typing in lots of different values yourself to see if you can see what's going 01:38:21.940 |
on, and then put this aside and test yourself by re-implementing log soft max, and cross 01:38:32.300 |
entropy yourself, and compare them to PyTorch's values. 01:38:36.680 |
And so that's a piece of homework for you for this week. 01:38:43.800 |
So now that we've got that, we can actually create a training loop. 01:38:46.480 |
So let's set our loss function to be cross entropy. 01:38:57.520 |
It's going to be from 0 up to 64 from our training set. 01:39:08.400 |
So for each of the 64 images in the mini batch, we have 10 probabilities, one for each digit. 01:39:14.920 |
And our y is just-- in fact, let's print those out. 01:39:31.800 |
So we're going to start with a bad loss because it's entirely random at this point. 01:39:38.080 |
OK, so for each of the predictions we made-- so those are our predictions. 01:39:49.640 |
And so remember, those predictions are a 64 by 10. 01:39:56.120 |
So for each one of these 64 rows, we have to go in and see where is the highest number. 01:40:05.480 |
So if we go through here, we can go through each one. 01:40:13.840 |
OK, it looks like this is the highest number. 01:40:18.520 |
So you've got to find the index of the highest number. 01:40:20.880 |
The function to find the index of the highest number is called argmax. 01:40:28.760 |
And I guess we could have also written this probably as preds.argmax. 01:40:35.840 |
I actually prefer normally to do it this way. 01:40:39.120 |
OK, and the reason we want this is because we want to be able to calculate accuracy. 01:40:47.360 |
But we just like to be able to see how we're going because it's like it's a metric. 01:40:51.720 |
It's something that we use for understanding. 01:41:03.040 |
If you turn those into floats, they'll be ones and zeros. 01:41:05.280 |
And the mean of those floats is the accuracy. 01:41:07.840 |
So our current accuracy, not surprisingly, is around 10%. 01:41:23.680 |
And we're going to go through from 0 up to n. 01:41:30.920 |
And skipping by 64, the batch size each time. 01:41:34.600 |
And so we're going to create a slice that starts at i. 01:41:38.840 |
So starting at 0 and goes up to 64, unless we've gone past the end, 01:41:48.240 |
And so then we will slice into our training set for the x 01:41:55.840 |
We will then calculate our predictions, our loss function, 01:42:00.960 |
So the way I did this originally was I had all of these in separate cells. 01:42:10.200 |
And I just typed in i equals 0 and then went through one cell at a time, 01:42:29.560 |
OK, so once we've got done backward, we can then, with torch.co.grad, 01:42:45.320 |
we'll update them to the existing weights minus the gradients 01:42:53.840 |
And then 0 out, so the weights and biases for the gradients, 01:43:23.160 |
So you can see that our accuracy on the training set-- 01:43:28.040 |
it's a bit unfair, but it's only three epochs-- 01:43:36.600 |
Trains pretty quickly and is not terrible at all. 01:43:42.680 |
All right, so what we're going to do next time 01:43:50.120 |
is we're going to refactor this training loop 01:43:56.960 |
dramatically simpler step by step until eventually we 01:44:04.200 |
will get it down to something much, much shorter. 01:44:12.440 |
And then we're going to add a validation set to it 01:44:19.240 |
And then, yeah, we'll be in a pretty good position, 01:44:21.960 |
I think, to start training some more interesting models. 01:44:47.040 |
kind of got all these key basic pieces in place, 01:44:51.840 |
is to really try to recreate them without peaking as much 01:45:07.480 |
recreate something that steps through layers, 01:45:10.120 |
and even see if you can recreate the idea of the dot forward 01:45:18.200 |
Make sure it's all in your head really clearly so that you 01:45:26.280 |
At the very least, if you don't have time for that, 01:45:39.040 |
So if you go to kernel restart and clear output, 01:45:43.160 |
it will delete all the outputs and try to think, 01:45:55.880 |
Hope you have a great week, and I will see you next time.