back to index

Building makemore Part 4: Becoming a Backprop Ninja


Chapters

0:0 intro: why you should care & fun history
7:26 starter code
13:1 exercise 1: backproping the atomic compute graph
65:17 brief digression: bessel’s correction in batchnorm
86:31 exercise 2: cross entropy loss backward pass
96:37 exercise 3: batch norm layer backward pass
110:2 exercise 4: putting it all together
114:24 outro

Whisper Transcript | Transcript Only Page

00:00:00.000 | Hi everyone. So today we are once again continuing our implementation of Makemore.
00:00:04.240 | Now so far we've come up to here, Montalio Perceptrons, and our neural net looked like this,
00:00:10.720 | and we were implementing this over the last few lectures. Now I'm sure everyone is very excited
00:00:14.880 | to go into recurrent neural networks and all of their variants and how they work, and the diagrams
00:00:19.360 | look cool and it's very exciting and interesting and we're going to get a better result, but
00:00:22.880 | unfortunately I think we have to remain here for one more lecture. And the reason for that is we've
00:00:28.880 | already trained this Montalio Perceptron, right, and we are getting pretty good loss, and I think
00:00:32.960 | we have a pretty decent understanding of the architecture and how it works, but the line of
00:00:37.840 | code here that I take an issue with is here, lost at backward. That is, we are taking PyTorch
00:00:43.760 | Autograd and using it to calculate all of our gradients along the way, and I would like to
00:00:48.400 | remove the use of lost at backward, and I would like us to write our backward pass manually on
00:00:52.800 | the level of tensors. And I think that this is a very useful exercise for the following reasons.
00:00:58.720 | I actually have an entire blog post on this topic, but I'd like to call backpropagation a
00:01:03.520 | leaky abstraction. And what I mean by that is backpropagation doesn't just make your neural
00:01:09.120 | networks just work magically. It's not the case that you can just stack up arbitrary Lego blocks
00:01:13.440 | of differentiable functions and just cross your fingers and backpropagate and everything is great.
00:01:17.680 | Things don't just work automatically. It is a leaky abstraction in the sense that
00:01:22.880 | you can shoot yourself in the foot if you do not understand its internals.
00:01:26.640 | It will magically not work or not work optimally, and you will need to understand how it works
00:01:32.240 | under the hood if you're hoping to debug it and if you are hoping to address it in your neural net.
00:01:36.480 | So this blog post here from a while ago goes into some of those examples. So for example,
00:01:42.560 | we've already covered them, some of them already. For example, the flat tails of these functions
00:01:48.320 | and how you do not want to saturate them too much because your gradients will die.
00:01:53.120 | The case of dead neurons, which I've already covered as well. The case of exploding or
00:01:57.840 | vanishing gradients in the case of recurrent neural networks, which we are about to cover.
00:02:01.680 | And then also, you will often come across some examples in the wild. This is a snippet that I
00:02:08.480 | found in a random code base on the internet where they actually have a very subtle but pretty major
00:02:14.400 | bug in their implementation. And the bug points at the fact that the author of this code does not
00:02:20.000 | actually understand backpropagation. So what they're trying to do here is they're trying to
00:02:23.360 | clip the loss at a certain maximum value. But actually what they're trying to do is they're
00:02:28.000 | trying to clip the gradients to have a maximum value instead of trying to clip the loss at a
00:02:32.000 | maximum value. And indirectly, they're basically causing some of the outliers to be actually
00:02:38.400 | ignored because when you clip a loss of an outlier, you are setting its gradient to zero.
00:02:44.080 | And so have a look through this and read through it. But there's basically a bunch of subtle issues
00:02:49.840 | that you're going to avoid if you actually know what you're doing. And that's why I don't think
00:02:53.600 | it's the case that because PyTorch or other frameworks offer autograd, it is okay for us
00:02:58.240 | to ignore how it works. Now, we've actually already covered autograd and we wrote micrograd.
00:03:04.640 | But micrograd was an autograd engine only on the level of individual scalars. So the atoms were
00:03:10.160 | single individual numbers. And I don't think it's enough. And I'd like us to basically think about
00:03:15.120 | backpropagation on the level of tensors as well. And so in a summary, I think it's a good exercise.
00:03:20.240 | I think it is very, very valuable. You're going to become better at debugging neural networks and
00:03:25.520 | making sure that you understand what you're doing. It is going to make everything fully explicit. So
00:03:29.920 | you're not going to be nervous about what is hidden away from you. And basically, in general,
00:03:34.000 | we're going to emerge stronger. And so let's get into it. A bit of a fun historical note here is
00:03:39.760 | that today, writing your backward pass by hand and manually is not recommended and no one does it
00:03:44.560 | except for the purposes of exercise. But about 10 years ago in deep learning, this was fairly
00:03:49.360 | standard and in fact pervasive. So at the time, everyone used to write their own backward pass
00:03:53.680 | by hand manually, including myself. And it's just what you would do. So we used to write backward
00:03:58.800 | pass by hand. And now everyone just called lost backward. We've lost something. I wanted to give
00:04:04.400 | you a few examples of this. So here's a 2006 paper from Geoff Hinton and Roslyn Slavkinov
00:04:12.560 | in science that was influential at the time. And this was training some architectures called
00:04:17.840 | restricted Boltzmann machines. And basically, it's an autoencoder trained here. And this is from
00:04:24.160 | roughly 2010. I had a library for training restricted Boltzmann machines. And this was
00:04:29.680 | at the time written in MATLAB. So Python was not used for deep learning pervasively. It was all
00:04:34.080 | MATLAB. And MATLAB was this scientific computing package that everyone would use. So we would
00:04:40.240 | write MATLAB, which is barely a programming language as well. But it had a very convenient
00:04:45.680 | tensor class. And it was this computing environment. And you would run here. It would all run on a CPU,
00:04:50.480 | of course. But you would have very nice plots to go with it and a built-in debugger. And it
00:04:54.560 | was pretty nice. Now, the code in this package in 2010 that I wrote for fitting restricted
00:05:00.800 | Boltzmann machines, to a large extent, is recognizable. But I wanted to show you how you
00:05:05.440 | would-- well, I'm creating the data and the xy batches. I'm initializing the neural net. So it's
00:05:11.520 | got weights and biases, just like we're used to. And then this is the training loop, where we
00:05:15.920 | actually do the forward pass. And then here, at this time, they didn't even necessarily use back
00:05:20.640 | propagation to train neural networks. So this, in particular, implements contrastive divergence,
00:05:25.920 | which estimates a gradient. And then here, we take that gradient and use it for a parameter update
00:05:32.000 | along the lines that we're used to. Yeah, here. But you can see that, basically, people were
00:05:37.920 | meddling with these gradients directly and inline and themselves. It wasn't that common to use an
00:05:42.720 | autograd engine. Here's one more example from a paper of mine from 2014 called Deep Fragment
00:05:48.720 | Embeddings. And here, what I was doing is I was aligning images and text. And so it's kind of
00:05:54.160 | like a clip, if you're familiar with it. But instead of working on the level of entire images
00:05:58.240 | and entire sentences, it was working on the level of individual objects and little pieces of
00:06:02.240 | sentences. And I was embedding them and then calculating a very much like a clip-like loss.
00:06:06.480 | And I dug up the code from 2014 of how I implemented this. And it was already in NumPy
00:06:12.880 | and Python. And here, I'm implementing the cost function. And it was standard to implement not
00:06:18.880 | just the cost, but also the backward pass manually. So here, I'm calculating the image
00:06:23.680 | embeddings, sentence embeddings, the loss function. I calculate the scores. This is the loss function.
00:06:29.920 | And then once I have the loss function, I do the backward pass right here. So I backward
00:06:34.640 | through the loss function and through the neural net. And I append regularization.
00:06:38.720 | So everything was done by hand manually. And you would just write out the backward pass.
00:06:43.680 | And then you would use a gradient checker to make sure that your numerical estimate of the gradient
00:06:47.840 | agrees with the one you calculated during the backpropagation. So this was very standard for
00:06:52.080 | a long time. But today, of course, it is standard to use an autograd engine. But it was definitely
00:06:57.680 | useful. And I think people sort of understood how these neural networks work on a very intuitive
00:07:01.280 | level. And so I think it's a good exercise again. And this is where we want to be.
00:07:05.280 | OK, so just as a reminder from our previous lecture, this is the Jupyter notebook that
00:07:08.800 | we implemented at the time. And we're going to keep everything the same. So we're still going
00:07:14.080 | to have a two-layer multilayer perceptron with a batch normalization layer. So the forward pass
00:07:18.480 | will be basically identical to this lecture. But here, we're going to get rid of loss.backward.
00:07:22.800 | And instead, we're going to write the backward pass manually. Now, here's the starter code for
00:07:27.120 | this lecture. We are becoming a backprop ninja in this notebook. And the first few cells here
00:07:33.600 | are identical to what we are used to. So we are doing some imports, loading in the data set,
00:07:38.160 | and processing the data set. None of this changed. Now, here I'm introducing a utility function that
00:07:43.680 | we're going to use later to compare the gradients. So in particular, we are going to have the
00:07:47.360 | gradients that we estimate manually ourselves. And we're going to have gradients that PyTorch
00:07:51.920 | calculates. And we're going to be checking for correctness, assuming, of course, that PyTorch
00:07:56.240 | is correct. Then here, we have the initialization that we are quite used to. So we have our
00:08:02.400 | embedding table for the characters, the first layer, second layer, and a batch normalization
00:08:07.200 | in between. And here's where we create all the parameters. Now, you will note that I changed
00:08:11.920 | the initialization a little bit to be small numbers. So normally, you would set the biases
00:08:16.800 | to be all zero. Here, I am setting them to be small random numbers. And I'm doing this because
00:08:21.760 | if your variables are initialized to exactly zero, sometimes what can happen is that can mask
00:08:27.120 | an incorrect implementation of a gradient. Because when everything is zero, it sort of
00:08:32.560 | simplifies and gives you a much simpler expression of the gradient than you would otherwise get.
00:08:36.720 | And so by making it small numbers, I'm trying to unmask those potential errors in these
00:08:41.520 | calculations. You also notice that I'm using b1 in the first layer. I'm using a bias, despite
00:08:48.400 | batch normalization right afterwards. So this would typically not be what you do, because we
00:08:53.280 | talked about the fact that you don't need a bias. But I'm doing this here just for fun, because
00:08:58.000 | we're going to have a gradient with respect to it. And we can check that we are still calculating it
00:09:01.520 | correctly, even though this bias is spurious. So here, I'm calculating a single batch. And then
00:09:07.680 | here, I am doing a forward pass. Now, you'll notice that the forward pass is significantly
00:09:12.160 | expanded from what we are used to. Here, the forward pass was just here. Now, the reason that
00:09:18.560 | the forward pass is longer is for two reasons. Number one, here, we just had an f dot cross
00:09:23.280 | entropy. But here, I am bringing back a explicit implementation of the loss function. And number
00:09:28.720 | two, I've broken up the implementation into manageable chunks. So we have a lot more
00:09:35.280 | intermediate tensors along the way in the forward pass. And that's because we are about to go
00:09:39.440 | backwards and calculate the gradients in this back propagation from the bottom to the top.
00:09:44.880 | So we're going to go upwards. And just like we have, for example, the log props tensor
00:09:50.080 | in a forward pass, in a backward pass, we're going to have a d log props, which is going to store the
00:09:54.560 | derivative of the loss with respect to the log props tensor. And so we're going to be pre-pending
00:09:59.040 | d to every one of these tensors and calculating it along the way of this back propagation.
00:10:04.960 | So as an example, we have a b in raw here. We're going to be calculating a d b in raw.
00:10:09.360 | So here, I'm telling PyTorch that we want to retain the grad of all these intermediate values,
00:10:15.360 | because here in exercise one, we're going to calculate the backward pass. So we're going
00:10:19.920 | to calculate all these d variables and use the CMP function I've introduced above to check our
00:10:25.920 | correctness with respect to what PyTorch is telling us. This is going to be exercise one,
00:10:31.040 | where we sort of back propagate through this entire graph. Now, just to give you a very quick
00:10:36.080 | preview of what's going to happen in exercise two and below, here we have fully broken up the loss
00:10:42.240 | and back propagated through it manually in all the little atomic pieces that make it up.
00:10:47.040 | But here we're going to collapse the loss into a single cross entropy call.
00:10:50.720 | And instead, we're going to analytically derive using math and paper and pencil, the gradient of
00:10:58.480 | the loss with respect to the logits. And instead of back propagating through all of its little
00:11:02.320 | chunks one at a time, we're just going to analytically derive what that gradient is,
00:11:06.160 | and we're going to implement that, which is much more efficient, as we'll see in a bit.
00:11:09.840 | Then we're going to do the exact same thing for batch normalization. So instead of breaking up
00:11:14.880 | batch normalization into all the little tiny components, we're going to use pen and paper and
00:11:19.760 | mathematics and calculus to derive the gradient through the batch normal layer. So we're going to
00:11:25.920 | calculate the backward pass through batch normal layer in a much more efficient expression,
00:11:29.680 | instead of backward propagating through all of its little pieces independently.
00:11:32.720 | So that's going to be exercise three. And then in exercise four, we're going to put it all together.
00:11:38.800 | And this is the full code of training this two-layer MLP. And we're going to basically
00:11:43.680 | insert our manual backprop, and we're going to take out lost at backward. And you will basically
00:11:48.720 | see that you can get all the same results using fully your own code. And the only thing we're
00:11:55.440 | using from PyTorch is the torch.tensor to make the calculations efficient. But otherwise, you
00:12:01.360 | will understand fully what it means to forward and backward the neural net and train it. And I think
00:12:05.680 | that'll be awesome. So let's get to it. Okay, so I ran all the cells of this notebook all the way up
00:12:11.120 | to here. And I'm going to erase this. And I'm going to start implementing backward pass, starting
00:12:16.240 | with dlogprops. So we want to understand what should go here to calculate the gradient of the
00:12:21.360 | loss with respect to all the elements of the logprops tensor. Now, I'm going to give away the
00:12:26.560 | answer here. But I wanted to put a quick note here that I think will be most pedagogically useful
00:12:31.120 | for you is to actually go into the description of this video and find the link to this Jupyter
00:12:36.400 | notebook. You can find it both on GitHub, but you can also find Google Colab with it. So you don't
00:12:40.480 | have to install anything, you'll just go to a website on Google Colab. And you can try to
00:12:44.640 | implement these derivatives or gradients yourself. And then if you are not able to come to my video
00:12:50.800 | and see me do it, and so work in tandem and try it first yourself and then see me give away the
00:12:56.640 | answer. And I think that'll be most valuable to you. And that's how I recommend you go through
00:13:00.080 | this lecture. So we are starting here with dlogprops. Now, dlogprops will hold the derivative
00:13:07.120 | of the loss with respect to all the elements of logprops. What is inside logprops? The shape of
00:13:13.360 | this is 32 by 27. So it's not going to surprise you that dlogprops should also be an array of
00:13:20.400 | size 32 by 27, because we want the derivative of the loss with respect to all of its elements.
00:13:25.200 | So the sizes of those are always going to be equal. Now, how does logprops influence the loss?
00:13:33.040 | Loss is negative logprops indexed with range of n and yb and then the mean of that. Now,
00:13:41.920 | just as a reminder, yb is just basically an array of all the correct indices.
00:13:49.760 | So what we're doing here is we're taking the logprops array of size 32 by 27.
00:13:55.120 | And then we are going in every single row. And in each row, we are plucking out the index 8 and
00:14:05.680 | then 14 and 15 and so on. So we're going down the rows. That's the iterator range of n. And
00:14:10.720 | then we are always plucking out the index at the column specified by this tensor yb. So in the
00:14:16.560 | zeroth row, we are taking the eighth column. In the first row, we're taking the 14th column, etc.
00:14:22.400 | And so logprops at this plucks out all those log probabilities of the correct next character in a
00:14:31.200 | sequence. So that's what that does. And the shape of this, or the size of it, is of course 32,
00:14:37.040 | because our batch size is 32. So these elements get plucked out, and then their mean and the
00:14:44.720 | negative of that becomes loss. So I always like to work with simpler examples to understand the
00:14:51.680 | numerical form of the derivative. What's going on here is once we've plucked out these examples,
00:14:57.040 | we're taking the mean and then the negative. So the loss basically,
00:15:03.120 | I can write it this way, is the negative of say a plus b plus c, and the mean of those three
00:15:08.960 | numbers would be say negative, would divide three. That would be how we achieve the mean of three
00:15:13.360 | numbers a, b, c, although we actually have 32 numbers here. And so what is basically the loss
00:15:19.440 | by say like da, right? Well, if we simplify this expression mathematically, this is negative 1 over
00:15:26.080 | 3 of a plus negative 1 over 3 of b plus negative 1 over 3 of c. And so what is the loss by da?
00:15:34.720 | It's just negative 1 over 3. And so you can see that if we don't just have a, b, and c,
00:15:39.280 | but we have 32 numbers, then d loss by d, you know, every one of those numbers is going to be
00:15:45.600 | 1 over n more generally, because n is the size of the batch, 32 in this case. So d loss by
00:15:55.040 | d logprobs is negative 1 over n in all these places. Now, what about the other elements
00:16:03.680 | inside logprobs? Because logprobs is a large array. You see that logprobs.sh is 32 by 27,
00:16:09.360 | but only 32 of them participate in the loss calculation. So what's the derivative of all
00:16:15.600 | the other, most of the elements that do not get plucked out here? Well, their loss intuitively
00:16:21.440 | is zero. Sorry, their gradient intuitively is zero. And that's because they do not participate
00:16:26.160 | in the loss. So most of these numbers inside this tensor does not feed into the loss. And so if we
00:16:32.320 | were to change these numbers, then the loss doesn't change, which is the equivalent of us saying that
00:16:38.320 | the derivative of the loss with respect to them is zero. They don't impact it.
00:16:41.600 | So here's a way to implement this derivative then. We start out with torsdat zeros of shape 32 by 27,
00:16:50.160 | or let's just say, instead of doing this, because we don't want to hard code numbers,
00:16:54.000 | let's do torsdat zeros like logprobs. So basically, this is going to create an array of zeros exactly
00:17:00.320 | in the shape of logprobs. And then we need to set the derivative of negative 1 over n
00:17:06.160 | inside exactly these locations. So here's what we can do. The logprobs indexed in the identical way
00:17:14.240 | will be just set to negative 1 over 0, divide n. Right, just like we derived here.
00:17:21.280 | So now let me erase all of this reasoning. And then this is the candidate derivative
00:17:28.240 | for dlogprobs. Let's uncomment the first line and check that this is correct.
00:17:32.240 | Okay, so CMP ran. And let's go back to CMP. And you see that what it's doing is it's calculating if
00:17:42.880 | the calculated value by us, which is dt, is exactly equal to t.grad as calculated by PyTorch.
00:17:48.800 | And then this is making sure that all of the elements are exactly equal, and then converting
00:17:54.400 | this to a single Boolean value, because we don't want a Boolean tensor, we just want a Boolean
00:17:58.720 | value. And then here, we are making sure that, okay, if they're not exactly equal, maybe they
00:18:04.320 | are approximately equal because of some floating point issues, but they're very, very close.
00:18:08.960 | So here we are using torch.all_close, which has a little bit of a wiggle available, because
00:18:14.160 | sometimes you can get very, very close. But if you use a slightly different calculation, because
00:18:19.040 | of floating point arithmetic, you can get a slightly different result. So this is checking
00:18:24.480 | if you get an approximately close result. And then here, we are checking the maximum,
00:18:28.880 | basically the value that has the highest difference, and what is the difference,
00:18:34.480 | and the absolute value difference between those two. And so we are printing whether we have an
00:18:38.800 | exact equality, an approximate equality, and what is the largest difference. And so here,
00:18:45.600 | we see that we actually have exact equality. And so therefore, of course, we also have an
00:18:50.480 | approximate equality, and the maximum difference is exactly zero. So basically, our DLOG_PROPS
00:18:56.880 | is exactly equal to what PyTorch calculated to be log_props.grad in its backpropagation.
00:19:02.960 | So, so far, we're doing pretty well. Okay, so let's now continue our backpropagation.
00:19:08.640 | We have that log_props depends on probs through a log. So all the elements of probs are being
00:19:14.080 | element-wise applied log_to. Now, if we want DPROPS, then, then remember your micrograd training.
00:19:21.280 | We have like a log node, it takes in probs and creates log_props. And DPROPS will be the local
00:19:28.880 | derivative of that individual operation, log, times the derivative of the loss with respect
00:19:34.240 | to its output, which in this case is DLOG_PROPS. So what is the local derivative of this operation?
00:19:40.160 | Well, we are taking log element-wise, and we can come here and we can see, well, from alpha
00:19:44.400 | is your friend, that d by dx of log of x is just simply one over x. So therefore, in this case,
00:19:50.640 | x is probs. So we have d by dx is one over x, which is one over probs, and then this is the
00:19:56.720 | local derivative, and then times we want to chain it. So this is chain rule, times DLOG_PROPS.
00:20:03.440 | Then let me uncomment this and let me run the cell in place. And we see that the derivative
00:20:08.880 | of probs as we calculated here is exactly correct. And so notice here how this works.
00:20:14.160 | Probs that are, probs is going to be inverted and then element-wise multiplied here.
00:20:19.840 | So if your probs is very, very close to one, that means you are, your network is currently predicting
00:20:25.520 | the character correctly, then this will become one over one, and DLOG_PROPS just gets passed through.
00:20:31.520 | But if your probabilities are incorrectly assigned, so if the correct character here
00:20:36.320 | is getting a very low probability, then 1.0 dividing by it will boost this,
00:20:42.240 | and then multiply by DLOG_PROPS. So basically what this line is doing intuitively is it's taking
00:20:48.320 | the examples that have a very low probability currently assigned, and it's boosting their
00:20:52.720 | gradient. You can look at it that way. Next up is COUNTSUM_INV. So we want
00:21:00.640 | the derivative of this. Now let me just pause here and kind of introduce what's happening here
00:21:05.920 | in general, because I know it's a little bit confusing. We have the logits that come out of
00:21:09.440 | the neural net. Here what I'm doing is I'm finding the maximum in each row, and I'm subtracting it
00:21:15.440 | for the purpose of numerical stability. And we talked about how if you do not do this,
00:21:19.840 | you run into numerical issues if some of the logits take on too large values,
00:21:23.680 | because we end up exponentiating them. So this is done just for safety, numerically.
00:21:29.600 | Then here's the exponentiation of all the logits to create our counts. And then we want to take
00:21:36.640 | the sum of these counts and normalize so that all of the probs sum to 1. Now here, instead of using
00:21:42.640 | 1 over COUNTSUM, I use raised to the power of negative 1. Mathematically, they are identical.
00:21:48.080 | I just found that there's something wrong with the PyTorch implementation of the backward pass
00:21:51.520 | of division, and it gives a weird result. But that doesn't happen for **-1, so I'm using this
00:21:59.120 | formula instead. But basically, all that's happening here is we got the logits, we want
00:22:04.080 | to exponentiate all of them, and we want to normalize the counts to create our probabilities.
00:22:08.880 | It's just that it's happening across multiple lines. So now, here, we want to first take the
00:22:19.360 | derivative, we want to backpropagate into COUNTSUM_INF and then into COUNTS as well.
00:22:24.320 | So what should be the COUNTSUM_INF? Now, we actually have to be careful here, because
00:22:29.840 | we have to scrutinize and be careful with the shapes. So COUNTS.shape and then COUNTSUM_INF.shape
00:22:37.040 | are different. So in particular, COUNTS is 32 by 27, but this COUNTSUM_INF is 32 by 1.
00:22:45.360 | And so in this multiplication here, we also have an implicit broadcasting that PyTorch will do,
00:22:52.240 | because it needs to take this column tensor of 32 numbers and replicate it horizontally 27 times
00:22:57.280 | to align these two tensors so it can do an element-wise multiply.
00:23:00.720 | So really what this looks like is the following, using a toy example again.
00:23:05.040 | What we really have here is just props is COUNTS times COUNTSUM_INF, so it's C equals A times B,
00:23:10.960 | but A is 3 by 3 and B is just 3 by 1, a column tensor. And so PyTorch internally replicated
00:23:18.320 | this elements of B, and it did that across all the columns. So for example, B1, which is the
00:23:24.160 | first element of B, would be replicated here across all the columns in this multiplication.
00:23:28.240 | And now we're trying to backpropagate through this operation to COUNTSUM_INF.
00:23:33.040 | So when we are calculating this derivative, it's important to realize that this looks like a single
00:23:40.880 | operation, but actually is two operations applied sequentially. The first operation that PyTorch
00:23:46.480 | did is it took this column tensor and replicated it across all the columns, basically 27 times.
00:23:54.560 | So that's the first operation, it's a replication. And then the second operation is the
00:23:58.160 | multiplication. So let's first backpropagate through the multiplication. If these two arrays
00:24:04.480 | were of the same size and we just have A and B, both of them 3 by 3, then how do we backpropagate
00:24:11.600 | through a multiplication? So if we just have scalars and not tensors, then if you have C
00:24:16.080 | equals A times B, then what is the derivative of C with respect to B? Well, it's just A.
00:24:22.000 | So that's the local derivative. So here in our case, undoing the multiplication and
00:24:28.240 | backpropagating through just the multiplication itself, which is element-wise, is going to be
00:24:32.640 | the local derivative, which in this case is simply COUNTS, because COUNTS is the A.
00:24:40.160 | So it's the local derivative, and then TIMES, because the chain rule, DPROPS.
00:24:44.240 | So this here is the derivative, or the gradient, but with respect to replicated B.
00:24:50.800 | But we don't have a replicated B, we just have a single B column. So how do we now backpropagate
00:24:57.120 | through the replication? And intuitively, this B1 is the same variable, and it's just reused
00:25:03.440 | multiple times. And so you can look at it as being equivalent to a case we've encountered
00:25:09.440 | in micrograd. And so here, I'm just pulling out a random graph we used in micrograd.
00:25:14.320 | We had an example where a single node has its output feeding into two branches of basically
00:25:21.440 | the graph until the loss function. And we're talking about how the correct thing to do in
00:25:25.520 | the backward pass is we need to sum all the gradients that arrive at any one node. So
00:25:30.960 | across these different branches, the gradients would sum. So if a node is used multiple times,
00:25:36.880 | the gradients for all of its uses sum during backpropagation. So here, B1 is used multiple
00:25:43.120 | times in all these columns, and therefore the right thing to do here is to sum horizontally
00:25:49.280 | across all the rows. So we want to sum in dimension 1, but we want to retain this dimension
00:25:57.200 | so that countSumInv and its gradient are going to be exactly the same shape.
00:26:01.600 | So we want to make sure that we keep them as true so we don't lose this dimension.
00:26:05.840 | And this will make the countSumInv be exactly shape 32 by 1. So revealing this comparison as
00:26:14.080 | well and running this, we see that we get an exact match. So this derivative is exactly correct.
00:26:22.080 | And let me erase this. Now let's also backpropagate into counts, which is the other
00:26:28.720 | variable here to create props. So from props to countSumInv, we just did that.
00:26:33.200 | Let's go into counts as well. So dcounts will be...
00:26:37.120 | dcounts is our A, so dc by dA is just B. So therefore it's countSumInv.
00:26:47.520 | And then times chain rule dprops. Now countSumInv is 32 by 1, dprops is 32 by 27.
00:26:56.720 | So those will broadcast fine and will give us dcounts. There's no additional summation
00:27:04.960 | required here. There will be a broadcasting that happens in this multiply here because
00:27:11.040 | countSumInv needs to be replicated again to correctly multiply dprops. But that's going
00:27:16.560 | to give the correct result as far as this single operation is concerned.
00:27:20.960 | So we've backpropagated from props to counts, but we can't actually check the derivative of counts.
00:27:27.920 | I have it much later on. And the reason for that is because countSumInv depends on counts.
00:27:34.560 | And so there's a second branch here that we have to finish because countSumInv backpropagates into
00:27:39.040 | countSum and countSum will backpropagate into counts. And so counts is a node that is being
00:27:44.320 | used twice. It's used right here in two props and it goes through this other branch through
00:27:48.800 | countSumInv. So even though we've calculated the first contribution of it, we still have
00:27:53.840 | to calculate the second contribution of it later. Okay, so we're continuing with this branch.
00:27:59.120 | We have the derivative for countSumInv. Now we want the derivative of countSum.
00:28:02.640 | So dcountSum equals, what is the local derivative of this operation? So this is basically an
00:28:08.560 | element-wise one over countsSum. So countSum raised to the power of negative one is the same
00:28:14.480 | as one over countsSum. If we go to WolframAlpha, we see that x to the negative one, d by dx of it,
00:28:21.360 | is basically negative x to the negative two. Negative one over x squared is the same as
00:28:26.720 | negative x to the negative two. So dcountSum here will be, local derivative is going to be negative
00:28:36.240 | countsSum to the negative two, that's the local derivative, times chain rule, which is dcountSumInv.
00:28:44.320 | So that's dcountSum. Let's uncomment this and check that I am correct. Okay, so we have perfect
00:28:53.200 | equality. And there's no sketchiness going on here with any shapes because these are of the
00:28:59.440 | same shape. Okay, next up we want to back propagate through this line. We have that
00:29:03.920 | countsSum is counts.sum along the rows. So I wrote out some help here. We have to keep in
00:29:11.440 | mind that counts, of course, is 32 by 27, and countsSum is 32 by one. So in this back propagation,
00:29:17.680 | we need to take this column of derivatives and transform it into an array of derivatives,
00:29:24.640 | two-dimensional array. So what is this operation doing? We're taking some kind of an input,
00:29:30.400 | like say a three-by-three matrix A, and we are summing up the rows into a column tensor B,
00:29:36.240 | B1, B2, B3, that is basically this. So now we have the derivatives of the loss with respect to
00:29:42.720 | B, all the elements of B. And now we want to derive the loss with respect to all these little
00:29:48.400 | As. So how do the Bs depend on the As is basically what we're after. What is the local derivative of
00:29:54.800 | this operation? Well, we can see here that B1 only depends on these elements here. The derivative of
00:30:01.840 | B1 with respect to all of these elements down here is zero. But for these elements here, like
00:30:07.040 | A11, A12, etc., the local derivative is one, right? So DB1 by DA11, for example, is one. So it's one,
00:30:16.000 | one, and one. So when we have the derivative of loss with respect to B1, the local derivative of
00:30:22.720 | B1 with respect to these inputs is zeroes here, but it's one on these guys. So in the chain rule,
00:30:28.880 | we have the local derivative times the derivative of B1. And so because the local derivative is one
00:30:37.360 | on these three elements, the local derivative multiplying the derivative of B1 will just be
00:30:42.400 | the derivative of B1. And so you can look at it as a router. Basically, an addition is a router
00:30:49.440 | of gradient. Whatever gradient comes from above, it just gets routed equally to all the elements
00:30:53.840 | that participate in that addition. So in this case, the derivative of B1 will just flow equally to
00:31:00.080 | the derivative of A11, A12, and A13. So if we have a derivative of all the elements of B in this
00:31:06.800 | column tensor, which is D counts sum that we've calculated just now, we basically see that what
00:31:13.440 | that amounts to is all of these are now flowing to all these elements of A, and they're doing that
00:31:19.600 | horizontally. So basically what we want is we want to take the D counts sum of size 32 by 1,
00:31:25.600 | and we just want to replicate it 27 times horizontally to create 32 by 27 array. So
00:31:32.320 | there's many ways to implement this operation. You could, of course, just replicate the tensor,
00:31:36.400 | but I think maybe one clean one is that D counts is simply torch.once like, so just a
00:31:44.560 | two-dimensional arrays of ones in the shape of counts, so 32 by 27, times D counts sum.
00:31:52.080 | So this way we're letting the broadcasting here basically implement the replication. You can look
00:31:57.840 | at it that way. But then we have to also be careful because D counts was all already calculated.
00:32:05.040 | We calculated earlier here, and that was just the first branch, and we're now finishing the second
00:32:10.000 | branch. So we need to make sure that these gradients add, so plus equals. And then here,
00:32:15.280 | let's comment out the comparison, and let's make sure, crossing fingers, that we have the
00:32:23.360 | correct result. So PyTorch agrees with us on this gradient as well. Okay, hopefully we're getting
00:32:28.880 | a hang of this now. Counts is an element-wise exp of normlogits. So now we want denormlogits.
00:32:35.440 | And because it's an element-wise operation, everything is very simple. What is the local
00:32:40.080 | derivative of e to the x? It's famously just e to the x. So this is the local derivative.
00:32:45.760 | That is the local derivative. Now we already calculated it, and it's inside counts,
00:32:52.240 | so we might as well potentially just reuse counts. That is the local derivative, times D counts.
00:32:58.640 | [typing]
00:33:01.840 | Funny as that looks. Counts times D counts is the derivative on the normlogits.
00:33:05.520 | And now let's erase this, and let's verify, and it looks good.
00:33:10.160 | So that's normlogits. Okay, so we are here on this line now, denormlogits. We have that,
00:33:19.600 | and we're trying to calculate D logits and D logit maxes, so backpropagating through this line.
00:33:25.280 | Now we have to be careful here because the shapes, again, are not the same, and so there's an
00:33:29.440 | implicit broadcasting happening here. So normlogits has the shape of 32 by 27. Logits does as well,
00:33:36.240 | but logit maxes is only 32 by 1. So there's a broadcasting here in the minus. Now here I tried
00:33:43.920 | to sort of write out a toy example again. We basically have that this is our C equals A minus
00:33:49.680 | B, and we see that because of the shape, these are 3 by 3, but this one is just a column.
00:33:55.120 | And so for example, every element of C, we have to look at how it came to be. And every element of
00:34:00.880 | C is just the corresponding element of A minus basically that associated B. So it's very clear
00:34:09.520 | now that the derivatives of every one of these Cs with respect to their inputs are 1 for the
00:34:16.560 | corresponding A, and it's a negative 1 for the corresponding B. And so therefore, the derivatives
00:34:26.080 | on the C will flow equally to the corresponding As and then also to the corresponding Bs,
00:34:32.720 | but then in addition to that, the Bs are broadcast, so we'll have to do the additional sum
00:34:37.280 | just like we did before. And of course, the derivatives for Bs will undergo A minus because
00:34:42.960 | the local derivative here is negative 1. So dC32 by dB3 is negative 1. So let's just implement that.
00:34:51.760 | Basically, dLogits will be exactly copying the derivative on normLogits. So dLogits equals
00:35:01.840 | dNormLogits, and I'll do a dot clone for safety, so we're just making a copy.
00:35:06.720 | And then we have that dLogitmaxis will be the negative of dNormLogits because of the negative
00:35:14.560 | sign. And then we have to be careful because Logitmaxis is a column. And so just like we saw
00:35:22.320 | before, because we keep replicating the same elements across all the columns, then in the
00:35:29.280 | backward pass, because we keep reusing this, these are all just like separate branches of use of that
00:35:35.120 | one variable. And so therefore, we have to do a sum along one, we'd keep them equals true,
00:35:40.320 | so that we don't destroy this dimension. And then dLogitmaxis will be the same shape.
00:35:45.680 | Now, we have to be careful because this dLogits is not the final dLogits, and that's because
00:35:51.040 | not only do we get gradient signal into Logits through here, but Logitmaxis is a function of
00:35:57.200 | Logits, and that's a second branch into Logits. So this is not yet our final derivative for Logits,
00:36:02.720 | we will come back later for the second branch. For now, dLogitmaxis is the final derivative.
00:36:08.080 | So let me uncomment this CMP here, and let's just run this. And Logitmaxis, if PyTorch agrees with
00:36:15.600 | us. So that was the derivative into through this line. Now, before we move on, I want to pause here
00:36:22.960 | briefly, and I want to look at these Logitmaxis and especially their gradients. We've talked
00:36:27.680 | previously in the previous lecture, that the only reason we're doing this is for the numerical
00:36:32.160 | stability of the softmax that we are implementing here. And we talked about how if you take these
00:36:37.520 | Logits for any one of these examples, so one row of this Logits tensor, if you add or subtract any
00:36:43.680 | value equally to all the elements, then the value of the probs will be unchanged. You're not changing
00:36:49.760 | the softmax. The only thing that this is doing is it's making sure that exp doesn't overflow.
00:36:54.480 | And the reason we're using a max is because then we are guaranteed that each row of Logits,
00:36:59.200 | the highest number, is zero. And so this will be safe. And so basically what that has repercussions.
00:37:08.560 | If it is the case that changing Logitmaxis does not change the probs, and therefore does not
00:37:14.480 | change the loss, then the gradient on Logitmaxis should be zero. Because saying those two things
00:37:20.160 | is the same. So indeed, we hope that this is very, very small numbers. Indeed, we hope this is zero.
00:37:26.000 | Now, because of floating point sort of wonkiness, this doesn't come out exactly zero. Only in some
00:37:31.840 | of the rows it does. But we get extremely small values, like 1, e, -9, or 10. And so this is
00:37:37.440 | telling us that the values of Logitmaxis are not impacting the loss, as they shouldn't.
00:37:42.240 | It feels kind of weird to backpropagate through this branch, honestly, because
00:37:46.640 | if you have any implementation of like f.crossentropy in PyTorch, and you block together
00:37:53.200 | all of these elements, and you're not doing the backpropagation piece by piece,
00:37:56.640 | then you would probably assume that the derivative through here is exactly zero.
00:38:00.560 | So you would be sort of skipping this branch, because it's only done for numerical stability.
00:38:09.200 | But it's interesting to see that even if you break up everything into the full atoms,
00:38:12.960 | and you still do the computation as you'd like with respect to numerical stability,
00:38:16.800 | the correct thing happens. And you still get very, very small gradients here,
00:38:21.680 | basically reflecting the fact that the values of these do not matter with respect to the final loss.
00:38:26.960 | Okay, so let's now continue backpropagation through this line here. We've just calculated
00:38:31.760 | the Logitmaxis, and now we want to backprop into Logits through this second branch. Now here,
00:38:37.040 | of course, we took Logits, and we took the max along all the rows, and then we looked at its
00:38:41.920 | values here. Now the way this works is that in PyTorch, this thing here, the max returns both
00:38:50.800 | the values, and it returns the indices at which those values to count the maximum value. Now,
00:38:56.320 | in the forward pass, we only used values, because that's all we needed. But in the backward pass,
00:39:00.720 | it's extremely useful to know about where those maximum values occurred. And we have the indices
00:39:06.560 | at which they occurred. And this will, of course, help us do the backpropagation. Because what
00:39:11.600 | should the backward pass be here in this case? We have the Logis tensor, which is 32 by 27,
00:39:16.800 | and in each row, we find the maximum value, and then that value gets plucked out into Logitmaxis.
00:39:21.840 | And so intuitively, basically, the derivative flowing through here then should be 1 times
00:39:32.320 | the local derivative is 1 for the appropriate entry that was plucked out, and then times the
00:39:38.400 | global derivative of the Logitmaxis. So really what we're doing here, if you think through it,
00:39:42.960 | is we need to take the DLogitmaxis, and we need to scatter it to the correct positions
00:39:48.320 | in these Logits from where the maximum values came. And so I came up with one line of code
00:39:57.920 | that does that. Let me just erase a bunch of stuff here. So the line of-- you could do it
00:40:02.560 | very similar to what we've done here, where we create a zeros, and then we populate the correct
00:40:07.680 | elements. So we use the indices here, and we would set them to be 1. But you can also use one-hot.
00:40:14.400 | So f.one-hot, and then I'm taking the Logits.max over the first dimension, dot indices, and I'm
00:40:22.400 | telling PyTorch that the dimension of every one of these tensors should be 27. And so what this
00:40:31.600 | is going to do is-- okay, I apologize, this is crazy. plt.imshow of this. It's really just an
00:40:40.640 | array of where the maxis came from in each row, and that element is 1, and all the other elements
00:40:46.320 | are 0. So it's a one-hot vector in each row, and these indices are now populating a single 1 in
00:40:52.640 | the proper place. And then what I'm doing here is I'm multiplying by the Logitmaxis. And keep in
00:40:58.480 | mind that this is a column of 32 by 1. And so when I'm doing this times the Logitmaxis, the Logitmaxis
00:41:08.240 | will broadcast, and that column will get replicated, and then an element-wise multiply will ensure
00:41:14.320 | that each of these just gets routed to whichever one of these bits is turned on. And so that's
00:41:19.760 | another way to implement this kind of an operation. And both of these can be used. I just thought I
00:41:27.200 | would show an equivalent way to do it. And I'm using += because we already calculated the logits
00:41:32.240 | here, and this is now the second branch. So let's look at logits and make sure that this is correct.
00:41:39.600 | And we see that we have exactly the correct answer. Next up, we want to continue with logits
00:41:46.720 | here. That is an outcome of a matrix multiplication and a bias offset in this linear layer. So I've
00:41:54.880 | printed out the shapes of all these intermediate tensors. We see that logits is of course 32 by 27,
00:42:00.320 | as we've just seen. Then the h here is 32 by 64. So these are 64-dimensional hidden states.
00:42:07.120 | And then this w matrix projects those 64-dimensional vectors into 27 dimensions.
00:42:13.040 | And then there's a 27-dimensional offset, which is a one-dimensional vector. Now we should note
00:42:19.120 | that this plus here actually broadcasts, because h multiplied by w2 will give us a 32 by 27.
00:42:26.320 | And so then this plus b2 is a 27-dimensional vector here. Now in the rules of broadcasting,
00:42:33.520 | what's going to happen with this bias vector is that this one-dimensional vector of 27
00:42:38.240 | will get aligned with a padded dimension of 1 on the left. And it will basically become a row
00:42:44.160 | vector. And then it will get replicated vertically 32 times to make it 32 by 27. And then there's an
00:42:50.400 | element-wise multiply. Now the question is, how do we back propagate from logits to the hidden
00:42:58.320 | states, the weight matrix w2, and the bias b2? And you might think that we need to go to some
00:43:04.400 | matrix calculus, and then we have to look up the derivative for a matrix multiplication.
00:43:10.640 | But actually, you don't have to do any of that. And you can go back to first principles and derive
00:43:14.240 | this yourself on a piece of paper. And specifically what I like to do, and what I find works well for
00:43:19.680 | me, is you find a specific small example that you then fully write out. And then in the process of
00:43:25.360 | analyzing how that individual small example works, you will understand the broader pattern.
00:43:29.760 | And you'll be able to generalize and write out the full general formula for how these derivatives
00:43:35.760 | flow in an expression like this. So let's try that out. So pardon the low budget production here,
00:43:41.120 | but what I've done here is I'm writing it out on a piece of paper. Really what we are interested
00:43:45.600 | in is we have a multiply b plus c, and that creates a d. And we have the derivative of the
00:43:52.800 | loss with respect to d, and we'd like to know what the derivative of the loss is with respect to a,
00:43:56.480 | b, and c. Now these here are little two-dimensional examples of a matrix
00:44:01.200 | multiplication. 2 by 2 times a 2 by 2 plus a 2, a vector of just two elements, c1 and c2,
00:44:08.960 | gives me a 2 by 2. Now notice here that I have a bias vector here called c, and the bias vector
00:44:16.880 | is c1 and c2. But as I described over here, that bias vector will become a row vector in the
00:44:21.920 | broadcasting and will replicate vertically. So that's what's happening here as well. c1, c2
00:44:26.880 | is replicated vertically, and we see how we have two rows of c1, c2 as a result.
00:44:31.760 | So now when I say write it out, I just mean like this. Basically break up this matrix
00:44:37.920 | multiplication into the actual thing that's going on under the hood. So as a result of matrix
00:44:43.920 | multiplication and how it works, d11 is the result of a dot product between the first row of a and
00:44:49.600 | the first column of b. So a11, b11 plus a12, b21 plus c1, and so on and so forth for all the other
00:44:59.840 | elements of d. And once you actually write it out, it becomes obvious this is just a bunch of
00:45:04.320 | multiplies and adds. And we know from micrograd how to differentiate multiplies and adds. And so
00:45:11.440 | this is not scary anymore. It's not just matrix multiplication. It's just tedious, unfortunately,
00:45:16.640 | but this is completely tractable. We have dl by d for all of these, and we want dl by all these
00:45:23.280 | little other variables. So how do we achieve that and how do we actually get the gradients?
00:45:27.280 | Okay, so the low budget production continues here. So let's, for example, derive the derivative of
00:45:33.040 | the loss with respect to a11. We see here that a11 occurs twice in our simple expression,
00:45:39.440 | right here, right here, and influences d11 and d12. So what is dl by d a11? Well, it's dl by d11
00:45:49.440 | times the local derivative of d11, which in this case is just b11, because that's what's multiplying
00:45:56.080 | a11 here. And likewise here, the local derivative of d12 with respect to a11 is just b12. And so
00:46:04.240 | b12 will, in the chain rule, therefore, multiply dl by d12. And then because a11 is used both to
00:46:11.440 | produce d11 and d12, we need to add up the contributions of both of those sort of chains
00:46:18.720 | that are running in parallel. And that's why we get a plus, just adding up those two contributions.
00:46:25.680 | And that gives us dl by d a11. We can do the exact same analysis for the other one,
00:46:31.280 | for all the other elements of A. And when you simply write it out, it's just super simple
00:46:36.240 | taking of gradients on expressions like this. You find that this matrix dl by da that we're after,
00:46:47.360 | right, if we just arrange all of them in the same shape as A takes, so A is just a 2x2 matrix,
00:46:53.680 | so dl by da here will be also just the same shape tensor with the derivatives now, so dl by da11,
00:47:04.000 | etc. And we see that actually we can express what we've written out here as a matrix multiply.
00:47:09.920 | And so it just so happens that all of these formulas that we've derived here by taking
00:47:16.400 | gradients can actually be expressed as a matrix multiplication. And in particular,
00:47:20.720 | we see that it is the matrix multiplication of these two matrices. So it is the dl by d
00:47:28.480 | and then matrix multiplying B, but B transpose actually. So you see that B21 and B12 have
00:47:36.640 | changed place, whereas before we had, of course, B11, B12, B21, B22. So you see that this other
00:47:44.880 | matrix B is transposed. And so basically what we have, long story short, just by doing very simple
00:47:50.800 | reasoning here, by breaking up the expression in the case of a very simple example, is that dl by
00:47:56.480 | da is, which is this, is simply equal to dl by dd matrix multiplied with B transpose.
00:48:03.520 | So that is what we have so far. Now we also want the derivative with respect to
00:48:10.800 | B and C. Now for B, I'm not actually doing the full derivation because honestly, it's not deep.
00:48:18.800 | It's just annoying. It's exhausting. You can actually do this analysis yourself. You'll
00:48:23.680 | also find that if you take these expressions and you differentiate with respect to B instead of A,
00:48:28.240 | you will find that dl by db is also a matrix multiplication. In this case, you have to take
00:48:33.760 | the matrix A and transpose it and matrix multiply that with dl by dd. And that's what gives you the
00:48:40.720 | dl by db. And then here for the offsets C1 and C2, if you again just differentiate with respect to C1,
00:48:48.560 | you will find an expression like this and C2, an expression like this. And basically you'll
00:48:55.840 | find that dl by dc is simply, because they're just offsetting these expressions, you just have
00:49:01.360 | to take the dl by dd matrix of the derivatives of d and you just have to sum across the columns.
00:49:10.480 | And that gives you the derivatives for C. So long story short, the backward pass of a matrix
00:49:17.680 | multiply is a matrix multiply. And instead of, just like we had d equals A times B plus C,
00:49:23.040 | in a scalar case, we sort of like arrive at something very, very similar, but now
00:49:28.080 | with a matrix multiplication instead of a scalar multiplication. So the derivative of
00:49:34.720 | d with respect to A is dl by dd matrix multiply B transpose. And here it's A transpose multiply dl
00:49:44.320 | by dd. But in both cases, it's a matrix multiplication with the derivative and the
00:49:49.760 | other term in the multiplication. And for C, it is a sum. Now I'll tell you a secret. I can never
00:49:58.560 | remember the formulas that we just derived for backpropagating from matrix multiplication,
00:50:02.560 | and I can backpropagate through these expressions just fine. And the reason this works is because
00:50:07.040 | the dimensions have to work out. So let me give you an example. Say I want to create dh.
00:50:12.640 | Then what should dh be? Number one, I have to know that the shape of dh must be the same as
00:50:19.760 | the shape of h. And the shape of h is 32 by 64. And then the other piece of information I know
00:50:25.760 | is that dh must be some kind of matrix multiplication of d logits with w2.
00:50:32.560 | And d logits is 32 by 27, and w2 is 64 by 27. There is only a single way to make the shape
00:50:40.880 | work out in this case, and it is indeed the correct result. In particular here, h needs to
00:50:47.440 | be 32 by 64. The only way to achieve that is to take a d logits and matrix multiply it with…
00:50:54.400 | You see how I have to take w2, but I have to transpose it to make the dimensions work out.
00:51:00.240 | So w2 transpose. And it is the only way to matrix multiply those two pieces to make the
00:51:05.680 | shapes work out. And that turns out to be the correct formula. So if we come here,
00:51:09.520 | we want dh, which is dA. And we see that dA is dL by dD matrix multiply B transpose.
00:51:17.120 | So that is d logits multiply, and B is w2, so w2 transpose, which is exactly what we have here.
00:51:24.880 | So there is no need to remember these formulas. Similarly, now if I want dw2,
00:51:30.960 | well I know that it must be a matrix multiplication of d logits and h. And maybe there is a few
00:51:38.400 | transpose… Like there is one transpose in there as well. And I do not know which way it is,
00:51:41.920 | so I have to come to w2. And I see that its shape is 64 by 27, and that has to come from
00:51:48.480 | some matrix multiplication of these two. And so to get a 64 by 27, I need to take h, I need to
00:51:57.680 | transpose it, and then I need to matrix multiply it. So that will become 64 by 32. And then I need
00:52:04.080 | to matrix multiply it with 32 by 27. And that is going to give me a 64 by 27. So I need to matrix
00:52:09.440 | multiply this with d logits dot shape, just like that. That is the only way to make the dimensions
00:52:14.000 | work out, and just use matrix multiplication. And if we come here, we see that that is exactly
00:52:19.680 | what is here. So a transpose, a for us is h, multiplied with d logits. So that is w2. And then
00:52:28.480 | db2 is just the vertical sum. And actually, in the same way, there is only one way to make the
00:52:37.120 | shapes work out. I do not have to remember that it is a vertical sum along the 0th axis, because
00:52:41.840 | that is the only way that this makes sense. Because b2 shape is 27, so in order to get a
00:52:47.600 | d logits here, it is 32 by 27. So knowing that it is just sum over d logits in some direction,
00:52:56.480 | that direction must be 0, because I need to eliminate this dimension. So it is this.
00:53:04.720 | So this is kind of like the hacky way. Let me copy, paste, and delete that.
00:53:11.040 | And let me swing over here. And this is our backward pass for the linear layer,
00:53:15.120 | hopefully. So now let us uncomment these three. And we are checking that we
00:53:21.040 | got all the three derivatives correct. And run. And we see that h, w2, and b2 are all exactly
00:53:30.720 | correct. So we backpropagated through a linear layer. Now next up, we have derivative for the
00:53:38.960 | h already. And we need to backpropagate through tanh into hpreact. So we want to derive dhpreact.
00:53:46.000 | And here we have to backpropagate through a tanh. And we have already done this in micrograd.
00:53:51.120 | And we remember that tanh is a very simple backward formula. Now unfortunately,
00:53:55.600 | if I just put in d by dx of tanh of x into Boltram alpha, it lets us down. It tells us that it is a
00:54:00.960 | hyperbolic secant function squared of x. It is not exactly helpful. But luckily, Google image
00:54:06.800 | search does not let us down. And it gives us the simpler formula. And in particular, if you have
00:54:11.440 | that a is equal to tanh of z, then da by dz, backpropagating through tanh, is just 1 minus a
00:54:18.160 | squared. And take note that 1 minus a squared, a here is the output of the tanh, not the input to
00:54:24.800 | the tanh, z. So the da by dz is here formulated in terms of the output of that tanh. And here also,
00:54:32.960 | in Google image search, we have the full derivation if you want to actually take the
00:54:36.400 | actual definition of tanh and work through the math to figure out 1 minus tanh squared of z.
00:54:41.440 | So 1 minus a squared is the local derivative. In our case, that is 1 minus the output of tanh
00:54:50.720 | squared, which here is h. So it's h squared. And that is the local derivative. And then times the
00:54:58.160 | chain rule, dh. So that is going to be our candidate implementation. So if we come here
00:55:04.160 | and then uncomment this, let's hope for the best. And we have the right answer.
00:55:11.280 | Okay, next up, we have dhpreact. And we want to backpropagate into the gain, the bn_raw,
00:55:17.600 | and the bn_bias. So here, this is the bash norm parameters, bn_gain and bn_bias inside the bash
00:55:23.120 | norm that take the bn_raw that is exact unit Gaussian, and they scale it and shift it. And
00:55:29.520 | these are the parameters of the bash norm. Now, here, we have a multiplication. But it's worth
00:55:34.880 | noting that this multiply is very, very different from this matrix multiply here. Matrix multiply
00:55:39.920 | are dot products between rows and columns of these matrices involved. This is an element-wise
00:55:45.200 | multiply. So things are quite a bit simpler. Now, we do have to be careful with some of the
00:55:49.440 | broadcasting happening in this line of code, though. So you see how bn_gain and bn_bias are
00:55:55.760 | 1 by 64, but dhpreact and bn_raw are 32 by 64. So we have to be careful with that and make sure that
00:56:04.080 | all the shapes work out fine and that the broadcasting is correctly backpropagated.
00:56:07.600 | So in particular, let's start with dbn_gain. So dbn_gain should be, and here, this is again,
00:56:16.240 | element-wise multiply. And whenever we have a times b equals c, we saw that the local derivative
00:56:21.680 | here is just, if this is a, the local derivative is just the b, the other one. So the local
00:56:26.880 | derivative is just bn_raw and then times chain rule. So dhpreact. So this is the candidate
00:56:36.640 | gradient. Now, again, we have to be careful because bn_gain is of size 1 by 64. But this here
00:56:45.280 | would be 32 by 64. And so the correct thing to do in this case, of course, is that bn_gain,
00:56:53.120 | here is a rule vector of 64 numbers, it gets replicated vertically in this operation.
00:56:57.840 | And so therefore, the correct thing to do is to sum because it's being replicated.
00:57:02.960 | And therefore, all the gradients in each of the rows that are now flowing backwards need to sum
00:57:09.040 | up to that same tensor dbn_gain. So we have to sum across all the zero, all the examples,
00:57:16.640 | basically, which is the direction in which this gets replicated. And now we have to be also
00:57:21.280 | careful because bn_gain is of shape 1 by 64. So in fact, I need to keep them as true. Otherwise,
00:57:29.840 | I would just get 64. Now, I don't actually really remember why the bn_gain and the bn_bias,
00:57:36.320 | I made them be 1 by 64. But the biases b1 and b2, I just made them be one dimensional vectors,
00:57:45.280 | they're not two dimensional tensors. So I can't recall exactly why I left the gain and the bias
00:57:51.520 | as two dimensional. But it doesn't really matter as long as you are consistent and you're keeping
00:57:55.440 | it the same. So in this case, we want to keep the dimension so that the tensor shapes work.
00:58:01.360 | Next up, we have bn_raw. So dbn_raw will be bn_gain multiplying dh_preact. That's our chain
00:58:13.840 | rule. Now, what about the dimensions of this? We have to be careful, right? So dh_preact is 32 by
00:58:22.960 | 64, bn_gain is 1 by 64. So it will just get replicated to create this multiplication,
00:58:30.880 | which is the correct thing because in a forward pass, it also gets replicated in just the same
00:58:34.880 | way. So in fact, we don't need the brackets here, we're done. And the shapes are already correct.
00:58:40.080 | And finally, for the bias, very similar. This bias here is very, very similar to the bias we saw in
00:58:47.600 | the linear layer. And we see that the gradients from h_preact will simply flow into the biases
00:58:53.280 | and add up because these are just offsets. And so basically, we want this to be dh_preact,
00:59:00.000 | but it needs to sum along the right dimension. And in this case, similar to the gain,
00:59:04.960 | we need to sum across the zeroth dimension, the examples, because of the way that the
00:59:09.280 | bias gets replicated vertically. And we also want to have keep_them as true. And so this will
00:59:15.920 | basically take this and sum it up and give us a 1 by 64. So this is the candidate implementation
00:59:22.880 | and makes all the shapes work. Let me bring it up down here. And then let me uncomment these three
00:59:30.240 | lines to check that we are getting the correct result for all the three tensors. And indeed,
00:59:36.400 | we see that all of that got backpropagated correctly. So now we get to the batch norm layer.
00:59:41.040 | We see how here bn_gain and bn_bias are the parameters, so the backpropagation ends.
00:59:46.640 | But bn_raw now is the output of the standardization. So here, what I'm doing, of course,
00:59:52.960 | is I'm breaking up the batch norm into manageable pieces so we can backpropagate through each line
00:59:57.040 | individually. But basically, what's happening is bn_mean_i is the sum. So this is the bn_mean_i.
01:00:06.160 | I apologize for the variable naming. bn_diff is x minus mu. bn_diff_2 is x minus mu squared
01:00:14.560 | here inside the variance. bn_var is the variance, so sigma squared. This is bn_var.
01:00:21.520 | And it's basically the sum of squares. So this is the x minus mu squared and then the sum.
01:00:28.800 | Now, you'll notice one departure here. Here, it is normalized as 1 over m,
01:00:34.160 | which is the number of examples. Here, I am normalizing as 1 over n minus 1 instead of m.
01:00:40.720 | And this is deliberate, and I'll come back to that in a bit when we are at this line.
01:00:44.640 | It is something called the Bessel's correction, but this is how I want it in our case.
01:00:49.600 | bn_var_inv then becomes basically bn_var plus epsilon. Epsilon is 1, negative 5.
01:00:57.120 | And then its 1 over square root is the same as raising to the power of negative 0.5,
01:01:03.680 | right? Because 0.5 is square root. And then negative makes it 1 over square root.
01:01:08.800 | So bn_var_inv is 1 over this denominator here. And then we can see that bn_raw, which is the x hat
01:01:16.080 | here, is equal to the bn_diff, the numerator, multiplied by the bn_var_inv. And this line here
01:01:26.160 | that creates H preact was the last piece we've already backpropagated through it.
01:01:29.920 | So now what we want to do is we are here, and we have bn_raw, and we have to first backpropagate
01:01:37.040 | into bn_diff and bn_var_inv. So now we're here, and we have dbn_raw, and we need to backpropagate
01:01:44.880 | through this line. Now, I've written out the shapes here, and indeed bn_var_inv is a shape 1
01:01:51.680 | by 64, so there is a broadcasting happening here that we have to be careful with. But it is just
01:01:57.520 | an element-wise simple multiplication. By now, we should be pretty comfortable with that. To get
01:02:02.000 | dbn_diff, we know that this is just bn_var_inv multiplied with dbn_raw. And conversely, to get
01:02:13.200 | dbn_var_inv, we need to take bn_diff and multiply that by dbn_raw. So this is the candidate, but of
01:02:24.240 | course we need to make sure that broadcasting is obeyed. So in particular, bn_var_inv multiplying
01:02:30.080 | with dbn_raw will be okay and give us 32 by 64 as we expect. But dbn_var_inv would be taking a 32
01:02:40.160 | by 64, multiplying it by 32 by 64. So this is a 32 by 64. But of course this bn_var_inv is only 1
01:02:50.640 | by 64. So the second line here needs a sum across the examples, and because there's this dimension
01:02:58.400 | here, we need to make sure that keep_dim is true. So this is the candidate. Let's erase this and
01:03:06.480 | let's swing down here and implement it. And then let's comment out dbn_var_inv and dbn_diff.
01:03:14.800 | Now, we'll actually notice that dbn_diff, by the way, is going to be incorrect. So when I run this,
01:03:24.640 | bn_var_inv is correct. bn_diff is not correct. And this is actually expected, because we're not
01:03:31.440 | done with bn_diff. So in particular, when we slide here, we see here that bn_raw is a function of bn_diff,
01:03:38.480 | but actually bn_var_inv is a function of bn_var, which is a function of bn_diff_do,
01:03:43.360 | which is a function of bn_diff. So it comes here. So bdn_diff, these variable names are crazy,
01:03:50.400 | I'm sorry. It branches out into two branches, and we've only done one branch of it. We have to
01:03:55.600 | continue our backpropagation and eventually come back to bn_diff, and then we'll be able to do a
01:03:59.680 | += and get the actual correct gradient. For now, it is good to verify that cmp also works. It doesn't
01:04:06.160 | just lie to us and tell us that everything is always correct. It can in fact detect when your
01:04:11.440 | gradient is not correct. So that's good to see as well. Okay, so now we have the derivative here,
01:04:16.480 | and we're trying to backpropagate through this line. And because we're raising to a power of
01:04:20.880 | -0.5, I brought up the power rule. And we see that basically we have that the bn_var will now be,
01:04:27.360 | we bring down the exponent, so -0.5 times x, which is this, and now raise to the power of -0.5-1,
01:04:37.360 | which is -1.5. Now, we would have to also apply a small chain rule here in our head,
01:04:44.320 | because we need to take further derivative of bn_var with respect to this expression here
01:04:49.840 | inside the bracket. But because this is an element-wise operation, and everything is
01:04:53.680 | fairly simple, that's just 1. And so there's nothing to do there. So this is the local
01:04:58.560 | derivative, and then times the global derivative to create the chain rule. This is just times the
01:05:03.760 | bn_var. So this is our candidate. Let me bring this down and uncomment the check.
01:05:11.920 | And we see that we have the correct result. Now, before we backpropagate through the next line,
01:05:19.440 | I want to briefly talk about the note here, where I'm using the Bessel's correction,
01:05:22.640 | dividing by n-1, instead of dividing by n, when I normalize here the sum of squares.
01:05:29.680 | Now, you'll notice that this is a departure from the paper, which uses 1/n instead,
01:05:34.000 | not 1/n-1. There, m is our n. And so it turns out that there are two ways of estimating variance
01:05:42.400 | of an array. One is the biased estimate, which is 1/n, and the other one is the unbiased estimate,
01:05:49.120 | which is 1/n-1. Now, confusingly, in the paper, this is not very clearly described,
01:05:55.760 | and also it's a detail that kind of matters, I think. They are using the biased version
01:06:00.400 | at training time, but later, when they are talking about the inference, they are mentioning that when
01:06:05.360 | they do the inference, they are using the unbiased estimate, which is the n-1 version, basically,
01:06:13.440 | for inference, and to calibrate the running mean and the running variance, basically.
01:06:20.080 | And so they actually introduce a train-test mismatch, where in training, they use the
01:06:24.320 | biased version, and in test time, they use the unbiased version. I find this extremely confusing.
01:06:30.000 | You can read more about the Bessel's correction and why dividing by n-1 gives you a better
01:06:35.680 | estimate of the variance in the case where you have population sizes or samples for a population
01:06:40.400 | that are very small. And that is indeed the case for us, because we are dealing with mini-batches,
01:06:46.960 | and these mini-batches are a small sample of a larger population, which is the entire training
01:06:51.760 | set. And so it just turns out that if you just estimate it using 1/n, that actually almost always
01:06:58.000 | underestimates the variance. And it is a biased estimator, and it is advised that you use the
01:07:02.800 | unbiased version and divide by n-1. And you can go through this article here that I liked
01:07:07.680 | that actually describes the full reasoning, and I'll link it in the video description.
01:07:11.120 | Now, when you calculate the torsion variance, you'll notice that they take the unbiased flag,
01:07:17.360 | whether or not you want to divide by n or n-1. Confusingly, they do not mention what the default
01:07:24.080 | is for unbiased, but I believe unbiased by default is true. I'm not sure why the docs
01:07:29.440 | here don't cite that. Now, in the batch norm 1D, the documentation again is kind of wrong
01:07:35.840 | and confusing. It says that the standard deviation is calculated via the biased estimator,
01:07:40.560 | but this is actually not exactly right, and people have pointed out that it is not right
01:07:45.040 | in a number of issues since then, because actually the rabbit hole is deeper, and they follow the
01:07:50.960 | paper exactly, and they use the biased version for training. But when they're estimating the
01:07:55.840 | running standard deviation, they are using the unbiased version. So again, there's the train
01:08:00.400 | test mismatch. So long story short, I'm not a fan of train test discrepancies. I basically kind of
01:08:07.200 | consider the fact that we use the biased version, the training time, and the unbiased test time,
01:08:13.040 | I basically consider this to be a bug, and I don't think that there's a good reason for that.
01:08:16.800 | They don't really go into the detail of the reasoning behind it in this paper. So that's
01:08:22.240 | why I basically prefer to use the Bessel's correction in my own work. Unfortunately,
01:08:27.360 | batch norm does not take a keyword argument that tells you whether or not you want to use the
01:08:32.400 | unbiased version or the biased version in both train and test, and so therefore anyone using
01:08:36.560 | batch normalization basically in my view has a bit of a bug in the code. And this turns out to
01:08:43.040 | be much less of a problem if your mini batch sizes are a bit larger. But still, I just find it kind
01:08:48.560 | of unpalatable. So maybe someone can explain why this is okay. But for now, I prefer to use the
01:08:54.320 | unbiased version consistently both during training and at test time, and that's why I'm using 1/n-1
01:09:00.720 | here. Okay, so let's now actually backpropagate through this line. So the first thing that I
01:09:07.840 | always like to do is I like to scrutinize the shapes first. So in particular here, looking at
01:09:12.480 | the shapes of what's involved, I see that bn_var shape is 1 by 64, so it's a row vector, and bn_div2.shape
01:09:20.640 | is 32 by 64. So clearly here we're doing a sum over the zeroth axis to squash the first dimension
01:09:29.600 | of the shapes here using a sum. So that right away actually hints to me that there will be some kind
01:09:36.240 | of a replication or broadcasting in the backward pass. And maybe you're noticing the pattern here,
01:09:41.120 | but basically anytime you have a sum in the forward pass, that turns into a replication
01:09:46.720 | or broadcasting in the backward pass along the same dimension. And conversely, when we have a
01:09:51.840 | replication or a broadcasting in the forward pass, that indicates a variable reuse. And so in the
01:09:58.880 | backward pass, that turns into a sum over the exact same dimension. And so hopefully you're
01:10:03.600 | noticing that duality, that those two are kind of like the opposites of each other in the forward
01:10:07.360 | and the backward pass. Now once we understand the shapes, the next thing I like to do always is I
01:10:12.480 | like to look at a toy example in my head to sort of just like understand roughly how the variable
01:10:18.240 | dependencies go in the mathematical formula. So here we have a two-dimensional array, bn_div2,
01:10:25.520 | which we are scaling by a constant, and then we are summing vertically over the columns. So if
01:10:31.840 | we have a 2x2 matrix A and then we sum over the columns and scale, we would get a row vector b1,
01:10:37.120 | b2, and b1 depends on A in this way, where it's just sum that is scaled of A, and b2 in this way,
01:10:45.200 | where it's the second column summed and scaled. And so looking at this basically, what we want
01:10:52.400 | to do now is we have the derivatives on b1 and b2, and we want to back propagate them into A's.
01:10:57.600 | And so it's clear that just differentiating in your head, the local derivative here is 1 over
01:11:01.920 | n minus 1 times 1 for each one of these A's. And basically the derivative of b1 has to flow
01:11:11.600 | through the columns of A scaled by 1 over n minus 1. And that's roughly what's happening here.
01:11:18.560 | So intuitively, the derivative flow tells us that dbn_div2 will be the local derivative of
01:11:27.280 | this operation. And there are many ways to do this, by the way, but I like to do something
01:11:30.960 | like this, torch.once_like of bn_div2. So I'll create a large array, two-dimensional, of ones,
01:11:38.640 | and then I will scale it. So 1.0 divide by n minus 1. So this is an array of 1 over n minus 1.
01:11:48.480 | And that's sort of like the local derivative. And now for the chain rule, I will simply just
01:11:53.360 | multiply it by dbn_var. And notice here what's going to happen. This is 32 by 64,
01:12:01.440 | and this is just 1 by 64. So I'm letting the broadcasting do the replication,
01:12:07.040 | because internally in PyTorch, basically dbn_var, which is 1 by 64 row vector,
01:12:13.040 | will in this multiplication get copied vertically until the two are of the same shape, and then
01:12:18.960 | there will be an element-wise multiply. And so the broadcasting is basically doing the replication.
01:12:24.480 | And I will end up with the derivatives of dbn_div2 here. So this is the candidate solution.
01:12:31.920 | Let's bring it down here. Let's uncomment this line where we check it, and let's hope for the
01:12:38.080 | best. And indeed, we see that this is the correct formula. Next up, let's differentiate here into
01:12:44.320 | bn_div. So here we have that bn_div is element-wise squared to create bn_div2. So this is a relatively
01:12:52.320 | simple derivative, because it's a simple element-wise operation. So it's kind of like the
01:12:55.920 | scalar case. And we have that dbn_div should be, if this is x squared, then the derivative of this
01:13:02.640 | is 2x. So it's simply 2 times bn_div, that's the local derivative, and then times chain rule. And
01:13:10.560 | the shape of these is the same. They are of the same shape. So times this. So that's the backward
01:13:16.800 | pass for this variable. Let me bring that down here. And now we have to be careful, because we
01:13:21.840 | already calculated dbn_div, right? So this is just the end of the other branch coming back to bn_div,
01:13:30.880 | because bn_div was already backpropagated to way over here from bn_raw. So we now completed the
01:13:37.520 | second branch. And so that's why I have to do plus equals. And if you recall, we had an incorrect
01:13:43.040 | derivative for bn_div before. And I'm hoping that once we append this last missing piece,
01:13:48.320 | we have the exact correctness. So let's run. And bn_div2, bn_div now actually shows the exact
01:13:55.360 | correct derivative. So that's comforting. Okay, so let's now backpropagate through this line here.
01:14:01.520 | The first thing we do, of course, is we check the shapes. And I wrote them out here. And basically,
01:14:07.840 | the shape of this is 32 by 64. H_prebn is the same shape. But bn_mini is a row vector, 1 by 64.
01:14:16.080 | So this minus here will actually do broadcasting. And so we have to be careful with that.
01:14:20.560 | And as a hint to us, again, because of the duality, a broadcasting in the forward pass
01:14:25.360 | means a variable reuse. And therefore, there will be a sum in the backward pass.
01:14:29.200 | So let's write out the backward pass here now. Backpropagate into the H_prebn. Because these
01:14:38.000 | are the same shape, then the local derivative for each one of the elements here is just one
01:14:42.640 | for the corresponding element in here. So basically, what this means is that the gradient
01:14:48.400 | just simply copies. It's just a variable assignment. It's equality. So I'm just going
01:14:52.880 | to clone this tensor just for safety to create an exact copy of db_ndiff. And then here,
01:15:00.640 | to backpropagate into this one, what I'm inclined to do here is db_bn_mini will basically be
01:15:08.400 | what is the local derivative? Well, it's negative torch dot once like of the shape of bn_diff.
01:15:19.280 | Right? And then times the derivative here, db_ndiff.
01:15:29.440 | And this here is the backpropagation for the replicated bn_mini. So I still have to
01:15:38.640 | backpropagate through the replication in the broadcasting, and I do that by doing a sum.
01:15:44.240 | So I'm going to take this whole thing, and I'm going to do a sum over the zeroth dimension,
01:15:48.800 | which was the replication. So if you scrutinize this, by the way, you'll notice that this is
01:15:57.120 | the same shape as that. And so what I'm doing here doesn't actually make that much sense,
01:16:01.920 | because it's just an array of ones multiplying db_ndiff. So in fact, I can just do this,
01:16:10.080 | and that is equivalent. So this is the candidate backward pass. Let me copy it here.
01:16:16.160 | And then let me comment out this one and this one. Enter. And it's wrong.
01:16:25.040 | Damn. Actually, sorry, this is supposed to be wrong. And it's supposed to be wrong because
01:16:34.560 | we are backpropagating from a bn_diff into h_prebn, but we're not done because bn_mini
01:16:41.120 | depends on h_prebn, and there will be a second portion of that derivative coming from this
01:16:45.840 | second branch. So we're not done yet, and we expect it to be incorrect. So there you go.
01:16:49.920 | So let's now backpropagate from bn_mini into h_prebn.
01:16:54.400 | And so here again, we have to be careful because there's a broadcasting along,
01:17:01.280 | or there's a sum along the zeroth dimension. So this will turn into broadcasting in the backward
01:17:05.920 | pass now. And I'm going to go a little bit faster on this line because it is very similar to the
01:17:10.560 | line that we had before, multiple lines in the past, in fact. So dh_prebn will be,
01:17:18.880 | the gradient will be scaled by 1/n, and then basically this gradient here, db_ndiff_mini,
01:17:27.280 | is going to be scaled by 1/n, and then it's going to flow across all the columns and deposit
01:17:33.120 | itself into dh_prebn. So what we want is this thing scaled by 1/n. Let me put the constant
01:17:40.480 | up front here. So scale down the gradient, and now we need to replicate it across all the
01:17:52.400 | rows here. So I like to do that by torch.once_like of basically h_prebn.
01:18:01.760 | And I will let the broadcasting do the work of replication. So
01:18:14.960 | like that. So this is dh_prebn, and hopefully we can plus equals that.
01:18:22.560 | So this here is broadcasting, and then this is the scaling. So this should be correct. Okay.
01:18:33.840 | So that completes the backpropagation of the bastrom layer, and we are now here. Let's
01:18:39.360 | backpropagate through the linear layer 1 here. Now because everything is getting a little vertically
01:18:44.800 | crazy, I copy-pasted the line here, and let's just backpropagate through this one line.
01:18:48.960 | So first, of course, we inspect the shapes, and we see that this is 32 by 64. mcat is 32 by 30,
01:18:57.760 | w1 is 30 by 64, and b1 is just 64. So as I mentioned, backpropagating through linear
01:19:06.880 | layers is fairly easy just by matching the shapes, so let's do that. We have that d_mpcat
01:19:14.080 | should be some matrix multiplication of dh_prebn with w1 and one transpose thrown in there.
01:19:21.600 | So to make mcat be 32 by 30, I need to take dh_prebn, 32 by 64, and multiply it by w1 dot
01:19:36.000 | transpose. To get dw1, I need to end up with 30 by 64. So to get that, I need to take mcat transpose
01:19:49.680 | and multiply that by dh_prebn. And finally, to get db1, this is an addition, and we saw that
01:20:04.640 | basically I need to just sum the elements in dh_prebn along some dimension. And to make the
01:20:10.800 | dimensions work out, I need to sum along the 0th axis here to eliminate this dimension, and we do
01:20:17.280 | not keep dims, so that we want to just get a single one-dimensional vector of 64. So these are the
01:20:24.320 | claimed derivatives. Let me put that here, and let me uncomment three lines and cross our fingers.
01:20:34.000 | Everything is great. Okay, so we now continue almost there. We have the derivative of mcat,
01:20:38.960 | and we want to backpropagate it into mb. So I again copied this line over here. So this is the
01:20:47.040 | forward pass, and then this is the shapes. So remember that the shape here was 32 by 30,
01:20:52.640 | and the original shape of mb was 32 by 3 by 10. So this layer in the forward pass, as you recall,
01:20:58.400 | did the concatenation of these three 10-dimensional character vectors. And so now we just want to
01:21:05.200 | undo that. So this is actually a relatively straightforward operation, because the backward
01:21:11.120 | pass of the... What is a view? A view is just a representation of the array. It's just a logical
01:21:16.640 | form of how you interpret the array. So let's just reinterpret it to be what it was before.
01:21:21.760 | So in other words, dmb is not 32 by 30. It is basically dmbcat, but if you view it as
01:21:31.840 | the original shape, so just m.shape, you can pass in tuples into view. And so this should just be...
01:21:41.440 | Okay, we just re-represent that view, and then we uncomment this line here, and hopefully...
01:21:51.040 | Yeah, so the derivative of m is correct. So in this case, we just have to re-represent the shape
01:21:56.880 | of those derivatives into the original view. So now we are at the final line, and the only
01:22:01.600 | thing that's left to backpropagate through is this indexing operation here, msc@xb. So as I did
01:22:08.480 | before, I copy-pasted this line here, and let's look at the shapes of everything that's involved
01:22:12.560 | and remind ourselves how this worked. So m.shape was 32 by 3 by 10. So it's 32 examples, and then
01:22:21.680 | we have three characters. Each one of them has a 10-dimensional embedding, and this was achieved
01:22:27.920 | by taking the lookup table C, which have 27 possible characters, each of them 10-dimensional,
01:22:34.240 | and we looked up at the rows that were specified inside this tensor xb. So xb is 32 by 3,
01:22:42.880 | and it's basically giving us, for each example, the identity or the index of which character
01:22:47.760 | is part of that example. And so here I'm showing the first five rows of this tensor xb.
01:22:56.240 | And so we can see that, for example, here, it was the first example in this batch is that the
01:23:01.520 | first character, and the first character, and the fourth character comes into the neural net,
01:23:05.440 | and then we want to predict the next character in a sequence after the character is 1, 1, 4.
01:23:10.800 | So basically what's happening here is there are integers inside xb, and each one of these
01:23:18.160 | integers is specifying which row of C we want to pluck out, right? And then we arrange those rows
01:23:25.360 | that we've plucked out into 32 by 3 by 10 tensor, and we just package them into this tensor. And now
01:23:33.920 | what's happening is that we have D_amp. So for every one of these basically plucked out rows,
01:23:40.320 | we have their gradients now, but they're arranged inside this 32 by 3 by 10 tensor.
01:23:45.920 | So all we have to do now is we just need to route this gradient backwards through this assignment.
01:23:51.600 | So we need to find which row of C did every one of these 10-dimensional embeddings come from,
01:23:58.080 | and then we need to deposit them into D_c. So we just need to undo the indexing, and of course,
01:24:06.240 | if any of these rows of C was used multiple times, which almost certainly is the case,
01:24:10.880 | like the row 1 and 1 was used multiple times, then we have to remember that the gradients
01:24:15.120 | that arrive there have to add. So for each occurrence, we have to have an addition.
01:24:21.200 | So let's now write this out. And I don't actually know of a much better way to do this than a
01:24:25.360 | for loop, unfortunately, in Python. So maybe someone can come up with a vectorized efficient
01:24:30.960 | operation, but for now, let's just use for loops. So let me create a torch.zeros_like
01:24:36.080 | C to initialize just a 27 by 10 tensor of all zeros. And then honestly, for k in range,
01:24:46.080 | xb.shape at 0. Maybe someone has a better way to do this, but for j in range, xb.shape at 1,
01:24:54.320 | this is going to iterate over all the elements of xb, all these integers.
01:25:02.000 | And then let's get the index at this position. So the index is basically xb at k, j.
01:25:11.520 | So an example of that is 11 or 14 and so on. And now in a forward pass, we basically took
01:25:20.800 | the row of C at index, and we deposited it into emb at k, j. That's what happened. That's where
01:25:32.000 | they are packaged. So now we need to go backwards, and we just need to route d_emb at the position
01:25:38.400 | k, j. We now have these derivatives for each position, and it's 10-dimensional. And you just
01:25:46.160 | need to go into the correct row of C. So d_C, rather, at i, x is this, but plus equals, because
01:25:55.920 | there could be multiple occurrences. Like the same row could have been used many, many times.
01:25:59.760 | And so all of those derivatives will just go backwards through the indexing, and they will add.
01:26:07.840 | So this is my candidate solution. Let's copy it here. Let's uncomment this and cross our fingers.
01:26:19.520 | Yay! So that's it. We've backpropagated through this entire beast. So there we go.
01:26:29.360 | Totally made sense. So now we come to exercise two. It basically turns out that in this first
01:26:34.800 | exercise, we were doing way too much work. We were backpropagating way too much. And it was all good
01:26:39.600 | practice and so on, but it's not what you would do in practice. And the reason for that is, for
01:26:44.080 | example, here I separated out this loss calculation over multiple lines, and I broke it up all to its
01:26:51.040 | smallest atomic pieces, and we backpropagated through all of those individually. But it turns
01:26:55.600 | out that if you just look at the mathematical expression for the loss, then actually you can
01:27:01.440 | do the differentiation on pen and paper, and a lot of terms cancel and simplify. And the mathematical
01:27:06.560 | expression you end up with can be significantly shorter and easier to implement than backpropagating
01:27:11.440 | through all the little pieces of everything you've done. So before we had this complicated forward
01:27:16.240 | pass going from logits to the loss. But in PyTorch, everything can just be glued together into a
01:27:21.920 | single call, f.crossentropy. You just pass in logits and the labels, and you get the exact same loss,
01:27:27.120 | as I verify here. So our previous loss and the fast loss coming from the chunk of operations as
01:27:32.880 | a single mathematical expression is the same, but it's much, much faster in a forward pass.
01:27:37.840 | It's also much, much faster in backward pass. And the reason for that is, if you just look at
01:27:42.480 | the mathematical form of this and differentiate again, you will end up with a very small and short
01:27:46.800 | expression. So that's what we want to do here. We want to, in a single operation or in a single go,
01:27:52.480 | or like very quickly, go directly into dlogits. And we need to implement dlogits as a function of
01:27:59.600 | logits and ybs. But it will be significantly shorter than whatever we did here, where to get
01:28:06.640 | to dlogits, we had to go all the way here. So all of this work can be skipped in a much, much simpler
01:28:13.040 | mathematical expression that you can implement here. So you can give it a shot yourself. Basically,
01:28:19.440 | look at what exactly is the mathematical expression of loss and differentiate with respect to the
01:28:24.480 | logits. So let me show you a hint. You can, of course, try it for yourself. But if not, I can
01:28:32.000 | give you some hint of how to get started mathematically. So basically, what's happening
01:28:37.520 | here is we have logits. Then there's a softmax that takes the logits and gives you probabilities.
01:28:42.480 | Then we are using the identity of the correct next character to pluck out a row of probabilities.
01:28:49.600 | Take the negative log of it to get our negative log probability. And then we average up all the
01:28:54.880 | log probabilities or negative log probabilities to get our loss. So basically, what we have is
01:29:00.960 | for a single individual example, rather, we have that loss is equal to negative log probability,
01:29:06.560 | where p here is kind of like thought of as a vector of all the probabilities. So at the yth
01:29:13.200 | position, where y is the label, and we have that p here, of course, is the softmax. So the
01:29:21.200 | ith component of p, of this probability vector, is just the softmax function. So raising all the
01:29:28.080 | logits basically to the power of e and normalizing so everything sums to one. Now, if you write out
01:29:36.480 | p of y here, you can just write out the softmax. And then basically what we're interested in is
01:29:41.120 | we're interested in the derivative of the loss with respect to the ith logit. And so basically,
01:29:48.800 | it's a d by d li of this expression here, where we have l indexed with the specific label y,
01:29:55.680 | and on the bottom, we have a sum over j of e to the lj and the negative log of all that.
01:29:59.680 | So potentially, give it a shot, pen and paper, and see if you can actually derive the expression
01:30:04.720 | for the loss by d li. And then we're going to implement it here. Okay, so I'm going to give
01:30:09.920 | away the result here. So this is some of the math I did to derive the gradients analytically. And so
01:30:17.120 | we see here that I'm just applying the rules of calculus from your first or second year of
01:30:20.800 | bachelor's degree, if you took it. And we see that the expressions actually simplify quite a bit.
01:30:25.920 | You have to separate out the analysis in the case where the ith index that you're interested in
01:30:30.480 | inside logits is either equal to the label or it's not equal to the label. And then the expressions
01:30:35.600 | simplify and cancel in a slightly different way. And what we end up with is something very,
01:30:39.920 | very simple. We either end up with basically p at i, where p is again this vector of probabilities
01:30:47.120 | after a softmax, or p at i minus one, where we just simply subtract to one. But in any case,
01:30:52.800 | we just need to calculate the softmax p, and then in the correct dimension, we need to subtract to
01:30:58.400 | one. And that's the gradient, the form that it takes analytically. So let's implement this,
01:31:03.280 | basically. And we have to keep in mind that this is only done for a single example. But here we
01:31:07.680 | are working with batches of examples. So we have to be careful of that. And then the loss for a
01:31:13.200 | batch is the average loss over all the examples. So in other words, is the example for all the
01:31:18.160 | individual examples, is the loss for each individual example summed up and then divided by n.
01:31:23.840 | And we have to back propagate through that as well and be careful with it. So d logits is going to be
01:31:29.760 | f dot softmax. PyTorch has a softmax function that you can call. And we want to apply the softmax on
01:31:37.120 | the logits. And we want to go in the dimension that is one. So basically, we want to do the
01:31:43.440 | softmax along the rows of these logits. Then at the correct positions, we need to subtract a one.
01:31:49.840 | So d logits at iterating over all the rows and indexing into the columns provided by the
01:31:58.320 | correct labels inside yb, we need to subtract one. And then finally, it's the average loss that is
01:32:05.520 | the loss. And in the average, there's a one over n of all the losses added up. And so we need to
01:32:11.440 | also back propagate through that division. So the gradient has to be scaled down by n as well,
01:32:17.440 | because of the mean. But this otherwise should be the result. So now if we verify this,
01:32:24.720 | we see that we don't get an exact match. But at the same time, the maximum difference from logits
01:32:31.600 | from PyTorch and rd logits here is on the order of 5e negative 9. So it's a tiny, tiny number.
01:32:38.960 | So because of floating point wonkiness, we don't get the exact bitwise result,
01:32:44.320 | but we basically get the correct answer approximately. Now I'd like to pause here
01:32:50.400 | briefly before we move on to the next exercise, because I'd like us to get an intuitive sense
01:32:54.480 | of what d logits is, because it has a beautiful and very simple explanation, honestly. So here,
01:33:01.200 | I'm taking d logits, and I'm visualizing it. And we can see that we have a batch of 32 examples
01:33:06.480 | of 27 characters. And what is d logits intuitively, right? d logits is the probabilities that the
01:33:13.600 | probabilities matrix in the forward pass. But then here, these black squares are the positions of the
01:33:18.240 | correct indices, where we subtracted a 1. And so what is this doing, right? These are the derivatives
01:33:25.280 | on d logits. And so let's look at just the first row here. So that's what I'm doing here. I'm
01:33:33.040 | calculating the probabilities of these logits, and then I'm taking just the first row. And this is
01:33:37.600 | the probability row. And then d logits of the first row, and multiplying by n just for us so that
01:33:44.320 | we don't have the scaling by n in here, and everything is more interpretable.
01:33:47.440 | We see that it's exactly equal to the probability, of course, but then the position of the correct
01:33:52.800 | index has a minus equals 1. So minus 1 on that position. And so notice that if you take d logits
01:34:00.160 | at 0, and you sum it, it actually sums to 0. And so you should think of these gradients here at
01:34:09.120 | each cell as like a force. We are going to be basically pulling down on the probabilities
01:34:17.280 | of the incorrect characters, and we're going to be pulling up on the probability
01:34:21.120 | at the correct index. And that's what's basically happening in each row. And the amount of push and
01:34:30.160 | pull is exactly equalized, because the sum is 0. So the amount to which we pull down on the
01:34:35.840 | probabilities, and the amount that we push up on the probability of the correct character is equal.
01:34:40.480 | So the repulsion and the attraction are equal. And think of the neural net now as a massive
01:34:47.200 | pulley system or something like that. We're up here on top of d logits, and we're pulling up,
01:34:52.800 | we're pulling down the probabilities of incorrect and pulling up the probability of the correct.
01:34:56.480 | And in this complicated pulley system, because everything is mathematically just determined,
01:35:01.920 | just think of it as sort of like this tension translating to this complicating pulley mechanism.
01:35:06.720 | And then eventually we get a tug on the weights and the biases. And basically in each update,
01:35:11.600 | we just kind of like tug in the direction that we like for each of these elements,
01:35:15.600 | and the parameters are slowly given in to the tug. And that's what training a neural net
01:35:20.000 | kind of like looks like on a high level. And so I think the forces of push and pull in these
01:35:25.280 | gradients are actually very intuitive here. We're pushing and pulling on the correct answer and the
01:35:30.720 | incorrect answers. And the amount of force that we're applying is actually proportional to the
01:35:36.400 | probabilities that came out in the forward pass. And so for example, if our probabilities came out
01:35:41.520 | exactly correct, so they would have had zero everywhere except for one at the correct position,
01:35:47.680 | then the d logits would be all a row of zeros for that example. There would be no push and pull.
01:35:53.440 | So the amount to which your prediction is incorrect is exactly the amount by which you're
01:35:58.640 | going to get a pull or a push in that dimension. So if you have, for example, a very confidently
01:36:04.240 | mispredicted element here, then what's going to happen is that element is going to be pulled down
01:36:10.080 | very heavily, and the correct answer is going to be pulled up to the same amount. And the other
01:36:15.600 | characters are not going to be influenced too much. So the amount to which you mispredict is
01:36:21.040 | then proportional to the strength of the pull. And that's happening independently in all the
01:36:26.240 | dimensions of this tensor. And it's sort of very intuitive and very easy to think through. And
01:36:31.440 | that's basically the magic of the cross-entropy loss and what it's doing dynamically in the
01:36:35.920 | backward pass of the neural net. So now we get to exercise number three, which is a very fun exercise,
01:36:41.600 | depending on your definition of fun. And we are going to do for batch normalization exactly what
01:36:46.560 | we did for cross-entropy loss in exercise number two. That is, we are going to consider it as a
01:36:51.280 | glued single mathematical expression and backpropagate through it in a very efficient manner,
01:36:56.000 | because we are going to derive a much simpler formula for the backward pass of batch normalization.
01:37:00.320 | And we're going to do that using pen and paper. So previously, we've broken up batch normalization
01:37:05.760 | into all of the little intermediate pieces and all the atomic operations inside it, and then we
01:37:09.920 | backpropagate it through it one by one. Now we just have a single sort of forward pass of a batch norm,
01:37:17.760 | and it's all glued together, and we see that we get the exact same result as before.
01:37:22.320 | Now for the backward pass, we'd like to also implement a single formula basically for
01:37:27.920 | backpropagating through this entire operation, that is the batch normalization.
01:37:31.360 | So in the forward pass previously, we took HPBN, the hidden states of the pre-batch normalization,
01:37:38.720 | and created HPREACT, which is the hidden states just before the activation. In the batch
01:37:44.320 | normalization paper, HPREBN is x and HPREACT is y. So in the backward pass, what we'd like to do now
01:37:51.280 | is we have DHPREACT, and we'd like to produce DHPREBN, and we'd like to do that in a very
01:37:57.760 | efficient manner. So that's the name of the game, calculate DHPREBN given DHPREACT. And for the
01:38:04.320 | purposes of this exercise, we're going to ignore gamma and beta and their derivatives, because they
01:38:09.440 | take on a very simple form in a very similar way to what we did up above. So let's calculate this
01:38:15.920 | given that right here. So to help you a little bit like I did before, I started off the
01:38:22.800 | implementation here on pen and paper, and I took two sheets of paper to derive the mathematical
01:38:28.240 | formulas for the backward pass. And basically to set up the problem, just write out the mu,
01:38:34.640 | sigma square, variance, xi hat, and yi, exactly as in the paper, except for the Bessel correction.
01:38:40.880 | And then in the backward pass, we have the derivative of the loss with respect to all the
01:38:46.160 | elements of y. And remember that y is a vector. There's multiple numbers here. So we have all
01:38:53.520 | of the derivatives with respect to all the y's. And then there's a gamma and a beta, and this is
01:38:59.200 | kind of like the compute graph. The gamma and the beta, there's the x hat, and then the mu and the
01:39:04.480 | sigma squared, and the x. So we have dl by dyi, and we want dl by dxi for all the i's in these
01:39:13.360 | vectors. So this is the compute graph, and you have to be careful because I'm trying to note here
01:39:20.320 | that these are vectors. There's many nodes here inside x, x hat, and y, but mu and sigma, sorry,
01:39:28.480 | sigma square are just individual scalars, single numbers. So you have to be careful with that. You
01:39:33.600 | have to imagine there's multiple nodes here, or you're going to get your math wrong.
01:39:36.400 | So as an example, I would suggest that you go in the following order, one, two, three, four,
01:39:43.920 | in terms of the backpropagation. So backpropagate into x hat, then into sigma square, then into mu,
01:39:49.280 | and then into x. Just like in a topological sort in micrograd, we would go from right to left.
01:39:56.080 | You're doing the exact same thing, except you're doing it with symbols and on a piece of paper.
01:40:00.240 | So for number one, I'm not giving away too much. If you want dl of dxi hat, then we just take dl
01:40:11.200 | by dyi and multiply it by gamma, because of this expression here, where any individual yi is just
01:40:17.440 | gamma times x i hat plus beta. So it didn't help you too much there, but this gives you basically
01:40:24.320 | the derivatives for all the x hats. And so now, try to go through this computational graph and
01:40:31.360 | derive what is dl by d sigma square, and then what is dl by d mu, and then what is dl by dx,
01:40:38.960 | eventually. So give it a go, and I'm going to be revealing the answer one piece at a time.
01:40:44.080 | Okay, so to get dl by d sigma square, we have to remember again, like I mentioned, that there are
01:40:49.440 | many x hats here. And remember that sigma square is just a single individual number here. So when
01:40:57.120 | we look at the expression for dl by d sigma square, we have to actually consider all the possible
01:41:04.400 | paths that we basically have that there's many x hats, and they all depend on sigma square.
01:41:13.920 | So sigma square has a large fan out. There's lots of arrows coming out from sigma square into all
01:41:19.040 | the x hats. And then there's a backpropagating signal from each x hat into sigma square. And
01:41:25.360 | that's why we actually need to sum over all those i's from i equal to one to m of the dl by d xi hat,
01:41:34.720 | which is the global gradient, times the xi hat by d sigma square, which is the local gradient
01:41:41.280 | of this operation here. And then mathematically, I'm just working it out here, and I'm simplifying,
01:41:47.920 | and you get a certain expression for dl by d sigma square. And we're going to be using this
01:41:52.560 | expression when we backpropagate into mu, and then eventually into x. So now let's continue
01:41:57.200 | our backpropagation into mu. So what is dl by d mu? Now again, be careful that mu influences x hat,
01:42:04.320 | and x hat is actually lots of values. So for example, if our mini-batch size is 32, as it is
01:42:09.520 | in our example that we were working on, then this is 32 numbers and 32 arrows going back to mu.
01:42:15.840 | And then mu going to sigma square is just a single arrow, because sigma square is a scalar.
01:42:19.840 | So in total, there are 33 arrows emanating from mu, and then all of them have gradients coming
01:42:26.160 | into mu, and they all need to be summed up. And so that's why when we look at the expression for dl
01:42:32.400 | by d mu, I am summing up over all the gradients of dl by d xi hat times d xi hat by d mu. So that's
01:42:41.040 | the that's this arrow, and that's 32 arrows here, and then plus the one arrow from here, which is dl
01:42:46.720 | by d sigma square times d sigma square by d mu. So now we have to work out that expression, and
01:42:52.880 | let me just reveal the rest of it. Simplifying here is not complicated, the first term, and you
01:42:59.280 | just get an expression here. For the second term though, there's something really interesting that
01:43:03.120 | happens. When we look at d sigma square by d mu and we simplify, at one point if we assume that
01:43:11.120 | in a special case where mu is actually the average of xi's, as it is in this case, then if we plug
01:43:18.800 | that in, then actually the gradient vanishes and becomes exactly zero. And that makes the entire
01:43:24.480 | second term cancel. And so these, if you just have a mathematical expression like this, and you look
01:43:30.640 | at d sigma square by d mu, you would get some mathematical formula for how mu impacts sigma
01:43:36.720 | square. But if it is the special case that mu is actually equal to the average, as it is in the
01:43:41.920 | case of batch normalization, that gradient will actually vanish and become zero. So the whole
01:43:46.800 | term cancels, and we just get a fairly straightforward expression here for dl by d mu.
01:43:51.600 | Okay, and now we get to the craziest part, which is deriving dl by d xi, which is ultimately what
01:43:58.000 | we're after. Now let's count, first of all, how many numbers are there inside x? As I mentioned,
01:44:04.320 | there are 32 numbers. There are 32 little xi's. And let's count the number of arrows emanating
01:44:09.440 | from each xi. There's an arrow going to mu, an arrow going to sigma square, and then there's
01:44:15.280 | an arrow going to x hat. But this arrow here, let's scrutinize that a little bit. Each xi hat
01:44:21.440 | is just a function of xi and all the other scalars. So xi hat only depends on xi and none of the other
01:44:29.200 | x's. And so therefore, there are actually, in this single arrow, there are 32 arrows. But those 32
01:44:35.440 | arrows are going exactly parallel. They don't interfere. They're just going parallel between x
01:44:40.720 | and x hat. You can look at it that way. And so how many arrows are emanating from each xi? There
01:44:45.360 | are three arrows, mu, sigma square, and the associated x hat. And so in backpropagation,
01:44:52.400 | we now need to apply the chain rule, and we need to add up those three contributions.
01:44:57.120 | So here's what that looks like if I just write that out.
01:45:00.080 | We're chaining through mu, sigma square, and through x hat. And those three terms are just
01:45:09.120 | here. Now, we already have three of these. We have dl by dx i hat. We have dl by d mu,
01:45:16.880 | which we derived here. And we have dl by d sigma square, which we derived here. But we need three
01:45:22.080 | other terms here. This one, this one, and this one. So I invite you to try to derive them. It's
01:45:28.240 | not that complicated. You're just looking at these expressions here and differentiating with respect
01:45:32.240 | to xi. So give it a shot, but here's the result, or at least what I got. I'm just differentiating
01:45:43.920 | with respect to xi for all of these expressions. And honestly, I don't think there's anything too
01:45:47.520 | tricky here. It's basic calculus. Now, what gets a little bit more tricky is we are now going to
01:45:52.960 | plug everything together. So all of these terms multiplied with all of these terms and add it up
01:45:57.680 | according to this formula. And that gets a little bit hairy. So what ends up happening is
01:46:02.240 | you get a large expression. And the thing to be very careful with here, of course, is
01:46:10.320 | we are working with a dl by d xi for a specific i here. But when we are plugging in some of these
01:46:16.080 | terms, like say this term here, dl by d sigma squared, you see how dl by d sigma squared,
01:46:24.400 | I end up with an expression. And I'm iterating over little i's here. But I can't use i as the
01:46:30.640 | variable when I plug in here, because this is a different i from this i. This i here is just a
01:46:36.240 | placeholder, like a local variable for a for loop in here. So here, when I plug that in, you notice
01:46:41.840 | that I renamed the i to a j, because I need to make sure that this j is not this i. This j is
01:46:48.640 | like a little local iterator over 32 terms. And so you have to be careful with that when you're
01:46:53.920 | plugging in the expressions from here to here. You may have to rename i's into j's. And you have
01:46:58.240 | to be very careful what is actually an i with respect to dl by d xi. So some of these are j's,
01:47:05.680 | some of these are i's. And then we simplify this expression. And I guess the big thing to notice
01:47:13.520 | here is a bunch of terms just kind of come out to the front, and you can refactor them. There's
01:47:17.840 | a sigma squared plus epsilon raised to the power of negative 3 over 2. This sigma squared plus
01:47:22.080 | epsilon can be actually separated out into three terms. Each of them are sigma squared plus epsilon
01:47:28.000 | to the negative 1 over 2. So the three of them multiplied is equal to this. And then those three
01:47:34.000 | terms can go different places because of the multiplication. So one of them actually comes
01:47:38.480 | out to the front and will end up here outside. One of them joins up with this term, and one of
01:47:45.200 | them joins up with this other term. And then when you simplify the expression, you'll notice that
01:47:50.400 | some of these terms that are coming out are just the xi hats. So you can simplify just by rewriting
01:47:55.920 | that. And what we end up with at the end is a fairly simple mathematical expression over here
01:48:00.800 | that I cannot simplify further. But basically, you'll notice that it only uses the stuff we have
01:48:06.080 | and it derives the thing we need. So we have dl by dy for all the i's, and those are used plenty
01:48:13.680 | of times here. And also in addition, what we're using is these xi hats and xj hats, and they just
01:48:18.720 | come from the forward pass. And otherwise, this is a simple expression, and it gives us dl by dxi
01:48:25.520 | for all the i's, and that's ultimately what we're interested in. So that's the end of BatchNorm
01:48:31.440 | backward pass analytically. Let's now implement this final result. Okay, so I implemented the
01:48:37.680 | expression into a single line of code here, and you can see that the max diff is tiny,
01:48:43.120 | so this is the correct implementation of this formula. Now, I'll just basically tell you that
01:48:49.760 | getting this formula here from this mathematical expression was not trivial, and there's a lot
01:48:54.480 | going on packed into this one formula. And this is a whole exercise by itself, because you have
01:49:00.000 | to consider the fact that this formula here is just for a single neuron and a batch of 32 examples.
01:49:05.440 | But what I'm doing here is we actually have 64 neurons, and so this expression has to in parallel
01:49:11.920 | evaluate the BatchNorm backward pass for all of those 64 neurons in parallel and independently.
01:49:16.720 | So this has to happen basically in every single column of the inputs here. And in addition to
01:49:25.200 | that, you see how there are a bunch of sums here, and we need to make sure that when I do those
01:49:29.280 | sums that they broadcast correctly onto everything else that's here. And so getting this expression
01:49:34.800 | is just like highly non-trivial, and I invite you to basically look through it and step through it,
01:49:38.080 | and it's a whole exercise to make sure that this checks out. But once all the shapes agree,
01:49:44.800 | and once you convince yourself that it's correct, you can also verify that PyTorch gets the exact
01:49:48.800 | same answer as well. And so that gives you a lot of peace of mind that this mathematical formula
01:49:53.520 | is correctly implemented here and broadcasted correctly and replicated in parallel for all
01:49:58.720 | of the 64 neurons inside this BatchNorm layer. Okay, and finally, exercise number four asks you
01:50:05.280 | to put it all together. And here we have a redefinition of the entire problem. So you see
01:50:10.320 | that we re-initialized the neural net from scratch and everything. And then here, instead of calling
01:50:15.280 | the loss that backward, we want to have the manual backpropagation here as we derived it up above.
01:50:20.160 | So go up, copy-paste all the chunks of code that we've already derived, put them here, and derive
01:50:26.000 | your own gradients, and then optimize this neural net basically using your own gradients all the way
01:50:31.200 | to the calibration of the BatchNorm and the evaluation of the loss. And I was able to
01:50:35.840 | achieve quite a good loss, basically the same loss you would achieve before. And that shouldn't be
01:50:40.240 | surprising because all we've done is we've really gotten into loss that backward, and we've pulled
01:50:45.040 | out all the code and inserted it here. But those gradients are identical, and everything is
01:50:50.160 | identical, and the results are identical. It's just that we have full visibility on exactly what
01:50:55.040 | goes on under the hood of loss that backward in this specific case. Okay, and this is all of our
01:51:01.200 | code. This is the full backward pass using basically the simplified backward pass for
01:51:05.760 | the cross-entropy loss and the BatchNormalization. So backpropagating through cross-entropy,
01:51:11.120 | the second layer, the 10-H nonlinearity, the BatchNormalization through the first layer,
01:51:18.480 | and through the embedding. And so you see that this is only maybe, what is this, 20 lines of
01:51:22.880 | code or something like that? And that's what gives us gradients. And now we can potentially
01:51:27.600 | erase loss that backward. So the way I have the code set up is you should be able to run this
01:51:32.560 | entire cell once you fill this in, and this will run for only 100 iterations and then break.
01:51:37.280 | And it breaks because it gives you an opportunity to check your gradients against PyTorch.
01:51:41.760 | So here, our gradients we see are not exactly equal. They are approximately equal, and the
01:51:49.600 | differences are tiny, 1 and negative 9 or so. And I don't exactly know where they're coming from,
01:51:54.240 | to be honest. So once we have some confidence that the gradients are basically correct,
01:51:58.320 | we can take out the gradient checking. We can disable this breaking statement.
01:52:04.400 | And then we can basically disable loss that backward. We don't need it anymore.
01:52:11.200 | Feels amazing to say that. And then here, when we are doing the update, we're not going to use
01:52:17.520 | p.grad. This is the old way of PyTorch. We don't have that anymore because we're not doing backward.
01:52:23.360 | We are going to use this update where we, you see that I'm iterating over, I've arranged the grads
01:52:30.080 | to be in the same order as the parameters, and I'm zipping them up, the gradients and the parameters,
01:52:35.120 | into p and grad. And then here, I'm going to step with just the grad that we derived manually.
01:52:40.880 | So the last piece is that none of this now requires gradients from PyTorch. And so one
01:52:49.280 | thing you can do here is you can do withTorch.noGrad and offset this whole code block.
01:52:57.200 | And really what you're saying is you're telling PyTorch that, "Hey, I'm not going to call backward
01:53:01.200 | on any of this." And this allows PyTorch to be a bit more efficient with all of it.
01:53:05.520 | And then we should be able to just run this. And it's running. And you see that
01:53:15.680 | lost at backward is commented out and we're optimizing. So we're going to leave this run,
01:53:22.560 | and hopefully we get a good result. Okay, so I allowed the neural net to finish optimization.
01:53:27.920 | Then here, I calibrate the bastion parameters because I did not keep track of the running mean
01:53:35.040 | variance in their training loop. Then here, I ran the loss. And you see that we actually obtained a
01:53:40.400 | pretty good loss, very similar to what we've achieved before. And then here, I'm sampling
01:53:44.880 | from the model. And we see some of the name-like gibberish that we're sort of used to. So basically,
01:53:49.760 | the model worked and samples pretty decent results compared to what we were used to.
01:53:55.520 | So everything is the same. But of course, the big deal is that we did not use lots of backward.
01:54:00.080 | We did not use PyTorch autograd, and we estimated our gradients ourselves by hand.
01:54:04.400 | And so hopefully, you're looking at this, the backward pass of this neural net,
01:54:08.320 | and you're thinking to yourself, actually, that's not too complicated. Each one of these layers is
01:54:14.480 | like three lines of code or something like that. And most of it is fairly straightforward,
01:54:18.960 | potentially with the notable exception of the batch normalization backward pass.
01:54:22.800 | Otherwise, it's pretty good. Okay, and that's everything I wanted to cover for this lecture.
01:54:27.440 | So hopefully, you found this interesting. And what I liked about it, honestly, is that it gave us a
01:54:32.400 | very nice diversity of layers to backpropagate through. And I think it gives a pretty nice
01:54:38.000 | and comprehensive sense of how these backward passes are implemented and how they work.
01:54:42.240 | And you'd be able to derive them yourself. But of course, in practice, you probably don't want to,
01:54:46.240 | and you want to use the PyTorch autograd. But hopefully, you have some intuition about how
01:54:50.240 | gradients flow backwards through the neural net, starting at the loss, and how they flow through
01:54:55.200 | all the variables and all the intermediate results. And if you understood a good chunk of it, and if
01:55:00.640 | you have a sense of that, then you can count yourself as one of these buff dojis on the left,
01:55:04.720 | instead of the dojis on the right here. Now, in the next lecture, we're actually going to go to
01:55:10.400 | recurrent neural nets, LSTMs, and all the other variants of RNS. And we're going to start to
01:55:16.240 | complexify the architecture and start to achieve better log likelihoods. And so I'm really looking
01:55:21.360 | forward to that. And I'll see you then.