back to indexBuilding makemore Part 4: Becoming a Backprop Ninja
Chapters
0:0 intro: why you should care & fun history
7:26 starter code
13:1 exercise 1: backproping the atomic compute graph
65:17 brief digression: bessel’s correction in batchnorm
86:31 exercise 2: cross entropy loss backward pass
96:37 exercise 3: batch norm layer backward pass
110:2 exercise 4: putting it all together
114:24 outro
00:00:00.000 |
Hi everyone. So today we are once again continuing our implementation of Makemore. 00:00:04.240 |
Now so far we've come up to here, Montalio Perceptrons, and our neural net looked like this, 00:00:10.720 |
and we were implementing this over the last few lectures. Now I'm sure everyone is very excited 00:00:14.880 |
to go into recurrent neural networks and all of their variants and how they work, and the diagrams 00:00:19.360 |
look cool and it's very exciting and interesting and we're going to get a better result, but 00:00:22.880 |
unfortunately I think we have to remain here for one more lecture. And the reason for that is we've 00:00:28.880 |
already trained this Montalio Perceptron, right, and we are getting pretty good loss, and I think 00:00:32.960 |
we have a pretty decent understanding of the architecture and how it works, but the line of 00:00:37.840 |
code here that I take an issue with is here, lost at backward. That is, we are taking PyTorch 00:00:43.760 |
Autograd and using it to calculate all of our gradients along the way, and I would like to 00:00:48.400 |
remove the use of lost at backward, and I would like us to write our backward pass manually on 00:00:52.800 |
the level of tensors. And I think that this is a very useful exercise for the following reasons. 00:00:58.720 |
I actually have an entire blog post on this topic, but I'd like to call backpropagation a 00:01:03.520 |
leaky abstraction. And what I mean by that is backpropagation doesn't just make your neural 00:01:09.120 |
networks just work magically. It's not the case that you can just stack up arbitrary Lego blocks 00:01:13.440 |
of differentiable functions and just cross your fingers and backpropagate and everything is great. 00:01:17.680 |
Things don't just work automatically. It is a leaky abstraction in the sense that 00:01:22.880 |
you can shoot yourself in the foot if you do not understand its internals. 00:01:26.640 |
It will magically not work or not work optimally, and you will need to understand how it works 00:01:32.240 |
under the hood if you're hoping to debug it and if you are hoping to address it in your neural net. 00:01:36.480 |
So this blog post here from a while ago goes into some of those examples. So for example, 00:01:42.560 |
we've already covered them, some of them already. For example, the flat tails of these functions 00:01:48.320 |
and how you do not want to saturate them too much because your gradients will die. 00:01:53.120 |
The case of dead neurons, which I've already covered as well. The case of exploding or 00:01:57.840 |
vanishing gradients in the case of recurrent neural networks, which we are about to cover. 00:02:01.680 |
And then also, you will often come across some examples in the wild. This is a snippet that I 00:02:08.480 |
found in a random code base on the internet where they actually have a very subtle but pretty major 00:02:14.400 |
bug in their implementation. And the bug points at the fact that the author of this code does not 00:02:20.000 |
actually understand backpropagation. So what they're trying to do here is they're trying to 00:02:23.360 |
clip the loss at a certain maximum value. But actually what they're trying to do is they're 00:02:28.000 |
trying to clip the gradients to have a maximum value instead of trying to clip the loss at a 00:02:32.000 |
maximum value. And indirectly, they're basically causing some of the outliers to be actually 00:02:38.400 |
ignored because when you clip a loss of an outlier, you are setting its gradient to zero. 00:02:44.080 |
And so have a look through this and read through it. But there's basically a bunch of subtle issues 00:02:49.840 |
that you're going to avoid if you actually know what you're doing. And that's why I don't think 00:02:53.600 |
it's the case that because PyTorch or other frameworks offer autograd, it is okay for us 00:02:58.240 |
to ignore how it works. Now, we've actually already covered autograd and we wrote micrograd. 00:03:04.640 |
But micrograd was an autograd engine only on the level of individual scalars. So the atoms were 00:03:10.160 |
single individual numbers. And I don't think it's enough. And I'd like us to basically think about 00:03:15.120 |
backpropagation on the level of tensors as well. And so in a summary, I think it's a good exercise. 00:03:20.240 |
I think it is very, very valuable. You're going to become better at debugging neural networks and 00:03:25.520 |
making sure that you understand what you're doing. It is going to make everything fully explicit. So 00:03:29.920 |
you're not going to be nervous about what is hidden away from you. And basically, in general, 00:03:34.000 |
we're going to emerge stronger. And so let's get into it. A bit of a fun historical note here is 00:03:39.760 |
that today, writing your backward pass by hand and manually is not recommended and no one does it 00:03:44.560 |
except for the purposes of exercise. But about 10 years ago in deep learning, this was fairly 00:03:49.360 |
standard and in fact pervasive. So at the time, everyone used to write their own backward pass 00:03:53.680 |
by hand manually, including myself. And it's just what you would do. So we used to write backward 00:03:58.800 |
pass by hand. And now everyone just called lost backward. We've lost something. I wanted to give 00:04:04.400 |
you a few examples of this. So here's a 2006 paper from Geoff Hinton and Roslyn Slavkinov 00:04:12.560 |
in science that was influential at the time. And this was training some architectures called 00:04:17.840 |
restricted Boltzmann machines. And basically, it's an autoencoder trained here. And this is from 00:04:24.160 |
roughly 2010. I had a library for training restricted Boltzmann machines. And this was 00:04:29.680 |
at the time written in MATLAB. So Python was not used for deep learning pervasively. It was all 00:04:34.080 |
MATLAB. And MATLAB was this scientific computing package that everyone would use. So we would 00:04:40.240 |
write MATLAB, which is barely a programming language as well. But it had a very convenient 00:04:45.680 |
tensor class. And it was this computing environment. And you would run here. It would all run on a CPU, 00:04:50.480 |
of course. But you would have very nice plots to go with it and a built-in debugger. And it 00:04:54.560 |
was pretty nice. Now, the code in this package in 2010 that I wrote for fitting restricted 00:05:00.800 |
Boltzmann machines, to a large extent, is recognizable. But I wanted to show you how you 00:05:05.440 |
would-- well, I'm creating the data and the xy batches. I'm initializing the neural net. So it's 00:05:11.520 |
got weights and biases, just like we're used to. And then this is the training loop, where we 00:05:15.920 |
actually do the forward pass. And then here, at this time, they didn't even necessarily use back 00:05:20.640 |
propagation to train neural networks. So this, in particular, implements contrastive divergence, 00:05:25.920 |
which estimates a gradient. And then here, we take that gradient and use it for a parameter update 00:05:32.000 |
along the lines that we're used to. Yeah, here. But you can see that, basically, people were 00:05:37.920 |
meddling with these gradients directly and inline and themselves. It wasn't that common to use an 00:05:42.720 |
autograd engine. Here's one more example from a paper of mine from 2014 called Deep Fragment 00:05:48.720 |
Embeddings. And here, what I was doing is I was aligning images and text. And so it's kind of 00:05:54.160 |
like a clip, if you're familiar with it. But instead of working on the level of entire images 00:05:58.240 |
and entire sentences, it was working on the level of individual objects and little pieces of 00:06:02.240 |
sentences. And I was embedding them and then calculating a very much like a clip-like loss. 00:06:06.480 |
And I dug up the code from 2014 of how I implemented this. And it was already in NumPy 00:06:12.880 |
and Python. And here, I'm implementing the cost function. And it was standard to implement not 00:06:18.880 |
just the cost, but also the backward pass manually. So here, I'm calculating the image 00:06:23.680 |
embeddings, sentence embeddings, the loss function. I calculate the scores. This is the loss function. 00:06:29.920 |
And then once I have the loss function, I do the backward pass right here. So I backward 00:06:34.640 |
through the loss function and through the neural net. And I append regularization. 00:06:38.720 |
So everything was done by hand manually. And you would just write out the backward pass. 00:06:43.680 |
And then you would use a gradient checker to make sure that your numerical estimate of the gradient 00:06:47.840 |
agrees with the one you calculated during the backpropagation. So this was very standard for 00:06:52.080 |
a long time. But today, of course, it is standard to use an autograd engine. But it was definitely 00:06:57.680 |
useful. And I think people sort of understood how these neural networks work on a very intuitive 00:07:01.280 |
level. And so I think it's a good exercise again. And this is where we want to be. 00:07:05.280 |
OK, so just as a reminder from our previous lecture, this is the Jupyter notebook that 00:07:08.800 |
we implemented at the time. And we're going to keep everything the same. So we're still going 00:07:14.080 |
to have a two-layer multilayer perceptron with a batch normalization layer. So the forward pass 00:07:18.480 |
will be basically identical to this lecture. But here, we're going to get rid of loss.backward. 00:07:22.800 |
And instead, we're going to write the backward pass manually. Now, here's the starter code for 00:07:27.120 |
this lecture. We are becoming a backprop ninja in this notebook. And the first few cells here 00:07:33.600 |
are identical to what we are used to. So we are doing some imports, loading in the data set, 00:07:38.160 |
and processing the data set. None of this changed. Now, here I'm introducing a utility function that 00:07:43.680 |
we're going to use later to compare the gradients. So in particular, we are going to have the 00:07:47.360 |
gradients that we estimate manually ourselves. And we're going to have gradients that PyTorch 00:07:51.920 |
calculates. And we're going to be checking for correctness, assuming, of course, that PyTorch 00:07:56.240 |
is correct. Then here, we have the initialization that we are quite used to. So we have our 00:08:02.400 |
embedding table for the characters, the first layer, second layer, and a batch normalization 00:08:07.200 |
in between. And here's where we create all the parameters. Now, you will note that I changed 00:08:11.920 |
the initialization a little bit to be small numbers. So normally, you would set the biases 00:08:16.800 |
to be all zero. Here, I am setting them to be small random numbers. And I'm doing this because 00:08:21.760 |
if your variables are initialized to exactly zero, sometimes what can happen is that can mask 00:08:27.120 |
an incorrect implementation of a gradient. Because when everything is zero, it sort of 00:08:32.560 |
simplifies and gives you a much simpler expression of the gradient than you would otherwise get. 00:08:36.720 |
And so by making it small numbers, I'm trying to unmask those potential errors in these 00:08:41.520 |
calculations. You also notice that I'm using b1 in the first layer. I'm using a bias, despite 00:08:48.400 |
batch normalization right afterwards. So this would typically not be what you do, because we 00:08:53.280 |
talked about the fact that you don't need a bias. But I'm doing this here just for fun, because 00:08:58.000 |
we're going to have a gradient with respect to it. And we can check that we are still calculating it 00:09:01.520 |
correctly, even though this bias is spurious. So here, I'm calculating a single batch. And then 00:09:07.680 |
here, I am doing a forward pass. Now, you'll notice that the forward pass is significantly 00:09:12.160 |
expanded from what we are used to. Here, the forward pass was just here. Now, the reason that 00:09:18.560 |
the forward pass is longer is for two reasons. Number one, here, we just had an f dot cross 00:09:23.280 |
entropy. But here, I am bringing back a explicit implementation of the loss function. And number 00:09:28.720 |
two, I've broken up the implementation into manageable chunks. So we have a lot more 00:09:35.280 |
intermediate tensors along the way in the forward pass. And that's because we are about to go 00:09:39.440 |
backwards and calculate the gradients in this back propagation from the bottom to the top. 00:09:44.880 |
So we're going to go upwards. And just like we have, for example, the log props tensor 00:09:50.080 |
in a forward pass, in a backward pass, we're going to have a d log props, which is going to store the 00:09:54.560 |
derivative of the loss with respect to the log props tensor. And so we're going to be pre-pending 00:09:59.040 |
d to every one of these tensors and calculating it along the way of this back propagation. 00:10:04.960 |
So as an example, we have a b in raw here. We're going to be calculating a d b in raw. 00:10:09.360 |
So here, I'm telling PyTorch that we want to retain the grad of all these intermediate values, 00:10:15.360 |
because here in exercise one, we're going to calculate the backward pass. So we're going 00:10:19.920 |
to calculate all these d variables and use the CMP function I've introduced above to check our 00:10:25.920 |
correctness with respect to what PyTorch is telling us. This is going to be exercise one, 00:10:31.040 |
where we sort of back propagate through this entire graph. Now, just to give you a very quick 00:10:36.080 |
preview of what's going to happen in exercise two and below, here we have fully broken up the loss 00:10:42.240 |
and back propagated through it manually in all the little atomic pieces that make it up. 00:10:47.040 |
But here we're going to collapse the loss into a single cross entropy call. 00:10:50.720 |
And instead, we're going to analytically derive using math and paper and pencil, the gradient of 00:10:58.480 |
the loss with respect to the logits. And instead of back propagating through all of its little 00:11:02.320 |
chunks one at a time, we're just going to analytically derive what that gradient is, 00:11:06.160 |
and we're going to implement that, which is much more efficient, as we'll see in a bit. 00:11:09.840 |
Then we're going to do the exact same thing for batch normalization. So instead of breaking up 00:11:14.880 |
batch normalization into all the little tiny components, we're going to use pen and paper and 00:11:19.760 |
mathematics and calculus to derive the gradient through the batch normal layer. So we're going to 00:11:25.920 |
calculate the backward pass through batch normal layer in a much more efficient expression, 00:11:29.680 |
instead of backward propagating through all of its little pieces independently. 00:11:32.720 |
So that's going to be exercise three. And then in exercise four, we're going to put it all together. 00:11:38.800 |
And this is the full code of training this two-layer MLP. And we're going to basically 00:11:43.680 |
insert our manual backprop, and we're going to take out lost at backward. And you will basically 00:11:48.720 |
see that you can get all the same results using fully your own code. And the only thing we're 00:11:55.440 |
using from PyTorch is the torch.tensor to make the calculations efficient. But otherwise, you 00:12:01.360 |
will understand fully what it means to forward and backward the neural net and train it. And I think 00:12:05.680 |
that'll be awesome. So let's get to it. Okay, so I ran all the cells of this notebook all the way up 00:12:11.120 |
to here. And I'm going to erase this. And I'm going to start implementing backward pass, starting 00:12:16.240 |
with dlogprops. So we want to understand what should go here to calculate the gradient of the 00:12:21.360 |
loss with respect to all the elements of the logprops tensor. Now, I'm going to give away the 00:12:26.560 |
answer here. But I wanted to put a quick note here that I think will be most pedagogically useful 00:12:31.120 |
for you is to actually go into the description of this video and find the link to this Jupyter 00:12:36.400 |
notebook. You can find it both on GitHub, but you can also find Google Colab with it. So you don't 00:12:40.480 |
have to install anything, you'll just go to a website on Google Colab. And you can try to 00:12:44.640 |
implement these derivatives or gradients yourself. And then if you are not able to come to my video 00:12:50.800 |
and see me do it, and so work in tandem and try it first yourself and then see me give away the 00:12:56.640 |
answer. And I think that'll be most valuable to you. And that's how I recommend you go through 00:13:00.080 |
this lecture. So we are starting here with dlogprops. Now, dlogprops will hold the derivative 00:13:07.120 |
of the loss with respect to all the elements of logprops. What is inside logprops? The shape of 00:13:13.360 |
this is 32 by 27. So it's not going to surprise you that dlogprops should also be an array of 00:13:20.400 |
size 32 by 27, because we want the derivative of the loss with respect to all of its elements. 00:13:25.200 |
So the sizes of those are always going to be equal. Now, how does logprops influence the loss? 00:13:33.040 |
Loss is negative logprops indexed with range of n and yb and then the mean of that. Now, 00:13:41.920 |
just as a reminder, yb is just basically an array of all the correct indices. 00:13:49.760 |
So what we're doing here is we're taking the logprops array of size 32 by 27. 00:13:55.120 |
And then we are going in every single row. And in each row, we are plucking out the index 8 and 00:14:05.680 |
then 14 and 15 and so on. So we're going down the rows. That's the iterator range of n. And 00:14:10.720 |
then we are always plucking out the index at the column specified by this tensor yb. So in the 00:14:16.560 |
zeroth row, we are taking the eighth column. In the first row, we're taking the 14th column, etc. 00:14:22.400 |
And so logprops at this plucks out all those log probabilities of the correct next character in a 00:14:31.200 |
sequence. So that's what that does. And the shape of this, or the size of it, is of course 32, 00:14:37.040 |
because our batch size is 32. So these elements get plucked out, and then their mean and the 00:14:44.720 |
negative of that becomes loss. So I always like to work with simpler examples to understand the 00:14:51.680 |
numerical form of the derivative. What's going on here is once we've plucked out these examples, 00:14:57.040 |
we're taking the mean and then the negative. So the loss basically, 00:15:03.120 |
I can write it this way, is the negative of say a plus b plus c, and the mean of those three 00:15:08.960 |
numbers would be say negative, would divide three. That would be how we achieve the mean of three 00:15:13.360 |
numbers a, b, c, although we actually have 32 numbers here. And so what is basically the loss 00:15:19.440 |
by say like da, right? Well, if we simplify this expression mathematically, this is negative 1 over 00:15:26.080 |
3 of a plus negative 1 over 3 of b plus negative 1 over 3 of c. And so what is the loss by da? 00:15:34.720 |
It's just negative 1 over 3. And so you can see that if we don't just have a, b, and c, 00:15:39.280 |
but we have 32 numbers, then d loss by d, you know, every one of those numbers is going to be 00:15:45.600 |
1 over n more generally, because n is the size of the batch, 32 in this case. So d loss by 00:15:55.040 |
d logprobs is negative 1 over n in all these places. Now, what about the other elements 00:16:03.680 |
inside logprobs? Because logprobs is a large array. You see that logprobs.sh is 32 by 27, 00:16:09.360 |
but only 32 of them participate in the loss calculation. So what's the derivative of all 00:16:15.600 |
the other, most of the elements that do not get plucked out here? Well, their loss intuitively 00:16:21.440 |
is zero. Sorry, their gradient intuitively is zero. And that's because they do not participate 00:16:26.160 |
in the loss. So most of these numbers inside this tensor does not feed into the loss. And so if we 00:16:32.320 |
were to change these numbers, then the loss doesn't change, which is the equivalent of us saying that 00:16:38.320 |
the derivative of the loss with respect to them is zero. They don't impact it. 00:16:41.600 |
So here's a way to implement this derivative then. We start out with torsdat zeros of shape 32 by 27, 00:16:50.160 |
or let's just say, instead of doing this, because we don't want to hard code numbers, 00:16:54.000 |
let's do torsdat zeros like logprobs. So basically, this is going to create an array of zeros exactly 00:17:00.320 |
in the shape of logprobs. And then we need to set the derivative of negative 1 over n 00:17:06.160 |
inside exactly these locations. So here's what we can do. The logprobs indexed in the identical way 00:17:14.240 |
will be just set to negative 1 over 0, divide n. Right, just like we derived here. 00:17:21.280 |
So now let me erase all of this reasoning. And then this is the candidate derivative 00:17:28.240 |
for dlogprobs. Let's uncomment the first line and check that this is correct. 00:17:32.240 |
Okay, so CMP ran. And let's go back to CMP. And you see that what it's doing is it's calculating if 00:17:42.880 |
the calculated value by us, which is dt, is exactly equal to t.grad as calculated by PyTorch. 00:17:48.800 |
And then this is making sure that all of the elements are exactly equal, and then converting 00:17:54.400 |
this to a single Boolean value, because we don't want a Boolean tensor, we just want a Boolean 00:17:58.720 |
value. And then here, we are making sure that, okay, if they're not exactly equal, maybe they 00:18:04.320 |
are approximately equal because of some floating point issues, but they're very, very close. 00:18:08.960 |
So here we are using torch.all_close, which has a little bit of a wiggle available, because 00:18:14.160 |
sometimes you can get very, very close. But if you use a slightly different calculation, because 00:18:19.040 |
of floating point arithmetic, you can get a slightly different result. So this is checking 00:18:24.480 |
if you get an approximately close result. And then here, we are checking the maximum, 00:18:28.880 |
basically the value that has the highest difference, and what is the difference, 00:18:34.480 |
and the absolute value difference between those two. And so we are printing whether we have an 00:18:38.800 |
exact equality, an approximate equality, and what is the largest difference. And so here, 00:18:45.600 |
we see that we actually have exact equality. And so therefore, of course, we also have an 00:18:50.480 |
approximate equality, and the maximum difference is exactly zero. So basically, our DLOG_PROPS 00:18:56.880 |
is exactly equal to what PyTorch calculated to be log_props.grad in its backpropagation. 00:19:02.960 |
So, so far, we're doing pretty well. Okay, so let's now continue our backpropagation. 00:19:08.640 |
We have that log_props depends on probs through a log. So all the elements of probs are being 00:19:14.080 |
element-wise applied log_to. Now, if we want DPROPS, then, then remember your micrograd training. 00:19:21.280 |
We have like a log node, it takes in probs and creates log_props. And DPROPS will be the local 00:19:28.880 |
derivative of that individual operation, log, times the derivative of the loss with respect 00:19:34.240 |
to its output, which in this case is DLOG_PROPS. So what is the local derivative of this operation? 00:19:40.160 |
Well, we are taking log element-wise, and we can come here and we can see, well, from alpha 00:19:44.400 |
is your friend, that d by dx of log of x is just simply one over x. So therefore, in this case, 00:19:50.640 |
x is probs. So we have d by dx is one over x, which is one over probs, and then this is the 00:19:56.720 |
local derivative, and then times we want to chain it. So this is chain rule, times DLOG_PROPS. 00:20:03.440 |
Then let me uncomment this and let me run the cell in place. And we see that the derivative 00:20:08.880 |
of probs as we calculated here is exactly correct. And so notice here how this works. 00:20:14.160 |
Probs that are, probs is going to be inverted and then element-wise multiplied here. 00:20:19.840 |
So if your probs is very, very close to one, that means you are, your network is currently predicting 00:20:25.520 |
the character correctly, then this will become one over one, and DLOG_PROPS just gets passed through. 00:20:31.520 |
But if your probabilities are incorrectly assigned, so if the correct character here 00:20:36.320 |
is getting a very low probability, then 1.0 dividing by it will boost this, 00:20:42.240 |
and then multiply by DLOG_PROPS. So basically what this line is doing intuitively is it's taking 00:20:48.320 |
the examples that have a very low probability currently assigned, and it's boosting their 00:20:52.720 |
gradient. You can look at it that way. Next up is COUNTSUM_INV. So we want 00:21:00.640 |
the derivative of this. Now let me just pause here and kind of introduce what's happening here 00:21:05.920 |
in general, because I know it's a little bit confusing. We have the logits that come out of 00:21:09.440 |
the neural net. Here what I'm doing is I'm finding the maximum in each row, and I'm subtracting it 00:21:15.440 |
for the purpose of numerical stability. And we talked about how if you do not do this, 00:21:19.840 |
you run into numerical issues if some of the logits take on too large values, 00:21:23.680 |
because we end up exponentiating them. So this is done just for safety, numerically. 00:21:29.600 |
Then here's the exponentiation of all the logits to create our counts. And then we want to take 00:21:36.640 |
the sum of these counts and normalize so that all of the probs sum to 1. Now here, instead of using 00:21:42.640 |
1 over COUNTSUM, I use raised to the power of negative 1. Mathematically, they are identical. 00:21:48.080 |
I just found that there's something wrong with the PyTorch implementation of the backward pass 00:21:51.520 |
of division, and it gives a weird result. But that doesn't happen for **-1, so I'm using this 00:21:59.120 |
formula instead. But basically, all that's happening here is we got the logits, we want 00:22:04.080 |
to exponentiate all of them, and we want to normalize the counts to create our probabilities. 00:22:08.880 |
It's just that it's happening across multiple lines. So now, here, we want to first take the 00:22:19.360 |
derivative, we want to backpropagate into COUNTSUM_INF and then into COUNTS as well. 00:22:24.320 |
So what should be the COUNTSUM_INF? Now, we actually have to be careful here, because 00:22:29.840 |
we have to scrutinize and be careful with the shapes. So COUNTS.shape and then COUNTSUM_INF.shape 00:22:37.040 |
are different. So in particular, COUNTS is 32 by 27, but this COUNTSUM_INF is 32 by 1. 00:22:45.360 |
And so in this multiplication here, we also have an implicit broadcasting that PyTorch will do, 00:22:52.240 |
because it needs to take this column tensor of 32 numbers and replicate it horizontally 27 times 00:22:57.280 |
to align these two tensors so it can do an element-wise multiply. 00:23:00.720 |
So really what this looks like is the following, using a toy example again. 00:23:05.040 |
What we really have here is just props is COUNTS times COUNTSUM_INF, so it's C equals A times B, 00:23:10.960 |
but A is 3 by 3 and B is just 3 by 1, a column tensor. And so PyTorch internally replicated 00:23:18.320 |
this elements of B, and it did that across all the columns. So for example, B1, which is the 00:23:24.160 |
first element of B, would be replicated here across all the columns in this multiplication. 00:23:28.240 |
And now we're trying to backpropagate through this operation to COUNTSUM_INF. 00:23:33.040 |
So when we are calculating this derivative, it's important to realize that this looks like a single 00:23:40.880 |
operation, but actually is two operations applied sequentially. The first operation that PyTorch 00:23:46.480 |
did is it took this column tensor and replicated it across all the columns, basically 27 times. 00:23:54.560 |
So that's the first operation, it's a replication. And then the second operation is the 00:23:58.160 |
multiplication. So let's first backpropagate through the multiplication. If these two arrays 00:24:04.480 |
were of the same size and we just have A and B, both of them 3 by 3, then how do we backpropagate 00:24:11.600 |
through a multiplication? So if we just have scalars and not tensors, then if you have C 00:24:16.080 |
equals A times B, then what is the derivative of C with respect to B? Well, it's just A. 00:24:22.000 |
So that's the local derivative. So here in our case, undoing the multiplication and 00:24:28.240 |
backpropagating through just the multiplication itself, which is element-wise, is going to be 00:24:32.640 |
the local derivative, which in this case is simply COUNTS, because COUNTS is the A. 00:24:40.160 |
So it's the local derivative, and then TIMES, because the chain rule, DPROPS. 00:24:44.240 |
So this here is the derivative, or the gradient, but with respect to replicated B. 00:24:50.800 |
But we don't have a replicated B, we just have a single B column. So how do we now backpropagate 00:24:57.120 |
through the replication? And intuitively, this B1 is the same variable, and it's just reused 00:25:03.440 |
multiple times. And so you can look at it as being equivalent to a case we've encountered 00:25:09.440 |
in micrograd. And so here, I'm just pulling out a random graph we used in micrograd. 00:25:14.320 |
We had an example where a single node has its output feeding into two branches of basically 00:25:21.440 |
the graph until the loss function. And we're talking about how the correct thing to do in 00:25:25.520 |
the backward pass is we need to sum all the gradients that arrive at any one node. So 00:25:30.960 |
across these different branches, the gradients would sum. So if a node is used multiple times, 00:25:36.880 |
the gradients for all of its uses sum during backpropagation. So here, B1 is used multiple 00:25:43.120 |
times in all these columns, and therefore the right thing to do here is to sum horizontally 00:25:49.280 |
across all the rows. So we want to sum in dimension 1, but we want to retain this dimension 00:25:57.200 |
so that countSumInv and its gradient are going to be exactly the same shape. 00:26:01.600 |
So we want to make sure that we keep them as true so we don't lose this dimension. 00:26:05.840 |
And this will make the countSumInv be exactly shape 32 by 1. So revealing this comparison as 00:26:14.080 |
well and running this, we see that we get an exact match. So this derivative is exactly correct. 00:26:22.080 |
And let me erase this. Now let's also backpropagate into counts, which is the other 00:26:28.720 |
variable here to create props. So from props to countSumInv, we just did that. 00:26:33.200 |
Let's go into counts as well. So dcounts will be... 00:26:37.120 |
dcounts is our A, so dc by dA is just B. So therefore it's countSumInv. 00:26:47.520 |
And then times chain rule dprops. Now countSumInv is 32 by 1, dprops is 32 by 27. 00:26:56.720 |
So those will broadcast fine and will give us dcounts. There's no additional summation 00:27:04.960 |
required here. There will be a broadcasting that happens in this multiply here because 00:27:11.040 |
countSumInv needs to be replicated again to correctly multiply dprops. But that's going 00:27:16.560 |
to give the correct result as far as this single operation is concerned. 00:27:20.960 |
So we've backpropagated from props to counts, but we can't actually check the derivative of counts. 00:27:27.920 |
I have it much later on. And the reason for that is because countSumInv depends on counts. 00:27:34.560 |
And so there's a second branch here that we have to finish because countSumInv backpropagates into 00:27:39.040 |
countSum and countSum will backpropagate into counts. And so counts is a node that is being 00:27:44.320 |
used twice. It's used right here in two props and it goes through this other branch through 00:27:48.800 |
countSumInv. So even though we've calculated the first contribution of it, we still have 00:27:53.840 |
to calculate the second contribution of it later. Okay, so we're continuing with this branch. 00:27:59.120 |
We have the derivative for countSumInv. Now we want the derivative of countSum. 00:28:02.640 |
So dcountSum equals, what is the local derivative of this operation? So this is basically an 00:28:08.560 |
element-wise one over countsSum. So countSum raised to the power of negative one is the same 00:28:14.480 |
as one over countsSum. If we go to WolframAlpha, we see that x to the negative one, d by dx of it, 00:28:21.360 |
is basically negative x to the negative two. Negative one over x squared is the same as 00:28:26.720 |
negative x to the negative two. So dcountSum here will be, local derivative is going to be negative 00:28:36.240 |
countsSum to the negative two, that's the local derivative, times chain rule, which is dcountSumInv. 00:28:44.320 |
So that's dcountSum. Let's uncomment this and check that I am correct. Okay, so we have perfect 00:28:53.200 |
equality. And there's no sketchiness going on here with any shapes because these are of the 00:28:59.440 |
same shape. Okay, next up we want to back propagate through this line. We have that 00:29:03.920 |
countsSum is counts.sum along the rows. So I wrote out some help here. We have to keep in 00:29:11.440 |
mind that counts, of course, is 32 by 27, and countsSum is 32 by one. So in this back propagation, 00:29:17.680 |
we need to take this column of derivatives and transform it into an array of derivatives, 00:29:24.640 |
two-dimensional array. So what is this operation doing? We're taking some kind of an input, 00:29:30.400 |
like say a three-by-three matrix A, and we are summing up the rows into a column tensor B, 00:29:36.240 |
B1, B2, B3, that is basically this. So now we have the derivatives of the loss with respect to 00:29:42.720 |
B, all the elements of B. And now we want to derive the loss with respect to all these little 00:29:48.400 |
As. So how do the Bs depend on the As is basically what we're after. What is the local derivative of 00:29:54.800 |
this operation? Well, we can see here that B1 only depends on these elements here. The derivative of 00:30:01.840 |
B1 with respect to all of these elements down here is zero. But for these elements here, like 00:30:07.040 |
A11, A12, etc., the local derivative is one, right? So DB1 by DA11, for example, is one. So it's one, 00:30:16.000 |
one, and one. So when we have the derivative of loss with respect to B1, the local derivative of 00:30:22.720 |
B1 with respect to these inputs is zeroes here, but it's one on these guys. So in the chain rule, 00:30:28.880 |
we have the local derivative times the derivative of B1. And so because the local derivative is one 00:30:37.360 |
on these three elements, the local derivative multiplying the derivative of B1 will just be 00:30:42.400 |
the derivative of B1. And so you can look at it as a router. Basically, an addition is a router 00:30:49.440 |
of gradient. Whatever gradient comes from above, it just gets routed equally to all the elements 00:30:53.840 |
that participate in that addition. So in this case, the derivative of B1 will just flow equally to 00:31:00.080 |
the derivative of A11, A12, and A13. So if we have a derivative of all the elements of B in this 00:31:06.800 |
column tensor, which is D counts sum that we've calculated just now, we basically see that what 00:31:13.440 |
that amounts to is all of these are now flowing to all these elements of A, and they're doing that 00:31:19.600 |
horizontally. So basically what we want is we want to take the D counts sum of size 32 by 1, 00:31:25.600 |
and we just want to replicate it 27 times horizontally to create 32 by 27 array. So 00:31:32.320 |
there's many ways to implement this operation. You could, of course, just replicate the tensor, 00:31:36.400 |
but I think maybe one clean one is that D counts is simply torch.once like, so just a 00:31:44.560 |
two-dimensional arrays of ones in the shape of counts, so 32 by 27, times D counts sum. 00:31:52.080 |
So this way we're letting the broadcasting here basically implement the replication. You can look 00:31:57.840 |
at it that way. But then we have to also be careful because D counts was all already calculated. 00:32:05.040 |
We calculated earlier here, and that was just the first branch, and we're now finishing the second 00:32:10.000 |
branch. So we need to make sure that these gradients add, so plus equals. And then here, 00:32:15.280 |
let's comment out the comparison, and let's make sure, crossing fingers, that we have the 00:32:23.360 |
correct result. So PyTorch agrees with us on this gradient as well. Okay, hopefully we're getting 00:32:28.880 |
a hang of this now. Counts is an element-wise exp of normlogits. So now we want denormlogits. 00:32:35.440 |
And because it's an element-wise operation, everything is very simple. What is the local 00:32:40.080 |
derivative of e to the x? It's famously just e to the x. So this is the local derivative. 00:32:45.760 |
That is the local derivative. Now we already calculated it, and it's inside counts, 00:32:52.240 |
so we might as well potentially just reuse counts. That is the local derivative, times D counts. 00:33:01.840 |
Funny as that looks. Counts times D counts is the derivative on the normlogits. 00:33:05.520 |
And now let's erase this, and let's verify, and it looks good. 00:33:10.160 |
So that's normlogits. Okay, so we are here on this line now, denormlogits. We have that, 00:33:19.600 |
and we're trying to calculate D logits and D logit maxes, so backpropagating through this line. 00:33:25.280 |
Now we have to be careful here because the shapes, again, are not the same, and so there's an 00:33:29.440 |
implicit broadcasting happening here. So normlogits has the shape of 32 by 27. Logits does as well, 00:33:36.240 |
but logit maxes is only 32 by 1. So there's a broadcasting here in the minus. Now here I tried 00:33:43.920 |
to sort of write out a toy example again. We basically have that this is our C equals A minus 00:33:49.680 |
B, and we see that because of the shape, these are 3 by 3, but this one is just a column. 00:33:55.120 |
And so for example, every element of C, we have to look at how it came to be. And every element of 00:34:00.880 |
C is just the corresponding element of A minus basically that associated B. So it's very clear 00:34:09.520 |
now that the derivatives of every one of these Cs with respect to their inputs are 1 for the 00:34:16.560 |
corresponding A, and it's a negative 1 for the corresponding B. And so therefore, the derivatives 00:34:26.080 |
on the C will flow equally to the corresponding As and then also to the corresponding Bs, 00:34:32.720 |
but then in addition to that, the Bs are broadcast, so we'll have to do the additional sum 00:34:37.280 |
just like we did before. And of course, the derivatives for Bs will undergo A minus because 00:34:42.960 |
the local derivative here is negative 1. So dC32 by dB3 is negative 1. So let's just implement that. 00:34:51.760 |
Basically, dLogits will be exactly copying the derivative on normLogits. So dLogits equals 00:35:01.840 |
dNormLogits, and I'll do a dot clone for safety, so we're just making a copy. 00:35:06.720 |
And then we have that dLogitmaxis will be the negative of dNormLogits because of the negative 00:35:14.560 |
sign. And then we have to be careful because Logitmaxis is a column. And so just like we saw 00:35:22.320 |
before, because we keep replicating the same elements across all the columns, then in the 00:35:29.280 |
backward pass, because we keep reusing this, these are all just like separate branches of use of that 00:35:35.120 |
one variable. And so therefore, we have to do a sum along one, we'd keep them equals true, 00:35:40.320 |
so that we don't destroy this dimension. And then dLogitmaxis will be the same shape. 00:35:45.680 |
Now, we have to be careful because this dLogits is not the final dLogits, and that's because 00:35:51.040 |
not only do we get gradient signal into Logits through here, but Logitmaxis is a function of 00:35:57.200 |
Logits, and that's a second branch into Logits. So this is not yet our final derivative for Logits, 00:36:02.720 |
we will come back later for the second branch. For now, dLogitmaxis is the final derivative. 00:36:08.080 |
So let me uncomment this CMP here, and let's just run this. And Logitmaxis, if PyTorch agrees with 00:36:15.600 |
us. So that was the derivative into through this line. Now, before we move on, I want to pause here 00:36:22.960 |
briefly, and I want to look at these Logitmaxis and especially their gradients. We've talked 00:36:27.680 |
previously in the previous lecture, that the only reason we're doing this is for the numerical 00:36:32.160 |
stability of the softmax that we are implementing here. And we talked about how if you take these 00:36:37.520 |
Logits for any one of these examples, so one row of this Logits tensor, if you add or subtract any 00:36:43.680 |
value equally to all the elements, then the value of the probs will be unchanged. You're not changing 00:36:49.760 |
the softmax. The only thing that this is doing is it's making sure that exp doesn't overflow. 00:36:54.480 |
And the reason we're using a max is because then we are guaranteed that each row of Logits, 00:36:59.200 |
the highest number, is zero. And so this will be safe. And so basically what that has repercussions. 00:37:08.560 |
If it is the case that changing Logitmaxis does not change the probs, and therefore does not 00:37:14.480 |
change the loss, then the gradient on Logitmaxis should be zero. Because saying those two things 00:37:20.160 |
is the same. So indeed, we hope that this is very, very small numbers. Indeed, we hope this is zero. 00:37:26.000 |
Now, because of floating point sort of wonkiness, this doesn't come out exactly zero. Only in some 00:37:31.840 |
of the rows it does. But we get extremely small values, like 1, e, -9, or 10. And so this is 00:37:37.440 |
telling us that the values of Logitmaxis are not impacting the loss, as they shouldn't. 00:37:42.240 |
It feels kind of weird to backpropagate through this branch, honestly, because 00:37:46.640 |
if you have any implementation of like f.crossentropy in PyTorch, and you block together 00:37:53.200 |
all of these elements, and you're not doing the backpropagation piece by piece, 00:37:56.640 |
then you would probably assume that the derivative through here is exactly zero. 00:38:00.560 |
So you would be sort of skipping this branch, because it's only done for numerical stability. 00:38:09.200 |
But it's interesting to see that even if you break up everything into the full atoms, 00:38:12.960 |
and you still do the computation as you'd like with respect to numerical stability, 00:38:16.800 |
the correct thing happens. And you still get very, very small gradients here, 00:38:21.680 |
basically reflecting the fact that the values of these do not matter with respect to the final loss. 00:38:26.960 |
Okay, so let's now continue backpropagation through this line here. We've just calculated 00:38:31.760 |
the Logitmaxis, and now we want to backprop into Logits through this second branch. Now here, 00:38:37.040 |
of course, we took Logits, and we took the max along all the rows, and then we looked at its 00:38:41.920 |
values here. Now the way this works is that in PyTorch, this thing here, the max returns both 00:38:50.800 |
the values, and it returns the indices at which those values to count the maximum value. Now, 00:38:56.320 |
in the forward pass, we only used values, because that's all we needed. But in the backward pass, 00:39:00.720 |
it's extremely useful to know about where those maximum values occurred. And we have the indices 00:39:06.560 |
at which they occurred. And this will, of course, help us do the backpropagation. Because what 00:39:11.600 |
should the backward pass be here in this case? We have the Logis tensor, which is 32 by 27, 00:39:16.800 |
and in each row, we find the maximum value, and then that value gets plucked out into Logitmaxis. 00:39:21.840 |
And so intuitively, basically, the derivative flowing through here then should be 1 times 00:39:32.320 |
the local derivative is 1 for the appropriate entry that was plucked out, and then times the 00:39:38.400 |
global derivative of the Logitmaxis. So really what we're doing here, if you think through it, 00:39:42.960 |
is we need to take the DLogitmaxis, and we need to scatter it to the correct positions 00:39:48.320 |
in these Logits from where the maximum values came. And so I came up with one line of code 00:39:57.920 |
that does that. Let me just erase a bunch of stuff here. So the line of-- you could do it 00:40:02.560 |
very similar to what we've done here, where we create a zeros, and then we populate the correct 00:40:07.680 |
elements. So we use the indices here, and we would set them to be 1. But you can also use one-hot. 00:40:14.400 |
So f.one-hot, and then I'm taking the Logits.max over the first dimension, dot indices, and I'm 00:40:22.400 |
telling PyTorch that the dimension of every one of these tensors should be 27. And so what this 00:40:31.600 |
is going to do is-- okay, I apologize, this is crazy. plt.imshow of this. It's really just an 00:40:40.640 |
array of where the maxis came from in each row, and that element is 1, and all the other elements 00:40:46.320 |
are 0. So it's a one-hot vector in each row, and these indices are now populating a single 1 in 00:40:52.640 |
the proper place. And then what I'm doing here is I'm multiplying by the Logitmaxis. And keep in 00:40:58.480 |
mind that this is a column of 32 by 1. And so when I'm doing this times the Logitmaxis, the Logitmaxis 00:41:08.240 |
will broadcast, and that column will get replicated, and then an element-wise multiply will ensure 00:41:14.320 |
that each of these just gets routed to whichever one of these bits is turned on. And so that's 00:41:19.760 |
another way to implement this kind of an operation. And both of these can be used. I just thought I 00:41:27.200 |
would show an equivalent way to do it. And I'm using += because we already calculated the logits 00:41:32.240 |
here, and this is now the second branch. So let's look at logits and make sure that this is correct. 00:41:39.600 |
And we see that we have exactly the correct answer. Next up, we want to continue with logits 00:41:46.720 |
here. That is an outcome of a matrix multiplication and a bias offset in this linear layer. So I've 00:41:54.880 |
printed out the shapes of all these intermediate tensors. We see that logits is of course 32 by 27, 00:42:00.320 |
as we've just seen. Then the h here is 32 by 64. So these are 64-dimensional hidden states. 00:42:07.120 |
And then this w matrix projects those 64-dimensional vectors into 27 dimensions. 00:42:13.040 |
And then there's a 27-dimensional offset, which is a one-dimensional vector. Now we should note 00:42:19.120 |
that this plus here actually broadcasts, because h multiplied by w2 will give us a 32 by 27. 00:42:26.320 |
And so then this plus b2 is a 27-dimensional vector here. Now in the rules of broadcasting, 00:42:33.520 |
what's going to happen with this bias vector is that this one-dimensional vector of 27 00:42:38.240 |
will get aligned with a padded dimension of 1 on the left. And it will basically become a row 00:42:44.160 |
vector. And then it will get replicated vertically 32 times to make it 32 by 27. And then there's an 00:42:50.400 |
element-wise multiply. Now the question is, how do we back propagate from logits to the hidden 00:42:58.320 |
states, the weight matrix w2, and the bias b2? And you might think that we need to go to some 00:43:04.400 |
matrix calculus, and then we have to look up the derivative for a matrix multiplication. 00:43:10.640 |
But actually, you don't have to do any of that. And you can go back to first principles and derive 00:43:14.240 |
this yourself on a piece of paper. And specifically what I like to do, and what I find works well for 00:43:19.680 |
me, is you find a specific small example that you then fully write out. And then in the process of 00:43:25.360 |
analyzing how that individual small example works, you will understand the broader pattern. 00:43:29.760 |
And you'll be able to generalize and write out the full general formula for how these derivatives 00:43:35.760 |
flow in an expression like this. So let's try that out. So pardon the low budget production here, 00:43:41.120 |
but what I've done here is I'm writing it out on a piece of paper. Really what we are interested 00:43:45.600 |
in is we have a multiply b plus c, and that creates a d. And we have the derivative of the 00:43:52.800 |
loss with respect to d, and we'd like to know what the derivative of the loss is with respect to a, 00:43:56.480 |
b, and c. Now these here are little two-dimensional examples of a matrix 00:44:01.200 |
multiplication. 2 by 2 times a 2 by 2 plus a 2, a vector of just two elements, c1 and c2, 00:44:08.960 |
gives me a 2 by 2. Now notice here that I have a bias vector here called c, and the bias vector 00:44:16.880 |
is c1 and c2. But as I described over here, that bias vector will become a row vector in the 00:44:21.920 |
broadcasting and will replicate vertically. So that's what's happening here as well. c1, c2 00:44:26.880 |
is replicated vertically, and we see how we have two rows of c1, c2 as a result. 00:44:31.760 |
So now when I say write it out, I just mean like this. Basically break up this matrix 00:44:37.920 |
multiplication into the actual thing that's going on under the hood. So as a result of matrix 00:44:43.920 |
multiplication and how it works, d11 is the result of a dot product between the first row of a and 00:44:49.600 |
the first column of b. So a11, b11 plus a12, b21 plus c1, and so on and so forth for all the other 00:44:59.840 |
elements of d. And once you actually write it out, it becomes obvious this is just a bunch of 00:45:04.320 |
multiplies and adds. And we know from micrograd how to differentiate multiplies and adds. And so 00:45:11.440 |
this is not scary anymore. It's not just matrix multiplication. It's just tedious, unfortunately, 00:45:16.640 |
but this is completely tractable. We have dl by d for all of these, and we want dl by all these 00:45:23.280 |
little other variables. So how do we achieve that and how do we actually get the gradients? 00:45:27.280 |
Okay, so the low budget production continues here. So let's, for example, derive the derivative of 00:45:33.040 |
the loss with respect to a11. We see here that a11 occurs twice in our simple expression, 00:45:39.440 |
right here, right here, and influences d11 and d12. So what is dl by d a11? Well, it's dl by d11 00:45:49.440 |
times the local derivative of d11, which in this case is just b11, because that's what's multiplying 00:45:56.080 |
a11 here. And likewise here, the local derivative of d12 with respect to a11 is just b12. And so 00:46:04.240 |
b12 will, in the chain rule, therefore, multiply dl by d12. And then because a11 is used both to 00:46:11.440 |
produce d11 and d12, we need to add up the contributions of both of those sort of chains 00:46:18.720 |
that are running in parallel. And that's why we get a plus, just adding up those two contributions. 00:46:25.680 |
And that gives us dl by d a11. We can do the exact same analysis for the other one, 00:46:31.280 |
for all the other elements of A. And when you simply write it out, it's just super simple 00:46:36.240 |
taking of gradients on expressions like this. You find that this matrix dl by da that we're after, 00:46:47.360 |
right, if we just arrange all of them in the same shape as A takes, so A is just a 2x2 matrix, 00:46:53.680 |
so dl by da here will be also just the same shape tensor with the derivatives now, so dl by da11, 00:47:04.000 |
etc. And we see that actually we can express what we've written out here as a matrix multiply. 00:47:09.920 |
And so it just so happens that all of these formulas that we've derived here by taking 00:47:16.400 |
gradients can actually be expressed as a matrix multiplication. And in particular, 00:47:20.720 |
we see that it is the matrix multiplication of these two matrices. So it is the dl by d 00:47:28.480 |
and then matrix multiplying B, but B transpose actually. So you see that B21 and B12 have 00:47:36.640 |
changed place, whereas before we had, of course, B11, B12, B21, B22. So you see that this other 00:47:44.880 |
matrix B is transposed. And so basically what we have, long story short, just by doing very simple 00:47:50.800 |
reasoning here, by breaking up the expression in the case of a very simple example, is that dl by 00:47:56.480 |
da is, which is this, is simply equal to dl by dd matrix multiplied with B transpose. 00:48:03.520 |
So that is what we have so far. Now we also want the derivative with respect to 00:48:10.800 |
B and C. Now for B, I'm not actually doing the full derivation because honestly, it's not deep. 00:48:18.800 |
It's just annoying. It's exhausting. You can actually do this analysis yourself. You'll 00:48:23.680 |
also find that if you take these expressions and you differentiate with respect to B instead of A, 00:48:28.240 |
you will find that dl by db is also a matrix multiplication. In this case, you have to take 00:48:33.760 |
the matrix A and transpose it and matrix multiply that with dl by dd. And that's what gives you the 00:48:40.720 |
dl by db. And then here for the offsets C1 and C2, if you again just differentiate with respect to C1, 00:48:48.560 |
you will find an expression like this and C2, an expression like this. And basically you'll 00:48:55.840 |
find that dl by dc is simply, because they're just offsetting these expressions, you just have 00:49:01.360 |
to take the dl by dd matrix of the derivatives of d and you just have to sum across the columns. 00:49:10.480 |
And that gives you the derivatives for C. So long story short, the backward pass of a matrix 00:49:17.680 |
multiply is a matrix multiply. And instead of, just like we had d equals A times B plus C, 00:49:23.040 |
in a scalar case, we sort of like arrive at something very, very similar, but now 00:49:28.080 |
with a matrix multiplication instead of a scalar multiplication. So the derivative of 00:49:34.720 |
d with respect to A is dl by dd matrix multiply B transpose. And here it's A transpose multiply dl 00:49:44.320 |
by dd. But in both cases, it's a matrix multiplication with the derivative and the 00:49:49.760 |
other term in the multiplication. And for C, it is a sum. Now I'll tell you a secret. I can never 00:49:58.560 |
remember the formulas that we just derived for backpropagating from matrix multiplication, 00:50:02.560 |
and I can backpropagate through these expressions just fine. And the reason this works is because 00:50:07.040 |
the dimensions have to work out. So let me give you an example. Say I want to create dh. 00:50:12.640 |
Then what should dh be? Number one, I have to know that the shape of dh must be the same as 00:50:19.760 |
the shape of h. And the shape of h is 32 by 64. And then the other piece of information I know 00:50:25.760 |
is that dh must be some kind of matrix multiplication of d logits with w2. 00:50:32.560 |
And d logits is 32 by 27, and w2 is 64 by 27. There is only a single way to make the shape 00:50:40.880 |
work out in this case, and it is indeed the correct result. In particular here, h needs to 00:50:47.440 |
be 32 by 64. The only way to achieve that is to take a d logits and matrix multiply it with… 00:50:54.400 |
You see how I have to take w2, but I have to transpose it to make the dimensions work out. 00:51:00.240 |
So w2 transpose. And it is the only way to matrix multiply those two pieces to make the 00:51:05.680 |
shapes work out. And that turns out to be the correct formula. So if we come here, 00:51:09.520 |
we want dh, which is dA. And we see that dA is dL by dD matrix multiply B transpose. 00:51:17.120 |
So that is d logits multiply, and B is w2, so w2 transpose, which is exactly what we have here. 00:51:24.880 |
So there is no need to remember these formulas. Similarly, now if I want dw2, 00:51:30.960 |
well I know that it must be a matrix multiplication of d logits and h. And maybe there is a few 00:51:38.400 |
transpose… Like there is one transpose in there as well. And I do not know which way it is, 00:51:41.920 |
so I have to come to w2. And I see that its shape is 64 by 27, and that has to come from 00:51:48.480 |
some matrix multiplication of these two. And so to get a 64 by 27, I need to take h, I need to 00:51:57.680 |
transpose it, and then I need to matrix multiply it. So that will become 64 by 32. And then I need 00:52:04.080 |
to matrix multiply it with 32 by 27. And that is going to give me a 64 by 27. So I need to matrix 00:52:09.440 |
multiply this with d logits dot shape, just like that. That is the only way to make the dimensions 00:52:14.000 |
work out, and just use matrix multiplication. And if we come here, we see that that is exactly 00:52:19.680 |
what is here. So a transpose, a for us is h, multiplied with d logits. So that is w2. And then 00:52:28.480 |
db2 is just the vertical sum. And actually, in the same way, there is only one way to make the 00:52:37.120 |
shapes work out. I do not have to remember that it is a vertical sum along the 0th axis, because 00:52:41.840 |
that is the only way that this makes sense. Because b2 shape is 27, so in order to get a 00:52:47.600 |
d logits here, it is 32 by 27. So knowing that it is just sum over d logits in some direction, 00:52:56.480 |
that direction must be 0, because I need to eliminate this dimension. So it is this. 00:53:04.720 |
So this is kind of like the hacky way. Let me copy, paste, and delete that. 00:53:11.040 |
And let me swing over here. And this is our backward pass for the linear layer, 00:53:15.120 |
hopefully. So now let us uncomment these three. And we are checking that we 00:53:21.040 |
got all the three derivatives correct. And run. And we see that h, w2, and b2 are all exactly 00:53:30.720 |
correct. So we backpropagated through a linear layer. Now next up, we have derivative for the 00:53:38.960 |
h already. And we need to backpropagate through tanh into hpreact. So we want to derive dhpreact. 00:53:46.000 |
And here we have to backpropagate through a tanh. And we have already done this in micrograd. 00:53:51.120 |
And we remember that tanh is a very simple backward formula. Now unfortunately, 00:53:55.600 |
if I just put in d by dx of tanh of x into Boltram alpha, it lets us down. It tells us that it is a 00:54:00.960 |
hyperbolic secant function squared of x. It is not exactly helpful. But luckily, Google image 00:54:06.800 |
search does not let us down. And it gives us the simpler formula. And in particular, if you have 00:54:11.440 |
that a is equal to tanh of z, then da by dz, backpropagating through tanh, is just 1 minus a 00:54:18.160 |
squared. And take note that 1 minus a squared, a here is the output of the tanh, not the input to 00:54:24.800 |
the tanh, z. So the da by dz is here formulated in terms of the output of that tanh. And here also, 00:54:32.960 |
in Google image search, we have the full derivation if you want to actually take the 00:54:36.400 |
actual definition of tanh and work through the math to figure out 1 minus tanh squared of z. 00:54:41.440 |
So 1 minus a squared is the local derivative. In our case, that is 1 minus the output of tanh 00:54:50.720 |
squared, which here is h. So it's h squared. And that is the local derivative. And then times the 00:54:58.160 |
chain rule, dh. So that is going to be our candidate implementation. So if we come here 00:55:04.160 |
and then uncomment this, let's hope for the best. And we have the right answer. 00:55:11.280 |
Okay, next up, we have dhpreact. And we want to backpropagate into the gain, the bn_raw, 00:55:17.600 |
and the bn_bias. So here, this is the bash norm parameters, bn_gain and bn_bias inside the bash 00:55:23.120 |
norm that take the bn_raw that is exact unit Gaussian, and they scale it and shift it. And 00:55:29.520 |
these are the parameters of the bash norm. Now, here, we have a multiplication. But it's worth 00:55:34.880 |
noting that this multiply is very, very different from this matrix multiply here. Matrix multiply 00:55:39.920 |
are dot products between rows and columns of these matrices involved. This is an element-wise 00:55:45.200 |
multiply. So things are quite a bit simpler. Now, we do have to be careful with some of the 00:55:49.440 |
broadcasting happening in this line of code, though. So you see how bn_gain and bn_bias are 00:55:55.760 |
1 by 64, but dhpreact and bn_raw are 32 by 64. So we have to be careful with that and make sure that 00:56:04.080 |
all the shapes work out fine and that the broadcasting is correctly backpropagated. 00:56:07.600 |
So in particular, let's start with dbn_gain. So dbn_gain should be, and here, this is again, 00:56:16.240 |
element-wise multiply. And whenever we have a times b equals c, we saw that the local derivative 00:56:21.680 |
here is just, if this is a, the local derivative is just the b, the other one. So the local 00:56:26.880 |
derivative is just bn_raw and then times chain rule. So dhpreact. So this is the candidate 00:56:36.640 |
gradient. Now, again, we have to be careful because bn_gain is of size 1 by 64. But this here 00:56:45.280 |
would be 32 by 64. And so the correct thing to do in this case, of course, is that bn_gain, 00:56:53.120 |
here is a rule vector of 64 numbers, it gets replicated vertically in this operation. 00:56:57.840 |
And so therefore, the correct thing to do is to sum because it's being replicated. 00:57:02.960 |
And therefore, all the gradients in each of the rows that are now flowing backwards need to sum 00:57:09.040 |
up to that same tensor dbn_gain. So we have to sum across all the zero, all the examples, 00:57:16.640 |
basically, which is the direction in which this gets replicated. And now we have to be also 00:57:21.280 |
careful because bn_gain is of shape 1 by 64. So in fact, I need to keep them as true. Otherwise, 00:57:29.840 |
I would just get 64. Now, I don't actually really remember why the bn_gain and the bn_bias, 00:57:36.320 |
I made them be 1 by 64. But the biases b1 and b2, I just made them be one dimensional vectors, 00:57:45.280 |
they're not two dimensional tensors. So I can't recall exactly why I left the gain and the bias 00:57:51.520 |
as two dimensional. But it doesn't really matter as long as you are consistent and you're keeping 00:57:55.440 |
it the same. So in this case, we want to keep the dimension so that the tensor shapes work. 00:58:01.360 |
Next up, we have bn_raw. So dbn_raw will be bn_gain multiplying dh_preact. That's our chain 00:58:13.840 |
rule. Now, what about the dimensions of this? We have to be careful, right? So dh_preact is 32 by 00:58:22.960 |
64, bn_gain is 1 by 64. So it will just get replicated to create this multiplication, 00:58:30.880 |
which is the correct thing because in a forward pass, it also gets replicated in just the same 00:58:34.880 |
way. So in fact, we don't need the brackets here, we're done. And the shapes are already correct. 00:58:40.080 |
And finally, for the bias, very similar. This bias here is very, very similar to the bias we saw in 00:58:47.600 |
the linear layer. And we see that the gradients from h_preact will simply flow into the biases 00:58:53.280 |
and add up because these are just offsets. And so basically, we want this to be dh_preact, 00:59:00.000 |
but it needs to sum along the right dimension. And in this case, similar to the gain, 00:59:04.960 |
we need to sum across the zeroth dimension, the examples, because of the way that the 00:59:09.280 |
bias gets replicated vertically. And we also want to have keep_them as true. And so this will 00:59:15.920 |
basically take this and sum it up and give us a 1 by 64. So this is the candidate implementation 00:59:22.880 |
and makes all the shapes work. Let me bring it up down here. And then let me uncomment these three 00:59:30.240 |
lines to check that we are getting the correct result for all the three tensors. And indeed, 00:59:36.400 |
we see that all of that got backpropagated correctly. So now we get to the batch norm layer. 00:59:41.040 |
We see how here bn_gain and bn_bias are the parameters, so the backpropagation ends. 00:59:46.640 |
But bn_raw now is the output of the standardization. So here, what I'm doing, of course, 00:59:52.960 |
is I'm breaking up the batch norm into manageable pieces so we can backpropagate through each line 00:59:57.040 |
individually. But basically, what's happening is bn_mean_i is the sum. So this is the bn_mean_i. 01:00:06.160 |
I apologize for the variable naming. bn_diff is x minus mu. bn_diff_2 is x minus mu squared 01:00:14.560 |
here inside the variance. bn_var is the variance, so sigma squared. This is bn_var. 01:00:21.520 |
And it's basically the sum of squares. So this is the x minus mu squared and then the sum. 01:00:28.800 |
Now, you'll notice one departure here. Here, it is normalized as 1 over m, 01:00:34.160 |
which is the number of examples. Here, I am normalizing as 1 over n minus 1 instead of m. 01:00:40.720 |
And this is deliberate, and I'll come back to that in a bit when we are at this line. 01:00:44.640 |
It is something called the Bessel's correction, but this is how I want it in our case. 01:00:49.600 |
bn_var_inv then becomes basically bn_var plus epsilon. Epsilon is 1, negative 5. 01:00:57.120 |
And then its 1 over square root is the same as raising to the power of negative 0.5, 01:01:03.680 |
right? Because 0.5 is square root. And then negative makes it 1 over square root. 01:01:08.800 |
So bn_var_inv is 1 over this denominator here. And then we can see that bn_raw, which is the x hat 01:01:16.080 |
here, is equal to the bn_diff, the numerator, multiplied by the bn_var_inv. And this line here 01:01:26.160 |
that creates H preact was the last piece we've already backpropagated through it. 01:01:29.920 |
So now what we want to do is we are here, and we have bn_raw, and we have to first backpropagate 01:01:37.040 |
into bn_diff and bn_var_inv. So now we're here, and we have dbn_raw, and we need to backpropagate 01:01:44.880 |
through this line. Now, I've written out the shapes here, and indeed bn_var_inv is a shape 1 01:01:51.680 |
by 64, so there is a broadcasting happening here that we have to be careful with. But it is just 01:01:57.520 |
an element-wise simple multiplication. By now, we should be pretty comfortable with that. To get 01:02:02.000 |
dbn_diff, we know that this is just bn_var_inv multiplied with dbn_raw. And conversely, to get 01:02:13.200 |
dbn_var_inv, we need to take bn_diff and multiply that by dbn_raw. So this is the candidate, but of 01:02:24.240 |
course we need to make sure that broadcasting is obeyed. So in particular, bn_var_inv multiplying 01:02:30.080 |
with dbn_raw will be okay and give us 32 by 64 as we expect. But dbn_var_inv would be taking a 32 01:02:40.160 |
by 64, multiplying it by 32 by 64. So this is a 32 by 64. But of course this bn_var_inv is only 1 01:02:50.640 |
by 64. So the second line here needs a sum across the examples, and because there's this dimension 01:02:58.400 |
here, we need to make sure that keep_dim is true. So this is the candidate. Let's erase this and 01:03:06.480 |
let's swing down here and implement it. And then let's comment out dbn_var_inv and dbn_diff. 01:03:14.800 |
Now, we'll actually notice that dbn_diff, by the way, is going to be incorrect. So when I run this, 01:03:24.640 |
bn_var_inv is correct. bn_diff is not correct. And this is actually expected, because we're not 01:03:31.440 |
done with bn_diff. So in particular, when we slide here, we see here that bn_raw is a function of bn_diff, 01:03:38.480 |
but actually bn_var_inv is a function of bn_var, which is a function of bn_diff_do, 01:03:43.360 |
which is a function of bn_diff. So it comes here. So bdn_diff, these variable names are crazy, 01:03:50.400 |
I'm sorry. It branches out into two branches, and we've only done one branch of it. We have to 01:03:55.600 |
continue our backpropagation and eventually come back to bn_diff, and then we'll be able to do a 01:03:59.680 |
+= and get the actual correct gradient. For now, it is good to verify that cmp also works. It doesn't 01:04:06.160 |
just lie to us and tell us that everything is always correct. It can in fact detect when your 01:04:11.440 |
gradient is not correct. So that's good to see as well. Okay, so now we have the derivative here, 01:04:16.480 |
and we're trying to backpropagate through this line. And because we're raising to a power of 01:04:20.880 |
-0.5, I brought up the power rule. And we see that basically we have that the bn_var will now be, 01:04:27.360 |
we bring down the exponent, so -0.5 times x, which is this, and now raise to the power of -0.5-1, 01:04:37.360 |
which is -1.5. Now, we would have to also apply a small chain rule here in our head, 01:04:44.320 |
because we need to take further derivative of bn_var with respect to this expression here 01:04:49.840 |
inside the bracket. But because this is an element-wise operation, and everything is 01:04:53.680 |
fairly simple, that's just 1. And so there's nothing to do there. So this is the local 01:04:58.560 |
derivative, and then times the global derivative to create the chain rule. This is just times the 01:05:03.760 |
bn_var. So this is our candidate. Let me bring this down and uncomment the check. 01:05:11.920 |
And we see that we have the correct result. Now, before we backpropagate through the next line, 01:05:19.440 |
I want to briefly talk about the note here, where I'm using the Bessel's correction, 01:05:22.640 |
dividing by n-1, instead of dividing by n, when I normalize here the sum of squares. 01:05:29.680 |
Now, you'll notice that this is a departure from the paper, which uses 1/n instead, 01:05:34.000 |
not 1/n-1. There, m is our n. And so it turns out that there are two ways of estimating variance 01:05:42.400 |
of an array. One is the biased estimate, which is 1/n, and the other one is the unbiased estimate, 01:05:49.120 |
which is 1/n-1. Now, confusingly, in the paper, this is not very clearly described, 01:05:55.760 |
and also it's a detail that kind of matters, I think. They are using the biased version 01:06:00.400 |
at training time, but later, when they are talking about the inference, they are mentioning that when 01:06:05.360 |
they do the inference, they are using the unbiased estimate, which is the n-1 version, basically, 01:06:13.440 |
for inference, and to calibrate the running mean and the running variance, basically. 01:06:20.080 |
And so they actually introduce a train-test mismatch, where in training, they use the 01:06:24.320 |
biased version, and in test time, they use the unbiased version. I find this extremely confusing. 01:06:30.000 |
You can read more about the Bessel's correction and why dividing by n-1 gives you a better 01:06:35.680 |
estimate of the variance in the case where you have population sizes or samples for a population 01:06:40.400 |
that are very small. And that is indeed the case for us, because we are dealing with mini-batches, 01:06:46.960 |
and these mini-batches are a small sample of a larger population, which is the entire training 01:06:51.760 |
set. And so it just turns out that if you just estimate it using 1/n, that actually almost always 01:06:58.000 |
underestimates the variance. And it is a biased estimator, and it is advised that you use the 01:07:02.800 |
unbiased version and divide by n-1. And you can go through this article here that I liked 01:07:07.680 |
that actually describes the full reasoning, and I'll link it in the video description. 01:07:11.120 |
Now, when you calculate the torsion variance, you'll notice that they take the unbiased flag, 01:07:17.360 |
whether or not you want to divide by n or n-1. Confusingly, they do not mention what the default 01:07:24.080 |
is for unbiased, but I believe unbiased by default is true. I'm not sure why the docs 01:07:29.440 |
here don't cite that. Now, in the batch norm 1D, the documentation again is kind of wrong 01:07:35.840 |
and confusing. It says that the standard deviation is calculated via the biased estimator, 01:07:40.560 |
but this is actually not exactly right, and people have pointed out that it is not right 01:07:45.040 |
in a number of issues since then, because actually the rabbit hole is deeper, and they follow the 01:07:50.960 |
paper exactly, and they use the biased version for training. But when they're estimating the 01:07:55.840 |
running standard deviation, they are using the unbiased version. So again, there's the train 01:08:00.400 |
test mismatch. So long story short, I'm not a fan of train test discrepancies. I basically kind of 01:08:07.200 |
consider the fact that we use the biased version, the training time, and the unbiased test time, 01:08:13.040 |
I basically consider this to be a bug, and I don't think that there's a good reason for that. 01:08:16.800 |
They don't really go into the detail of the reasoning behind it in this paper. So that's 01:08:22.240 |
why I basically prefer to use the Bessel's correction in my own work. Unfortunately, 01:08:27.360 |
batch norm does not take a keyword argument that tells you whether or not you want to use the 01:08:32.400 |
unbiased version or the biased version in both train and test, and so therefore anyone using 01:08:36.560 |
batch normalization basically in my view has a bit of a bug in the code. And this turns out to 01:08:43.040 |
be much less of a problem if your mini batch sizes are a bit larger. But still, I just find it kind 01:08:48.560 |
of unpalatable. So maybe someone can explain why this is okay. But for now, I prefer to use the 01:08:54.320 |
unbiased version consistently both during training and at test time, and that's why I'm using 1/n-1 01:09:00.720 |
here. Okay, so let's now actually backpropagate through this line. So the first thing that I 01:09:07.840 |
always like to do is I like to scrutinize the shapes first. So in particular here, looking at 01:09:12.480 |
the shapes of what's involved, I see that bn_var shape is 1 by 64, so it's a row vector, and bn_div2.shape 01:09:20.640 |
is 32 by 64. So clearly here we're doing a sum over the zeroth axis to squash the first dimension 01:09:29.600 |
of the shapes here using a sum. So that right away actually hints to me that there will be some kind 01:09:36.240 |
of a replication or broadcasting in the backward pass. And maybe you're noticing the pattern here, 01:09:41.120 |
but basically anytime you have a sum in the forward pass, that turns into a replication 01:09:46.720 |
or broadcasting in the backward pass along the same dimension. And conversely, when we have a 01:09:51.840 |
replication or a broadcasting in the forward pass, that indicates a variable reuse. And so in the 01:09:58.880 |
backward pass, that turns into a sum over the exact same dimension. And so hopefully you're 01:10:03.600 |
noticing that duality, that those two are kind of like the opposites of each other in the forward 01:10:07.360 |
and the backward pass. Now once we understand the shapes, the next thing I like to do always is I 01:10:12.480 |
like to look at a toy example in my head to sort of just like understand roughly how the variable 01:10:18.240 |
dependencies go in the mathematical formula. So here we have a two-dimensional array, bn_div2, 01:10:25.520 |
which we are scaling by a constant, and then we are summing vertically over the columns. So if 01:10:31.840 |
we have a 2x2 matrix A and then we sum over the columns and scale, we would get a row vector b1, 01:10:37.120 |
b2, and b1 depends on A in this way, where it's just sum that is scaled of A, and b2 in this way, 01:10:45.200 |
where it's the second column summed and scaled. And so looking at this basically, what we want 01:10:52.400 |
to do now is we have the derivatives on b1 and b2, and we want to back propagate them into A's. 01:10:57.600 |
And so it's clear that just differentiating in your head, the local derivative here is 1 over 01:11:01.920 |
n minus 1 times 1 for each one of these A's. And basically the derivative of b1 has to flow 01:11:11.600 |
through the columns of A scaled by 1 over n minus 1. And that's roughly what's happening here. 01:11:18.560 |
So intuitively, the derivative flow tells us that dbn_div2 will be the local derivative of 01:11:27.280 |
this operation. And there are many ways to do this, by the way, but I like to do something 01:11:30.960 |
like this, torch.once_like of bn_div2. So I'll create a large array, two-dimensional, of ones, 01:11:38.640 |
and then I will scale it. So 1.0 divide by n minus 1. So this is an array of 1 over n minus 1. 01:11:48.480 |
And that's sort of like the local derivative. And now for the chain rule, I will simply just 01:11:53.360 |
multiply it by dbn_var. And notice here what's going to happen. This is 32 by 64, 01:12:01.440 |
and this is just 1 by 64. So I'm letting the broadcasting do the replication, 01:12:07.040 |
because internally in PyTorch, basically dbn_var, which is 1 by 64 row vector, 01:12:13.040 |
will in this multiplication get copied vertically until the two are of the same shape, and then 01:12:18.960 |
there will be an element-wise multiply. And so the broadcasting is basically doing the replication. 01:12:24.480 |
And I will end up with the derivatives of dbn_div2 here. So this is the candidate solution. 01:12:31.920 |
Let's bring it down here. Let's uncomment this line where we check it, and let's hope for the 01:12:38.080 |
best. And indeed, we see that this is the correct formula. Next up, let's differentiate here into 01:12:44.320 |
bn_div. So here we have that bn_div is element-wise squared to create bn_div2. So this is a relatively 01:12:52.320 |
simple derivative, because it's a simple element-wise operation. So it's kind of like the 01:12:55.920 |
scalar case. And we have that dbn_div should be, if this is x squared, then the derivative of this 01:13:02.640 |
is 2x. So it's simply 2 times bn_div, that's the local derivative, and then times chain rule. And 01:13:10.560 |
the shape of these is the same. They are of the same shape. So times this. So that's the backward 01:13:16.800 |
pass for this variable. Let me bring that down here. And now we have to be careful, because we 01:13:21.840 |
already calculated dbn_div, right? So this is just the end of the other branch coming back to bn_div, 01:13:30.880 |
because bn_div was already backpropagated to way over here from bn_raw. So we now completed the 01:13:37.520 |
second branch. And so that's why I have to do plus equals. And if you recall, we had an incorrect 01:13:43.040 |
derivative for bn_div before. And I'm hoping that once we append this last missing piece, 01:13:48.320 |
we have the exact correctness. So let's run. And bn_div2, bn_div now actually shows the exact 01:13:55.360 |
correct derivative. So that's comforting. Okay, so let's now backpropagate through this line here. 01:14:01.520 |
The first thing we do, of course, is we check the shapes. And I wrote them out here. And basically, 01:14:07.840 |
the shape of this is 32 by 64. H_prebn is the same shape. But bn_mini is a row vector, 1 by 64. 01:14:16.080 |
So this minus here will actually do broadcasting. And so we have to be careful with that. 01:14:20.560 |
And as a hint to us, again, because of the duality, a broadcasting in the forward pass 01:14:25.360 |
means a variable reuse. And therefore, there will be a sum in the backward pass. 01:14:29.200 |
So let's write out the backward pass here now. Backpropagate into the H_prebn. Because these 01:14:38.000 |
are the same shape, then the local derivative for each one of the elements here is just one 01:14:42.640 |
for the corresponding element in here. So basically, what this means is that the gradient 01:14:48.400 |
just simply copies. It's just a variable assignment. It's equality. So I'm just going 01:14:52.880 |
to clone this tensor just for safety to create an exact copy of db_ndiff. And then here, 01:15:00.640 |
to backpropagate into this one, what I'm inclined to do here is db_bn_mini will basically be 01:15:08.400 |
what is the local derivative? Well, it's negative torch dot once like of the shape of bn_diff. 01:15:19.280 |
Right? And then times the derivative here, db_ndiff. 01:15:29.440 |
And this here is the backpropagation for the replicated bn_mini. So I still have to 01:15:38.640 |
backpropagate through the replication in the broadcasting, and I do that by doing a sum. 01:15:44.240 |
So I'm going to take this whole thing, and I'm going to do a sum over the zeroth dimension, 01:15:48.800 |
which was the replication. So if you scrutinize this, by the way, you'll notice that this is 01:15:57.120 |
the same shape as that. And so what I'm doing here doesn't actually make that much sense, 01:16:01.920 |
because it's just an array of ones multiplying db_ndiff. So in fact, I can just do this, 01:16:10.080 |
and that is equivalent. So this is the candidate backward pass. Let me copy it here. 01:16:16.160 |
And then let me comment out this one and this one. Enter. And it's wrong. 01:16:25.040 |
Damn. Actually, sorry, this is supposed to be wrong. And it's supposed to be wrong because 01:16:34.560 |
we are backpropagating from a bn_diff into h_prebn, but we're not done because bn_mini 01:16:41.120 |
depends on h_prebn, and there will be a second portion of that derivative coming from this 01:16:45.840 |
second branch. So we're not done yet, and we expect it to be incorrect. So there you go. 01:16:49.920 |
So let's now backpropagate from bn_mini into h_prebn. 01:16:54.400 |
And so here again, we have to be careful because there's a broadcasting along, 01:17:01.280 |
or there's a sum along the zeroth dimension. So this will turn into broadcasting in the backward 01:17:05.920 |
pass now. And I'm going to go a little bit faster on this line because it is very similar to the 01:17:10.560 |
line that we had before, multiple lines in the past, in fact. So dh_prebn will be, 01:17:18.880 |
the gradient will be scaled by 1/n, and then basically this gradient here, db_ndiff_mini, 01:17:27.280 |
is going to be scaled by 1/n, and then it's going to flow across all the columns and deposit 01:17:33.120 |
itself into dh_prebn. So what we want is this thing scaled by 1/n. Let me put the constant 01:17:40.480 |
up front here. So scale down the gradient, and now we need to replicate it across all the 01:17:52.400 |
rows here. So I like to do that by torch.once_like of basically h_prebn. 01:18:01.760 |
And I will let the broadcasting do the work of replication. So 01:18:14.960 |
like that. So this is dh_prebn, and hopefully we can plus equals that. 01:18:22.560 |
So this here is broadcasting, and then this is the scaling. So this should be correct. Okay. 01:18:33.840 |
So that completes the backpropagation of the bastrom layer, and we are now here. Let's 01:18:39.360 |
backpropagate through the linear layer 1 here. Now because everything is getting a little vertically 01:18:44.800 |
crazy, I copy-pasted the line here, and let's just backpropagate through this one line. 01:18:48.960 |
So first, of course, we inspect the shapes, and we see that this is 32 by 64. mcat is 32 by 30, 01:18:57.760 |
w1 is 30 by 64, and b1 is just 64. So as I mentioned, backpropagating through linear 01:19:06.880 |
layers is fairly easy just by matching the shapes, so let's do that. We have that d_mpcat 01:19:14.080 |
should be some matrix multiplication of dh_prebn with w1 and one transpose thrown in there. 01:19:21.600 |
So to make mcat be 32 by 30, I need to take dh_prebn, 32 by 64, and multiply it by w1 dot 01:19:36.000 |
transpose. To get dw1, I need to end up with 30 by 64. So to get that, I need to take mcat transpose 01:19:49.680 |
and multiply that by dh_prebn. And finally, to get db1, this is an addition, and we saw that 01:20:04.640 |
basically I need to just sum the elements in dh_prebn along some dimension. And to make the 01:20:10.800 |
dimensions work out, I need to sum along the 0th axis here to eliminate this dimension, and we do 01:20:17.280 |
not keep dims, so that we want to just get a single one-dimensional vector of 64. So these are the 01:20:24.320 |
claimed derivatives. Let me put that here, and let me uncomment three lines and cross our fingers. 01:20:34.000 |
Everything is great. Okay, so we now continue almost there. We have the derivative of mcat, 01:20:38.960 |
and we want to backpropagate it into mb. So I again copied this line over here. So this is the 01:20:47.040 |
forward pass, and then this is the shapes. So remember that the shape here was 32 by 30, 01:20:52.640 |
and the original shape of mb was 32 by 3 by 10. So this layer in the forward pass, as you recall, 01:20:58.400 |
did the concatenation of these three 10-dimensional character vectors. And so now we just want to 01:21:05.200 |
undo that. So this is actually a relatively straightforward operation, because the backward 01:21:11.120 |
pass of the... What is a view? A view is just a representation of the array. It's just a logical 01:21:16.640 |
form of how you interpret the array. So let's just reinterpret it to be what it was before. 01:21:21.760 |
So in other words, dmb is not 32 by 30. It is basically dmbcat, but if you view it as 01:21:31.840 |
the original shape, so just m.shape, you can pass in tuples into view. And so this should just be... 01:21:41.440 |
Okay, we just re-represent that view, and then we uncomment this line here, and hopefully... 01:21:51.040 |
Yeah, so the derivative of m is correct. So in this case, we just have to re-represent the shape 01:21:56.880 |
of those derivatives into the original view. So now we are at the final line, and the only 01:22:01.600 |
thing that's left to backpropagate through is this indexing operation here, msc@xb. So as I did 01:22:08.480 |
before, I copy-pasted this line here, and let's look at the shapes of everything that's involved 01:22:12.560 |
and remind ourselves how this worked. So m.shape was 32 by 3 by 10. So it's 32 examples, and then 01:22:21.680 |
we have three characters. Each one of them has a 10-dimensional embedding, and this was achieved 01:22:27.920 |
by taking the lookup table C, which have 27 possible characters, each of them 10-dimensional, 01:22:34.240 |
and we looked up at the rows that were specified inside this tensor xb. So xb is 32 by 3, 01:22:42.880 |
and it's basically giving us, for each example, the identity or the index of which character 01:22:47.760 |
is part of that example. And so here I'm showing the first five rows of this tensor xb. 01:22:56.240 |
And so we can see that, for example, here, it was the first example in this batch is that the 01:23:01.520 |
first character, and the first character, and the fourth character comes into the neural net, 01:23:05.440 |
and then we want to predict the next character in a sequence after the character is 1, 1, 4. 01:23:10.800 |
So basically what's happening here is there are integers inside xb, and each one of these 01:23:18.160 |
integers is specifying which row of C we want to pluck out, right? And then we arrange those rows 01:23:25.360 |
that we've plucked out into 32 by 3 by 10 tensor, and we just package them into this tensor. And now 01:23:33.920 |
what's happening is that we have D_amp. So for every one of these basically plucked out rows, 01:23:40.320 |
we have their gradients now, but they're arranged inside this 32 by 3 by 10 tensor. 01:23:45.920 |
So all we have to do now is we just need to route this gradient backwards through this assignment. 01:23:51.600 |
So we need to find which row of C did every one of these 10-dimensional embeddings come from, 01:23:58.080 |
and then we need to deposit them into D_c. So we just need to undo the indexing, and of course, 01:24:06.240 |
if any of these rows of C was used multiple times, which almost certainly is the case, 01:24:10.880 |
like the row 1 and 1 was used multiple times, then we have to remember that the gradients 01:24:15.120 |
that arrive there have to add. So for each occurrence, we have to have an addition. 01:24:21.200 |
So let's now write this out. And I don't actually know of a much better way to do this than a 01:24:25.360 |
for loop, unfortunately, in Python. So maybe someone can come up with a vectorized efficient 01:24:30.960 |
operation, but for now, let's just use for loops. So let me create a torch.zeros_like 01:24:36.080 |
C to initialize just a 27 by 10 tensor of all zeros. And then honestly, for k in range, 01:24:46.080 |
xb.shape at 0. Maybe someone has a better way to do this, but for j in range, xb.shape at 1, 01:24:54.320 |
this is going to iterate over all the elements of xb, all these integers. 01:25:02.000 |
And then let's get the index at this position. So the index is basically xb at k, j. 01:25:11.520 |
So an example of that is 11 or 14 and so on. And now in a forward pass, we basically took 01:25:20.800 |
the row of C at index, and we deposited it into emb at k, j. That's what happened. That's where 01:25:32.000 |
they are packaged. So now we need to go backwards, and we just need to route d_emb at the position 01:25:38.400 |
k, j. We now have these derivatives for each position, and it's 10-dimensional. And you just 01:25:46.160 |
need to go into the correct row of C. So d_C, rather, at i, x is this, but plus equals, because 01:25:55.920 |
there could be multiple occurrences. Like the same row could have been used many, many times. 01:25:59.760 |
And so all of those derivatives will just go backwards through the indexing, and they will add. 01:26:07.840 |
So this is my candidate solution. Let's copy it here. Let's uncomment this and cross our fingers. 01:26:19.520 |
Yay! So that's it. We've backpropagated through this entire beast. So there we go. 01:26:29.360 |
Totally made sense. So now we come to exercise two. It basically turns out that in this first 01:26:34.800 |
exercise, we were doing way too much work. We were backpropagating way too much. And it was all good 01:26:39.600 |
practice and so on, but it's not what you would do in practice. And the reason for that is, for 01:26:44.080 |
example, here I separated out this loss calculation over multiple lines, and I broke it up all to its 01:26:51.040 |
smallest atomic pieces, and we backpropagated through all of those individually. But it turns 01:26:55.600 |
out that if you just look at the mathematical expression for the loss, then actually you can 01:27:01.440 |
do the differentiation on pen and paper, and a lot of terms cancel and simplify. And the mathematical 01:27:06.560 |
expression you end up with can be significantly shorter and easier to implement than backpropagating 01:27:11.440 |
through all the little pieces of everything you've done. So before we had this complicated forward 01:27:16.240 |
pass going from logits to the loss. But in PyTorch, everything can just be glued together into a 01:27:21.920 |
single call, f.crossentropy. You just pass in logits and the labels, and you get the exact same loss, 01:27:27.120 |
as I verify here. So our previous loss and the fast loss coming from the chunk of operations as 01:27:32.880 |
a single mathematical expression is the same, but it's much, much faster in a forward pass. 01:27:37.840 |
It's also much, much faster in backward pass. And the reason for that is, if you just look at 01:27:42.480 |
the mathematical form of this and differentiate again, you will end up with a very small and short 01:27:46.800 |
expression. So that's what we want to do here. We want to, in a single operation or in a single go, 01:27:52.480 |
or like very quickly, go directly into dlogits. And we need to implement dlogits as a function of 01:27:59.600 |
logits and ybs. But it will be significantly shorter than whatever we did here, where to get 01:28:06.640 |
to dlogits, we had to go all the way here. So all of this work can be skipped in a much, much simpler 01:28:13.040 |
mathematical expression that you can implement here. So you can give it a shot yourself. Basically, 01:28:19.440 |
look at what exactly is the mathematical expression of loss and differentiate with respect to the 01:28:24.480 |
logits. So let me show you a hint. You can, of course, try it for yourself. But if not, I can 01:28:32.000 |
give you some hint of how to get started mathematically. So basically, what's happening 01:28:37.520 |
here is we have logits. Then there's a softmax that takes the logits and gives you probabilities. 01:28:42.480 |
Then we are using the identity of the correct next character to pluck out a row of probabilities. 01:28:49.600 |
Take the negative log of it to get our negative log probability. And then we average up all the 01:28:54.880 |
log probabilities or negative log probabilities to get our loss. So basically, what we have is 01:29:00.960 |
for a single individual example, rather, we have that loss is equal to negative log probability, 01:29:06.560 |
where p here is kind of like thought of as a vector of all the probabilities. So at the yth 01:29:13.200 |
position, where y is the label, and we have that p here, of course, is the softmax. So the 01:29:21.200 |
ith component of p, of this probability vector, is just the softmax function. So raising all the 01:29:28.080 |
logits basically to the power of e and normalizing so everything sums to one. Now, if you write out 01:29:36.480 |
p of y here, you can just write out the softmax. And then basically what we're interested in is 01:29:41.120 |
we're interested in the derivative of the loss with respect to the ith logit. And so basically, 01:29:48.800 |
it's a d by d li of this expression here, where we have l indexed with the specific label y, 01:29:55.680 |
and on the bottom, we have a sum over j of e to the lj and the negative log of all that. 01:29:59.680 |
So potentially, give it a shot, pen and paper, and see if you can actually derive the expression 01:30:04.720 |
for the loss by d li. And then we're going to implement it here. Okay, so I'm going to give 01:30:09.920 |
away the result here. So this is some of the math I did to derive the gradients analytically. And so 01:30:17.120 |
we see here that I'm just applying the rules of calculus from your first or second year of 01:30:20.800 |
bachelor's degree, if you took it. And we see that the expressions actually simplify quite a bit. 01:30:25.920 |
You have to separate out the analysis in the case where the ith index that you're interested in 01:30:30.480 |
inside logits is either equal to the label or it's not equal to the label. And then the expressions 01:30:35.600 |
simplify and cancel in a slightly different way. And what we end up with is something very, 01:30:39.920 |
very simple. We either end up with basically p at i, where p is again this vector of probabilities 01:30:47.120 |
after a softmax, or p at i minus one, where we just simply subtract to one. But in any case, 01:30:52.800 |
we just need to calculate the softmax p, and then in the correct dimension, we need to subtract to 01:30:58.400 |
one. And that's the gradient, the form that it takes analytically. So let's implement this, 01:31:03.280 |
basically. And we have to keep in mind that this is only done for a single example. But here we 01:31:07.680 |
are working with batches of examples. So we have to be careful of that. And then the loss for a 01:31:13.200 |
batch is the average loss over all the examples. So in other words, is the example for all the 01:31:18.160 |
individual examples, is the loss for each individual example summed up and then divided by n. 01:31:23.840 |
And we have to back propagate through that as well and be careful with it. So d logits is going to be 01:31:29.760 |
f dot softmax. PyTorch has a softmax function that you can call. And we want to apply the softmax on 01:31:37.120 |
the logits. And we want to go in the dimension that is one. So basically, we want to do the 01:31:43.440 |
softmax along the rows of these logits. Then at the correct positions, we need to subtract a one. 01:31:49.840 |
So d logits at iterating over all the rows and indexing into the columns provided by the 01:31:58.320 |
correct labels inside yb, we need to subtract one. And then finally, it's the average loss that is 01:32:05.520 |
the loss. And in the average, there's a one over n of all the losses added up. And so we need to 01:32:11.440 |
also back propagate through that division. So the gradient has to be scaled down by n as well, 01:32:17.440 |
because of the mean. But this otherwise should be the result. So now if we verify this, 01:32:24.720 |
we see that we don't get an exact match. But at the same time, the maximum difference from logits 01:32:31.600 |
from PyTorch and rd logits here is on the order of 5e negative 9. So it's a tiny, tiny number. 01:32:38.960 |
So because of floating point wonkiness, we don't get the exact bitwise result, 01:32:44.320 |
but we basically get the correct answer approximately. Now I'd like to pause here 01:32:50.400 |
briefly before we move on to the next exercise, because I'd like us to get an intuitive sense 01:32:54.480 |
of what d logits is, because it has a beautiful and very simple explanation, honestly. So here, 01:33:01.200 |
I'm taking d logits, and I'm visualizing it. And we can see that we have a batch of 32 examples 01:33:06.480 |
of 27 characters. And what is d logits intuitively, right? d logits is the probabilities that the 01:33:13.600 |
probabilities matrix in the forward pass. But then here, these black squares are the positions of the 01:33:18.240 |
correct indices, where we subtracted a 1. And so what is this doing, right? These are the derivatives 01:33:25.280 |
on d logits. And so let's look at just the first row here. So that's what I'm doing here. I'm 01:33:33.040 |
calculating the probabilities of these logits, and then I'm taking just the first row. And this is 01:33:37.600 |
the probability row. And then d logits of the first row, and multiplying by n just for us so that 01:33:44.320 |
we don't have the scaling by n in here, and everything is more interpretable. 01:33:47.440 |
We see that it's exactly equal to the probability, of course, but then the position of the correct 01:33:52.800 |
index has a minus equals 1. So minus 1 on that position. And so notice that if you take d logits 01:34:00.160 |
at 0, and you sum it, it actually sums to 0. And so you should think of these gradients here at 01:34:09.120 |
each cell as like a force. We are going to be basically pulling down on the probabilities 01:34:17.280 |
of the incorrect characters, and we're going to be pulling up on the probability 01:34:21.120 |
at the correct index. And that's what's basically happening in each row. And the amount of push and 01:34:30.160 |
pull is exactly equalized, because the sum is 0. So the amount to which we pull down on the 01:34:35.840 |
probabilities, and the amount that we push up on the probability of the correct character is equal. 01:34:40.480 |
So the repulsion and the attraction are equal. And think of the neural net now as a massive 01:34:47.200 |
pulley system or something like that. We're up here on top of d logits, and we're pulling up, 01:34:52.800 |
we're pulling down the probabilities of incorrect and pulling up the probability of the correct. 01:34:56.480 |
And in this complicated pulley system, because everything is mathematically just determined, 01:35:01.920 |
just think of it as sort of like this tension translating to this complicating pulley mechanism. 01:35:06.720 |
And then eventually we get a tug on the weights and the biases. And basically in each update, 01:35:11.600 |
we just kind of like tug in the direction that we like for each of these elements, 01:35:15.600 |
and the parameters are slowly given in to the tug. And that's what training a neural net 01:35:20.000 |
kind of like looks like on a high level. And so I think the forces of push and pull in these 01:35:25.280 |
gradients are actually very intuitive here. We're pushing and pulling on the correct answer and the 01:35:30.720 |
incorrect answers. And the amount of force that we're applying is actually proportional to the 01:35:36.400 |
probabilities that came out in the forward pass. And so for example, if our probabilities came out 01:35:41.520 |
exactly correct, so they would have had zero everywhere except for one at the correct position, 01:35:47.680 |
then the d logits would be all a row of zeros for that example. There would be no push and pull. 01:35:53.440 |
So the amount to which your prediction is incorrect is exactly the amount by which you're 01:35:58.640 |
going to get a pull or a push in that dimension. So if you have, for example, a very confidently 01:36:04.240 |
mispredicted element here, then what's going to happen is that element is going to be pulled down 01:36:10.080 |
very heavily, and the correct answer is going to be pulled up to the same amount. And the other 01:36:15.600 |
characters are not going to be influenced too much. So the amount to which you mispredict is 01:36:21.040 |
then proportional to the strength of the pull. And that's happening independently in all the 01:36:26.240 |
dimensions of this tensor. And it's sort of very intuitive and very easy to think through. And 01:36:31.440 |
that's basically the magic of the cross-entropy loss and what it's doing dynamically in the 01:36:35.920 |
backward pass of the neural net. So now we get to exercise number three, which is a very fun exercise, 01:36:41.600 |
depending on your definition of fun. And we are going to do for batch normalization exactly what 01:36:46.560 |
we did for cross-entropy loss in exercise number two. That is, we are going to consider it as a 01:36:51.280 |
glued single mathematical expression and backpropagate through it in a very efficient manner, 01:36:56.000 |
because we are going to derive a much simpler formula for the backward pass of batch normalization. 01:37:00.320 |
And we're going to do that using pen and paper. So previously, we've broken up batch normalization 01:37:05.760 |
into all of the little intermediate pieces and all the atomic operations inside it, and then we 01:37:09.920 |
backpropagate it through it one by one. Now we just have a single sort of forward pass of a batch norm, 01:37:17.760 |
and it's all glued together, and we see that we get the exact same result as before. 01:37:22.320 |
Now for the backward pass, we'd like to also implement a single formula basically for 01:37:27.920 |
backpropagating through this entire operation, that is the batch normalization. 01:37:31.360 |
So in the forward pass previously, we took HPBN, the hidden states of the pre-batch normalization, 01:37:38.720 |
and created HPREACT, which is the hidden states just before the activation. In the batch 01:37:44.320 |
normalization paper, HPREBN is x and HPREACT is y. So in the backward pass, what we'd like to do now 01:37:51.280 |
is we have DHPREACT, and we'd like to produce DHPREBN, and we'd like to do that in a very 01:37:57.760 |
efficient manner. So that's the name of the game, calculate DHPREBN given DHPREACT. And for the 01:38:04.320 |
purposes of this exercise, we're going to ignore gamma and beta and their derivatives, because they 01:38:09.440 |
take on a very simple form in a very similar way to what we did up above. So let's calculate this 01:38:15.920 |
given that right here. So to help you a little bit like I did before, I started off the 01:38:22.800 |
implementation here on pen and paper, and I took two sheets of paper to derive the mathematical 01:38:28.240 |
formulas for the backward pass. And basically to set up the problem, just write out the mu, 01:38:34.640 |
sigma square, variance, xi hat, and yi, exactly as in the paper, except for the Bessel correction. 01:38:40.880 |
And then in the backward pass, we have the derivative of the loss with respect to all the 01:38:46.160 |
elements of y. And remember that y is a vector. There's multiple numbers here. So we have all 01:38:53.520 |
of the derivatives with respect to all the y's. And then there's a gamma and a beta, and this is 01:38:59.200 |
kind of like the compute graph. The gamma and the beta, there's the x hat, and then the mu and the 01:39:04.480 |
sigma squared, and the x. So we have dl by dyi, and we want dl by dxi for all the i's in these 01:39:13.360 |
vectors. So this is the compute graph, and you have to be careful because I'm trying to note here 01:39:20.320 |
that these are vectors. There's many nodes here inside x, x hat, and y, but mu and sigma, sorry, 01:39:28.480 |
sigma square are just individual scalars, single numbers. So you have to be careful with that. You 01:39:33.600 |
have to imagine there's multiple nodes here, or you're going to get your math wrong. 01:39:36.400 |
So as an example, I would suggest that you go in the following order, one, two, three, four, 01:39:43.920 |
in terms of the backpropagation. So backpropagate into x hat, then into sigma square, then into mu, 01:39:49.280 |
and then into x. Just like in a topological sort in micrograd, we would go from right to left. 01:39:56.080 |
You're doing the exact same thing, except you're doing it with symbols and on a piece of paper. 01:40:00.240 |
So for number one, I'm not giving away too much. If you want dl of dxi hat, then we just take dl 01:40:11.200 |
by dyi and multiply it by gamma, because of this expression here, where any individual yi is just 01:40:17.440 |
gamma times x i hat plus beta. So it didn't help you too much there, but this gives you basically 01:40:24.320 |
the derivatives for all the x hats. And so now, try to go through this computational graph and 01:40:31.360 |
derive what is dl by d sigma square, and then what is dl by d mu, and then what is dl by dx, 01:40:38.960 |
eventually. So give it a go, and I'm going to be revealing the answer one piece at a time. 01:40:44.080 |
Okay, so to get dl by d sigma square, we have to remember again, like I mentioned, that there are 01:40:49.440 |
many x hats here. And remember that sigma square is just a single individual number here. So when 01:40:57.120 |
we look at the expression for dl by d sigma square, we have to actually consider all the possible 01:41:04.400 |
paths that we basically have that there's many x hats, and they all depend on sigma square. 01:41:13.920 |
So sigma square has a large fan out. There's lots of arrows coming out from sigma square into all 01:41:19.040 |
the x hats. And then there's a backpropagating signal from each x hat into sigma square. And 01:41:25.360 |
that's why we actually need to sum over all those i's from i equal to one to m of the dl by d xi hat, 01:41:34.720 |
which is the global gradient, times the xi hat by d sigma square, which is the local gradient 01:41:41.280 |
of this operation here. And then mathematically, I'm just working it out here, and I'm simplifying, 01:41:47.920 |
and you get a certain expression for dl by d sigma square. And we're going to be using this 01:41:52.560 |
expression when we backpropagate into mu, and then eventually into x. So now let's continue 01:41:57.200 |
our backpropagation into mu. So what is dl by d mu? Now again, be careful that mu influences x hat, 01:42:04.320 |
and x hat is actually lots of values. So for example, if our mini-batch size is 32, as it is 01:42:09.520 |
in our example that we were working on, then this is 32 numbers and 32 arrows going back to mu. 01:42:15.840 |
And then mu going to sigma square is just a single arrow, because sigma square is a scalar. 01:42:19.840 |
So in total, there are 33 arrows emanating from mu, and then all of them have gradients coming 01:42:26.160 |
into mu, and they all need to be summed up. And so that's why when we look at the expression for dl 01:42:32.400 |
by d mu, I am summing up over all the gradients of dl by d xi hat times d xi hat by d mu. So that's 01:42:41.040 |
the that's this arrow, and that's 32 arrows here, and then plus the one arrow from here, which is dl 01:42:46.720 |
by d sigma square times d sigma square by d mu. So now we have to work out that expression, and 01:42:52.880 |
let me just reveal the rest of it. Simplifying here is not complicated, the first term, and you 01:42:59.280 |
just get an expression here. For the second term though, there's something really interesting that 01:43:03.120 |
happens. When we look at d sigma square by d mu and we simplify, at one point if we assume that 01:43:11.120 |
in a special case where mu is actually the average of xi's, as it is in this case, then if we plug 01:43:18.800 |
that in, then actually the gradient vanishes and becomes exactly zero. And that makes the entire 01:43:24.480 |
second term cancel. And so these, if you just have a mathematical expression like this, and you look 01:43:30.640 |
at d sigma square by d mu, you would get some mathematical formula for how mu impacts sigma 01:43:36.720 |
square. But if it is the special case that mu is actually equal to the average, as it is in the 01:43:41.920 |
case of batch normalization, that gradient will actually vanish and become zero. So the whole 01:43:46.800 |
term cancels, and we just get a fairly straightforward expression here for dl by d mu. 01:43:51.600 |
Okay, and now we get to the craziest part, which is deriving dl by d xi, which is ultimately what 01:43:58.000 |
we're after. Now let's count, first of all, how many numbers are there inside x? As I mentioned, 01:44:04.320 |
there are 32 numbers. There are 32 little xi's. And let's count the number of arrows emanating 01:44:09.440 |
from each xi. There's an arrow going to mu, an arrow going to sigma square, and then there's 01:44:15.280 |
an arrow going to x hat. But this arrow here, let's scrutinize that a little bit. Each xi hat 01:44:21.440 |
is just a function of xi and all the other scalars. So xi hat only depends on xi and none of the other 01:44:29.200 |
x's. And so therefore, there are actually, in this single arrow, there are 32 arrows. But those 32 01:44:35.440 |
arrows are going exactly parallel. They don't interfere. They're just going parallel between x 01:44:40.720 |
and x hat. You can look at it that way. And so how many arrows are emanating from each xi? There 01:44:45.360 |
are three arrows, mu, sigma square, and the associated x hat. And so in backpropagation, 01:44:52.400 |
we now need to apply the chain rule, and we need to add up those three contributions. 01:44:57.120 |
So here's what that looks like if I just write that out. 01:45:00.080 |
We're chaining through mu, sigma square, and through x hat. And those three terms are just 01:45:09.120 |
here. Now, we already have three of these. We have dl by dx i hat. We have dl by d mu, 01:45:16.880 |
which we derived here. And we have dl by d sigma square, which we derived here. But we need three 01:45:22.080 |
other terms here. This one, this one, and this one. So I invite you to try to derive them. It's 01:45:28.240 |
not that complicated. You're just looking at these expressions here and differentiating with respect 01:45:32.240 |
to xi. So give it a shot, but here's the result, or at least what I got. I'm just differentiating 01:45:43.920 |
with respect to xi for all of these expressions. And honestly, I don't think there's anything too 01:45:47.520 |
tricky here. It's basic calculus. Now, what gets a little bit more tricky is we are now going to 01:45:52.960 |
plug everything together. So all of these terms multiplied with all of these terms and add it up 01:45:57.680 |
according to this formula. And that gets a little bit hairy. So what ends up happening is 01:46:02.240 |
you get a large expression. And the thing to be very careful with here, of course, is 01:46:10.320 |
we are working with a dl by d xi for a specific i here. But when we are plugging in some of these 01:46:16.080 |
terms, like say this term here, dl by d sigma squared, you see how dl by d sigma squared, 01:46:24.400 |
I end up with an expression. And I'm iterating over little i's here. But I can't use i as the 01:46:30.640 |
variable when I plug in here, because this is a different i from this i. This i here is just a 01:46:36.240 |
placeholder, like a local variable for a for loop in here. So here, when I plug that in, you notice 01:46:41.840 |
that I renamed the i to a j, because I need to make sure that this j is not this i. This j is 01:46:48.640 |
like a little local iterator over 32 terms. And so you have to be careful with that when you're 01:46:53.920 |
plugging in the expressions from here to here. You may have to rename i's into j's. And you have 01:46:58.240 |
to be very careful what is actually an i with respect to dl by d xi. So some of these are j's, 01:47:05.680 |
some of these are i's. And then we simplify this expression. And I guess the big thing to notice 01:47:13.520 |
here is a bunch of terms just kind of come out to the front, and you can refactor them. There's 01:47:17.840 |
a sigma squared plus epsilon raised to the power of negative 3 over 2. This sigma squared plus 01:47:22.080 |
epsilon can be actually separated out into three terms. Each of them are sigma squared plus epsilon 01:47:28.000 |
to the negative 1 over 2. So the three of them multiplied is equal to this. And then those three 01:47:34.000 |
terms can go different places because of the multiplication. So one of them actually comes 01:47:38.480 |
out to the front and will end up here outside. One of them joins up with this term, and one of 01:47:45.200 |
them joins up with this other term. And then when you simplify the expression, you'll notice that 01:47:50.400 |
some of these terms that are coming out are just the xi hats. So you can simplify just by rewriting 01:47:55.920 |
that. And what we end up with at the end is a fairly simple mathematical expression over here 01:48:00.800 |
that I cannot simplify further. But basically, you'll notice that it only uses the stuff we have 01:48:06.080 |
and it derives the thing we need. So we have dl by dy for all the i's, and those are used plenty 01:48:13.680 |
of times here. And also in addition, what we're using is these xi hats and xj hats, and they just 01:48:18.720 |
come from the forward pass. And otherwise, this is a simple expression, and it gives us dl by dxi 01:48:25.520 |
for all the i's, and that's ultimately what we're interested in. So that's the end of BatchNorm 01:48:31.440 |
backward pass analytically. Let's now implement this final result. Okay, so I implemented the 01:48:37.680 |
expression into a single line of code here, and you can see that the max diff is tiny, 01:48:43.120 |
so this is the correct implementation of this formula. Now, I'll just basically tell you that 01:48:49.760 |
getting this formula here from this mathematical expression was not trivial, and there's a lot 01:48:54.480 |
going on packed into this one formula. And this is a whole exercise by itself, because you have 01:49:00.000 |
to consider the fact that this formula here is just for a single neuron and a batch of 32 examples. 01:49:05.440 |
But what I'm doing here is we actually have 64 neurons, and so this expression has to in parallel 01:49:11.920 |
evaluate the BatchNorm backward pass for all of those 64 neurons in parallel and independently. 01:49:16.720 |
So this has to happen basically in every single column of the inputs here. And in addition to 01:49:25.200 |
that, you see how there are a bunch of sums here, and we need to make sure that when I do those 01:49:29.280 |
sums that they broadcast correctly onto everything else that's here. And so getting this expression 01:49:34.800 |
is just like highly non-trivial, and I invite you to basically look through it and step through it, 01:49:38.080 |
and it's a whole exercise to make sure that this checks out. But once all the shapes agree, 01:49:44.800 |
and once you convince yourself that it's correct, you can also verify that PyTorch gets the exact 01:49:48.800 |
same answer as well. And so that gives you a lot of peace of mind that this mathematical formula 01:49:53.520 |
is correctly implemented here and broadcasted correctly and replicated in parallel for all 01:49:58.720 |
of the 64 neurons inside this BatchNorm layer. Okay, and finally, exercise number four asks you 01:50:05.280 |
to put it all together. And here we have a redefinition of the entire problem. So you see 01:50:10.320 |
that we re-initialized the neural net from scratch and everything. And then here, instead of calling 01:50:15.280 |
the loss that backward, we want to have the manual backpropagation here as we derived it up above. 01:50:20.160 |
So go up, copy-paste all the chunks of code that we've already derived, put them here, and derive 01:50:26.000 |
your own gradients, and then optimize this neural net basically using your own gradients all the way 01:50:31.200 |
to the calibration of the BatchNorm and the evaluation of the loss. And I was able to 01:50:35.840 |
achieve quite a good loss, basically the same loss you would achieve before. And that shouldn't be 01:50:40.240 |
surprising because all we've done is we've really gotten into loss that backward, and we've pulled 01:50:45.040 |
out all the code and inserted it here. But those gradients are identical, and everything is 01:50:50.160 |
identical, and the results are identical. It's just that we have full visibility on exactly what 01:50:55.040 |
goes on under the hood of loss that backward in this specific case. Okay, and this is all of our 01:51:01.200 |
code. This is the full backward pass using basically the simplified backward pass for 01:51:05.760 |
the cross-entropy loss and the BatchNormalization. So backpropagating through cross-entropy, 01:51:11.120 |
the second layer, the 10-H nonlinearity, the BatchNormalization through the first layer, 01:51:18.480 |
and through the embedding. And so you see that this is only maybe, what is this, 20 lines of 01:51:22.880 |
code or something like that? And that's what gives us gradients. And now we can potentially 01:51:27.600 |
erase loss that backward. So the way I have the code set up is you should be able to run this 01:51:32.560 |
entire cell once you fill this in, and this will run for only 100 iterations and then break. 01:51:37.280 |
And it breaks because it gives you an opportunity to check your gradients against PyTorch. 01:51:41.760 |
So here, our gradients we see are not exactly equal. They are approximately equal, and the 01:51:49.600 |
differences are tiny, 1 and negative 9 or so. And I don't exactly know where they're coming from, 01:51:54.240 |
to be honest. So once we have some confidence that the gradients are basically correct, 01:51:58.320 |
we can take out the gradient checking. We can disable this breaking statement. 01:52:04.400 |
And then we can basically disable loss that backward. We don't need it anymore. 01:52:11.200 |
Feels amazing to say that. And then here, when we are doing the update, we're not going to use 01:52:17.520 |
p.grad. This is the old way of PyTorch. We don't have that anymore because we're not doing backward. 01:52:23.360 |
We are going to use this update where we, you see that I'm iterating over, I've arranged the grads 01:52:30.080 |
to be in the same order as the parameters, and I'm zipping them up, the gradients and the parameters, 01:52:35.120 |
into p and grad. And then here, I'm going to step with just the grad that we derived manually. 01:52:40.880 |
So the last piece is that none of this now requires gradients from PyTorch. And so one 01:52:49.280 |
thing you can do here is you can do withTorch.noGrad and offset this whole code block. 01:52:57.200 |
And really what you're saying is you're telling PyTorch that, "Hey, I'm not going to call backward 01:53:01.200 |
on any of this." And this allows PyTorch to be a bit more efficient with all of it. 01:53:05.520 |
And then we should be able to just run this. And it's running. And you see that 01:53:15.680 |
lost at backward is commented out and we're optimizing. So we're going to leave this run, 01:53:22.560 |
and hopefully we get a good result. Okay, so I allowed the neural net to finish optimization. 01:53:27.920 |
Then here, I calibrate the bastion parameters because I did not keep track of the running mean 01:53:35.040 |
variance in their training loop. Then here, I ran the loss. And you see that we actually obtained a 01:53:40.400 |
pretty good loss, very similar to what we've achieved before. And then here, I'm sampling 01:53:44.880 |
from the model. And we see some of the name-like gibberish that we're sort of used to. So basically, 01:53:49.760 |
the model worked and samples pretty decent results compared to what we were used to. 01:53:55.520 |
So everything is the same. But of course, the big deal is that we did not use lots of backward. 01:54:00.080 |
We did not use PyTorch autograd, and we estimated our gradients ourselves by hand. 01:54:04.400 |
And so hopefully, you're looking at this, the backward pass of this neural net, 01:54:08.320 |
and you're thinking to yourself, actually, that's not too complicated. Each one of these layers is 01:54:14.480 |
like three lines of code or something like that. And most of it is fairly straightforward, 01:54:18.960 |
potentially with the notable exception of the batch normalization backward pass. 01:54:22.800 |
Otherwise, it's pretty good. Okay, and that's everything I wanted to cover for this lecture. 01:54:27.440 |
So hopefully, you found this interesting. And what I liked about it, honestly, is that it gave us a 01:54:32.400 |
very nice diversity of layers to backpropagate through. And I think it gives a pretty nice 01:54:38.000 |
and comprehensive sense of how these backward passes are implemented and how they work. 01:54:42.240 |
And you'd be able to derive them yourself. But of course, in practice, you probably don't want to, 01:54:46.240 |
and you want to use the PyTorch autograd. But hopefully, you have some intuition about how 01:54:50.240 |
gradients flow backwards through the neural net, starting at the loss, and how they flow through 01:54:55.200 |
all the variables and all the intermediate results. And if you understood a good chunk of it, and if 01:55:00.640 |
you have a sense of that, then you can count yourself as one of these buff dojis on the left, 01:55:04.720 |
instead of the dojis on the right here. Now, in the next lecture, we're actually going to go to 01:55:10.400 |
recurrent neural nets, LSTMs, and all the other variants of RNS. And we're going to start to 01:55:16.240 |
complexify the architecture and start to achieve better log likelihoods. And so I'm really looking