Welcome to lesson 10, which I've rather enthusiastically titled wrapping up our CNN, but looking at how many things we want to cover, I've added a nearly to the end and I'm not actually sure how nearly we'll get there. We'll see. We'll probably have a few more things to cover next week as well.
I just wanted to remind you after hearing from a few folks during the week who are very sad that they're not quite keeping up with everything. That's totally okay. Don't worry. As I mentioned in lesson one, I'm trying to give you enough here to keep you busy until the next part to next year.
So you can dive into the bits you're interested in and go back and look over stuff and yeah, don't feel like you have to understand everything within within a week of first hearing it. But also if you're not putting in the time during the homework or you didn't put in the time during the homework in the last part, you know, expect to have to go back and recover things particularly because a lot of the stuff we covered in part one, I'm kind of assuming that you're deeply comfortable with at this point.
Not because you're stupid if you're not, but just because it gives you the opportunity to go back and re-study it and practice and experiment until you are deeply comfortable. So yeah, if you're finding it whizzing along at a pace, that is because it is whizzing along at a pace.
Also it's covering a lot of more software engineering kind of stuff, which for the people who are practicing software engineers, you'll be thinking this is all pretty straightforward. And for those of you that are not, you'll be thinking, wow, there's a lot here. Part of that is because I think data scientists need to be good software engineers.
So I'm trying to show you some of these things, but, you know, it's stuff which people can spend years learning. And so hopefully this is the start of a long process for you that haven't done software engineering before of becoming better software engineers. And there are some useful tips, hopefully.
So to remind you, we're trying to recreate fast AI and much of PyTorch from these foundations. And starting to make things even better. And today you'll actually see some bits, well, in fact, you've already seen some bits that are going to be even better. I think the next version of fast AI will have this new callback system, which I think is better than the old one.
And today we're going to be showing you some new previously unpublished research, which will be finding its way into fast AI and maybe other libraries as well also. So we're going to try and stick to, and we will stick to, using nothing but these foundations. And we're working through developing a modern CNN model, and we've got to the point where we've done our training loop at this point, and we've got a nice flexible training loop.
So from here, the rest of it, when I say we're going to finish out a modern CNN model, it's not just going to be some basic getting by model, but we're actually going to endeavor to get something that is approximately state-of-the-art on ImageNet in the next week or two.
So that's the goal. And in our testing at this point, we're feeling pretty good about showing you some stuff that maybe hasn't been seen before on ImageNet results. So that's where we're going to try and head as a group. And so these are some of the things that we're going to be covering to get there.
One of the things you might not have seen before in this section called optimization is LAM. The reason for this is that this was going to be some of the unpublished research we were going to show you, which is a new optimization algorithm that we've been developing. The framework is still going to be new, but actually the particular approach to using it was published by Google two days ago, so we've kind of been scooped there.
So this is a cool paper, really great, and they introduce a new optimization algorithm called LAM, which we'll be showing you how to implement it very easily. And if you're wondering how we're able to do that so fast, it's because we've kind of been working on the same thing ourselves for a few weeks now.
So then from next week, we'll start also developing a completely new fast.ai module called fastai.audio. So you'll be seeing how to actually create modules and how to write Jupyter documentation and tests. And we're going to be learning about audio, such as complex numbers and Fourier transforms, which if you're like me at this point, you're going, oh, what, no, because I managed to spend my life avoiding complex numbers and Fourier transforms on the whole.
But don't worry, it'll be OK. It's actually not at all bad, or at least the bits we need to learn about are not at all bad. And you'll totally get it even if you've never ever touched these before. We'll be learning about audio formats and spectrograms, doing data augmentation and things that aren't images, and some particular kinds of loss functions and architectures for audio.
And as much as anything, it'll just be a great kind of exercise in, OK, I've got some different data type that's not in fastai. How do I build up all the bits I need to make it work? Then we'll be looking at neural translation as a way to learn about sequence to sequence with attention models, and then we'll be going deeper and deeper into attention models, looking at transformer, and it's even more fantastic, descendant, transformer, Excel.
And then we'll wrap up our Python adventures with a deep dive into some really interesting vision topics, which is going to require building some bigger models. So we'll talk about how to build your own deep learning box, how to run big experiments on AWS with a new library we've developed called Fast EC2.
And then we're going to see exactly what happened last course when we did that unit super resolution image generation, what are some of the pieces there. And we've actually got some really exciting new results to show you, which have been done in collaboration with some really cool partners. So I'm looking forward to showing you that to give you a tip.
Generative video models is what we're going to be looking at. And then we'll be looking at some interesting different applications, device, cycleGAN, and object detection. And then Swift, of course. So the Swift lessons are coming together nicely, really excited about them, and we'll be covering as much of the same territory as we can, but obviously it'll be in Swift and it'll be in only two lessons, so it won't be everything, but we'll try to give you enough of a taste that you'll feel like you understand why Swift is important and how to get started with building something similar in Swift.
And maybe building out the whole thing in Swift will take the next 12 months. Who knows? We'll see. So we're going to start today on zero 5A foundations. And what we're going to do is we're going to recover some of the software engineering and math basics that we were relying on last week and going into a little bit more detail.
Specifically, we'll be looking at callbacks and variants and a couple of other Python concepts like Dunder special methods. If you're familiar with those things, feel free to skip ahead if you're watching the video till we get to the new material. Callbacks, as I'm sure you've seen, are super important for fast AI, and in general they're a really useful technique for software engineering.
And great for researchers because they allow you to build things that you can quickly adjust and add things in and pull them out again, so really great for research as well. So what is a callback? Let's look at an example. So here's a function called f, which prints out hi.
And I'm going to create a button, and I'm going to create this button using IPY widgets, which is a framework for creating GUI widgets in Python. So if I run, if I say W, then it shows me a button, which says click B, and I can click on it, and nothing happens.
So how do I get something to happen? Well, what I need to do is I need to pass a function to the IPY widgets framework to say please run this function when you click on this button. So IPY widget doc says that there's an onclick method which can register a function to be called when the button is clicked.
So let's try running that method, passing it f, my function. OK, so now nothing happened, didn't run anything, but now if I click on here-- oh, hi, hi. So what happened is I told W that when a click occurs, you should call back to my f function and run it.
So anybody who's done GUI programming will be extremely comfortable with this idea, and if you haven't, this will be kind of mind bending. So f is a callback. It's not a particular class. It doesn't have a particular signature. It's not a particular library. It's a concept. It's a function that we treat as an object.
So look, we're not calling the function. We don't have any parentheses after f. We're passing the function itself to this method, and it says please call back to me when something happens, and in this case, it's when I click. So there's our starting point. And these kinds of functions, these kinds of callbacks that are used in a GUI in particular framework when some event happens are often called events.
So if you've heard of events, they're a kind of callback, and then callbacks are a kind of what we would call a function pointer. I mean, they can be much more general than that, as you'll see, but it's basically a way of passing in something to say call back to this when something happens.
Now, by the way, these widgets are really worth looking at if you're interested in building some analytical GUIs. Here's a great example from the Plotly documentation of the kinds of things you can create with widgets, and it's not just for creating applications for others to use, but if you want to experiment with different types of function or hyperparameters or explore some data you've collected, widgets are a great way to do that, and as you can see, they're very, very easy to use.
In part one, you saw the image labeling stuff that was built with widgets like this. So that's how you can use somebody else's callback. Now to create our own callback. So let's create a callback, and the event that it's going to call back on is after a calculation is complete.
So let's create a function called slow_calculation, and it's going to do five calculations. It's going to add i squared to a result, and then it's going to take a second to do it because we're going to add a sleep there. So this is kind of something like an epoch of deep learning.
It's some calculation that takes a while. So if we call slow_calculation, then it's going to take five seconds to calculate the sum of i squared, and there it's done it. So I'd really like to know how's it going, get some progress. So we could take that and we could add something that you pass in a callback, and we just add one line of code that says if there's a callback, then call it and pass in the epoch number.
So then we could create a function called show_progress that prints out awesome. We finished epoch number, epoch, and look, it takes a parameter and we're passing a parameter, so therefore we could now call slow_calculation and pass in show_progress, and it will call back to our function after each epoch.
So there's our starting point for our callback. Now, what will tend to happen, you'll notice with stuff that we do in fast.ai, we'll start somewhere like this that's, for many of you, is trivially easy, and at some point during the next hour or two, you might reach a point where you're feeling totally lost.
And the trick is to go back, if you're watching the video, to the point where it was trivially easy and figure out the bit where you suddenly noticed you were totally lost, and find the bit in the middle where you kind of missed a bit, because we're going to just keep building up from trivially easy stuff, just like we did with that matrix multiplication, right?
So we're going to gradually build up from here and look at more and more interesting callbacks, but we're starting with this wonderfully short and simple line of code. So rather than defining a function just for the purpose of using it once, we can actually define the function at the point we use it using lambda notation.
So lambda notation is just another way of creating a function. So rather than saying def, we say lambda, and then rather than putting in parentheses the arguments, we put them before a colon, and then we list the thing you want to do. So this is identical to the previous one.
It's just a convenience for times where you want to define the callback at the same time that you use it, can make your code a little bit more concise. What if you wanted to have something where you could define what exclamation to use in the string as well? So we've now got two things.
We can't pass this show progress to slow calculation. Let's try it. Right? It tries to call back, and it calls, remember CB is now show progress, so it's passing show progress and it's passing epoch as exclamation, and then epoch is missing. So that's an error. We've called a function with two arguments with only one.
So we have to convert this into a function with only one argument. So lambda O is a function with only one argument, and this function calls show progress with a particular exclamation. So we've converted something with two arguments into something with one argument. We might want to make it really easy to allow people to create different progress indicators with different exclamations.
So we could create a function called make show progress that returns that lambda. So now we could say make show progress, so we could do that here. Make show progress. And that's the same thing. This is a little bit awkward, so generally you might see it done like this instead.
You see this in fast AI all the time. We define the function inside it, but this is basically just the same as our lambda, and then we return that function. So this is kind of interesting, because you might think of defining a function as being like a declarative thing, that as soon as you define it, that now it's part of the thing that's compiled, if you see your C++, that's how they work.
In Python, that's not how they work. When you define a function, you're actually saying basically the same as this, which is there's a variable with this name, which is a function. And that's how come then we can actually take something that's passed to this function and use it inside here.
So this is actually, every time we call make show progress, it's going to create a new function, underscore inner internally, with a different exclamation. And so it'll work the same as before. So this thing where you create a function that actually stores some information from the external context, and like it can be different every time, that's called a closure.
So it's a concept you'll come across a lot, particularly if you're a JavaScript programmer. So we could say, f2 equals make show progress, terrific. And so that now contains that closure. So it actually remembers what exclamation you passed it. Because it's so often that you want to take a function that takes two parameters and turn it into a function that takes one parameter, Python and most languages have a way to do that, which is called partial function application.
So the standard library functools has this thing called partial. So if you take run call partial and you pass it a function, and then you pass in some arguments for that function, it returns a new function, which that parameter is always a given. So let's check it out. So we could run it like that, or we could say f2 equals this partial function application.
And so if I say f2 shift tab, then you can see this is now a function that just takes epoch. It just takes epoch because show progress took two parameters. We've already passed it one. So this now takes one parameter, which is what we need. So that's why we could pass that to as our callback.
So we've seen a lot of those techniques already last week. Most of what we saw last week, though, did not use a function as a callback, but used a class as a callback. So we could do exactly the same thing, but pretty much any place you can use a closure, you can also use a class.
Instead of storing it away inside the closure, some state, we can store our state, in this case the exclamation, inside self, passing it into init. So here's exactly the same thing as we saw before, but as a class. Dundacall is a special magic name which will be called if you take an object, so in this case a progress showing callback object, and call it with parentheses.
So if I go cb high, you see I'm taking that object and I'm treating it as if it's a function. And that will call dundacall. If you've used other languages like in C++, this is called a functor. More generally it's called a callable in Python, so it's kind of something that a lot of languages have.
All right, so now we can use that as a callback, just like before. All right, next thing to look at is, for our callback, is we're going to use star args and star star kwargs, or otherwise known as quargs. For those of you that don't know what these mean, let's create a function that takes star args and star star kwargs and prints out args and kwargs.
So if I call that function, I could pass it 3a, thing1 equals hello, and you'll see that all the things that are passed as positional arguments end up in a tuple called args, and all the things passed as keyword arguments end up as a dictionary called quargs. That's literally all these things do.
And so PyTorch uses that, for example, when you create an nn.sequential, it takes what you pass in as a star args, right, you just pass them directly and it turns it into a tuple. So why do we use this? There's a few reasons we use it, but one of the common ways to use it is if you kind of want to wrap some other class or object, then you can take a bunch of stuff as star star kwargs and pass it off to some other functional object.
We're getting better at this, and we're removing a lot of the usages, but in the early days of fast AI version 1, we actually were overusing quargs. So quite often, we would kind of -- there would be a lot of stuff that wasn't obviously in the parameter list of a function that ended up in quargs, and then we would pass it down to, I don't know, the PyTorch data loader initializer or something.
And so we've been gradually removing those usages, because, like, it's mainly most helpful for kind of quick and dirty throwing things together. In R, they actually use an ellipsis for the same thing. They kind of overuse it. Quite often, it's hard to see what's going on. You might have noticed in Matplotlib, a lot of times the thing you're trying to pass to Matplotlib isn't there in the shift tab when you hit shift tab.
It's the same thing. They're using quargs. So there are some downsides to using it, but there are some places you really want to use it. For example, take a look at this. Let's take rewrite slow calculation, but this time we're going to allow the user to create a callback that is called before the calculation occurs and after the calculation that occurs.
And the after calculation one's a bit tricky, because it's going to take two parameters now. It's going to take both the epoch number and also what have we calculated so far. So we can't just call CB parentheses I. We actually now have to assume that it's got some particular methods.
So here is, for example, a print step callback, which before calculation just says I'm about to start and after calculation it says I'm done and there it's running. So in this case, this callback didn't actually care about the epoch number or about the value. And so, it just has star, star, star, quargs in both places.
It doesn't have to worry about exactly what's being passed in, because it's not using them. So this is quite a good kind of use of this, is to basically create a function that's going to be used somewhere else and you don't care about one or more of the parameters or you want to make things more flexible.
So in this case, we don't get an error saying, because if we remove this, which looks like we should be able to do because we don't use anything, but here's a problem. It tried to call before calc I, and before calc doesn't take an I. So if you put in both positional and keyword arguments, it'll always work everywhere.
And so here we can actually use them. So let's actually use epoch and value to print out those details. So now you can see there it is printing them out. And in this case, I've put star, star, quargs at the end, because maybe in the future, there'll be some other things that are passed in and we want to make sure this doesn't break.
So it kind of makes it more resilient. The next thing we might want to do with callbacks is to actually change something. So a couple of things that we did last week, one was we wanted to be able to cancel out of a loop to stop early. The other thing we might want to do is actually change the value of something.
So in order to stop early, we could check. And also the other thing we might want to do is say, well, what if you don't want to define before calc or after calc? We wouldn't want everything to break. So we can actually check whether a callback is defined and only call it if it is.
And we could actually check the return value and then do something based on the return value. So here's something which will cancel out of our loop if the value that's been calculated so far is over 10. So here we stop. Okay? What if you actually want to change the way the calculation is being done?
So we could even change the way the calculation is being done by taking our calculation function, putting it into a class. And so now the value that it's calculated is an attribute of the class. And so now we could actually do something, a callback that reaches back inside the calculator and changes it, right?
So this is going to double the result if it's less than three. So if we run this, right, we now actually have to call this because it's a class, but you can see it's giving a different value. And so we're also taking advantage of this in the callbacks that we're using.
So this is kind of the ultimately flexible callback system. And so you'll see in this case, we actually have to pass the calculator object to the callback. So the way we do that is we've defined a callback method here, which checks to see whether it's defined and if it is, it grabs it and then it calls it passing in the calculator object itself so it's now available.
And so what we actually did last week is we didn't call this callback. We called this dunder call, which means we were able to do it like this, okay? Now you know, which do you prefer? It's kind of up to you, right? I mean, we had so many callbacks being called that I felt the extra noise of giving it a name was a bit messy.
On the other hand, you might feel that calling a callback isn't something you expect dunder call to do, in which case you can do it that way. So there's pros and cons, neither is right or wrong. Okay, so that's callbacks. We've been using dunder thingys a lot. Dunder thingys look like this.
And in Python, a dunder thingy is special somehow. Most languages kind of let you define special behaviors. For example, in C++, there's an operator keyword where if you define a function that says operator something like plus, you're defining the plus operator. So most languages tend to have special magic names you can give things that make something a constructor or a destructor or an operator.
I like in Python that all of the magic names actually look magic. They all look like that, which I think is actually a really good way to do it. So the Python docs have a data model reference where they tell you about all these special method names. And you can go through and you can see what are all the special things you can get your method to do.
Like you can override how it behaves with less than or equal to or et cetera, et cetera. There's a particular list I suggest you know, and this is the list. So you can go to those docs and see what these things do because we use all of these in this course.
So here's an example. Here's a sloppy adder plus. You pass in some number that you're going to add up. And then when you add two things together, it will give you the result of adding them up, but it will be wrong by 0.01. And that is called dunder add because that's what happens when you see plus.
This is called dunder init because this is what happens when an object gets constructed. And this is called dunder repre because this is what gets called when you print it out. So now I can create a one adder and a two adder and I can plus them together and I can see the result.
So that's kind of an example of how these special dunder methods work. So that's a bit of that Python stuff. There's another bit of code stuff that I wanted to show you, which you'll need to be doing a lot of, which is you need to be really good at browsing source code.
If you're going to be contributing stuff to fast AI or to the fast AI for Swift for TensorFlow or just building your own more complex projects, you need to be able to jump around source code. Or even just to find out how PyTorch does something, if you're doing some research, you need to really understand what's going on under the hood.
This is a list of things you should know how to do in your editor of choice. Any editor that can't do all of these things is worth replacing with one that can. Most editors can do these things, emacs can, visual studio code can, sublime can, and the editor I use most of the time, vim, can as well.
I'll show you what these things are in vim. On the forums there are already some topics saying how to do these things in other editors. If you don't find one that seems any good, feel free to create your own topic if you've got some tips about how to do these things or other useful things in your editor of choice.
I'm going to show you in vim for no particular reason, just because I use vim. My editor, it's called vim. One of the things I like about vim is I can use it in a terminal, which I find super helpful because I'm working on remote machines all the time and I like to be at least as productive in a terminal as I am on my local computer.
The first thing you should be able to do is to jump to a symbol. A symbol would be like a class or a function or something like that. For example, I might want to be able to jump straight to the definition of createCNN, but I can't quite remember the name of the function, createCNN.
I would go colon tag, create, I'm pretty sure it's create underscore something, and then I'd press tab a few times and it would loop through, there it is, createCNN, and then I'd hit enter. That's the first thing that your editor should do, is it should make it easy to jump to a tag even if you can't remember exactly what it is.
The second thing it should do is that you should be able to click on something like CNN learner and hit a button, which in vim's case is control right square bracket, and it should take you to the definition of that thing. Okay, let's create this CNN learner, what's this thing called a data bunch, right square bracket, okay, there's data bunch.
You'll also see that my vim is folding things, classes, and functions to make it easier for me to see exactly what's in this file. In some editors, this is called outlining, in some it's called folding, most editors should do this. Then there should be a way to go back to where you were before, in vim that's control T for going back up the tag stack.
So here's my CNN learner, here's my createCNN, and so you can see in this way it makes it nice and easy to kind of jump around a little bit. Something I find super helpful is to also be able to jump into the source code of libraries I'm using. So for example, here's chiming normal, so I've got my vim configured, so if I hit control right square bracket on that, it takes me to the definition of chiming normal in the PyTorch source code.
And I find docstrings kind of annoying, so I have mine folded up by default, but I can always open them up. If you use vim, the way to do that is to add additional tags for any packages that you want to be able to jump to. I'm sure most editors will do something pretty similar.
Now that I've seen how chiming normal works, I can use the same control T to jump back to where I was in my fast AI source code. Then the only other thing that's particularly important to know how to do is to just do more general searches. So let's say I wanted to find all the places that I've used Lambda, since we talked about Lambda today, I have a particular thing I use called ACK, I can say ACK Lambda, and here is a list of all of the places I've used Lambda, and I could click on one, and it will jump to the code where it's used.
Again most editors should do something like that for you. So I find with that basic set of stuff, you should be able to get around pretty well. If you're a professional software engineer, I know you know all this. If you're not, hopefully you're feeling pretty excited right now to discover that editors can do more than you realized.
And so sometimes people will jump on our GitHub and say, I don't know how to find out what a function is that you're calling because you don't list all your imports at the top of the screen, but this is a great place where you should be using your editor to tell you.
And in fact, one place that GUI editors can be pretty good is often if you actually just point at something, they will pop up something saying exactly where is that symbol coming from. I don't have that set up in VIM, so I just have to hit the right square bracket to see where something's coming from.
Okay, so that's some tips about stuff that you should be able to do when you're browsing source code, and if you don't know how to do it yet, please Google or look at the forums and practice. Something else we were looking at a lot last week and you need to know pretty well is variance.
So just a quick refresher on what variance is, or for those of you who haven't studied it before, here's what variance is. Variance is the average of how far away each data point is from the mean. So here's some data, right, and here's the mean of that data. And so the average distance for each data point from the mean is T, the data points, minus M, top mean.
Oh, that's zero. That didn't work. Oh, well of course it didn't work. The mean is defined as the thing which is in the middle, right? So of course that's always zero. So we need to do something else that doesn't have the positives and negatives cancel out. So there's two main ways we fix it.
One is by squaring each thing before we take the mean, like so. The other is taking the absolute value of each thing. So turning all the negatives and positives before we take the mean. So they're both common fixes for this problem. You can see though the first is now on a totally different scale, right?
The numbers were like 1, 2, 4, 8, 8, and this is 47. So we need to undo that squaring. So after we've squared, we then take the square root at the end. So here are two numbers that represent how far things are away from the mean. Or in other words, how much do they vary?
If everything's pretty close to similar to each other, those two numbers will be small. If they're wildly different to each other, those two numbers will be big. This one here is called the standard deviation. And it's defined as the square root of this one here which is called the variance.
And this one here is called the mean absolute deviation. You could replace this M made with various other things like median, for example. So we have one outlier here, 18. So in the case of the one where we took a square in the middle of it, this number is higher because the square takes that 18 and makes it much bigger.
So in other words, standard deviation is more sensitive to outliers than mean absolute deviation. So for that reason, the mean absolute deviation is very often the thing you want to be using because in machine learning outliers are more of a problem than to help a lot of the time.
But mathematicians and statisticians tend to work with standard deviation rather than mean absolute deviation because it makes their math proofs easier and that's the only reason. They'll tell you otherwise, but that's the only reason. So the mean absolute deviation is really underused and it actually is a really great measure to use and you should definitely get used to it.
There's a lot of places where I kind of notice that replacing things involving squares with things involving absolute values, the absolute value things just often work better. It's a good tip to remember that there's this kind of long-held assumption. We have to use a squared thing everywhere, but it actually often doesn't work as well.
This is our definition of variance. Notice that this is the same. So this is written in math. This written in math looks like this and it's another way of writing the variance. It's important because it's super handy and it's super handy because in this one here we have to go through the whole data set once.
To calculate the mean of the data and then a second time to get the squares of the differences. This is really nice because in this case, we only have to keep track of two numbers, the squares of the data and the sum of the data and as you'll see shortly, this kind of way of doing things is generally therefore just easier to work with.
So even though this is kind of the definition of the variance that makes intuitive sense, this is the definition of variance that you normally want to implement. And so there it is in math. The other thing we see quite a bit is covariance and correlation. So if we take our same data set, let's now create a second data set which is double t times a little bit of random noise, so here's that plotted.
Let's now look at the difference between each item of t and its mean and multiply it by each item of u and its mean, so there's those values and let's look at the mean of that. So what's this number? So it's the average of the difference of how far away the x value is from the mean of the x value is, the x's, multiplied by each difference between the y value and how far away from the y mean it is.
Let's compare this number to the same number calculated with this data set, where this data set is just some random numbers compared to v. And let's now calculate the exact same product, the exact same mean. This number's much smaller than this number. Why is this number much smaller? So if you think about it, if these are kind of all lined up nicely, then every time it's higher than the average on the x-axis, it's also higher than the average on the y-axis.
So you have two big positive numbers and vice versa, two big negative numbers. So in either case, you end up with, when you multiply them together, a big positive number. So this is adding up a whole bunch of big positive numbers. So in other words, this number tells you how much these two things vary in the same way, kind of how lined up are they on this graph.
And so this one, when one is big, the other's not necessarily big. When one is small, the other's not necessarily very small. So this is the covariance, and you can also calculate it in this way, which might look somewhat similar to what we saw before with our different variance calculation.
And again, this is kind of the easier way to use it. So as I say here, from now on, I don't want you to ever look at an equation or type in an equation in LaTeX without typing it in Python, calculating some values and plotting them. Because this is the only way we get a sense in here of what these things mean.
And so in this case, we're going to take our covariance and we're going to divide it by the product of the standard deviations. And this gives us a new number, and this is called correlation, or more specifically Pearson correlation coefficient. So we don't cover covariance and Pearson correlation coefficient too much in the course, but it is one of these things which it's often nice to see how things vary.
But remember, it's telling you really about how things vary linearly, right? So if you want to know how things vary non-linearly, you have to create something called a neural network and check the loss and the metrics. But it's kind of interesting to see also how variance and covariance, you can see they're much the same thing, you know, where else one of them, in fact, you basically you can think of it this way, right?
One of them is E of X squared, in other words, X and X are kind of the same thing, it's E of X times X, where else this is two different things, E of X times Y, right? And so rather than having here, we had E of X squared here and E of X squared here.
If you replace the second X with a Y, you get that and you get that. So they're like literally the same thing. And then again here, if X and X are the same, then this is just sigma squared, right? So the last thing I want to quickly talk about a little bit more is Softmax.
This was our final log Softmax definition from the other day, and this is the formula, the same thing as an equation. And this is our cross entropy loss, remember. So these are all important concepts we're going to be using a lot. So I just wanted to kind of clarify something that a lot of researchers that are published in big name conferences get wrong, which is when should you and shouldn't you use Softmax.
So this is our Softmax page from our entropy example spreadsheet, where we were looking at cat, dog, plane, fish, building. And so we had various outputs. This is just the activations that we might have gotten out of the last layer of our model. And this is just E to the power of each of those activations.
And this is just the sum of all of those E to the power ofs. And then this is E to the power of divided by the sum, which is Softmax. And of course, they all add up to one. This is like some image number one that gave these activations.
Here's some other image number two, which gave these activations, which are very different to these. But the Softmaxes are identical. That and that are identical. So that's weird. How has that happened? Well, it's happened because in every case, the E to the power of this divided by the sum of the E to the power ofs ended up in the same ratio.
So in other words, even though fish is only 0.63 here, but it's true here, once you take E to the power of, it's the same percentage of the sum, right? And so we end up with the same Softmax. Why does that matter? Well, in this model, it seems like being a fish is associated with having an activation of maybe like 2ish, right?
And this is only like 0.6ish. So maybe there's no fish in this. But what's actually happened is there's no cats or dogs or planes or fishes or buildings. So in the end then, because Softmax has to add to 1, it has to pick something. So it's fish that comes through.
And what's more is because we do this E to the power of, the thing that's a little bit higher, it pushes much higher because it's exponential, right? So Softmax likes to pick one thing and make it big. And they have to add up to 1. So the problem here is that I would guess that maybe image 2 doesn't have any of these things in it.
And we had to pick something. So it said, oh, I'm pretty sure there's a fish. Or maybe the problem actually is that this image had a cat and a fish and a building. But again, because Softmax, they have to add to 1, and one of them is going to be much bigger than the others.
So I don't know exactly which of these happened, but it's definitely not true that they both have an equal probability of having a fish in them. So to put this another way, Softmax is a terrible idea unless you know that every one of your, if you're doing image recognition, every one of your images, or if you're doing audio or tabular or whatever, every one of your items has one, no more than one, and definitely at least one example of the thing you care about in it.
Because if it doesn't have any of cat, dog, plane, fish, or building, it's still going to tell you with high probability that it has one of those things. Even if it has more than just one of cat, dog, plane, fish, or building, it'll pick one of them until you're pretty sure it's got that one.
So what do you do if there could be no things or there could be more than one of these things? Well, instead you use binomial, regular old binomial, which is e to the x divided by one plus e to the x. It's exactly the same as Softmax if your two categories are, has the thing and doesn't have the thing because they're like p and one minus p.
So you can convince yourself of that during the week. So in this case, let's take image one and let's go 1.02 divided by one plus 1.02. And ditto for each of our different ones. And then let's do the same thing for image two. And you can see now the numbers are different, as we would hope.
And so for image one, it's kind of saying, oh, it looks like there might be a cat in it if we assume 0.5 is a cut off, there's probably a fish in it, and it seems likely that there's a building in it, right? Whereas for image two, it's saying, I don't think there's anything in there, but maybe a fish.
And this is what we want, right? And so when you think about it, like for image recognition, probably most of the time you don't want Softmax. So why do we always use Softmax? Because we all grew up with ImageNet. And ImageNet was specifically curated, so it only has one of the classes in ImageNet in it, and it always has one of those classes in it.
An alternative, if you want to be able to handle the what if none of these classes are in it case, is you could create another category called background or doesn't exist or null or missing. So let's say you created this missing category. So there's six, cat, dog, plane, fish, building or missing, nothing.
A lot of researchers have tried that, but it's actually a terrible idea and it doesn't work. And the reason it doesn't work is because to be able to successfully predict missing, the penultimate layer activations have to have the features in it that is what a not cat, dog, plane, fish, fish or building looks like.
So how do you describe a not cat, dog, plane, fish or building? What are the things that would activate high? Is it shininess? Is it fur? Is it sunshine? Is it edges? No. It's none of those things. There is no set of features that when they're all high is clearly a not cat, dog, plane, fish or building.
So that's just not a kind of object. So a neural net can kind of try to hack its way around it by creating a negative model of every other single type and create a kind of not one of any of those other things. But that's very hard for it.
Whereas creating simply a binomial does it or doesn't it have this for every one of the classes is really easy for it, right? Because it just doesn't have a cat. Yes or no. It doesn't have a dog. Yes or no. And so forth. So lots and lots of well-regarded academic papers make this mistake.
So look out for it. And if you do come across an academic paper that's using softmax and you think, does that actually work with softmax? And you think maybe the answer is no. Try replicating it without softmax and you may just find you get a better result. An example of somewhere where softmax is obviously a good idea or something like softmax is obviously a good idea, language modeling.
What's the next word? It's definitely at least one word. It's definitely not more than one word, right? So you want softmax. So I'm not saying softmax is always a dumb idea, but it's often a dumb idea. So that's something to look out for. Okay. Next thing I want to do is I want to build a learning rate finder.
And to build a learning rate finder, we need to use this test callback kind of idea, this ability to stop somewhere. Problem is, as you may have noticed, this I want to stop somewhere callback wasn't working in our new refactoring where we created this runner class. And the reason it wasn't working is because we were turning true to mean cancel.
But even after we do that, it still goes on to do the next batch. And even if we set self.stop, even after we do that, it'll go on to the next epoch. So to like actually stop it, you would have to return false from every single callback that's checked to make sure it like really stops or you would have to add something that checks for self.stop in lots of places.
But it would be a real pain. Right? And it's also not as flexible as we would like. So what I want to show you today is something which I think is really interesting, which is using the idea of exceptions as a kind of control flow statement. You may have think of exceptions as just being a way of handling errors.
But actually exceptions are a very versatile way of writing very neat code that will be very helpful for your users. Let me show you what I mean. So let's start by just grabbing our MNIST data set as before and creating our data bunch as before. And here's our callback as before and our train eval callback as before.
But there's a couple of things I'm going to do differently. The first is, and this is a bit unrelated, but I think it's a useful refactoring, is previously inside runner end under call, we went through each callback in order, and we checked to see whether that particular method exists in that callback.
And if it was, we called it and checked whether it returns true or false. It actually makes more sense for this to be inside the callback class. Because by putting it into the callback class, the callback class is now taking a -- has a dunder call which takes a callback name, and it can do this stuff.
And what it means is that now your users who want to create their own callbacks, let's say they wanted to create a callback that printed out the callback name for every callback every time it was run. Or let's say they wanted to add a break point, like a set trace that happened every time the callback was run.
They could now create their own inherit from callback and actually replace dunder call itself with something that added this behavior they want. Or they could add something that looks at three or four different callback names and attaches to all of them. So this is like a nice little extra piece of flexibility.
It's not the key thing I wanted to show you, but it's an example of a nice little refactoring. The key thing I wanted to show you is that I've created three new types of exception. So an exception in Python is just a class that inherits from exception. And most of the time you don't have to give it any other behavior.
So to create a class that's just like its parent, but it just has a new name and no more behavior, you just say pass. So pass means this has all the same attributes and everything as the parent. But it's got a different name. So why do we do that?
Well, you might get a sense from the names. Cancel train exception, cancel epoch exception, cancel batch exception. The idea is that we're going to let people's callbacks cancel anything, you know, cancel at one of these levels. So if they cancel a batch, it will keep going with the next batch, but not finish this one.
If they cancel an epoch, it will keep going with the next epoch that will cancel this one. Cancel train will stop the training altogether. So how would cancel train exception work? Well, here's the same runner we had before. But now fit, we already had try finally to make sure that our after fit and remove learner happened, even if there's an exception, I've added one line of code.
Accept cancel train exception. And if that happens, then optionally it could call some after cancel train callback. But most importantly, no error occurs. It just keeps on going to the finally block and will elegantly and happily finish up. So we can cancel training. So now our test callback can after step, we'll just print out what step we're up to.
And if it's greater than or equal to 10, we will raise cancel train exception. And so now when we say run dot fit, it just prints out up to 10 and stops. There's no stack trace, there's no error. This is using exception as a control flow technique, not as an error handling technique.
So another example, inside all batches, I go through all my batches in a try block, except if there's a cancel epoch exception, in which case I optionally call an after cancel epoch callback and then continue to the next epoch. Or inside one batch, I try to do all this stuff for a batch, except if there's a cancel batch exception, I will optionally call the after cancel batch callback, and then continue to the next batch.
So this is like a super neat way that we've allowed any callback writer to stop any one of these three levels of things happening. So in this case, we're using cancel train exception to stop training. So we can now use that to create a learning rate finder. So the basic approach of the learning rate finder is that there's something in begin batch which, just like our parameter scheduler, is using exponential curve to set the learning rate.
So this is identical to parameter scheduler. And then after each step, it checks to see whether we've done more than the maximum number of iterations, which is defaulting to 100, or whether the loss is much worse than the best we've had so far. And if either of those happens, we will raise cancel train exception.
So to be clear, this neat exception-based approach to control flow isn't being used in the fast AI version one at the moment, but it's very likely that fast AI 1.1 or 2 will switch to this approach because it's just so much more convenient and flexible. And then assuming we haven't canceled, just see if the loss is better than our best loss, and if it is, then set best loss to the loss.
So now we can create a learner, we can add the LR find, we can fit, and you can see that it only does less than 100 epochs before it stops because the loss got a lot worse. And so now we know that we want something about there for our learning rate.
Okay, so now we have a learning rate finder. So let's go ahead and create a CNN, and specifically a scooter CNN. So we'll keep doing the same stuff we've been doing, get our MNIST data, normalize it. Here's a nice little refactoring because we very often want to normalize with this dataset and normalize both datasets using this dataset's mean and standard deviation.
Let's create a function called normalize2, which does that and returns the normalized training set and the normalized validation set. So we can now use that, make sure that it's behaved properly, that looks good. Create our data bunch, and so now we're going to create a CNN model, and the CNN is just a sequential model that contains a bunch of stride two convolutions.
And remember the input's 28 by 28, so after the first it'll be 14 by 14, then 7 by 7, then 4 by 4, then 2 by 2, then we'll do our average pooling, flatten it, and a linear layer, and then we're done. Now remember our original data is vectors of length 768, they're not 28 by 28, so we need to do a Bax.view one channel by 28 by 28 because that's what nn.com2d expects, and then minus one, the batch size remains whatever it was before.
So we need to somehow include this function in our nn.sequential, PyTorch doesn't support that by default, we could write our own class with a forward function, but nn.sequential is convenient for lots of ways, it has a nice representation, you can do all kinds of customizations with it, so instead we create a layer called lambda, an nn.module called lambda, it just pasted a function, and the forward is simply to call that function.
And so now we can say lambda, MNIST resize, and that will cause that function to be called. And here lambda flatten simply calls this function to be called, which removes that one comma one axis at the end after the adaptive average pooling. So now we've got a CNN model, we can grab our callback functions and our optimizer and our runner and we can run it, and six seconds later we get back one epochs result.
So that's at this point now getting a bit slow, so let's make it faster. So let's use CUDA, let's pop it on the GPU. So we need to do two things, we need to put the model on the GPU, which specifically means the model's parameters on the GPU. So remember a model contains two kinds of numbers, parameters, they're the things that you're updating, the things that it stores, and there's the activations, there's the things that it's calculating.
So it's the parameters that we need to actually put on the GPU. And the inputs to the model and the loss function, so in other words the things that come out of the data loader we need to put those on the GPU. How do we do that? With a callback of course.
So here's a CUDA callback, when you initialize it you pass it a device and then when you begin fitting you move the model to that device. So model.2.2 is part of PyTorch, it moves something with parameters or a tensor to a device and you can create a device by calling torch.device, pass it the string CUDA, and whatever GPU number you want to use, if you only have one GPU, it's device0.
Then when we begin a batch, let's go back and look at our runner. When we begin a batch, we've put xbatch and ybatch inside self.xb and self.yb. So that means we can change them. So let's set the runner's xb and the runner's yb to whatever they were before, but move to the device.
So that's it. That's going to run everything on CUDA. That's all we need. This is kind of flexible because we can put things on any device we want. Maybe more easily is just to call this once, which is torch.cuda.setdevice, and you don't even need to do this if you've only got one GPU.
And then everything by default will now be sent to that device. And then instead of saying .2device, we can just say .cuda. And so since we're doing pretty much everything with just one GPU for this course, this is the one we're going to export. So just model.cuda, xb.cuda, yb.cuda.
So that's our CUDA callback. So let's add that to our callback functions, grab our model and our runner and fit, and now we can do three epochs in five seconds versus one epoch in six seconds. So that's a lot better. And for a much deeper model, it'll be dozens of times faster.
So this is literally all we need to use CUDA. So that was nice and easy. Now we want to make it easier to create different kinds of architectures, make things a bit easier. So the first thing we should do is recognize that we go conv.value a lot. So let's pop that into a function called conv2d that just goes conv.value.
Since we use a kernel size of three and a stride of two in this MNIST model a lot, let's make those defaults. Also this model we can't reuse for anything except MNIST because it has a MNIST resize at the start. So we need to remove that. So if we're going to remove that, something else is going to have to do the resizing.
And of course the answer to that is a callback. So here's a callback which transforms the independent variable, the x, for a batch. And so you pass it some transformation function, which it stores away. And then begin batch simply replaces the batch with the result of that transformation function.
So now we can simply append another callback, which is the partial function application of that callback with this function. And this function is just to view something at one by 28 by 28. And you can see here we've used the trick we saw earlier of using underscore inner to define a function and then return it.
So this is something which creates a new view function that views it in this size. So for those of you that aren't that comfortable with closures and partial function application, this is a great piece of code to study, experiment, make sure that you feel comfortable with it. So using this approach, we now have the MNIST view resizing as a callback, which means we can remove it from the model.
So now we can create a generic getCNN model function that returns a sequential model containing some arbitrary set of layers, containing some arbitrary set of filters. So we're going to say, OK, this is the number of filters I have per layer, 8, 16, 32, 32. And so here is my getCNN layers.
And the last few layers is the average pooling, flattening, and the linear layer. The first few layers is for every one of those filters, length of the filters. It's a conv2D from that filter to the next one. And then what's the kernel size? The kernel size depends. It's a kernel size of 5 for the first layer, or 3 otherwise.
Why is that? Well, the number of filters we had for the first layer was 8. And that's a pretty reasonable starting point for a small model to start with 8 filters. And remember, our image had a single channel. And imagine if we had a single channel, and we were using 3 by 3 filters.
So as that convolution kernel scrolls through the image, at each point in time, it's looking at a 3 by 3 window. And it's just one channel. So in total, there's 9 input activations that it's looking at. And then it spits those into a dot product with-- sorry, it spits those into 8 dot products, so a matrix multiplication, I should say, 8 by 9.
And out of that will come a vector of length 8. Because we said we wanted 8 filters. So that's what a convolution does. And this seems pretty pointless. Because we started with 9 numbers, and we ended it with 8 numbers. So all we're really doing is just reordering them.
It's not really doing any useful computation. So there's no point making your first layer basically just shuffle the numbers into a different order. So what you'll find happens in, for example, most ImageNet models. Most ImageNet models are a little bit different, because they have three channels. So it's actually 3 by 3 by 3, which is 27.
But it's still kind of like-- quite often with ImageNet models, the first layer will be like 32 channels. So going from 27 to 32 is literally losing information. So most ImageNet models, they actually make the first layer 7 by 7, not 3 by 3. And so for a similar reason, we're going to make our first layer 5 by 5.
So we'll have 25 inputs for our 8 outputs. So this is the kind of things that you want to be thinking about when you're designing or reviewing an architecture, is like how many numbers are actually going into that little dot product that happens inside your CNN kernel. So that's something which can give us a CNN model.
So let's pop it all together into a little function that just grabs an optimization function, grabs an optimizer, grabs a learner, grabs a runner. And at this point, if you can't remember what any of these things does, remember, we've built them all by hand from scratch. So go back and see what we wrote.
There's no magic here. And so let's look if we say getCNNModel, passing in 8, 16, 32, 32. Here you can see 8, 16, 32, 32. Here's our 5 by 5, the rest are 3 by 3. They all have a stride of 2, and then a linear layer, and then train.
So at this point, we've got a fairly general simple CNN creator that we can fit. And so let's try to find out what's going on inside. How do we make this number higher? How do we make it train more stably? How do we make it train more quickly? Well, we really want to see what's going on inside.
We know already that different ways of initializing changes the variance of different layers. How do we find out if it's saturating somewhere, if it's too small, if it's too big, what's going on? So what if we replace nn.Sequential with our own sequential model class? And if you remember back, we've already built our own sequential model class before, and it just had these two lines of code, plus return.
So let's keep the same two lines of code, but also add two more lines of code that grabs the mean of the outputs and the standard deviation of the outputs, and saves them away inside a bunch of lists. So here's a list for every layer, for means, and a list for every layer for standard deviations.
So let's calculate the mean of standard deviations, pop them inside those two lists. And so now it's a sequential model that also keeps track of what's going on, the telemetry of the model. So we can now create it in the same way as usual, fit it in the same way as usual, but now our model has two extra things in it.
It has an act means, and it acts standard deviations. So let's plot the act means for every one of those lists that we had. And here it is, right? Here's all of the different means. And you can see it looks absolutely awful. What happens early in training is every layer, the means get exponentially bigger until they suddenly collapse, and then it happens again, and it suddenly collapses, and it happens again, and it suddenly collapses until eventually, it kind of starts training.
So you might think, well, it's eventually training, so isn't this okay? But my concern would be this thing where it kind of falls off a cliff. There's lots of parameters in our model. Are we sure that all of them are getting back into reasonable places, or is it just that a few of them have got back into a reasonable place?
Maybe the vast majority of them have zero gradients at this point. I don't know. It seems very likely that this awful training profile early in training is leaving our model in a really sad state. That's my guess, and we're going to check it to see later. But for now, we're just going to say, let's try to make this not happen.
And let's also look at the standard deviations, and you see exactly the same thing. This just looks really bad. So let's look at just the first 10 means, and they all look okay. They're all pretty close-ish to zero, which is about what we want. But more importantly, let's look at the standard deviations for the first 10 batches, and this is a problem.
The first layer has a standard deviation not too far away from one, but then, not surprisingly, the next layer is lower. The next layer is lower. As we would expect, because the first layer is less than one, the following layers are getting exponentially further away from one, until the last layer is really close to zero.
So now we can kind of see what was going on here, is that our final layers were getting basically no activations, they were basically getting no gradients. So gradually, it was moving into spaces where they actually at least had some gradient. But by the time they kind of got there, the gradient was so fast that they were kind of falling off a cliff and having to start again.
So this is the thing we're going to try and fix. And we think we already know how to, we can use some initialization. Yes, Rachel? Did you say that if we went from 27 numbers to 32, that we were losing information? And could you say more about what that means?
Yeah, I guess we're not losing information where, that was poorly said, where we're wasting information, I guess, where like that, if you start with 27 numbers and you do some matrix multiplication and end up with 32 numbers, you are now taking more space for the same information you started with.
And the whole point of a neural network layer is to pull out some interesting features. So you would expect to have less total activations going on because you're trying to say, oh, in this area, I've kind of pulled this set of pixels down into something that says how fairy this is or how much of a diagonal line does this have or whatever.
So increasing the actual number of activations we have for a particular position is a total waste of time. We're not doing any useful, we're wasting a lot of calculation. We can talk more about that on the forum if that's still not clear. So this idea of creating telemetry for your model is really vital.
This approach to doing it, where you actually write a whole new class that only can do one kind of telemetry is clearly stupid. And so we clearly need a better way to do it. And what's the better way to do it? It's callbacks, of course. Except we can't use our callbacks because we don't have a callback that says when you calculate this layer, call back to our code.
We have no way to do that, right? So we actually need to use a feature inside PyTorch that can call back into our code when a layer is calculated, either the forward pass or the backward pass. And for reasons I can't begin to imagine PyTorch doesn't call them callbacks, they're called hooks, right?
But it's the same thing. It's a callback, okay? And so we can say for any module, we can say register forward hook and pass in a function. This is a callback. It's a callback that will be called when this module's forward pass is calculated. Or you could say register backward hook and that will call this function when this module's backward pass is calculated.
So to replace that previous thing with hooks, we can simply create a couple of global variables to store our means and standard deviations for every layer. We can create a function to call back to to calculate the mean and standard deviation. And if you Google for the documentation for register forward hook, you will find that it will tell you that the callback will be called with three things.
The module that's doing the callback, the input to the module, and the output of that module, either the forward or the backward pass is appropriate. In our case, it's the output we want. And then we've got a fourth thing here because this is the layer number we're looking at, and we used partial to connect the appropriate closure with each layer.
So once we've done that, we can call fit, and we can do exactly the same thing. So this is the same thing, just much more convenient. And because this is such a handy thing to be able to do, fast.ai has a hook class. So we can create our own hook class now, which allows us to, rather than having this kind of messy global state, we can instead put the state inside the hook.
So let's create a class called hook that when you initialize it, it registers a forward hook on some function. And what it's going to do is it's going to recall back to this object. So we pass in self with the partial. And so that way, we can get access to the hook.
We can pop inside it our two empty lists when we first call this to store away our means and standard deviations. And then we can just append our means and standard deviations. So now, we just go hooks equals hook for layer in children of my model. And we'll just grab the first two layers because I don't care so much about the linear layers.
It's really the conf layers that are most interesting. And so now, this does exactly the same thing. Since we do this a lot, let's put that into a class two called hooks. So here's our hooks class, which simply calls hook for every module in some list of modules. Now something to notice is that when you're done using a hooked module, you should call hook.remove.
Because otherwise, if you keep registering more hooks on the same module, they're all going to get called. And eventually, you're going to run out of memory. So one thing I did in our hook class was I created a dunder dell. This is called automatically when Python cleans up some memory.
So when it's done with your hook, it will automatically call remove, which in turn will remove the hook. So I then have a similar thing in hooks. So when hooks is done, it calls self.remove, which in turn goes through every one of my registered hooks and removes them. You'll see that somehow I'm able to go for H in self, but I haven't registered any kind of iterator here.
And the trick is I've created something called a list container just above, which is super handy. It basically defines all the things you would expect to see in a list using all of the various special dunder methods and then some. It actually has some of the behavior of numpy as well.
We're not allowed to use numpy in our foundations, so we use this instead. And this actually also works a bit better than numpy for this stuff because numpy does some weird casting and weird edge cases. So for example, with this list container, it's got dunder get item. So that's the thing that gets called when you call something with square brackets.
So if you index into it with an int, then we just pass it off to the enclosed list because we gave it a list to enclose. If you send it a list of balls, like false, false, false, false, false, false, false, then it will return all of the things where that's true, or you can index it into it with a list, in which case it will return all of the index, the things that are indexed by that list.
For instance, it's got a length which just passes off to length and an iterator that passes off to iterator and so forth. And then we've also defined the representation for it such that if you print it out, it just prints out the contents unless there's more than 10 things in it, in which case it shows dot, dot, dot.
So with a nice little base class like this, so you can create really useful little base classes in much less than a screen full of code. And then we can use them, and we will use them everywhere from now on. So now we've created our own listy class that has hooks in it.
And so now we can just use it like this. We can just say hooks equals hooks, everything in our model with that function we had before, to pen stats, we can print it out to see all the hooks. We can grab a batch of data. So now we've got one batch of data.
And check its mean and standard deviation is about zero one, as you would expect. We can pass it through the first layer of our model. Model zero is the first layer of our model, which is the first convolution. And our mean is not quite zero, and our standard deviation is quite a lot less than one, as we kind of know what's going to happen.
So now we'll just go ahead and initialize it with timing. And after that, our variance is quite close to one. And our mean, as expected, is quite close to 0.5 because of the value. So now we can go ahead and create our hooks and do a fit. And we can plot the first 10 means and standard deviations, and then we can plot all the means and standard deviations, and there it all is.
And this time we're doing it after we've initialized all the layers of our model. And as you can see, we don't have that awful exponential crash, exponential crash, exponential crash. So this is looking much better. And you can see early on in training, our variances all look, our standard deviations all look much closer to one.
So this is looking super hopeful. I've used a with block. A with block is something that will create this object, give it this name, and when it's finished, it will do something. The something it does is to call your dunder exit method here, which will not remove. So here's a nice way to ensure that things are cleaned up.
For example, your hooks are removed. So that's why we have a dunder enter. That's what happens when you start the with block, dunder exit when you finish the with block. So this is looking very hopeful, but it's not quite what we wanted to know. Really the concern was, does this actually do something bad?
Is it actually, or does it just train fine afterwards? So something bad really is more about how many of the activations are really, really small. How well is it actually getting everything activated nicely? So what we could do is we could adjust our append stats. So not only does it have a mean and a standard deviation, but it's also got a histogram.
So we could create a histogram of the activations, pop them into 40 bins between 0 and 10. We don't need to go underneath 0 because we have a value. So we know that there's none underneath 0. So let's again run this. We will use our timing initialization. And what we find is that even with that, if we make our learning rate really high, 0.9, we can still get this same behavior.
And so here's plotting the entire histogram. And I should say thank you to Stefano for the original code here from our San Francisco study group to plot these nicely. So you can see this kind of grow, collapse, grow, collapse, grow, collapse thing. The biggest concern for me though is this yellow line at the bottom.
The yellow line, yellow is where most of the histogram is. Actually, what I really care about is how much yellow is there. So let's say the first two histogram bins are 0 or nearly 0. So let's get the sum of how much is in those two bins and divide by the sum of all of the bins.
And so that's going to tell us what percentage of the activations are 0 or nearly 0. Let's plot that for each of the first four layers. And you can see that in the last layer, it's just as we suspected, over 90% of the activations are actually 0. So if you were training your model like this, it could eventually look like it's training nicely without you realizing that 90% of your activations were totally wasted.
And so you're never going to get great results by wasting 90% of your activations. So let's try and fix it. Let's try and be able to train at a nice high learning rate and not have this happen. And so the trick is, is we're going to try a few things, but the main one is we're going to use our better ReLU.
And so we've created a generalized ReLU class where now we can pass in things like an amount to subtract from the ReLU because remember we thought subtracting half from the ReLU might be a good idea. We can also use leaky ReLU and maybe things that are too big are also a problem.
So let's also optionally have a maximum value. So in this generalized ReLU, if you passed a leakiness, then we'll use leaky ReLU. Otherwise we'll use normal ReLU. You could very easily write these leaky ReLU by hand, but I'm just trying to make it run a little faster by taking advantage of PyTorch.
If you said I want to subtract something from it, then go ahead and subtract that from it. If I said there's some maximum value, go ahead and clamp it at that maximum value. So here's our generalized ReLU. And so now let's have our conv layer and getCNN layers both take a * * quags and just pass them on through so that eventually they end up passed to our generalized ReLU.
And so that way we're going to be able to create a CNN and say what ReLU characteristics do we want nice and easily. And even getCNN model will pass down quags as well. So now that our ReLU can go negative because it's leaky and because it's subtracting stuff, we'll need to change our histogram so it goes from -7 to 7 rather than from 0 to 10.
So we'll also need to change our definition of getMin so that the middle few bits of the histogram are 0 rather than the first two. And now we can just go ahead and train this model just like before and plot just like before. And this is looking pretty hopeful.
Let's keep looking at the rest. So here's the first one, two, three, four layers. So compared to that, which was expand, does die, expand, die, expand, die, we're now seeing this is looking much better. It's straight away. It's using the full richness of the possible activations. There's no death going on.
But our real question is how much is in this yellow line? There's a question. And let's see, in the final layer, look at that, less than 20%. So we're now using nearly all of our activations by being careful about our initialization and our ReLU, but we're still training at a nice high learning rate.
So this is looking great. Could you explain again how to read the histograms? Sure. So the four histogram, let's go back to the earlier one. So the four histograms are simply the four layers. So after the first, second, third, fourth. And the x-axis is the iteration. So each one is just one more iteration as most of our plots show.
The y-axis is how many activations are the highest they can be or the lowest they can be. So what this one here is showing us, for example, is that there are some activations that are at the max and some activations are in the middle and some activations at the bottom, whereas this one here is showing us that all of the activations are basically zero.
So what this shows us in this histogram is that now we're going all the way from plus seven to minus seven because we can have negatives. This is zero. It's showing us that most of them are zero because yellow is the most energy. There are activations throughout everything from the bottom to the top.
And a few less than zero, as we would expect, because we have a leaky value and we also have that minus, we're not doing minus 0.5, we're doing minus 0.4, because leaky value means that we don't need to subtract half anymore, we subtract a bit less than half. And so then this line is telling us what percentage of them are zero or nearly zero.
And so this is one of those things which is good to run lots of experiments in the notebook yourself to get a sense of what's actually in these histograms. So you can just go ahead and have a look at each hook's stats. And the third thing in it will be the histograms, so you can see what shape is it and how is it calculated and so forth.
So now that we've done that, this is looking really good. So what actually happens if we train like this? So let's do a one cycle training. So use that combined sheds we built last week, 50/50 two phases, cosine scheduling, cosine annealing. So gradual warm up, gradual cool down, and then run it for eight epochs.
And there we go, we're doing really well, we're getting up to 98%. So this kind of, we hardly were really training in a thing, we were just trying to get something that looked good. And once we had something that looked good in terms of the telemetry, it's really training really well.
One option I added by the way in at CNN was I added a uniform Boolean, which will set the initialization function to chiming normal, if it's false, which is what we've been using so far, or chiming uniform, if it's true. Timing uniform, so now I've just trained the same model with uniform equals true.
A lot of people think that uniform is better than normal, because a uniform random number is less often close to zero. And so the thinking is that maybe uniform random, uniform initialization might cause it to kind of have a better richness of activations. I haven't studied this closely, I'm not sure I've seen a careful analysis in a paper.
In this case, 9822 versus 9826, they're looking pretty similar, but that's just something else that it's there to play with. So at this point, we've got a pretty nice bunch of things you can look at now, and so you can see as your problem to play with during the week is how accurate can you make a model?
Just using the layers we've created so far. And for the ones that are great accuracy, what does the telemetry look like? How can you tell whether it's going to be good? And then what insights can you gain from that to make it even better? So in the end, try to beat me, try to beat 98%.
You'll find you can beat it pretty easily with some playing around, but do some experiments. All right, so that's kind of about what we can do with initialization. You can go further, as we discussed with Selu or with Fixup, like there are these really finely tuned initialization methods that you can do 1,000 layers deep, but they're super fiddly.
So generally, I would use something like the layer-wise sequential unit variance, LSUV thing that we saw earlier in ... Oh, sorry, we haven't done that one yet. Okay, we're going to do that next. Okay, so, forget I said that. So that's kind of about as far as we can get with basic initialization.
To go further, we really need to use normalization, of which the most commonly known approach to normalization in the model is batch normalization. So let's look at batch normalization. So batch normalization has been around since, I think, about 2005. This is the paper. And they first of all describe a bit about why they thought batch normalization was a good idea, and by about page 3, they provide the algorithm.
So it's one of those things that if you don't read a lot of math, it might look a bit scary. But then when you look at it for a little bit longer, you suddenly notice that this is literally just the mean, sum divided by the count. And this is the mean of the difference to the mean squared, and it's the mean of that.
Oh, that's just what we looked at, that's variance. And this is just subtract the mean, divide by the standard deviation. Oh, that's just normalization. So once you look at it a second time, you realize we've done all this. We've just done it with code, not with math. And so then, the only thing they do is after they've normalized it in the usual way, is that they then multiply it by gamma, and they add beta.
What are gamma and beta? They are parameters to be learned. What does that mean? That's the most important line here. Remember that there are two types of numbers in a neural network, parameters and activations. Activations are things we calculate, parameters are things we learn. So these are just numbers that we learn.
So that's all the information we need to implement batch norm. So let's go ahead and do it. So first of all, we'll grab our data as before, create our callbacks as before. Here's our pre-batch norm version, 96.5%. And the highest I could get was a 0.4 learning rate this way.
And so now let's try batch norm. So here's batch norm. So let's look at the forward first. We're going to get the mean and the variance. And the way we do that is we call update stats, and the mean is just the mean. And the variance is just the variance.
And then we subtract the mean, and we divide by the square root of the variance. And then we multiply by, and then I didn't call them gamma and beta, because why use Greek letters when, because who remembers which one's gamma and which one's beta? Let's use English. The thing we multiply, we'll call the malts, and the things we add, we'll call the ads.
And so malts and ads are parameters. We multiply by a parameter that initially is just a bunch of ones, so it does nothing. And we add a parameter which is initially just a bunch of zeros, so it does nothing. But they're parameters, so they can learn. Just like our, remember our original linear layer we created by hand just looked like this.
In fact, if you think about it, ads is just bias. It's identical to the bias we created earlier. So then there's a few extra little things we have to think about. One is what happens at inference time, right? So during training, we normalize. But the problem is that if we normalize in the same way at inference time, if we get like a totally different kind of image, we might kind of remove all of the things that are interesting about it.
So what we do is while we're training, we keep an exponentially weighted moving average of the means and the variances. I'll talk more about what that means in a moment. But basically we've got a running average of the last few batches means and a running average of the last few batches variances.
And so then when we're not training, in other words at inference time, we don't use the mean and variance of this mini-batch, we use that running average mean and variance that we've been keeping track of. So how do we calculate that running average? Well, we don't just create something called self.vars.
We go self.register buffer vars. Now that creates something called self.vars. So why didn't we just say self.vars=torch.ones? Why do we say self.register buffer? It's almost exactly the same as saying self.vars=torch.ones, but it does a couple of nice things. The first is that if we move the model to the GPU, anything that's registered as a buffer will be moved to the GPU as well.
And if we didn't do that, then it's going to try and do this calculation down here. And if the vars and means aren't on the GPU, but everything else is on the GPU, we'll get an error. It'll say, "Oh, you're trying to add this thing on the CPU to this thing on the GPU, and it'll fail." So that's one nice thing about register buffer.
The other nice thing is that the variances and the means, these running averages, they're part of the model, right? When we do inference, in order to calculate our predictions, we actually need to know what those numbers are. So if we save the model, we have to save those variances and means.
So register buffer also causes them to be saved along with everything else in the model. So that's what register buffer does. So the variances, we start them out at ones, the means we start them out at zeros. We then calculate the mean and variance of the minibatch, and we average out the axes zero, two, and three.
So in other words, we average over all the batches, and we average over all of the x and y coordinates. So all we're left with is a mean for each channel, or a mean for each filter. Keepgame equals true means that it's going to leave an empty unit access in positions zero, two, and three, so it'll still broadcast nicely.
So now, we want to take a running average. So normally, if we want to take a moving average, if we've got a bunch of data points, we want a moving average. We would grab five at a time, and we would take the average of those five, and they would take the next five, and we'd take their average, and we keep doing that a few at a time.
We don't want to do that here, though, because these batch norm statistics, every single activation has one. So it's giant. Models can have hundreds of millions of activations. We don't want to have to save a whole history of every single one of those, just so that we can calculate an average.
So there's a handy trick for this, which is instead to use an exponentially weighted moving average. And basically, what we do is we start out with this first point, and we say, okay, our first average is just the first point. So let's say, I don't know, that's three. And then the second point is five.
And what we do is to take an exponentially weighted moving average, we first of all need some number, which we call momentum, let's say it's 0.9. So for the second value, so for the first value, our exponentially weighted moving average, which we'll call mu, equals three. And then for the second one, we take mu one, we multiply it by our momentum, and then we add our second value, five, and we multiply it by one minus our momentum.
So in other words, it's mainly whatever it used to be before, plus a little bit of the new thing. And then mu two, sorry, mu three equals mu two times 0.9 plus, and maybe this one here is four, the new one times 0.1. So we're basically continuing to say, oh, it's mainly the thing before plus a little bit of the new one.
And so what you end up with is something where, like by the time we get to here, the amount of influence of each of the previous data points, once you calculate it out, it turns out to be exponentially decayed. So it's a moving average with an exponential decay, with the benefit that we only ever have to keep track of one value.
So that's what an exponentially weighted moving average is. This thing we do here, where we basically say we've got some function where we say it's some previous value times 0.9, say, plus some other value times one minus that thing. This is called a linear interpolation. It's a bit of this and a bit of this other thing, and the two together make one.
Linear interpolation in PyTorch is spelt lerp. So we take the means, and then we lerp with our new mean using this amount of momentum. Unfortunately, lerp uses the exact opposite of the normal sense of momentum. So momentum of 0.1 in batch norm actually means momentum of 0.9 in normal person speak.
So this is actually how nn.batchnorm works as well. So batch norm momentum is the opposite of what you would expect. I wish they'd given it a different name. They didn't, sadly. So this is what we're stuck with. So this is the running average means instead of deviations. So now we can go ahead and use that.
So now we can create a new conv layer, which you can optionally say whether you want batch norm. If you do, we append a batch norm layer. If we do append a batch norm layer, we remove the bias layer. Because remember I said that the ads in batch norm just is a bias.
So there's no point having a bias layer anymore. So we'll remove the unnecessary bias layer. And so now we can go ahead and initialize our CNN. This is a slightly more convenient initialization now that's actually going to go in and recursively initialize every module inside our module, the weights in the standard deviations.
And then we will train it with our hooks. And you can see our mean starts at 0 exactly. And our standard deviation starts at 1 exactly. So our training has entirely gotten rid of all of the exponential growth and sudden crash stuff that we had before. There's something interesting going on at the very end of training, which I don't quite know what that is.
I mean, when I say the end of training, we've only done one epoch. But this is looking a lot better than anything we've seen before. I mean, that's just a very nice-looking curve. And so we're now able to get up to learning rates up to 1. We've got 97% accuracy after just three epochs.
This is looking very encouraging. So now that we've built our own batch norm, we're allowed to use PyTorch's batch norm. And we get pretty much the same results. Sometimes it's 97, sometimes it's 98. This is just random variation. So now that we've got that, let's try going crazy. Let's try using our little one-cycle learning scheduler we had.
And let's try and go all the way up to a learning rate of 2. And look at that. We totally can, right? And we're now up towards nearly 99% accuracy. So batch norm really is quite fantastic. Batch norm has a bit of a problem, though, which is that you can't apply it to what we call online learning tasks.
In other words, if you have a batch size of 1, right, so you're getting a single item at a time and learning from that item, what's the variance of that batch? The variance of a batch of 1 is infinite, right? So we can't use batch norm in that case.
Well, what if we're doing like a segmentation task where we can only have a batch size of 2 or 4, which we've seen plenty of times in part 1? That's going to be a problem, right? Because across all of our layers, across all of our training, across all of the channels, the batch size of 2, at some point, those two values are going to be the same or nearly the same.
And so we then divide by that variance, which is about 0. We have infinity, right? So we have this problem where any time you have a small batch size, you're going to get unstable or impossible training. It's also going to be really hard for RNNs. Because for RNNs, remember, it looks something like this, right?
We have this hidden state, and we use the same weight matrix again and again and again. Right? Remember, we can unroll it, and it looks like this. If you've forgotten, go back to lesson 7. And then we can even stack them together into two RNNs. One RNN fits to another RNN.
And if we unroll that, it looks like this. And remember, these state, you know, time step to time step transitions, if we're doing IMDB with a movie review with 2,000 words, there's 2,000 of these. And this is the same weight matrix each time, and the number of these circles will vary.
It's the number of time steps will vary from document to document. So how would you do batch norm, right? How would you say what's the running average of means and variances? Because you can't put a different one between each of these unrolled layers, because, like, this is a for loop, remember?
So we can't have different values every time. So it's not at all clear how you would insert batch norm into an RNN. So batch norm has these two deficiencies. How do we handle very small batch sizes all the way down to a batch size of one? How do we handle RNNs?
So this paper called layer normalization suggests a solution to this. And the layer normalization paper from Jimmy Barr and Kyros and Jeffrey Hinton, who just won the Turing Award with Yoshua, Benjio, and Yann LeCun, which is kind of the Nobel prize of computer science, they created this paper, which, like many papers, when you read it, it looks reasonably terrifying, particularly once you start looking at all this stuff.
But actually, when we take this paper and we convert it to code, it's this. Now, which is not to say the paper's garbage, it's just that the paper has lots of explanation about what's going on and what do we find out and what does that mean. But the actual, what's layer norm?
It's the same as batch norm, but rather than saying x dot means 0, 2, 3, you say x dot mean 1, 2, 3, and you remove all the running averages. So this is layer norm with none of that running average stuff. And the reason we don't need the running averages anymore is because we're not taking the mean across all the items in the batch.
Every image has its own mean, every image has its own standard deviation. So there's no concept of having to average across things in a batch. And so that's all layer norm is. We also average over the channels. So we average over the channels and the x and the y for each image individually.
So we don't have to keep track of any running averages. The problem is that when we do that and we train, even at a lower learning rate of 0.8, it doesn't work. Layer norm's not as good. So it's a workaround we can use, but because we don't have the running averages at inference time and more importantly, because we don't have a different normalization for each channel, we're just throwing them all together and pretending they're the same and they're not.
So layer norm helps, but it's nowhere near as good as batch norm. But for RNNs, what you have to use is something like this. So here's a thought experiment. What if you're using layer norm on the actual input data and you're trying to distinguish between foggy days and sunny days?
So foggy days will have less activations on average because they're less bright and they will have less contrast. In other words, they have lower variance. So layer norm would cause the variances to be normalized to be the same and the means to be normalized to be the same. So now the sunny day picture and the hazy day picture would have the same overall kind of activations and amount of contrast.
And so the answer to this question is, no, you couldn't. With layer norm, you would literally not be able to tell the difference between pictures of sunny days and pictures of foggy days. Now, it's not only if you put the layer norm on the input data, which you wouldn't do, but everywhere in the middle layers, it's the same, right?
Anywhere where the overall level of activation or the amount of difference of activation is something that is part of what you care about, it throws it away. It's designed to throw it away. Furthermore, if your inference time is using things from kind of a different distribution where that different distribution is important, it throws that away.
So layer norm's a partial hacky workaround for some very genuine problems. There's also something called instance norm, and instance norm is basically the same thing as layer norm. It's a bit easier to read in the paper because they actually lay out all the indexes. So a particular output for a particular batch for a particular channel for a particular x for a particular y is equal to the input for that batch and channel in x and y minus the mean for the batch and the channel.
So in other words, it's the same as layer norm, but now it's mean 2,3 rather than mean 1,2,3. So you can see how all these different papers, when you turn them into code, they're tiny variations, right? Instance norm, even at a learning rate of 0.1, doesn't learn anything at all.
Why can't it classify anything? Because we're now taking the mean, removing the difference in means and the difference in activations for every channel and for every image, which means we've literally thrown away all the things that allow us to classify. Does that mean that instance norm is stupid? No, certainly not.
It wasn't designed for classification. It was designed for style transfer, where the authors guessed that these differences in contrast and overall amount were not important, or something they should remove from trying to create things that looked like different types of pictures. It turned out to work really well. But you've got to be careful, right?
You can't just go in and say, "Oh, here's another normalization thing, I'll try it." You've got to actually know what it's for to know whether it's going to work. So then finally, there's a paper called Group Norm, which has this wonderful picture, and it shows the differences. Batch Norm is averaging over the batch, and the height, and the width, and is different for each channel.
Layer Norm is averaging for each channel, for each height, for each width, and is different for each element of the batch. Instance Norm is averaging over height and width, and is different for each channel and each batch. And then Group Norm is the same as Instance Norm, but they arbitrarily group a few channels together and do that.
So Group Norm is a more general way to do it. In the PyTorch docs, they point out that you can actually turn Group Norm into Instance Norm, or Group Norm into Layer Norm, depending on how you group things up. So there's all kinds of attempts to work around the problem that we can't use small batch sizes, and we can't use RNNs with Batch Norm.
But none of them are as good as Batch Norm. So what do we do? Well, I don't know how to fix the RNN problem, but I think I know how to fix the batch size problem. So let's start by taking a look at the batch size problem in practice.
Let's create a new data bunch with a batch size of 2. And so here's our Conf layer, as before, with our Batch Norm. And let's use a learning rate of 0.4 and fit that. And the first thing you'll notice is that it takes a long time. Small batch sizes take a long time, because it's just lots and lots of kernel launches on the GPU, it's just a lot of overhead.
Something like this might even run faster on the CPU. And then you'll notice that it's only 26% accurate, which is awful. Why is it awful? Because of what I said, the small batch size is causing a huge problem. Because quite often, there's one channel in one layer where the variance is really small, because those two numbers just happen to be really close, and so it blows out the activations out to a billion, and everything falls apart.
There is one thing we could try to do to fix this really easily, which is to use Epsilon. What's Epsilon? Let's go take a look at our code. Here's our Batch Norm. Look, we don't divide by the square root of variance. We divide by the square root of variance plus Epsilon, where Epsilon is 1e neg 5.
Epsilon's a number that computer scientists and mathematicians, they use this Greek letter very frequently to mean some very small number. And in computer science, it's normally a small number that you add to avoid floating point rounding problems and stuff like that. So it's very common to see it on the bottom of a division to avoid dividing by such small numbers that you can't calculate things in floating point properly.
But our view is that Epsilon is actually a fantastic hyperparameter that you should be using to train things better. And here's a great example. With Batch Norm, what if we didn't set Epsilon to 1e neg 5? But what if we set it to 0.1? If we set Epsilon to 0.1, then that basically would cause this to never make the overall activations be multiplied by anything more than 10.
Sorry, that would be 0.01 because we're taking the square root. So if you set it to 0.01, let's say the variance was 0, it would be 0 plus 0.01 square root. So it ends up dividing by 0.1, which ends up multiplying by 10. So even in the worst case, it's not going to blow out.
I mean, it's still not great because there actually are huge differences in variance between different channels and different layers, but at least this would cause it to not fall apart. So option number one would be use a much higher Epsilon value. And we'll keep coming back to this idea that Epsilon appears in lots of places in deep learning and we should use it as a hyper parameter we control and take advantage of.
But we have a better idea. We think we have a better idea, which is we've built a new algorithm called running batch norm. And running batch norm, I think, is the first true solution to the small batch size batch norm problem. And like everything we do at fast AI, it's ridiculously simple.
And I don't know why no one's done it before. Maybe they have and I've missed it. And the ridiculously simple thing is this. In the forward function for running batch norm, don't divide by the batch standard deviation. Don't subtract the batch mean, but instead use the moving average statistics at training time as well.
Not just at inference time. Why does this help? Because let's say you're using a batch size of two. Then from time to time, in this particular layer, in this particular channel, you happen to get two values that are really close together and they have a variance really close to zero.
But that's fine because you're only taking point one of that and point nine of whatever you had before. Like that's how running averages work. So if previously the variance was one, now it's not 1e neg five, it's just point nine. So in this way, as long as you don't get really unlucky and have the very first batch be dreadful, because you're using this moving average, you never have this problem.
So let's take a look. We'll look at the code in a moment, but let's do the same thing, 0.4. We're going to use our running batch norm. We train it for one epoch, and instead of 26% accuracy, it's 91% accuracy. So it totally nails it. In one epoch, just a two batch size and a pretty high learning rate.
There's quite a few details we have to get right to make this work. But they're all details that we're going to see in lots of other places in this course. We're just kind of seeing them here for the first time. So I'm going to show you all of the details, but don't get overwhelmed.
We'll keep coming back to them. The first detail is something very simple, which is in normal batch norm, we take the running average of variance, but you can't take the running average of variance. It doesn't make sense to take the running average of variance. It's a variance. You can't just average a bunch of variances, particularly because they might even be different batch sizes, because batch size isn't necessarily constant.
Instead, as we learned earlier in the class, the way that we want to calculate variance is like this, sum of expected value of mean of X squared minus mean of X squared. So let's do that. Let's just, as I mentioned, we can do, let's keep track of the squares and the sums.
So we register a buffer called sums and we register a buffer called squares and we just go X dot sum over 023 dimensions and X times X dot sum, so squared. And then we'll take the lerp, the exponentially weighted moving average of the sums and the squares. And then for the variance, we will do squares divided by count minus squared mean.
So it's that formula. So that's detail number one that we have to be careful of. Detail number two is that the batch size could vary from many batch to many batch. So we should also register a buffer for count and take an exponentially weighted moving average of the counts, of the batch sizes.
So that basically tells us, so what do we need to divide by each time? The amount we need to divide by each time is the total number of elements in the mini-batch divided by the number of channels. That's basically grid X times grid Y times batch size. So let's take an exponentially weighted moving average of the count and then that's what we will divide by for both our means and variances.
That's detail number two. Detail number three is that we need to do something called debiasing. So debiasing is this. We want to make sure that at every point, and we're going to look at this in more detail when we look at optimizers, we want to make sure that every point that no observation is weighted too highly.
And the problem is that the normal way of doing moving averages, the very first point gets far too much weight because it appears in the first moving average and the second and the third and the fourth. So there's a really simple way to fix this, which is that you initialize both sums and squares to zeros and then you do a lerp in the usual way and let's see what happens when we do this.
So let's say our values are 10 and then 20. These are the first two values we get. So actually we only need to look at the first value. So the value, so actually let's say the value is 10. So we initialize our mean to zero at the very start of training.
And then the value that comes in is 10. So we would expect the moving average to be 10. But our lerp formula says it's equal to our previous value, which is 0, times 0.9 plus our new value times 0.1 equals 0 plus 1, equals 1, it's 10 times too small.
So that's very easy to correct for because we know it's always going to be wrong by that amount. So we then divide it by 0.1 and that fixes it. And then the second value has exactly the same problem. It's got too much zero in it. But this time it's actually going to be divided by, let's not call it 0.1.
Let's call it 1 minus 0.9. Because when you work through the math, you'll see the second one, it's going to be divided by 1 minus 0.9 squared and so forth. So this thing here where we divide by that, that's called debiasing. It's going to appear again when we look at optimization.
So you can see what we do is we have a exponentially weighted debiasing amount where we simply keep multiplying momentum times the previous debiasing amount. So initially it's just equal to momentum and then momentum squared and then momentum cubed and so forth. So then we do what I just said, we divide by the debiasing amount.
And then there's just one more thing we do, which is remember how I said you might get really unlucky that your first mini-batch is just really close to zero and we don't want that to destroy everything. So I just say if you haven't seen more than a total of 20 items yet, just clamp the variance to be no smaller than 0.01, just to avoid blowing out of the water.
And then the last two lines are the same. So that's it, right? It's all pretty straightforward arithmetic. It's a very straightforward idea, but when we put it all together, it's shockingly effective. And so then we can try an interesting thought experiment, so here's another thing to try during the week.
What's the best accuracy you can get in a single epoch? So say run.fit 1. And with this convolutional with running batch norm layer and a batch size of 32 and a linear schedule from one to 0.2, I got 97.5%. I only tried a couple of things, so this is definitely something that I hope you can beat me at.
But it's really good to create interesting little games to play. In research, we call them toy problems. Almost everything in research is basically toy problems. Come up with toy problems and try to find good solutions to them. So another toy problem for this week is what's the best you can get using whatever kind of normalization you like, whatever kind of architecture you like, as long as it only uses concepts we've used up to lesson 7 to get the best accuracy you can in one epoch.
So yeah, that's basically it. So what's the future of running batch norm? I mean, it's kind of early days. We haven't published this research yet. We haven't done all the kind of ablation studies and stuff we need to do yet. At this stage, though, I'm really excited about this.
Every time I've tried it on something, it's been working really well. The last time that we had something in a lesson that we said, this is unpublished research that we're excited about, it turned into ULM fit, which is now a really widely used algorithm and was published at the ACL.
So fingers crossed that this turns out to be something really terrific as well. But either way, you've kind of got to see the process, because literally building these notebooks was the process I used to create this algorithm. So you've seen the exact process that I used to build up this idea and do some initial testing of it.
So hopefully that's been fun for you, and see you next week. (audience applauds)