back to indexLesson 10 (2019) - Looking inside the model
Chapters
0:0 Introduction
0:21 Dont worry
2:8 Where are we
2:49 Training loop
4:26 Faster Audio
5:45 Vision Topics
7:12 What is a callback
11:20 Creating a callback
13:31 lambda notation
17:23 partial function application
18:25 class as a callback
19:53 kwargs
24:13 change
27:22 dunder
29:34 editor tips
35:21 variance
40:15 covariance
44:59 softmax
52:44 learning rate finder
54:34 refactoring
00:00:00.000 |
Welcome to lesson 10, which I've rather enthusiastically titled wrapping up our CNN, but looking at 00:00:10.020 |
how many things we want to cover, I've added a nearly to the end and I'm not actually sure 00:00:16.880 |
We'll probably have a few more things to cover next week as well. 00:00:22.680 |
I just wanted to remind you after hearing from a few folks during the week who are very 00:00:28.200 |
sad that they're not quite keeping up with everything. 00:00:34.480 |
As I mentioned in lesson one, I'm trying to give you enough here to keep you busy until 00:00:44.780 |
So you can dive into the bits you're interested in and go back and look over stuff and yeah, 00:00:52.200 |
don't feel like you have to understand everything within within a week of first hearing it. 00:00:57.580 |
But also if you're not putting in the time during the homework or you didn't put in the 00:01:02.440 |
time during the homework in the last part, you know, expect to have to go back and recover 00:01:07.160 |
things particularly because a lot of the stuff we covered in part one, I'm kind of assuming 00:01:11.880 |
that you're deeply comfortable with at this point. 00:01:16.100 |
Not because you're stupid if you're not, but just because it gives you the opportunity 00:01:20.000 |
to go back and re-study it and practice and experiment until you are deeply comfortable. 00:01:26.320 |
So yeah, if you're finding it whizzing along at a pace, that is because it is whizzing 00:01:32.880 |
Also it's covering a lot of more software engineering kind of stuff, which for the people 00:01:38.720 |
who are practicing software engineers, you'll be thinking this is all pretty straightforward. 00:01:42.640 |
And for those of you that are not, you'll be thinking, wow, there's a lot here. 00:01:47.880 |
Part of that is because I think data scientists need to be good software engineers. 00:01:52.080 |
So I'm trying to show you some of these things, but, you know, it's stuff which people can 00:01:58.420 |
And so hopefully this is the start of a long process for you that haven't done software 00:02:02.520 |
engineering before of becoming better software engineers. 00:02:09.200 |
So to remind you, we're trying to recreate fast AI and much of PyTorch from these foundations. 00:02:24.600 |
And today you'll actually see some bits, well, in fact, you've already seen some bits that 00:02:29.800 |
I think the next version of fast AI will have this new callback system, which I think is 00:02:34.480 |
And today we're going to be showing you some new previously unpublished research, which 00:02:39.400 |
will be finding its way into fast AI and maybe other libraries as well also. 00:02:45.680 |
So we're going to try and stick to, and we will stick to, using nothing but these foundations. 00:02:51.100 |
And we're working through developing a modern CNN model, and we've got to the point where 00:02:55.180 |
we've done our training loop at this point, and we've got a nice flexible training loop. 00:03:01.960 |
So from here, the rest of it, when I say we're going to finish out a modern CNN model, it's 00:03:09.320 |
not just going to be some basic getting by model, but we're actually going to endeavor 00:03:13.660 |
to get something that is approximately state-of-the-art on ImageNet in the next week or two. 00:03:22.840 |
And in our testing at this point, we're feeling pretty good about showing you some stuff that 00:03:29.080 |
maybe hasn't been seen before on ImageNet results. 00:03:32.960 |
So that's where we're going to try and head as a group. 00:03:36.200 |
And so these are some of the things that we're going to be covering to get there. 00:03:43.700 |
One of the things you might not have seen before in this section called optimization 00:03:48.480 |
The reason for this is that this was going to be some of the unpublished research we 00:03:51.520 |
were going to show you, which is a new optimization algorithm that we've been developing. 00:03:56.960 |
The framework is still going to be new, but actually the particular approach to using 00:04:01.260 |
it was published by Google two days ago, so we've kind of been scooped there. 00:04:07.000 |
So this is a cool paper, really great, and they introduce a new optimization algorithm 00:04:11.640 |
called LAM, which we'll be showing you how to implement it very easily. 00:04:18.600 |
And if you're wondering how we're able to do that so fast, it's because we've kind of 00:04:21.520 |
been working on the same thing ourselves for a few weeks now. 00:04:26.880 |
So then from next week, we'll start also developing a completely new fast.ai module called fastai.audio. 00:04:36.000 |
So you'll be seeing how to actually create modules and how to write Jupyter documentation 00:04:41.240 |
And we're going to be learning about audio, such as complex numbers and Fourier transforms, 00:04:45.480 |
which if you're like me at this point, you're going, oh, what, no, because I managed to 00:04:50.840 |
spend my life avoiding complex numbers and Fourier transforms on the whole. 00:04:58.680 |
It's actually not at all bad, or at least the bits we need to learn about are not at 00:05:03.000 |
And you'll totally get it even if you've never ever touched these before. 00:05:06.920 |
We'll be learning about audio formats and spectrograms, doing data augmentation and 00:05:12.560 |
things that aren't images, and some particular kinds of loss functions and architectures 00:05:19.800 |
And as much as anything, it'll just be a great kind of exercise in, OK, I've got some different 00:05:26.100 |
How do I build up all the bits I need to make it work? 00:05:31.000 |
Then we'll be looking at neural translation as a way to learn about sequence to sequence 00:05:34.620 |
with attention models, and then we'll be going deeper and deeper into attention models, looking 00:05:38.460 |
at transformer, and it's even more fantastic, descendant, transformer, Excel. 00:05:46.840 |
And then we'll wrap up our Python adventures with a deep dive into some really interesting 00:05:51.640 |
vision topics, which is going to require building some bigger models. 00:05:56.520 |
So we'll talk about how to build your own deep learning box, how to run big experiments 00:06:00.520 |
on AWS with a new library we've developed called Fast EC2. 00:06:06.560 |
And then we're going to see exactly what happened last course when we did that unit super resolution 00:06:11.240 |
image generation, what are some of the pieces there. 00:06:15.100 |
And we've actually got some really exciting new results to show you, which have been done 00:06:18.920 |
in collaboration with some really cool partners. 00:06:21.140 |
So I'm looking forward to showing you that to give you a tip. 00:06:25.240 |
Generative video models is what we're going to be looking at. 00:06:28.720 |
And then we'll be looking at some interesting different applications, device, cycleGAN, 00:06:40.960 |
So the Swift lessons are coming together nicely, really excited about them, and we'll be covering 00:06:48.400 |
as much of the same territory as we can, but obviously it'll be in Swift and it'll be in 00:06:52.060 |
only two lessons, so it won't be everything, but we'll try to give you enough of a taste 00:06:57.320 |
that you'll feel like you understand why Swift is important and how to get started with building 00:07:05.320 |
And maybe building out the whole thing in Swift will take the next 12 months. 00:07:13.920 |
So we're going to start today on zero 5A foundations. 00:07:19.880 |
And what we're going to do is we're going to recover some of the software engineering 00:07:23.680 |
and math basics that we were relying on last week and going into a little bit more detail. 00:07:28.760 |
Specifically, we'll be looking at callbacks and variants and a couple of other Python concepts 00:07:38.420 |
If you're familiar with those things, feel free to skip ahead if you're watching the 00:07:44.440 |
Callbacks, as I'm sure you've seen, are super important for fast AI, and in general they're 00:07:50.520 |
a really useful technique for software engineering. 00:07:55.040 |
And great for researchers because they allow you to build things that you can quickly adjust 00:07:59.800 |
and add things in and pull them out again, so really great for research as well. 00:08:10.120 |
So here's a function called f, which prints out hi. 00:08:14.760 |
And I'm going to create a button, and I'm going to create this button using IPY widgets, 00:08:21.400 |
which is a framework for creating GUI widgets in Python. 00:08:26.160 |
So if I run, if I say W, then it shows me a button, which says click B, and I can click 00:08:39.160 |
Well, what I need to do is I need to pass a function to the IPY widgets framework to 00:08:47.120 |
say please run this function when you click on this button. 00:08:52.360 |
So IPY widget doc says that there's an onclick method which can register a function to be 00:09:01.080 |
So let's try running that method, passing it f, my function. 00:09:06.560 |
OK, so now nothing happened, didn't run anything, but now if I click on here-- oh, hi, hi. 00:09:16.640 |
So what happened is I told W that when a click occurs, you should call back to my f function 00:09:30.080 |
So anybody who's done GUI programming will be extremely comfortable with this idea, and 00:09:34.280 |
if you haven't, this will be kind of mind bending. 00:09:55.120 |
We're passing the function itself to this method, and it says please call back to me 00:10:03.120 |
when something happens, and in this case, it's when I click. 00:10:09.400 |
And these kinds of functions, these kinds of callbacks that are used in a GUI in particular 00:10:16.520 |
framework when some event happens are often called events. 00:10:19.840 |
So if you've heard of events, they're a kind of callback, and then callbacks are a kind 00:10:26.800 |
I mean, they can be much more general than that, as you'll see, but it's basically a 00:10:30.520 |
way of passing in something to say call back to this when something happens. 00:10:34.800 |
Now, by the way, these widgets are really worth looking at if you're interested in building 00:10:44.160 |
Here's a great example from the Plotly documentation of the kinds of things you can create with 00:10:50.960 |
widgets, and it's not just for creating applications for others to use, but if you want to experiment 00:10:56.000 |
with different types of function or hyperparameters or explore some data you've collected, widgets 00:11:04.720 |
are a great way to do that, and as you can see, they're very, very easy to use. 00:11:11.600 |
In part one, you saw the image labeling stuff that was built with widgets like this. 00:11:21.040 |
So that's how you can use somebody else's callback. 00:11:26.680 |
So let's create a callback, and the event that it's going to call back on is after a 00:11:34.840 |
So let's create a function called slow_calculation, and it's going to do five calculations. 00:11:41.120 |
It's going to add i squared to a result, and then it's going to take a second to do it 00:11:47.440 |
So this is kind of something like an epoch of deep learning. 00:11:53.360 |
So if we call slow_calculation, then it's going to take five seconds to calculate the 00:12:02.200 |
So I'd really like to know how's it going, get some progress. 00:12:06.740 |
So we could take that and we could add something that you pass in a callback, and we just add 00:12:14.080 |
one line of code that says if there's a callback, then call it and pass in the epoch number. 00:12:21.880 |
So then we could create a function called show_progress that prints out awesome. 00:12:25.560 |
We finished epoch number, epoch, and look, it takes a parameter and we're passing a parameter, 00:12:31.840 |
so therefore we could now call slow_calculation and pass in show_progress, and it will call 00:12:43.680 |
So there's our starting point for our callback. 00:12:46.640 |
Now, what will tend to happen, you'll notice with stuff that we do in fast.ai, we'll start 00:12:52.840 |
somewhere like this that's, for many of you, is trivially easy, and at some point during 00:12:58.440 |
the next hour or two, you might reach a point where you're feeling totally lost. 00:13:04.200 |
And the trick is to go back, if you're watching the video, to the point where it was trivially 00:13:08.520 |
easy and figure out the bit where you suddenly noticed you were totally lost, and find the 00:13:12.200 |
bit in the middle where you kind of missed a bit, because we're going to just keep building 00:13:15.760 |
up from trivially easy stuff, just like we did with that matrix multiplication, right? 00:13:20.440 |
So we're going to gradually build up from here and look at more and more interesting 00:13:23.200 |
callbacks, but we're starting with this wonderfully short and simple line of code. 00:13:32.560 |
So rather than defining a function just for the purpose of using it once, we can actually 00:13:38.560 |
define the function at the point we use it using lambda notation. 00:13:43.000 |
So lambda notation is just another way of creating a function. 00:13:45.960 |
So rather than saying def, we say lambda, and then rather than putting in parentheses 00:13:50.640 |
the arguments, we put them before a colon, and then we list the thing you want to do. 00:13:59.280 |
It's just a convenience for times where you want to define the callback at the same time 00:14:05.080 |
that you use it, can make your code a little bit more concise. 00:14:10.540 |
What if you wanted to have something where you could define what exclamation to use in 00:14:20.060 |
We can't pass this show progress to slow calculation. 00:14:31.680 |
It tries to call back, and it calls, remember CB is now show progress, so it's passing show 00:14:36.560 |
progress and it's passing epoch as exclamation, and then epoch is missing. 00:14:42.480 |
We've called a function with two arguments with only one. 00:14:45.760 |
So we have to convert this into a function with only one argument. 00:14:49.960 |
So lambda O is a function with only one argument, and this function calls show progress with 00:14:59.720 |
So we've converted something with two arguments into something with one argument. 00:15:04.880 |
We might want to make it really easy to allow people to create different progress indicators 00:15:10.840 |
So we could create a function called make show progress that returns that lambda. 00:15:17.240 |
So now we could say make show progress, so we could do that here. 00:15:37.880 |
This is a little bit awkward, so generally you might see it done like this instead. 00:15:44.220 |
We define the function inside it, but this is basically just the same as our lambda, 00:15:53.000 |
So this is kind of interesting, because you might think of defining a function as being 00:15:57.680 |
like a declarative thing, that as soon as you define it, that now it's part of the thing 00:16:02.200 |
that's compiled, if you see your C++, that's how they work. 00:16:08.020 |
When you define a function, you're actually saying basically the same as this, which is 00:16:13.800 |
there's a variable with this name, which is a function. 00:16:19.640 |
And that's how come then we can actually take something that's passed to this function and 00:16:27.480 |
So this is actually, every time we call make show progress, it's going to create a new 00:16:32.160 |
function, underscore inner internally, with a different exclamation. 00:16:44.300 |
So this thing where you create a function that actually stores some information from 00:16:48.200 |
the external context, and like it can be different every time, that's called a closure. 00:16:52.240 |
So it's a concept you'll come across a lot, particularly if you're a JavaScript programmer. 00:16:57.300 |
So we could say, f2 equals make show progress, terrific. 00:17:13.640 |
So it actually remembers what exclamation you passed it. 00:17:24.760 |
Because it's so often that you want to take a function that takes two parameters and turn 00:17:30.360 |
it into a function that takes one parameter, Python and most languages have a way to do 00:17:35.400 |
that, which is called partial function application. 00:17:39.160 |
So the standard library functools has this thing called partial. 00:17:42.480 |
So if you take run call partial and you pass it a function, and then you pass in some arguments 00:17:48.120 |
for that function, it returns a new function, which that parameter is always a given. 00:17:57.720 |
So we could run it like that, or we could say f2 equals this partial function application. 00:18:06.400 |
And so if I say f2 shift tab, then you can see this is now a function that just takes 00:18:12.920 |
It just takes epoch because show progress took two parameters. 00:18:18.960 |
So this now takes one parameter, which is what we need. 00:18:21.760 |
So that's why we could pass that to as our callback. 00:18:26.760 |
So we've seen a lot of those techniques already last week. 00:18:33.640 |
Most of what we saw last week, though, did not use a function as a callback, but used 00:18:40.440 |
So we could do exactly the same thing, but pretty much any place you can use a closure, 00:18:47.320 |
Instead of storing it away inside the closure, some state, we can store our state, in this 00:18:52.520 |
case the exclamation, inside self, passing it into init. 00:18:57.720 |
So here's exactly the same thing as we saw before, but as a class. 00:19:02.960 |
Dundacall is a special magic name which will be called if you take an object, so in this 00:19:12.600 |
case a progress showing callback object, and call it with parentheses. 00:19:17.160 |
So if I go cb high, you see I'm taking that object and I'm treating it as if it's a function. 00:19:27.600 |
If you've used other languages like in C++, this is called a functor. 00:19:34.400 |
More generally it's called a callable in Python, so it's kind of something that a lot of languages 00:19:42.680 |
All right, so now we can use that as a callback, just like before. 00:19:50.920 |
All right, next thing to look at is, for our callback, is we're going to use star args and 00:19:59.640 |
star star kwargs, or otherwise known as quargs. 00:20:03.920 |
For those of you that don't know what these mean, let's create a function that takes star 00:20:08.200 |
args and star star kwargs and prints out args and kwargs. 00:20:15.320 |
So if I call that function, I could pass it 3a, thing1 equals hello, and you'll see that 00:20:22.080 |
all the things that are passed as positional arguments end up in a tuple called args, and 00:20:28.280 |
all the things passed as keyword arguments end up as a dictionary called quargs. 00:20:38.000 |
And so PyTorch uses that, for example, when you create an nn.sequential, it takes what 00:20:44.360 |
you pass in as a star args, right, you just pass them directly and it turns it into a tuple. 00:20:53.080 |
There's a few reasons we use it, but one of the common ways to use it is if you kind of 00:20:58.200 |
want to wrap some other class or object, then you can take a bunch of stuff as star star 00:21:04.760 |
kwargs and pass it off to some other functional object. 00:21:10.480 |
We're getting better at this, and we're removing a lot of the usages, but in the early days 00:21:13.980 |
of fast AI version 1, we actually were overusing quargs. 00:21:18.240 |
So quite often, we would kind of -- there would be a lot of stuff that wasn't obviously 00:21:24.640 |
in the parameter list of a function that ended up in quargs, and then we would pass it down 00:21:29.600 |
to, I don't know, the PyTorch data loader initializer or something. 00:21:33.880 |
And so we've been gradually removing those usages, because, like, it's mainly most helpful 00:21:39.440 |
for kind of quick and dirty throwing things together. 00:21:45.120 |
In R, they actually use an ellipsis for the same thing. 00:21:49.040 |
Quite often, it's hard to see what's going on. 00:21:51.000 |
You might have noticed in Matplotlib, a lot of times the thing you're trying to pass to 00:21:56.080 |
Matplotlib isn't there in the shift tab when you hit shift tab. 00:22:01.120 |
So there are some downsides to using it, but there are some places you really want to use 00:22:08.480 |
Let's take rewrite slow calculation, but this time we're going to allow the user to create 00:22:13.680 |
a callback that is called before the calculation occurs and after the calculation that occurs. 00:22:21.840 |
And the after calculation one's a bit tricky, because it's going to take two parameters 00:22:25.960 |
It's going to take both the epoch number and also what have we calculated so far. 00:22:32.920 |
So we can't just call CB parentheses I. We actually now have to assume that it's got 00:22:42.020 |
So here is, for example, a print step callback, which before calculation just says I'm about 00:22:49.240 |
to start and after calculation it says I'm done and there it's running. 00:22:55.800 |
So in this case, this callback didn't actually care about the epoch number or about the value. 00:23:02.120 |
And so, it just has star, star, star, quargs in both places. 00:23:07.880 |
It doesn't have to worry about exactly what's being passed in, because it's not using them. 00:23:11.880 |
So this is quite a good kind of use of this, is to basically create a function that's going 00:23:19.280 |
to be used somewhere else and you don't care about one or more of the parameters or you 00:23:25.360 |
So in this case, we don't get an error saying, because if we remove this, which looks like 00:23:32.400 |
we should be able to do because we don't use anything, but here's a problem. 00:23:37.600 |
It tried to call before calc I, and before calc doesn't take an I. 00:23:42.960 |
So if you put in both positional and keyword arguments, it'll always work everywhere. 00:23:54.400 |
So let's actually use epoch and value to print out those details. 00:24:00.180 |
So now you can see there it is printing them out. 00:24:02.000 |
And in this case, I've put star, star, quargs at the end, because maybe in the future, there'll 00:24:06.560 |
be some other things that are passed in and we want to make sure this doesn't break. 00:24:14.680 |
The next thing we might want to do with callbacks is to actually change something. 00:24:21.080 |
So a couple of things that we did last week, one was we wanted to be able to cancel out 00:24:29.000 |
The other thing we might want to do is actually change the value of something. 00:24:38.280 |
And also the other thing we might want to do is say, well, what if you don't want to 00:24:46.020 |
So we can actually check whether a callback is defined and only call it if it is. 00:24:52.120 |
And we could actually check the return value and then do something based on the return 00:24:57.480 |
So here's something which will cancel out of our loop if the value that's been calculated 00:25:13.480 |
What if you actually want to change the way the calculation is being done? 00:25:18.960 |
So we could even change the way the calculation is being done by taking our calculation function, 00:25:29.560 |
And so now the value that it's calculated is an attribute of the class. 00:25:36.260 |
And so now we could actually do something, a callback that reaches back inside the calculator 00:25:46.600 |
So this is going to double the result if it's less than three. 00:25:50.160 |
So if we run this, right, we now actually have to call this because it's a class, but 00:25:59.480 |
And so we're also taking advantage of this in the callbacks that we're using. 00:26:03.520 |
So this is kind of the ultimately flexible callback system. 00:26:10.040 |
And so you'll see in this case, we actually have to pass the calculator object to the 00:26:21.340 |
So the way we do that is we've defined a callback method here, which checks to see whether it's 00:26:30.280 |
defined and if it is, it grabs it and then it calls it passing in the calculator object 00:26:40.520 |
And so what we actually did last week is we didn't call this callback. 00:26:45.720 |
We called this dunder call, which means we were able to do it like this, okay? 00:27:02.080 |
I mean, we had so many callbacks being called that I felt the extra noise of giving it a 00:27:09.880 |
On the other hand, you might feel that calling a callback isn't something you expect dunder 00:27:13.800 |
call to do, in which case you can do it that way. 00:27:17.480 |
So there's pros and cons, neither is right or wrong. 00:27:36.440 |
And in Python, a dunder thingy is special somehow. 00:27:42.480 |
Most languages kind of let you define special behaviors. 00:27:46.520 |
For example, in C++, there's an operator keyword where if you define a function that says operator 00:27:52.920 |
something like plus, you're defining the plus operator. 00:27:58.400 |
So most languages tend to have special magic names you can give things that make something 00:28:04.480 |
a constructor or a destructor or an operator. 00:28:08.000 |
I like in Python that all of the magic names actually look magic. 00:28:13.440 |
They all look like that, which I think is actually a really good way to do it. 00:28:19.080 |
So the Python docs have a data model reference where they tell you about all these special 00:28:30.000 |
And you can go through and you can see what are all the special things you can get your 00:28:34.920 |
Like you can override how it behaves with less than or equal to or et cetera, et cetera. 00:28:41.740 |
There's a particular list I suggest you know, and this is the list. 00:28:45.440 |
So you can go to those docs and see what these things do because we use all of these in this 00:28:57.640 |
You pass in some number that you're going to add up. 00:29:00.320 |
And then when you add two things together, it will give you the result of adding them 00:29:06.960 |
And that is called dunder add because that's what happens when you see plus. 00:29:11.480 |
This is called dunder init because this is what happens when an object gets constructed. 00:29:15.720 |
And this is called dunder repre because this is what gets called when you print it out. 00:29:19.320 |
So now I can create a one adder and a two adder and I can plus them together and I can 00:29:26.120 |
So that's kind of an example of how these special dunder methods work. 00:29:42.920 |
There's another bit of code stuff that I wanted to show you, which you'll need to be doing 00:29:47.280 |
a lot of, which is you need to be really good at browsing source code. 00:29:53.120 |
If you're going to be contributing stuff to fast AI or to the fast AI for Swift for TensorFlow 00:30:01.640 |
or just building your own more complex projects, you need to be able to jump around source 00:30:08.560 |
Or even just to find out how PyTorch does something, if you're doing some research, 00:30:12.120 |
you need to really understand what's going on under the hood. 00:30:16.920 |
This is a list of things you should know how to do in your editor of choice. 00:30:22.080 |
Any editor that can't do all of these things is worth replacing with one that can. 00:30:26.900 |
Most editors can do these things, emacs can, visual studio code can, sublime can, and the 00:30:35.080 |
editor I use most of the time, vim, can as well. 00:30:42.160 |
On the forums there are already some topics saying how to do these things in other editors. 00:30:47.440 |
If you don't find one that seems any good, feel free to create your own topic if you've 00:30:52.000 |
got some tips about how to do these things or other useful things in your editor of choice. 00:30:56.400 |
I'm going to show you in vim for no particular reason, just because I use vim. 00:31:08.640 |
One of the things I like about vim is I can use it in a terminal, which I find super helpful 00:31:12.860 |
because I'm working on remote machines all the time and I like to be at least as productive 00:31:24.760 |
The first thing you should be able to do is to jump to a symbol. 00:31:28.320 |
A symbol would be like a class or a function or something like that. 00:31:33.720 |
For example, I might want to be able to jump straight to the definition of createCNN, but 00:31:45.960 |
I can't quite remember the name of the function, createCNN. 00:31:49.680 |
I would go colon tag, create, I'm pretty sure it's create underscore something, and then 00:31:55.360 |
I'd press tab a few times and it would loop through, there it is, createCNN, and then 00:32:02.480 |
That's the first thing that your editor should do, is it should make it easy to jump to a 00:32:06.680 |
tag even if you can't remember exactly what it is. 00:32:09.720 |
The second thing it should do is that you should be able to click on something like 00:32:13.880 |
CNN learner and hit a button, which in vim's case is control right square bracket, and 00:32:19.320 |
it should take you to the definition of that thing. 00:32:22.060 |
Okay, let's create this CNN learner, what's this thing called a data bunch, right square 00:32:30.320 |
You'll also see that my vim is folding things, classes, and functions to make it easier for 00:32:40.840 |
In some editors, this is called outlining, in some it's called folding, most editors 00:32:48.600 |
Then there should be a way to go back to where you were before, in vim that's control T for 00:32:56.280 |
So here's my CNN learner, here's my createCNN, and so you can see in this way it makes it 00:33:02.360 |
nice and easy to kind of jump around a little bit. 00:33:09.200 |
Something I find super helpful is to also be able to jump into the source code of libraries 00:33:14.520 |
So for example, here's chiming normal, so I've got my vim configured, so if I hit control 00:33:20.040 |
right square bracket on that, it takes me to the definition of chiming normal in the PyTorch 00:33:26.720 |
And I find docstrings kind of annoying, so I have mine folded up by default, but I can 00:33:32.640 |
If you use vim, the way to do that is to add additional tags for any packages that you 00:33:45.120 |
I'm sure most editors will do something pretty similar. 00:33:48.600 |
Now that I've seen how chiming normal works, I can use the same control T to jump back 00:33:59.800 |
Then the only other thing that's particularly important to know how to do is to just do 00:34:05.360 |
So let's say I wanted to find all the places that I've used Lambda, since we talked about 00:34:10.600 |
Lambda today, I have a particular thing I use called ACK, I can say ACK Lambda, and here 00:34:19.360 |
is a list of all of the places I've used Lambda, and I could click on one, and it will jump 00:34:30.920 |
Again most editors should do something like that for you. 00:34:33.860 |
So I find with that basic set of stuff, you should be able to get around pretty well. 00:34:38.380 |
If you're a professional software engineer, I know you know all this. 00:34:41.980 |
If you're not, hopefully you're feeling pretty excited right now to discover that editors 00:34:47.440 |
And so sometimes people will jump on our GitHub and say, I don't know how to find out what 00:34:54.000 |
a function is that you're calling because you don't list all your imports at the top 00:34:57.960 |
of the screen, but this is a great place where you should be using your editor to tell you. 00:35:03.840 |
And in fact, one place that GUI editors can be pretty good is often if you actually just 00:35:08.560 |
point at something, they will pop up something saying exactly where is that symbol coming 00:35:14.080 |
I don't have that set up in VIM, so I just have to hit the right square bracket to see 00:35:21.280 |
Okay, so that's some tips about stuff that you should be able to do when you're browsing 00:35:27.680 |
source code, and if you don't know how to do it yet, please Google or look at the forums 00:35:35.720 |
Something else we were looking at a lot last week and you need to know pretty well is variance. 00:35:40.880 |
So just a quick refresher on what variance is, or for those of you who haven't studied 00:35:48.720 |
Variance is the average of how far away each data point is from the mean. 00:35:53.140 |
So here's some data, right, and here's the mean of that data. 00:35:57.880 |
And so the average distance for each data point from the mean is T, the data points, 00:36:11.640 |
The mean is defined as the thing which is in the middle, right? 00:36:17.800 |
So we need to do something else that doesn't have the positives and negatives cancel out. 00:36:25.680 |
One is by squaring each thing before we take the mean, like so. 00:36:32.500 |
The other is taking the absolute value of each thing. 00:36:36.600 |
So turning all the negatives and positives before we take the mean. 00:36:40.300 |
So they're both common fixes for this problem. 00:36:43.780 |
You can see though the first is now on a totally different scale, right? 00:36:47.240 |
The numbers were like 1, 2, 4, 8, 8, and this is 47. 00:36:52.920 |
So after we've squared, we then take the square root at the end. 00:36:57.000 |
So here are two numbers that represent how far things are away from the mean. 00:37:04.480 |
If everything's pretty close to similar to each other, those two numbers will be small. 00:37:08.960 |
If they're wildly different to each other, those two numbers will be big. 00:37:14.240 |
This one here is called the standard deviation. 00:37:17.880 |
And it's defined as the square root of this one here which is called the variance. 00:37:23.000 |
And this one here is called the mean absolute deviation. 00:37:27.200 |
You could replace this M made with various other things like median, for example. 00:37:42.400 |
So in the case of the one where we took a square in the middle of it, this number is 00:37:47.120 |
higher because the square takes that 18 and makes it much bigger. 00:37:51.920 |
So in other words, standard deviation is more sensitive to outliers than mean absolute deviation. 00:37:58.880 |
So for that reason, the mean absolute deviation is very often the thing you want to be using 00:38:07.320 |
because in machine learning outliers are more of a problem than to help a lot of the time. 00:38:13.840 |
But mathematicians and statisticians tend to work with standard deviation rather than 00:38:18.200 |
mean absolute deviation because it makes their math proofs easier and that's the only reason. 00:38:22.800 |
They'll tell you otherwise, but that's the only reason. 00:38:25.440 |
So the mean absolute deviation is really underused and it actually is a really great measure 00:38:32.720 |
to use and you should definitely get used to it. 00:38:40.360 |
There's a lot of places where I kind of notice that replacing things involving squares with 00:38:45.360 |
things involving absolute values, the absolute value things just often work better. 00:38:49.920 |
It's a good tip to remember that there's this kind of long-held assumption. 00:38:54.920 |
We have to use a squared thing everywhere, but it actually often doesn't work as well. 00:39:15.920 |
This written in math looks like this and it's another way of writing the variance. 00:39:21.960 |
It's important because it's super handy and it's super handy because in this one here 00:39:28.280 |
we have to go through the whole data set once. 00:39:30.680 |
To calculate the mean of the data and then a second time to get the squares of the differences. 00:39:39.480 |
This is really nice because in this case, we only have to keep track of two numbers, 00:39:45.760 |
the squares of the data and the sum of the data and as you'll see shortly, this kind 00:39:53.800 |
of way of doing things is generally therefore just easier to work with. 00:39:59.360 |
So even though this is kind of the definition of the variance that makes intuitive sense, 00:40:08.280 |
this is the definition of variance that you normally want to implement. 00:40:16.360 |
The other thing we see quite a bit is covariance and correlation. 00:40:21.480 |
So if we take our same data set, let's now create a second data set which is double t 00:40:31.800 |
times a little bit of random noise, so here's that plotted. 00:40:39.640 |
Let's now look at the difference between each item of t and its mean and multiply it by 00:40:47.520 |
each item of u and its mean, so there's those values and let's look at the mean of that. 00:41:00.440 |
So it's the average of the difference of how far away the x value is from the mean of the 00:41:05.920 |
x value is, the x's, multiplied by each difference between the y value and how far away from 00:41:16.940 |
Let's compare this number to the same number calculated with this data set, where this 00:41:22.320 |
data set is just some random numbers compared to v. 00:41:28.840 |
And let's now calculate the exact same product, the exact same mean. 00:41:38.880 |
So if you think about it, if these are kind of all lined up nicely, then every time it's 00:41:45.880 |
higher than the average on the x-axis, it's also higher than the average on the y-axis. 00:41:52.520 |
So you have two big positive numbers and vice versa, two big negative numbers. 00:41:58.180 |
So in either case, you end up with, when you multiply them together, a big positive number. 00:42:03.280 |
So this is adding up a whole bunch of big positive numbers. 00:42:07.660 |
So in other words, this number tells you how much these two things vary in the same way, 00:42:20.240 |
And so this one, when one is big, the other's not necessarily big. 00:42:24.400 |
When one is small, the other's not necessarily very small. 00:42:29.120 |
So this is the covariance, and you can also calculate it in this way, which might look 00:42:39.960 |
somewhat similar to what we saw before with our different variance calculation. 00:42:44.840 |
And again, this is kind of the easier way to use it. 00:42:51.160 |
So as I say here, from now on, I don't want you to ever look at an equation or type in 00:42:59.040 |
an equation in LaTeX without typing it in Python, calculating some values and plotting them. 00:43:05.880 |
Because this is the only way we get a sense in here of what these things mean. 00:43:12.800 |
And so in this case, we're going to take our covariance and we're going to divide it by 00:43:21.720 |
And this gives us a new number, and this is called correlation, or more specifically Pearson 00:43:30.000 |
So we don't cover covariance and Pearson correlation coefficient too much in the course, but it 00:43:34.720 |
is one of these things which it's often nice to see how things vary. 00:43:38.480 |
But remember, it's telling you really about how things vary linearly, right? 00:43:42.800 |
So if you want to know how things vary non-linearly, you have to create something called a neural 00:43:50.520 |
But it's kind of interesting to see also how variance and covariance, you can see they're 00:43:55.200 |
much the same thing, you know, where else one of them, in fact, you basically you can think 00:44:01.120 |
One of them is E of X squared, in other words, X and X are kind of the same thing, it's E 00:44:07.320 |
of X times X, where else this is two different things, E of X times Y, right? 00:44:11.960 |
And so rather than having here, we had E of X squared here and E of X squared here. 00:44:19.960 |
If you replace the second X with a Y, you get that and you get that. 00:44:28.760 |
And then again here, if X and X are the same, then this is just sigma squared, right? 00:44:36.040 |
So the last thing I want to quickly talk about a little bit more is Softmax. 00:44:43.540 |
This was our final log Softmax definition from the other day, and this is the formula, 00:44:55.680 |
And this is our cross entropy loss, remember. 00:45:00.040 |
So these are all important concepts we're going to be using a lot. 00:45:02.040 |
So I just wanted to kind of clarify something that a lot of researchers that are published 00:45:06.960 |
in big name conferences get wrong, which is when should you and shouldn't you use Softmax. 00:45:13.040 |
So this is our Softmax page from our entropy example spreadsheet, where we were looking 00:45:23.200 |
This is just the activations that we might have gotten out of the last layer of our model. 00:45:27.720 |
And this is just E to the power of each of those activations. 00:45:31.040 |
And this is just the sum of all of those E to the power ofs. 00:45:34.360 |
And then this is E to the power of divided by the sum, which is Softmax. 00:45:43.040 |
This is like some image number one that gave these activations. 00:45:46.760 |
Here's some other image number two, which gave these activations, which are very different 00:46:01.680 |
Well, it's happened because in every case, the E to the power of this divided by the 00:46:09.060 |
sum of the E to the power ofs ended up in the same ratio. 00:46:13.920 |
So in other words, even though fish is only 0.63 here, but it's true here, once you take 00:46:20.640 |
E to the power of, it's the same percentage of the sum, right? 00:46:30.120 |
Well, in this model, it seems like being a fish is associated with having an activation 00:46:47.240 |
But what's actually happened is there's no cats or dogs or planes or fishes or buildings. 00:46:54.800 |
So in the end then, because Softmax has to add to 1, it has to pick something. 00:47:09.280 |
And what's more is because we do this E to the power of, the thing that's a little bit 00:47:13.960 |
higher, it pushes much higher because it's exponential, right? 00:47:17.760 |
So Softmax likes to pick one thing and make it big. 00:47:24.480 |
So the problem here is that I would guess that maybe image 2 doesn't have any of these 00:47:31.560 |
So it said, oh, I'm pretty sure there's a fish. 00:47:34.040 |
Or maybe the problem actually is that this image had a cat and a fish and a building. 00:47:43.840 |
But again, because Softmax, they have to add to 1, and one of them is going to be much 00:47:49.960 |
So I don't know exactly which of these happened, but it's definitely not true that they both 00:47:55.880 |
have an equal probability of having a fish in them. 00:48:00.060 |
So to put this another way, Softmax is a terrible idea unless you know that every one of your, 00:48:06.840 |
if you're doing image recognition, every one of your images, or if you're doing audio or 00:48:11.600 |
tabular or whatever, every one of your items has one, no more than one, and definitely 00:48:17.920 |
at least one example of the thing you care about in it. 00:48:22.900 |
Because if it doesn't have any of cat, dog, plane, fish, or building, it's still going 00:48:27.160 |
to tell you with high probability that it has one of those things. 00:48:31.600 |
Even if it has more than just one of cat, dog, plane, fish, or building, it'll pick 00:48:35.080 |
one of them until you're pretty sure it's got that one. 00:48:38.160 |
So what do you do if there could be no things or there could be more than one of these things? 00:48:45.840 |
Well, instead you use binomial, regular old binomial, which is e to the x divided by one 00:48:56.000 |
It's exactly the same as Softmax if your two categories are, has the thing and doesn't 00:49:01.960 |
have the thing because they're like p and one minus p. 00:49:05.000 |
So you can convince yourself of that during the week. 00:49:07.840 |
So in this case, let's take image one and let's go 1.02 divided by one plus 1.02. 00:49:20.600 |
And then let's do the same thing for image two. 00:49:23.360 |
And you can see now the numbers are different, as we would hope. 00:49:28.000 |
And so for image one, it's kind of saying, oh, it looks like there might be a cat in 00:49:33.960 |
it if we assume 0.5 is a cut off, there's probably a fish in it, and it seems likely 00:49:44.080 |
Whereas for image two, it's saying, I don't think there's anything in there, but maybe 00:49:51.800 |
And so when you think about it, like for image recognition, probably most of the time you 00:50:05.040 |
And ImageNet was specifically curated, so it only has one of the classes in ImageNet in 00:50:11.400 |
it, and it always has one of those classes in it. 00:50:17.200 |
An alternative, if you want to be able to handle the what if none of these classes are 00:50:22.480 |
in it case, is you could create another category called background or doesn't exist or null 00:50:31.240 |
So let's say you created this missing category. 00:50:33.120 |
So there's six, cat, dog, plane, fish, building or missing, nothing. 00:50:41.540 |
A lot of researchers have tried that, but it's actually a terrible idea and it doesn't 00:50:48.080 |
And the reason it doesn't work is because to be able to successfully predict missing, 00:50:54.320 |
the penultimate layer activations have to have the features in it that is what a not 00:51:00.800 |
cat, dog, plane, fish, fish or building looks like. 00:51:04.200 |
So how do you describe a not cat, dog, plane, fish or building? 00:51:08.480 |
What are the things that would activate high? 00:51:17.360 |
There is no set of features that when they're all high is clearly a not cat, dog, plane, 00:51:25.840 |
So a neural net can kind of try to hack its way around it by creating a negative model 00:51:33.060 |
of every other single type and create a kind of not one of any of those other things. 00:51:39.720 |
Whereas creating simply a binomial does it or doesn't it have this for every one of the 00:51:54.960 |
So lots and lots of well-regarded academic papers make this mistake. 00:52:06.640 |
And if you do come across an academic paper that's using softmax and you think, does that 00:52:15.800 |
Try replicating it without softmax and you may just find you get a better result. 00:52:20.320 |
An example of somewhere where softmax is obviously a good idea or something like softmax is obviously 00:52:32.600 |
It's definitely not more than one word, right? 00:52:36.040 |
So I'm not saying softmax is always a dumb idea, but it's often a dumb idea. 00:52:45.240 |
Okay. Next thing I want to do is I want to build a learning rate finder. 00:52:53.000 |
And to build a learning rate finder, we need to use this test callback kind of idea, this 00:53:05.120 |
Problem is, as you may have noticed, this I want to stop somewhere callback wasn't working 00:53:12.560 |
in our new refactoring where we created this runner class. 00:53:16.080 |
And the reason it wasn't working is because we were turning true to mean cancel. 00:53:23.000 |
But even after we do that, it still goes on to do the next batch. 00:53:27.440 |
And even if we set self.stop, even after we do that, it'll go on to the next epoch. 00:53:32.080 |
So to like actually stop it, you would have to return false from every single callback 00:53:36.840 |
that's checked to make sure it like really stops or you would have to add something that 00:53:47.920 |
And it's also not as flexible as we would like. 00:53:50.840 |
So what I want to show you today is something which I think is really interesting, which 00:53:54.480 |
is using the idea of exceptions as a kind of control flow statement. 00:54:03.800 |
You may have think of exceptions as just being a way of handling errors. 00:54:07.600 |
But actually exceptions are a very versatile way of writing very neat code that will be 00:54:18.960 |
So let's start by just grabbing our MNIST data set as before and creating our data bunch 00:54:23.960 |
And here's our callback as before and our train eval callback as before. 00:54:29.080 |
But there's a couple of things I'm going to do differently. 00:54:32.800 |
The first is, and this is a bit unrelated, but I think it's a useful refactoring, is 00:54:36.760 |
previously inside runner end under call, we went through each callback in order, and we 00:54:45.840 |
checked to see whether that particular method exists in that callback. 00:54:51.480 |
And if it was, we called it and checked whether it returns true or false. 00:54:58.720 |
It actually makes more sense for this to be inside the callback class. 00:55:07.960 |
Because by putting it into the callback class, the callback class is now taking a -- has a 00:55:13.320 |
dunder call which takes a callback name, and it can do this stuff. 00:55:18.440 |
And what it means is that now your users who want to create their own callbacks, let's 00:55:25.400 |
say they wanted to create a callback that printed out the callback name for every callback every 00:55:32.600 |
Or let's say they wanted to add a break point, like a set trace that happened every time 00:55:40.200 |
They could now create their own inherit from callback and actually replace dunder call itself 00:55:48.400 |
with something that added this behavior they want. 00:55:52.000 |
Or they could add something that looks at three or four different callback names and 00:55:57.660 |
So this is like a nice little extra piece of flexibility. 00:56:02.280 |
It's not the key thing I wanted to show you, but it's an example of a nice little refactoring. 00:56:07.440 |
The key thing I wanted to show you is that I've created three new types of exception. 00:56:13.460 |
So an exception in Python is just a class that inherits from exception. 00:56:18.640 |
And most of the time you don't have to give it any other behavior. 00:56:21.880 |
So to create a class that's just like its parent, but it just has a new name and no 00:56:28.600 |
So pass means this has all the same attributes and everything as the parent. 00:56:40.680 |
Cancel train exception, cancel epoch exception, cancel batch exception. 00:56:45.360 |
The idea is that we're going to let people's callbacks cancel anything, you know, cancel 00:56:51.280 |
So if they cancel a batch, it will keep going with the next batch, but not finish this one. 00:56:56.560 |
If they cancel an epoch, it will keep going with the next epoch that will cancel this 00:57:02.080 |
Cancel train will stop the training altogether. 00:57:12.000 |
But now fit, we already had try finally to make sure that our after fit and remove learner 00:57:20.120 |
happened, even if there's an exception, I've added one line of code. 00:57:27.320 |
And if that happens, then optionally it could call some after cancel train callback. 00:57:35.920 |
It just keeps on going to the finally block and will elegantly and happily finish up. 00:57:46.680 |
So now our test callback can after step, we'll just print out what step we're up to. 00:57:53.840 |
And if it's greater than or equal to 10, we will raise cancel train exception. 00:57:58.640 |
And so now when we say run dot fit, it just prints out up to 10 and stops. 00:58:07.180 |
This is using exception as a control flow technique, not as an error handling technique. 00:58:13.280 |
So another example, inside all batches, I go through all my batches in a try block, 00:58:23.000 |
except if there's a cancel epoch exception, in which case I optionally call an after cancel 00:58:28.040 |
epoch callback and then continue to the next epoch. 00:58:34.080 |
Or inside one batch, I try to do all this stuff for a batch, except if there's a cancel 00:58:39.360 |
batch exception, I will optionally call the after cancel batch callback, and then continue 00:58:47.520 |
So this is like a super neat way that we've allowed any callback writer to stop any one 00:58:59.800 |
So in this case, we're using cancel train exception to stop training. 00:59:05.580 |
So we can now use that to create a learning rate finder. 00:59:10.320 |
So the basic approach of the learning rate finder is that there's something in begin 00:59:14.500 |
batch which, just like our parameter scheduler, is using exponential curve to set the learning 00:59:28.000 |
And then after each step, it checks to see whether we've done more than the maximum number 00:59:32.920 |
of iterations, which is defaulting to 100, or whether the loss is much worse than the 00:59:42.240 |
And if either of those happens, we will raise cancel train exception. 00:59:46.600 |
So to be clear, this neat exception-based approach to control flow isn't being used in the fast 00:59:52.160 |
AI version one at the moment, but it's very likely that fast AI 1.1 or 2 will switch to 00:59:58.440 |
this approach because it's just so much more convenient and flexible. 01:00:02.880 |
And then assuming we haven't canceled, just see if the loss is better than our best loss, 01:00:06.520 |
and if it is, then set best loss to the loss. 01:00:13.280 |
So now we can create a learner, we can add the LR find, we can fit, and you can see that 01:00:21.600 |
it only does less than 100 epochs before it stops because the loss got a lot worse. 01:00:29.720 |
And so now we know that we want something about there for our learning rate. 01:00:42.080 |
So let's go ahead and create a CNN, and specifically a scooter CNN. 01:00:50.900 |
So we'll keep doing the same stuff we've been doing, get our MNIST data, normalize it. 01:00:57.680 |
Here's a nice little refactoring because we very often want to normalize with this dataset 01:01:04.280 |
and normalize both datasets using this dataset's mean and standard deviation. 01:01:08.180 |
Let's create a function called normalize2, which does that and returns the normalized 01:01:12.440 |
training set and the normalized validation set. 01:01:15.640 |
So we can now use that, make sure that it's behaved properly, that looks good. 01:01:22.800 |
Create our data bunch, and so now we're going to create a CNN model, and the CNN is just 01:01:29.280 |
a sequential model that contains a bunch of stride two convolutions. 01:01:35.920 |
And remember the input's 28 by 28, so after the first it'll be 14 by 14, then 7 by 7, 01:01:41.520 |
then 4 by 4, then 2 by 2, then we'll do our average pooling, flatten it, and a linear 01:01:51.540 |
Now remember our original data is vectors of length 768, they're not 28 by 28, so we 01:01:58.000 |
need to do a Bax.view one channel by 28 by 28 because that's what nn.com2d expects, and 01:02:08.720 |
then minus one, the batch size remains whatever it was before. 01:02:12.720 |
So we need to somehow include this function in our nn.sequential, PyTorch doesn't support 01:02:19.280 |
that by default, we could write our own class with a forward function, but nn.sequential 01:02:25.080 |
is convenient for lots of ways, it has a nice representation, you can do all kinds of customizations 01:02:30.100 |
with it, so instead we create a layer called lambda, an nn.module called lambda, it just 01:02:38.400 |
pasted a function, and the forward is simply to call that function. 01:02:44.720 |
And so now we can say lambda, MNIST resize, and that will cause that function to be called. 01:02:51.840 |
And here lambda flatten simply calls this function to be called, which removes that one 01:02:58.400 |
comma one axis at the end after the adaptive average pooling. 01:03:03.020 |
So now we've got a CNN model, we can grab our callback functions and our optimizer and 01:03:10.120 |
our runner and we can run it, and six seconds later we get back one epochs result. 01:03:18.880 |
So that's at this point now getting a bit slow, so let's make it faster. 01:03:27.940 |
So we need to do two things, we need to put the model on the GPU, which specifically means 01:03:38.840 |
So remember a model contains two kinds of numbers, parameters, they're the things that 01:03:42.240 |
you're updating, the things that it stores, and there's the activations, there's the things 01:03:48.000 |
So it's the parameters that we need to actually put on the GPU. 01:03:51.900 |
And the inputs to the model and the loss function, so in other words the things that come out 01:03:56.240 |
of the data loader we need to put those on the GPU. 01:04:02.800 |
So here's a CUDA callback, when you initialize it you pass it a device and then when you 01:04:09.240 |
begin fitting you move the model to that device. 01:04:13.920 |
So model.2.2 is part of PyTorch, it moves something with parameters or a tensor to a 01:04:22.760 |
device and you can create a device by calling torch.device, pass it the string CUDA, and 01:04:29.240 |
whatever GPU number you want to use, if you only have one GPU, it's device0. 01:04:35.600 |
Then when we begin a batch, let's go back and look at our runner. 01:04:43.160 |
When we begin a batch, we've put xbatch and ybatch inside self.xb and self.yb. 01:04:54.520 |
So let's set the runner's xb and the runner's yb to whatever they were before, but move 01:05:09.360 |
This is kind of flexible because we can put things on any device we want. 01:05:14.000 |
Maybe more easily is just to call this once, which is torch.cuda.setdevice, and you don't 01:05:19.560 |
even need to do this if you've only got one GPU. 01:05:22.200 |
And then everything by default will now be sent to that device. 01:05:25.280 |
And then instead of saying .2device, we can just say .cuda. 01:05:29.800 |
And so since we're doing pretty much everything with just one GPU for this course, this is 01:05:42.240 |
So let's add that to our callback functions, grab our model and our runner and fit, and 01:05:47.480 |
now we can do three epochs in five seconds versus one epoch in six seconds. 01:05:53.600 |
And for a much deeper model, it'll be dozens of times faster. 01:05:57.920 |
So this is literally all we need to use CUDA. 01:06:04.520 |
Now we want to make it easier to create different kinds of architectures, make things a bit 01:06:10.280 |
So the first thing we should do is recognize that we go conv.value a lot. 01:06:14.800 |
So let's pop that into a function called conv2d that just goes conv.value. 01:06:19.220 |
Since we use a kernel size of three and a stride of two in this MNIST model a lot, let's make 01:06:26.080 |
Also this model we can't reuse for anything except MNIST because it has a MNIST resize 01:06:37.960 |
So if we're going to remove that, something else is going to have to do the resizing. 01:06:41.880 |
And of course the answer to that is a callback. 01:06:45.600 |
So here's a callback which transforms the independent variable, the x, for a batch. 01:06:55.000 |
And so you pass it some transformation function, which it stores away. 01:07:00.000 |
And then begin batch simply replaces the batch with the result of that transformation function. 01:07:04.960 |
So now we can simply append another callback, which is the partial function application 01:07:16.160 |
And this function is just to view something at one by 28 by 28. 01:07:23.040 |
And you can see here we've used the trick we saw earlier of using underscore inner to 01:07:31.560 |
So this is something which creates a new view function that views it in this size. 01:07:37.620 |
So for those of you that aren't that comfortable with closures and partial function application, 01:07:44.200 |
this is a great piece of code to study, experiment, make sure that you feel comfortable with it. 01:07:52.160 |
So using this approach, we now have the MNIST view resizing as a callback, which means we 01:08:02.580 |
So now we can create a generic getCNN model function that returns a sequential model containing 01:08:10.160 |
some arbitrary set of layers, containing some arbitrary set of filters. 01:08:16.500 |
So we're going to say, OK, this is the number of filters I have per layer, 8, 16, 32, 32. 01:08:28.080 |
And the last few layers is the average pooling, flattening, and the linear layer. 01:08:34.680 |
The first few layers is for every one of those filters, length of the filters. 01:08:38.800 |
It's a conv2D from that filter to the next one. 01:08:51.160 |
It's a kernel size of 5 for the first layer, or 3 otherwise. 01:09:00.800 |
Well, the number of filters we had for the first layer was 8. 01:09:09.480 |
And that's a pretty reasonable starting point for a small model to start with 8 filters. 01:09:13.680 |
And remember, our image had a single channel. 01:09:18.560 |
And imagine if we had a single channel, and we were using 3 by 3 filters. 01:09:24.040 |
So as that convolution kernel scrolls through the image, at each point in time, it's looking 01:09:37.240 |
So in total, there's 9 input activations that it's looking at. 01:09:45.100 |
And then it spits those into a dot product with-- sorry, it spits those into 8 dot products, 01:09:55.480 |
so a matrix multiplication, I should say, 8 by 9. 01:10:04.840 |
And out of that will come a vector of length 8. 01:10:21.680 |
Because we started with 9 numbers, and we ended it with 8 numbers. 01:10:26.800 |
So all we're really doing is just reordering them. 01:10:29.980 |
It's not really doing any useful computation. 01:10:32.560 |
So there's no point making your first layer basically just shuffle the numbers into a 01:10:42.120 |
So what you'll find happens in, for example, most ImageNet models. 01:10:46.760 |
Most ImageNet models are a little bit different, because they have three channels. 01:10:59.080 |
But it's still kind of like-- quite often with ImageNet models, the first layer will 01:11:07.240 |
So going from 27 to 32 is literally losing information. 01:11:11.320 |
So most ImageNet models, they actually make the first layer 7 by 7, not 3 by 3. 01:11:19.560 |
And so for a similar reason, we're going to make our first layer 5 by 5. 01:11:28.840 |
So this is the kind of things that you want to be thinking about when you're designing 01:11:33.400 |
or reviewing an architecture, is like how many numbers are actually going into that 01:11:37.760 |
little dot product that happens inside your CNN kernel. 01:11:44.600 |
So that's something which can give us a CNN model. 01:11:48.560 |
So let's pop it all together into a little function that just grabs an optimization function, 01:11:52.760 |
grabs an optimizer, grabs a learner, grabs a runner. 01:11:56.560 |
And at this point, if you can't remember what any of these things does, remember, we've 01:12:08.400 |
And so let's look if we say getCNNModel, passing in 8, 16, 32, 32. 01:12:21.120 |
They all have a stride of 2, and then a linear layer, and then train. 01:12:26.200 |
So at this point, we've got a fairly general simple CNN creator that we can fit. 01:12:33.800 |
And so let's try to find out what's going on inside. 01:12:43.800 |
Well, we really want to see what's going on inside. 01:12:46.880 |
We know already that different ways of initializing changes the variance of different layers. 01:12:53.600 |
How do we find out if it's saturating somewhere, if it's too small, if it's too big, what's 01:13:00.580 |
So what if we replace nn.Sequential with our own sequential model class? 01:13:07.960 |
And if you remember back, we've already built our own sequential model class before, and 01:13:11.800 |
it just had these two lines of code, plus return. 01:13:14.860 |
So let's keep the same two lines of code, but also add two more lines of code that grabs 01:13:19.880 |
the mean of the outputs and the standard deviation of the outputs, and saves them away inside 01:13:28.920 |
So here's a list for every layer, for means, and a list for every layer for standard deviations. 01:13:37.280 |
So let's calculate the mean of standard deviations, pop them inside those two lists. 01:13:42.580 |
And so now it's a sequential model that also keeps track of what's going on, the telemetry 01:13:50.100 |
So we can now create it in the same way as usual, fit it in the same way as usual, but 01:13:56.060 |
It has an act means, and it acts standard deviations. 01:13:59.900 |
So let's plot the act means for every one of those lists that we had. 01:14:15.520 |
What happens early in training is every layer, the means get exponentially bigger until they 01:14:23.520 |
suddenly collapse, and then it happens again, and it suddenly collapses, and it happens 01:14:28.260 |
again, and it suddenly collapses until eventually, it kind of starts training. 01:14:36.440 |
So you might think, well, it's eventually training, so isn't this okay? 01:14:44.320 |
But my concern would be this thing where it kind of falls off a cliff. 01:14:54.520 |
Are we sure that all of them are getting back into reasonable places, or is it just that 01:14:59.480 |
a few of them have got back into a reasonable place? 01:15:01.960 |
Maybe the vast majority of them have zero gradients at this point. 01:15:07.060 |
It seems very likely that this awful training profile early in training is leaving our model 01:15:16.120 |
That's my guess, and we're going to check it to see later. 01:15:18.520 |
But for now, we're just going to say, let's try to make this not happen. 01:15:23.680 |
And let's also look at the standard deviations, and you see exactly the same thing. 01:15:32.440 |
So let's look at just the first 10 means, and they all look okay. 01:15:37.640 |
They're all pretty close-ish to zero, which is about what we want. 01:15:41.840 |
But more importantly, let's look at the standard deviations for the first 10 batches, and this 01:15:47.960 |
The first layer has a standard deviation not too far away from one, but then, not surprisingly, 01:16:00.760 |
As we would expect, because the first layer is less than one, the following layers are 01:16:05.240 |
getting exponentially further away from one, until the last layer is really close to zero. 01:16:14.600 |
So now we can kind of see what was going on here, is that our final layers were getting 01:16:21.720 |
basically no activations, they were basically getting no gradients. 01:16:25.240 |
So gradually, it was moving into spaces where they actually at least had some gradient. 01:16:30.340 |
But by the time they kind of got there, the gradient was so fast that they were kind of 01:16:36.760 |
falling off a cliff and having to start again. 01:16:40.760 |
So this is the thing we're going to try and fix. 01:16:43.360 |
And we think we already know how to, we can use some initialization. 01:16:51.120 |
Did you say that if we went from 27 numbers to 32, that we were losing information? 01:16:55.280 |
And could you say more about what that means? 01:16:58.320 |
Yeah, I guess we're not losing information where, that was poorly said, where we're 01:17:07.280 |
wasting information, I guess, where like that, if you start with 27 numbers and you do some 01:17:12.640 |
matrix multiplication and end up with 32 numbers, you are now taking more space for the same 01:17:22.280 |
And the whole point of a neural network layer is to pull out some interesting features. 01:17:28.360 |
So you would expect to have less total activations going on because you're trying to say, oh, 01:17:37.480 |
in this area, I've kind of pulled this set of pixels down into something that says how 01:17:43.920 |
fairy this is or how much of a diagonal line does this have or whatever. 01:17:48.800 |
So increasing the actual number of activations we have for a particular position is a total 01:18:01.200 |
We're not doing any useful, we're wasting a lot of calculation. 01:18:09.400 |
We can talk more about that on the forum if that's still not clear. 01:18:19.900 |
So this idea of creating telemetry for your model is really vital. 01:18:24.560 |
This approach to doing it, where you actually write a whole new class that only can do one 01:18:33.100 |
And so we clearly need a better way to do it. 01:18:39.680 |
Except we can't use our callbacks because we don't have a callback that says when you 01:18:50.620 |
So we actually need to use a feature inside PyTorch that can call back into our code when 01:18:58.880 |
a layer is calculated, either the forward pass or the backward pass. 01:19:03.280 |
And for reasons I can't begin to imagine PyTorch doesn't call them callbacks, they're called 01:19:12.120 |
And so we can say for any module, we can say register forward hook and pass in a function. 01:19:22.680 |
It's a callback that will be called when this module's forward pass is calculated. 01:19:28.960 |
Or you could say register backward hook and that will call this function when this module's 01:19:37.600 |
So to replace that previous thing with hooks, we can simply create a couple of global variables 01:19:44.200 |
to store our means and standard deviations for every layer. 01:19:47.480 |
We can create a function to call back to to calculate the mean and standard deviation. 01:19:53.960 |
And if you Google for the documentation for register forward hook, you will find that 01:20:00.120 |
it will tell you that the callback will be called with three things. 01:20:06.040 |
The module that's doing the callback, the input to the module, and the output of that 01:20:11.680 |
module, either the forward or the backward pass is appropriate. 01:20:18.120 |
And then we've got a fourth thing here because this is the layer number we're looking at, 01:20:22.280 |
and we used partial to connect the appropriate closure with each layer. 01:20:28.880 |
So once we've done that, we can call fit, and we can do exactly the same thing. 01:20:33.080 |
So this is the same thing, just much more convenient. 01:20:36.280 |
And because this is such a handy thing to be able to do, fast.ai has a hook class. 01:20:41.600 |
So we can create our own hook class now, which allows us to, rather than having this kind 01:20:46.560 |
of messy global state, we can instead put the state inside the hook. 01:20:52.720 |
So let's create a class called hook that when you initialize it, it registers a forward 01:21:01.320 |
And what it's going to do is it's going to recall back to this object. 01:21:10.220 |
And so that way, we can get access to the hook. 01:21:13.660 |
We can pop inside it our two empty lists when we first call this to store away our means 01:21:22.040 |
And then we can just append our means and standard deviations. 01:21:24.360 |
So now, we just go hooks equals hook for layer in children of my model. 01:21:32.680 |
And we'll just grab the first two layers because I don't care so much about the linear layers. 01:21:36.640 |
It's really the conf layers that are most interesting. 01:21:39.340 |
And so now, this does exactly the same thing. 01:21:44.380 |
Since we do this a lot, let's put that into a class two called hooks. 01:21:51.000 |
So here's our hooks class, which simply calls hook for every module in some list of modules. 01:22:01.480 |
Now something to notice is that when you're done using a hooked module, you should call 01:22:11.200 |
Because otherwise, if you keep registering more hooks on the same module, they're all 01:22:16.640 |
And eventually, you're going to run out of memory. 01:22:19.040 |
So one thing I did in our hook class was I created a dunder dell. 01:22:23.720 |
This is called automatically when Python cleans up some memory. 01:22:28.400 |
So when it's done with your hook, it will automatically call remove, which in turn will 01:22:41.620 |
So when hooks is done, it calls self.remove, which in turn goes through every one of my 01:22:57.640 |
You'll see that somehow I'm able to go for H in self, but I haven't registered any kind 01:23:08.360 |
And the trick is I've created something called a list container just above, which is super 01:23:14.880 |
It basically defines all the things you would expect to see in a list using all of the various 01:23:25.880 |
It actually has some of the behavior of numpy as well. 01:23:30.000 |
We're not allowed to use numpy in our foundations, so we use this instead. 01:23:33.440 |
And this actually also works a bit better than numpy for this stuff because numpy does 01:23:41.200 |
So for example, with this list container, it's got dunder get item. 01:23:45.960 |
So that's the thing that gets called when you call something with square brackets. 01:23:51.000 |
So if you index into it with an int, then we just pass it off to the enclosed list because 01:24:02.220 |
If you send it a list of balls, like false, false, false, false, false, false, false, 01:24:10.200 |
then it will return all of the things where that's true, or you can index it into it with 01:24:16.760 |
a list, in which case it will return all of the index, the things that are indexed by 01:24:23.360 |
For instance, it's got a length which just passes off to length and an iterator that 01:24:31.400 |
And then we've also defined the representation for it such that if you print it out, it just 01:24:36.960 |
prints out the contents unless there's more than 10 things in it, in which case it shows 01:24:43.680 |
So with a nice little base class like this, so you can create really useful little base 01:24:49.160 |
classes in much less than a screen full of code. 01:24:53.120 |
And then we can use them, and we will use them everywhere from now on. 01:24:56.280 |
So now we've created our own listy class that has hooks in it. 01:25:05.480 |
We can just say hooks equals hooks, everything in our model with that function we had before, 01:25:11.960 |
to pen stats, we can print it out to see all the hooks. 01:25:22.040 |
And check its mean and standard deviation is about zero one, as you would expect. 01:25:26.060 |
We can pass it through the first layer of our model. 01:25:29.000 |
Model zero is the first layer of our model, which is the first convolution. 01:25:32.880 |
And our mean is not quite zero, and our standard deviation is quite a lot less than one, as 01:25:41.080 |
So now we'll just go ahead and initialize it with timing. 01:25:44.280 |
And after that, our variance is quite close to one. 01:25:48.120 |
And our mean, as expected, is quite close to 0.5 because of the value. 01:25:57.280 |
So now we can go ahead and create our hooks and do a fit. 01:26:03.440 |
And we can plot the first 10 means and standard deviations, and then we can plot all the means 01:26:08.080 |
and standard deviations, and there it all is. 01:26:11.800 |
And this time we're doing it after we've initialized all the layers of our model. 01:26:18.000 |
And as you can see, we don't have that awful exponential crash, exponential crash, exponential 01:26:27.440 |
And you can see early on in training, our variances all look, our standard deviations 01:26:40.440 |
A with block is something that will create this object, give it this name, and when it's 01:26:51.040 |
The something it does is to call your dunder exit method here, which will not remove. 01:26:59.080 |
So here's a nice way to ensure that things are cleaned up. 01:27:08.520 |
That's what happens when you start the with block, dunder exit when you finish the with 01:27:15.520 |
So this is looking very hopeful, but it's not quite what we wanted to know. 01:27:20.720 |
Really the concern was, does this actually do something bad? 01:27:26.120 |
Is it actually, or does it just train fine afterwards? 01:27:29.520 |
So something bad really is more about how many of the activations are really, really 01:27:36.080 |
How well is it actually getting everything activated nicely? 01:27:40.280 |
So what we could do is we could adjust our append stats. 01:27:43.400 |
So not only does it have a mean and a standard deviation, but it's also got a histogram. 01:27:49.600 |
So we could create a histogram of the activations, pop them into 40 bins between 0 and 10. 01:27:58.040 |
We don't need to go underneath 0 because we have a value. 01:28:14.040 |
And what we find is that even with that, if we make our learning rate really high, 0.9, 01:28:30.560 |
And I should say thank you to Stefano for the original code here from our San Francisco 01:28:38.760 |
So you can see this kind of grow, collapse, grow, collapse, grow, collapse thing. 01:28:45.520 |
The biggest concern for me though is this yellow line at the bottom. 01:28:51.000 |
The yellow line, yellow is where most of the histogram is. 01:28:54.640 |
Actually, what I really care about is how much yellow is there. 01:28:59.440 |
So let's say the first two histogram bins are 0 or nearly 0. 01:29:07.180 |
So let's get the sum of how much is in those two bins and divide by the sum of all of the 01:29:13.600 |
And so that's going to tell us what percentage of the activations are 0 or nearly 0. 01:29:21.800 |
Let's plot that for each of the first four layers. 01:29:25.320 |
And you can see that in the last layer, it's just as we suspected, over 90% of the activations 01:29:35.560 |
So if you were training your model like this, it could eventually look like it's training 01:29:40.120 |
nicely without you realizing that 90% of your activations were totally wasted. 01:29:45.760 |
And so you're never going to get great results by wasting 90% of your activations. 01:29:52.800 |
Let's try and be able to train at a nice high learning rate and not have this happen. 01:29:57.340 |
And so the trick is, is we're going to try a few things, but the main one is we're going 01:30:04.840 |
And so we've created a generalized ReLU class where now we can pass in things like an amount 01:30:11.200 |
to subtract from the ReLU because remember we thought subtracting half from the ReLU 01:30:16.280 |
We can also use leaky ReLU and maybe things that are too big are also a problem. 01:30:20.840 |
So let's also optionally have a maximum value. 01:30:23.960 |
So in this generalized ReLU, if you passed a leakiness, then we'll use leaky ReLU. 01:30:32.560 |
You could very easily write these leaky ReLU by hand, but I'm just trying to make it run 01:30:38.680 |
a little faster by taking advantage of PyTorch. 01:30:41.560 |
If you said I want to subtract something from it, then go ahead and subtract that from it. 01:30:46.620 |
If I said there's some maximum value, go ahead and clamp it at that maximum value. 01:30:54.040 |
And so now let's have our conv layer and getCNN layers both take a * * quags and just pass 01:31:00.800 |
them on through so that eventually they end up passed to our generalized ReLU. 01:31:07.360 |
And so that way we're going to be able to create a CNN and say what ReLU characteristics 01:31:16.280 |
And even getCNN model will pass down quags as well. 01:31:21.440 |
So now that our ReLU can go negative because it's leaky and because it's subtracting stuff, 01:31:26.120 |
we'll need to change our histogram so it goes from -7 to 7 rather than from 0 to 10. 01:31:32.900 |
So we'll also need to change our definition of getMin so that the middle few bits of the 01:31:49.060 |
And now we can just go ahead and train this model just like before and plot just like 01:32:00.280 |
So here's the first one, two, three, four layers. 01:32:02.880 |
So compared to that, which was expand, does die, expand, die, expand, die, we're now seeing 01:32:16.200 |
It's using the full richness of the possible activations. 01:32:21.080 |
But our real question is how much is in this yellow line? 01:32:28.800 |
And let's see, in the final layer, look at that, less than 20%. 01:32:35.560 |
So we're now using nearly all of our activations by being careful about our initialization 01:32:44.440 |
and our ReLU, but we're still training at a nice high learning rate. 01:32:54.360 |
Could you explain again how to read the histograms? 01:33:00.120 |
So the four histogram, let's go back to the earlier one. 01:33:03.180 |
So the four histograms are simply the four layers. 01:33:16.100 |
So each one is just one more iteration as most of our plots show. 01:33:21.240 |
The y-axis is how many activations are the highest they can be or the lowest they can 01:33:33.700 |
So what this one here is showing us, for example, is that there are some activations that are 01:33:40.120 |
at the max and some activations are in the middle and some activations at the bottom, 01:33:44.680 |
whereas this one here is showing us that all of the activations are basically zero. 01:33:51.920 |
So what this shows us in this histogram is that now we're going all the way from plus 01:33:58.800 |
seven to minus seven because we can have negatives. 01:34:02.800 |
It's showing us that most of them are zero because yellow is the most energy. 01:34:11.160 |
There are activations throughout everything from the bottom to the top. 01:34:16.440 |
And a few less than zero, as we would expect, because we have a leaky value and we also 01:34:21.960 |
have that minus, we're not doing minus 0.5, we're doing minus 0.4, because leaky value 01:34:27.840 |
means that we don't need to subtract half anymore, we subtract a bit less than half. 01:34:34.560 |
And so then this line is telling us what percentage of them are zero or nearly zero. 01:34:45.520 |
And so this is one of those things which is good to run lots of experiments in the notebook 01:34:50.680 |
yourself to get a sense of what's actually in these histograms. 01:34:54.640 |
So you can just go ahead and have a look at each hook's stats. 01:34:58.720 |
And the third thing in it will be the histograms, so you can see what shape is it and how is 01:35:07.880 |
So now that we've done that, this is looking really good. 01:35:12.760 |
So what actually happens if we train like this? 01:35:20.320 |
So use that combined sheds we built last week, 50/50 two phases, cosine scheduling, cosine 01:35:28.560 |
So gradual warm up, gradual cool down, and then run it for eight epochs. 01:35:33.940 |
And there we go, we're doing really well, we're getting up to 98%. 01:35:37.760 |
So this kind of, we hardly were really training in a thing, we were just trying to get something 01:35:46.680 |
And once we had something that looked good in terms of the telemetry, it's really training 01:35:54.320 |
One option I added by the way in at CNN was I added a uniform Boolean, which will set 01:36:02.920 |
the initialization function to chiming normal, if it's false, which is what we've been using 01:36:13.160 |
Timing uniform, so now I've just trained the same model with uniform equals true. 01:36:19.000 |
A lot of people think that uniform is better than normal, because a uniform random number 01:36:27.960 |
And so the thinking is that maybe uniform random, uniform initialization might cause 01:36:34.520 |
it to kind of have a better richness of activations. 01:36:37.800 |
I haven't studied this closely, I'm not sure I've seen a careful analysis in a paper. 01:36:43.160 |
In this case, 9822 versus 9826, they're looking pretty similar, but that's just something 01:36:51.380 |
So at this point, we've got a pretty nice bunch of things you can look at now, and so 01:36:57.800 |
you can see as your problem to play with during the week is how accurate can you make a model? 01:37:10.320 |
And for the ones that are great accuracy, what does the telemetry look like? 01:37:14.720 |
How can you tell whether it's going to be good? 01:37:17.040 |
And then what insights can you gain from that to make it even better? 01:37:20.300 |
So in the end, try to beat me, try to beat 98%. 01:37:26.200 |
You'll find you can beat it pretty easily with some playing around, but do some experiments. 01:37:33.320 |
All right, so that's kind of about what we can do with initialization. 01:37:47.480 |
You can go further, as we discussed with Selu or with Fixup, like there are these really 01:37:54.880 |
finely tuned initialization methods that you can do 1,000 layers deep, but they're super 01:38:00.820 |
So generally, I would use something like the layer-wise sequential unit variance, LSUV 01:38:07.760 |
thing that we saw earlier in ... Oh, sorry, we haven't done that one yet. 01:38:22.080 |
So that's kind of about as far as we can get with basic initialization. 01:38:29.560 |
To go further, we really need to use normalization, of which the most commonly known approach 01:38:36.000 |
to normalization in the model is batch normalization. 01:38:43.280 |
So batch normalization has been around since, I think, about 2005. 01:38:53.080 |
And they first of all describe a bit about why they thought batch normalization was a 01:38:57.320 |
good idea, and by about page 3, they provide the algorithm. 01:39:05.560 |
So it's one of those things that if you don't read a lot of math, it might look a bit scary. 01:39:10.760 |
But then when you look at it for a little bit longer, you suddenly notice that this 01:39:14.760 |
is literally just the mean, sum divided by the count. 01:39:18.840 |
And this is the mean of the difference to the mean squared, and it's the mean of that. 01:39:26.760 |
Oh, that's just what we looked at, that's variance. 01:39:29.440 |
And this is just subtract the mean, divide by the standard deviation. 01:39:34.720 |
So once you look at it a second time, you realize we've done all this. 01:39:41.520 |
And so then, the only thing they do is after they've normalized it in the usual way, is 01:39:46.360 |
that they then multiply it by gamma, and they add beta. 01:39:59.280 |
Remember that there are two types of numbers in a neural network, parameters and activations. 01:40:06.200 |
Activations are things we calculate, parameters are things we learn. 01:40:13.640 |
So that's all the information we need to implement batch norm. 01:40:19.360 |
So first of all, we'll grab our data as before, create our callbacks as before. 01:40:29.480 |
And the highest I could get was a 0.4 learning rate this way. 01:40:47.240 |
We're going to get the mean and the variance. 01:40:51.400 |
And the way we do that is we call update stats, and the mean is just the mean. 01:41:00.960 |
And then we subtract the mean, and we divide by the square root of the variance. 01:41:06.320 |
And then we multiply by, and then I didn't call them gamma and beta, because why use 01:41:11.320 |
Greek letters when, because who remembers which one's gamma and which one's beta? 01:41:15.660 |
The thing we multiply, we'll call the malts, and the things we add, we'll call the ads. 01:41:26.080 |
We multiply by a parameter that initially is just a bunch of ones, so it does nothing. 01:41:31.420 |
And we add a parameter which is initially just a bunch of zeros, so it does nothing. 01:41:39.400 |
Just like our, remember our original linear layer we created by hand just looked like 01:41:45.480 |
In fact, if you think about it, ads is just bias. 01:41:49.760 |
It's identical to the bias we created earlier. 01:41:54.600 |
So then there's a few extra little things we have to think about. 01:41:58.120 |
One is what happens at inference time, right? 01:42:05.360 |
But the problem is that if we normalize in the same way at inference time, if we get 01:42:09.320 |
like a totally different kind of image, we might kind of remove all of the things that 01:42:19.320 |
So what we do is while we're training, we keep an exponentially weighted moving average 01:42:29.360 |
I'll talk more about what that means in a moment. 01:42:31.400 |
But basically we've got a running average of the last few batches means and a running 01:42:39.120 |
And so then when we're not training, in other words at inference time, we don't use the 01:42:44.800 |
mean and variance of this mini-batch, we use that running average mean and variance that 01:42:58.480 |
Well, we don't just create something called self.vars. 01:43:11.960 |
So why didn't we just say self.vars=torch.ones? 01:43:18.840 |
It's almost exactly the same as saying self.vars=torch.ones, but it does a couple of nice things. 01:43:25.840 |
The first is that if we move the model to the GPU, anything that's registered as a buffer 01:43:35.400 |
And if we didn't do that, then it's going to try and do this calculation down here. 01:43:39.680 |
And if the vars and means aren't on the GPU, but everything else is on the GPU, we'll get 01:43:44.840 |
It'll say, "Oh, you're trying to add this thing on the CPU to this thing on the GPU, 01:43:49.280 |
So that's one nice thing about register buffer. 01:43:51.380 |
The other nice thing is that the variances and the means, these running averages, they're 01:43:59.080 |
When we do inference, in order to calculate our predictions, we actually need to know 01:44:06.000 |
So if we save the model, we have to save those variances and means. 01:44:11.240 |
So register buffer also causes them to be saved along with everything else in the model. 01:44:19.600 |
So the variances, we start them out at ones, the means we start them out at zeros. 01:44:25.160 |
We then calculate the mean and variance of the minibatch, and we average out the axes 01:44:31.960 |
zero, two, and three. So in other words, we average over all the batches, and we average 01:44:41.940 |
So all we're left with is a mean for each channel, or a mean for each filter. 01:44:50.880 |
Keepgame equals true means that it's going to leave an empty unit access in positions 01:44:57.640 |
zero, two, and three, so it'll still broadcast nicely. 01:45:08.140 |
So normally, if we want to take a moving average, if we've got a bunch of data points, we want 01:45:17.680 |
We would grab five at a time, and we would take the average of those five, and they would 01:45:24.880 |
take the next five, and we'd take their average, and we keep doing that a few at a time. 01:45:32.200 |
We don't want to do that here, though, because these batch norm statistics, every single 01:45:43.600 |
Models can have hundreds of millions of activations. 01:45:47.800 |
We don't want to have to save a whole history of every single one of those, just so that 01:45:54.840 |
So there's a handy trick for this, which is instead to use an exponentially weighted moving 01:46:02.380 |
And basically, what we do is we start out with this first point, and we say, okay, our 01:46:19.160 |
And what we do is to take an exponentially weighted moving average, we first of all need 01:46:24.320 |
some number, which we call momentum, let's say it's 0.9. 01:46:30.200 |
So for the second value, so for the first value, our exponentially weighted moving average, 01:46:38.840 |
And then for the second one, we take mu one, we multiply it by our momentum, and then we 01:46:50.480 |
add our second value, five, and we multiply it by one minus our momentum. 01:46:59.320 |
So in other words, it's mainly whatever it used to be before, plus a little bit of the 01:47:05.780 |
And then mu two, sorry, mu three equals mu two times 0.9 plus, and maybe this one here 01:47:21.120 |
So we're basically continuing to say, oh, it's mainly the thing before plus a little 01:47:28.000 |
And so what you end up with is something where, like by the time we get to here, the amount 01:47:36.460 |
of influence of each of the previous data points, once you calculate it out, it turns 01:47:46.060 |
So it's a moving average with an exponential decay, with the benefit that we only ever 01:47:55.660 |
So that's what an exponentially weighted moving average is. 01:48:02.140 |
This thing we do here, where we basically say we've got some function where we say it's 01:48:08.980 |
some previous value times 0.9, say, plus some other value times one minus that thing. 01:48:28.080 |
It's a bit of this and a bit of this other thing, and the two together make one. 01:48:32.760 |
Linear interpolation in PyTorch is spelt lerp. 01:48:38.120 |
So we take the means, and then we lerp with our new mean using this amount of momentum. 01:48:49.320 |
Unfortunately, lerp uses the exact opposite of the normal sense of momentum. 01:48:58.080 |
So momentum of 0.1 in batch norm actually means momentum of 0.9 in normal person speak. 01:49:07.440 |
So this is actually how nn.batchnorm works as well. 01:49:14.560 |
So batch norm momentum is the opposite of what you would expect. 01:49:24.120 |
So this is the running average means instead of deviations. 01:49:32.200 |
So now we can create a new conv layer, which you can optionally say whether you want batch 01:49:40.760 |
If we do append a batch norm layer, we remove the bias layer. 01:49:44.400 |
Because remember I said that the ads in batch norm just is a bias. 01:49:49.160 |
So there's no point having a bias layer anymore. 01:49:56.360 |
And so now we can go ahead and initialize our CNN. 01:50:01.240 |
This is a slightly more convenient initialization now that's actually going to go in and recursively 01:50:08.480 |
initialize every module inside our module, the weights in the standard deviations. 01:50:25.280 |
And you can see our mean starts at 0 exactly. 01:50:29.400 |
And our standard deviation starts at 1 exactly. 01:50:35.400 |
So our training has entirely gotten rid of all of the exponential growth and sudden crash 01:50:46.840 |
There's something interesting going on at the very end of training, which I don't quite 01:50:52.040 |
I mean, when I say the end of training, we've only done one epoch. 01:50:56.240 |
But this is looking a lot better than anything we've seen before. 01:51:01.680 |
I mean, that's just a very nice-looking curve. 01:51:04.720 |
And so we're now able to get up to learning rates up to 1. 01:51:12.680 |
We've got 97% accuracy after just three epochs. 01:51:19.720 |
So now that we've built our own batch norm, we're allowed to use PyTorch's batch norm. 01:51:32.320 |
So now that we've got that, let's try going crazy. 01:51:34.800 |
Let's try using our little one-cycle learning scheduler we had. 01:51:39.880 |
And let's try and go all the way up to a learning rate of 2. 01:51:50.560 |
And we're now up towards nearly 99% accuracy. 01:52:00.360 |
Batch norm has a bit of a problem, though, which is that you can't apply it to what we 01:52:10.360 |
In other words, if you have a batch size of 1, right, so you're getting a single item 01:52:16.400 |
at a time and learning from that item, what's the variance of that batch? 01:52:21.760 |
The variance of a batch of 1 is infinite, right? 01:52:28.280 |
Well, what if we're doing like a segmentation task where we can only have a batch size of 01:52:32.800 |
2 or 4, which we've seen plenty of times in part 1? 01:52:37.960 |
Because across all of our layers, across all of our training, across all of the channels, 01:52:42.480 |
the batch size of 2, at some point, those two values are going to be the same or nearly 01:52:49.840 |
And so we then divide by that variance, which is about 0. 01:52:55.800 |
So we have this problem where any time you have a small batch size, you're going to get 01:53:07.720 |
Because for RNNs, remember, it looks something like this, right? 01:53:12.440 |
We have this hidden state, and we use the same weight matrix again and again and again. 01:53:19.360 |
Remember, we can unroll it, and it looks like this. 01:53:26.720 |
And then we can even stack them together into two RNNs. 01:53:35.360 |
And remember, these state, you know, time step to time step transitions, if we're doing IMDB 01:53:42.400 |
with a movie review with 2,000 words, there's 2,000 of these. 01:53:48.160 |
And this is the same weight matrix each time, and the number of these circles will vary. 01:53:53.560 |
It's the number of time steps will vary from document to document. 01:54:00.200 |
How would you say what's the running average of means and variances? 01:54:07.440 |
Because you can't put a different one between each of these unrolled layers, because, like, 01:54:13.820 |
So we can't have different values every time. 01:54:17.240 |
So it's not at all clear how you would insert batch norm into an RNN. 01:54:25.680 |
How do we handle very small batch sizes all the way down to a batch size of one? 01:54:32.520 |
So this paper called layer normalization suggests a solution to this. 01:54:40.920 |
And the layer normalization paper from Jimmy Barr and Kyros and Jeffrey Hinton, who just 01:54:49.880 |
won the Turing Award with Yoshua, Benjio, and Yann LeCun, which is kind of the Nobel 01:54:57.240 |
prize of computer science, they created this paper, which, like many papers, when you read 01:55:03.800 |
it, it looks reasonably terrifying, particularly once you start looking at all this stuff. 01:55:11.920 |
But actually, when we take this paper and we convert it to code, it's this. 01:55:18.440 |
Now, which is not to say the paper's garbage, it's just that the paper has lots of explanation 01:55:24.080 |
about what's going on and what do we find out and what does that mean. 01:55:29.760 |
It's the same as batch norm, but rather than saying x dot means 0, 2, 3, you say x dot 01:55:35.680 |
mean 1, 2, 3, and you remove all the running averages. 01:55:41.620 |
So this is layer norm with none of that running average stuff. 01:55:48.680 |
And the reason we don't need the running averages anymore is because we're not taking the mean 01:55:57.460 |
Every image has its own mean, every image has its own standard deviation. 01:56:02.520 |
So there's no concept of having to average across things in a batch. 01:56:14.580 |
So we average over the channels and the x and the y for each image individually. 01:56:20.080 |
So we don't have to keep track of any running averages. 01:56:23.280 |
The problem is that when we do that and we train, even at a lower learning rate of 0.8, 01:56:33.560 |
So it's a workaround we can use, but because we don't have the running averages at inference 01:56:40.000 |
time and more importantly, because we don't have a different normalization for each channel, 01:56:48.640 |
we're just throwing them all together and pretending they're the same and they're not. 01:56:52.800 |
So layer norm helps, but it's nowhere near as good as batch norm. 01:56:58.000 |
But for RNNs, what you have to use is something like this. 01:57:06.880 |
What if you're using layer norm on the actual input data and you're trying to distinguish 01:57:16.920 |
So foggy days will have less activations on average because they're less bright and they 01:57:29.320 |
So layer norm would cause the variances to be normalized to be the same and the means 01:57:39.580 |
So now the sunny day picture and the hazy day picture would have the same overall kind 01:57:48.560 |
And so the answer to this question is, no, you couldn't. 01:57:52.160 |
With layer norm, you would literally not be able to tell the difference between pictures 01:57:57.840 |
Now, it's not only if you put the layer norm on the input data, which you wouldn't do, 01:58:03.400 |
but everywhere in the middle layers, it's the same, right? 01:58:07.200 |
Anywhere where the overall level of activation or the amount of difference of activation 01:58:11.580 |
is something that is part of what you care about, it throws it away. 01:58:19.480 |
Furthermore, if your inference time is using things from kind of a different distribution 01:58:24.820 |
where that different distribution is important, it throws that away. 01:58:28.960 |
So layer norm's a partial hacky workaround for some very genuine problems. 01:58:35.480 |
There's also something called instance norm, and instance norm is basically the same thing 01:58:43.960 |
It's a bit easier to read in the paper because they actually lay out all the indexes. 01:58:48.140 |
So a particular output for a particular batch for a particular channel for a particular 01:58:51.960 |
x for a particular y is equal to the input for that batch and channel in x and y minus 01:59:01.960 |
So in other words, it's the same as layer norm, but now it's mean 2,3 rather than mean 01:59:07.800 |
So you can see how all these different papers, when you turn them into code, they're tiny 01:59:16.400 |
Instance norm, even at a learning rate of 0.1, doesn't learn anything at all. 01:59:24.280 |
Because we're now taking the mean, removing the difference in means and the difference 01:59:29.960 |
in activations for every channel and for every image, which means we've literally thrown 01:59:35.920 |
away all the things that allow us to classify. 01:59:45.800 |
It was designed for style transfer, where the authors guessed that these differences 01:59:52.360 |
in contrast and overall amount were not important, or something they should remove from trying 01:59:58.600 |
to create things that looked like different types of pictures. 02:00:05.440 |
You can't just go in and say, "Oh, here's another normalization thing, I'll try it." 02:00:09.200 |
You've got to actually know what it's for to know whether it's going to work. 02:00:13.400 |
So then finally, there's a paper called Group Norm, which has this wonderful picture, and 02:00:20.080 |
Batch Norm is averaging over the batch, and the height, and the width, and is different 02:00:33.320 |
Layer Norm is averaging for each channel, for each height, for each width, and is different 02:00:40.960 |
Instance Norm is averaging over height and width, and is different for each channel and 02:00:51.000 |
And then Group Norm is the same as Instance Norm, but they arbitrarily group a few channels 02:00:59.320 |
So Group Norm is a more general way to do it. 02:01:05.280 |
In the PyTorch docs, they point out that you can actually turn Group Norm into Instance 02:01:09.580 |
Norm, or Group Norm into Layer Norm, depending on how you group things up. 02:01:17.920 |
So there's all kinds of attempts to work around the problem that we can't use small batch 02:01:26.080 |
sizes, and we can't use RNNs with Batch Norm. 02:01:37.360 |
Well, I don't know how to fix the RNN problem, but I think I know how to fix the batch size 02:01:46.320 |
So let's start by taking a look at the batch size problem in practice. 02:01:49.240 |
Let's create a new data bunch with a batch size of 2. 02:01:55.120 |
And so here's our Conf layer, as before, with our Batch Norm. 02:01:59.580 |
And let's use a learning rate of 0.4 and fit that. 02:02:04.960 |
And the first thing you'll notice is that it takes a long time. 02:02:09.720 |
Small batch sizes take a long time, because it's just lots and lots of kernel launches 02:02:17.400 |
Something like this might even run faster on the CPU. 02:02:20.760 |
And then you'll notice that it's only 26% accurate, which is awful. 02:02:27.520 |
Because of what I said, the small batch size is causing a huge problem. 02:02:33.660 |
Because quite often, there's one channel in one layer where the variance is really small, 02:02:39.840 |
because those two numbers just happen to be really close, and so it blows out the activations 02:02:43.800 |
out to a billion, and everything falls apart. 02:02:49.400 |
There is one thing we could try to do to fix this really easily, which is to use Epsilon. 02:03:06.480 |
Look, we don't divide by the square root of variance. 02:03:12.240 |
We divide by the square root of variance plus Epsilon, where Epsilon is 1e neg 5. 02:03:19.660 |
Epsilon's a number that computer scientists and mathematicians, they use this Greek letter 02:03:25.600 |
very frequently to mean some very small number. 02:03:29.560 |
And in computer science, it's normally a small number that you add to avoid floating point 02:03:39.100 |
So it's very common to see it on the bottom of a division to avoid dividing by such small 02:03:46.320 |
numbers that you can't calculate things in floating point properly. 02:03:51.880 |
But our view is that Epsilon is actually a fantastic hyperparameter that you should be 02:04:03.720 |
With Batch Norm, what if we didn't set Epsilon to 1e neg 5? 02:04:11.340 |
If we set Epsilon to 0.1, then that basically would cause this to never make the overall 02:04:20.240 |
activations be multiplied by anything more than 10. 02:04:23.760 |
Sorry, that would be 0.01 because we're taking the square root. 02:04:27.800 |
So if you set it to 0.01, let's say the variance was 0, it would be 0 plus 0.01 square root. 02:04:37.400 |
So it ends up dividing by 0.1, which ends up multiplying by 10. 02:04:41.300 |
So even in the worst case, it's not going to blow out. 02:04:45.280 |
I mean, it's still not great because there actually are huge differences in variance 02:04:52.520 |
between different channels and different layers, but at least this would cause it to not fall 02:04:57.240 |
So option number one would be use a much higher Epsilon value. 02:05:02.920 |
And we'll keep coming back to this idea that Epsilon appears in lots of places in deep learning 02:05:06.640 |
and we should use it as a hyper parameter we control and take advantage of. 02:05:17.900 |
We think we have a better idea, which is we've built a new algorithm called running batch 02:05:24.000 |
And running batch norm, I think, is the first true solution to the small batch size batch 02:05:36.120 |
And like everything we do at fast AI, it's ridiculously simple. 02:05:40.240 |
And I don't know why no one's done it before. 02:05:49.500 |
In the forward function for running batch norm, don't divide by the batch standard deviation. 02:05:59.320 |
Don't subtract the batch mean, but instead use the moving average statistics at training 02:06:14.840 |
Because let's say you're using a batch size of two. 02:06:20.200 |
Then from time to time, in this particular layer, in this particular channel, you happen 02:06:24.760 |
to get two values that are really close together and they have a variance really close to zero. 02:06:30.000 |
But that's fine because you're only taking point one of that and point nine of whatever 02:06:38.720 |
So if previously the variance was one, now it's not 1e neg five, it's just point nine. 02:06:47.400 |
So in this way, as long as you don't get really unlucky and have the very first batch be dreadful, 02:06:56.440 |
because you're using this moving average, you never have this problem. 02:07:02.960 |
We'll look at the code in a moment, but let's do the same thing, 0.4. 02:07:10.600 |
We train it for one epoch, and instead of 26% accuracy, it's 91% accuracy. 02:07:22.680 |
In one epoch, just a two batch size and a pretty high learning rate. 02:07:28.920 |
There's quite a few details we have to get right to make this work. 02:07:36.380 |
But they're all details that we're going to see in lots of other places in this course. 02:07:41.080 |
We're just kind of seeing them here for the first time. 02:07:44.720 |
So I'm going to show you all of the details, but don't get overwhelmed. 02:07:50.760 |
The first detail is something very simple, which is in normal batch norm, we take the 02:07:57.600 |
running average of variance, but you can't take the running average of variance. 02:08:03.140 |
It doesn't make sense to take the running average of variance. 02:08:06.800 |
You can't just average a bunch of variances, particularly because they might even be different 02:08:12.760 |
batch sizes, because batch size isn't necessarily constant. 02:08:16.600 |
Instead, as we learned earlier in the class, the way that we want to calculate variance 02:08:25.600 |
is like this, sum of expected value of mean of X squared minus mean of X squared. 02:08:34.960 |
Let's just, as I mentioned, we can do, let's keep track of the squares and the sums. 02:08:42.920 |
So we register a buffer called sums and we register a buffer called squares and we just 02:08:50.000 |
go X dot sum over 023 dimensions and X times X dot sum, so squared. 02:09:01.040 |
And then we'll take the lerp, the exponentially weighted moving average of the sums and the 02:09:08.080 |
And then for the variance, we will do squares divided by count minus squared mean. 02:09:23.260 |
So that's detail number one that we have to be careful of. 02:09:29.400 |
Detail number two is that the batch size could vary from many batch to many batch. 02:09:36.080 |
So we should also register a buffer for count and take an exponentially weighted moving 02:09:47.180 |
So that basically tells us, so what do we need to divide by each time? 02:09:53.000 |
The amount we need to divide by each time is the total number of elements in the mini-batch 02:10:01.960 |
That's basically grid X times grid Y times batch size. 02:10:06.940 |
So let's take an exponentially weighted moving average of the count and then that's what 02:10:12.840 |
we will divide by for both our means and variances. 02:10:22.040 |
Detail number three is that we need to do something called debiasing. 02:10:32.280 |
We want to make sure that at every point, and we're going to look at this in more detail 02:10:38.400 |
when we look at optimizers, we want to make sure that every point that no observation 02:10:47.160 |
And the problem is that the normal way of doing moving averages, the very first point 02:10:54.020 |
gets far too much weight because it appears in the first moving average and the second 02:11:00.380 |
So there's a really simple way to fix this, which is that you initialize both sums and 02:11:05.440 |
squares to zeros and then you do a lerp in the usual way and let's see what happens when 02:11:35.000 |
So actually we only need to look at the first value. 02:11:38.540 |
So the value, so actually let's say the value is 10. 02:11:42.940 |
So we initialize our mean to zero at the very start of training. 02:11:53.740 |
So we would expect the moving average to be 10. 02:11:56.460 |
But our lerp formula says it's equal to our previous value, which is 0, times 0.9 plus 02:12:05.820 |
our new value times 0.1 equals 0 plus 1, equals 1, it's 10 times too small. 02:12:17.820 |
So that's very easy to correct for because we know it's always going to be wrong by that 02:12:24.620 |
So we then divide it by 0.1 and that fixes it. 02:12:31.540 |
And then the second value has exactly the same problem. 02:12:37.780 |
But this time it's actually going to be divided by, let's not call it 0.1. 02:12:48.240 |
Because when you work through the math, you'll see the second one, it's going to be divided 02:12:59.980 |
So this thing here where we divide by that, that's called debiasing. 02:13:03.620 |
It's going to appear again when we look at optimization. 02:13:08.700 |
So you can see what we do is we have a exponentially weighted debiasing amount where we simply 02:13:16.380 |
keep multiplying momentum times the previous debiasing amount. 02:13:22.560 |
So initially it's just equal to momentum and then momentum squared and then momentum cubed 02:13:32.300 |
So then we do what I just said, we divide by the debiasing amount. 02:13:42.540 |
And then there's just one more thing we do, which is remember how I said you might get 02:13:48.060 |
really unlucky that your first mini-batch is just really close to zero and we don't 02:13:55.420 |
So I just say if you haven't seen more than a total of 20 items yet, just clamp the variance 02:14:01.780 |
to be no smaller than 0.01, just to avoid blowing out of the water. 02:14:18.940 |
It's a very straightforward idea, but when we put it all together, it's shockingly effective. 02:14:29.120 |
And so then we can try an interesting thought experiment, so here's another thing to try 02:14:35.460 |
What's the best accuracy you can get in a single epoch? 02:14:45.380 |
And with this convolutional with running batch norm layer and a batch size of 32 and a linear 02:15:01.380 |
I only tried a couple of things, so this is definitely something that I hope you can beat 02:15:07.500 |
But it's really good to create interesting little games to play. 02:15:16.980 |
Almost everything in research is basically toy problems. 02:15:19.260 |
Come up with toy problems and try to find good solutions to them. 02:15:23.540 |
So another toy problem for this week is what's the best you can get using whatever kind of 02:15:31.300 |
normalization you like, whatever kind of architecture you like, as long as it only uses concepts 02:15:36.980 |
we've used up to lesson 7 to get the best accuracy you can in one epoch. 02:15:58.620 |
We haven't done all the kind of ablation studies and stuff we need to do yet. 02:16:02.940 |
At this stage, though, I'm really excited about this. 02:16:05.020 |
Every time I've tried it on something, it's been working really well. 02:16:10.220 |
The last time that we had something in a lesson that we said, this is unpublished research 02:16:15.100 |
that we're excited about, it turned into ULM fit, which is now a really widely used algorithm 02:16:28.700 |
So fingers crossed that this turns out to be something really terrific as well. 02:16:32.420 |
But either way, you've kind of got to see the process, because literally building these 02:16:38.220 |
notebooks was the process I used to create this algorithm. 02:16:41.500 |
So you've seen the exact process that I used to build up this idea and do some initial 02:16:50.380 |
So hopefully that's been fun for you, and see you next week.