back to index

Lesson 10 (2019) - Looking inside the model


Chapters

0:0 Introduction
0:21 Dont worry
2:8 Where are we
2:49 Training loop
4:26 Faster Audio
5:45 Vision Topics
7:12 What is a callback
11:20 Creating a callback
13:31 lambda notation
17:23 partial function application
18:25 class as a callback
19:53 kwargs
24:13 change
27:22 dunder
29:34 editor tips
35:21 variance
40:15 covariance
44:59 softmax
52:44 learning rate finder
54:34 refactoring

Whisper Transcript | Transcript Only Page

00:00:00.000 | Welcome to lesson 10, which I've rather enthusiastically titled wrapping up our CNN, but looking at
00:00:10.020 | how many things we want to cover, I've added a nearly to the end and I'm not actually sure
00:00:14.800 | how nearly we'll get there.
00:00:15.880 | We'll see.
00:00:16.880 | We'll probably have a few more things to cover next week as well.
00:00:22.680 | I just wanted to remind you after hearing from a few folks during the week who are very
00:00:28.200 | sad that they're not quite keeping up with everything.
00:00:31.680 | That's totally okay.
00:00:33.280 | Don't worry.
00:00:34.480 | As I mentioned in lesson one, I'm trying to give you enough here to keep you busy until
00:00:39.920 | the next part to next year.
00:00:44.780 | So you can dive into the bits you're interested in and go back and look over stuff and yeah,
00:00:52.200 | don't feel like you have to understand everything within within a week of first hearing it.
00:00:57.580 | But also if you're not putting in the time during the homework or you didn't put in the
00:01:02.440 | time during the homework in the last part, you know, expect to have to go back and recover
00:01:07.160 | things particularly because a lot of the stuff we covered in part one, I'm kind of assuming
00:01:11.880 | that you're deeply comfortable with at this point.
00:01:16.100 | Not because you're stupid if you're not, but just because it gives you the opportunity
00:01:20.000 | to go back and re-study it and practice and experiment until you are deeply comfortable.
00:01:26.320 | So yeah, if you're finding it whizzing along at a pace, that is because it is whizzing
00:01:31.680 | along at a pace.
00:01:32.880 | Also it's covering a lot of more software engineering kind of stuff, which for the people
00:01:38.720 | who are practicing software engineers, you'll be thinking this is all pretty straightforward.
00:01:42.640 | And for those of you that are not, you'll be thinking, wow, there's a lot here.
00:01:47.880 | Part of that is because I think data scientists need to be good software engineers.
00:01:52.080 | So I'm trying to show you some of these things, but, you know, it's stuff which people can
00:01:56.360 | spend years learning.
00:01:58.420 | And so hopefully this is the start of a long process for you that haven't done software
00:02:02.520 | engineering before of becoming better software engineers.
00:02:05.160 | And there are some useful tips, hopefully.
00:02:09.200 | So to remind you, we're trying to recreate fast AI and much of PyTorch from these foundations.
00:02:21.280 | And starting to make things even better.
00:02:24.600 | And today you'll actually see some bits, well, in fact, you've already seen some bits that
00:02:28.000 | are going to be even better.
00:02:29.800 | I think the next version of fast AI will have this new callback system, which I think is
00:02:33.120 | better than the old one.
00:02:34.480 | And today we're going to be showing you some new previously unpublished research, which
00:02:39.400 | will be finding its way into fast AI and maybe other libraries as well also.
00:02:45.680 | So we're going to try and stick to, and we will stick to, using nothing but these foundations.
00:02:51.100 | And we're working through developing a modern CNN model, and we've got to the point where
00:02:55.180 | we've done our training loop at this point, and we've got a nice flexible training loop.
00:03:01.960 | So from here, the rest of it, when I say we're going to finish out a modern CNN model, it's
00:03:09.320 | not just going to be some basic getting by model, but we're actually going to endeavor
00:03:13.660 | to get something that is approximately state-of-the-art on ImageNet in the next week or two.
00:03:20.160 | So that's the goal.
00:03:22.840 | And in our testing at this point, we're feeling pretty good about showing you some stuff that
00:03:29.080 | maybe hasn't been seen before on ImageNet results.
00:03:32.960 | So that's where we're going to try and head as a group.
00:03:36.200 | And so these are some of the things that we're going to be covering to get there.
00:03:43.700 | One of the things you might not have seen before in this section called optimization
00:03:46.560 | is LAM.
00:03:48.480 | The reason for this is that this was going to be some of the unpublished research we
00:03:51.520 | were going to show you, which is a new optimization algorithm that we've been developing.
00:03:56.960 | The framework is still going to be new, but actually the particular approach to using
00:04:01.260 | it was published by Google two days ago, so we've kind of been scooped there.
00:04:07.000 | So this is a cool paper, really great, and they introduce a new optimization algorithm
00:04:11.640 | called LAM, which we'll be showing you how to implement it very easily.
00:04:18.600 | And if you're wondering how we're able to do that so fast, it's because we've kind of
00:04:21.520 | been working on the same thing ourselves for a few weeks now.
00:04:26.880 | So then from next week, we'll start also developing a completely new fast.ai module called fastai.audio.
00:04:36.000 | So you'll be seeing how to actually create modules and how to write Jupyter documentation
00:04:39.760 | and tests.
00:04:41.240 | And we're going to be learning about audio, such as complex numbers and Fourier transforms,
00:04:45.480 | which if you're like me at this point, you're going, oh, what, no, because I managed to
00:04:50.840 | spend my life avoiding complex numbers and Fourier transforms on the whole.
00:04:54.920 | But don't worry, it'll be OK.
00:04:58.680 | It's actually not at all bad, or at least the bits we need to learn about are not at
00:05:02.000 | all bad.
00:05:03.000 | And you'll totally get it even if you've never ever touched these before.
00:05:06.920 | We'll be learning about audio formats and spectrograms, doing data augmentation and
00:05:12.560 | things that aren't images, and some particular kinds of loss functions and architectures
00:05:17.280 | for audio.
00:05:19.800 | And as much as anything, it'll just be a great kind of exercise in, OK, I've got some different
00:05:24.100 | data type that's not in fastai.
00:05:26.100 | How do I build up all the bits I need to make it work?
00:05:31.000 | Then we'll be looking at neural translation as a way to learn about sequence to sequence
00:05:34.620 | with attention models, and then we'll be going deeper and deeper into attention models, looking
00:05:38.460 | at transformer, and it's even more fantastic, descendant, transformer, Excel.
00:05:46.840 | And then we'll wrap up our Python adventures with a deep dive into some really interesting
00:05:51.640 | vision topics, which is going to require building some bigger models.
00:05:56.520 | So we'll talk about how to build your own deep learning box, how to run big experiments
00:06:00.520 | on AWS with a new library we've developed called Fast EC2.
00:06:06.560 | And then we're going to see exactly what happened last course when we did that unit super resolution
00:06:11.240 | image generation, what are some of the pieces there.
00:06:15.100 | And we've actually got some really exciting new results to show you, which have been done
00:06:18.920 | in collaboration with some really cool partners.
00:06:21.140 | So I'm looking forward to showing you that to give you a tip.
00:06:25.240 | Generative video models is what we're going to be looking at.
00:06:28.720 | And then we'll be looking at some interesting different applications, device, cycleGAN,
00:06:33.520 | and object detection.
00:06:36.320 | And then Swift, of course.
00:06:40.960 | So the Swift lessons are coming together nicely, really excited about them, and we'll be covering
00:06:48.400 | as much of the same territory as we can, but obviously it'll be in Swift and it'll be in
00:06:52.060 | only two lessons, so it won't be everything, but we'll try to give you enough of a taste
00:06:57.320 | that you'll feel like you understand why Swift is important and how to get started with building
00:07:03.560 | something similar in Swift.
00:07:05.320 | And maybe building out the whole thing in Swift will take the next 12 months.
00:07:09.560 | Who knows?
00:07:10.560 | We'll see.
00:07:13.920 | So we're going to start today on zero 5A foundations.
00:07:19.880 | And what we're going to do is we're going to recover some of the software engineering
00:07:23.680 | and math basics that we were relying on last week and going into a little bit more detail.
00:07:28.760 | Specifically, we'll be looking at callbacks and variants and a couple of other Python concepts
00:07:35.840 | like Dunder special methods.
00:07:38.420 | If you're familiar with those things, feel free to skip ahead if you're watching the
00:07:41.320 | video till we get to the new material.
00:07:44.440 | Callbacks, as I'm sure you've seen, are super important for fast AI, and in general they're
00:07:50.520 | a really useful technique for software engineering.
00:07:55.040 | And great for researchers because they allow you to build things that you can quickly adjust
00:07:59.800 | and add things in and pull them out again, so really great for research as well.
00:08:04.320 | So what is a callback?
00:08:06.360 | Let's look at an example.
00:08:10.120 | So here's a function called f, which prints out hi.
00:08:14.760 | And I'm going to create a button, and I'm going to create this button using IPY widgets,
00:08:21.400 | which is a framework for creating GUI widgets in Python.
00:08:26.160 | So if I run, if I say W, then it shows me a button, which says click B, and I can click
00:08:34.200 | on it, and nothing happens.
00:08:37.720 | So how do I get something to happen?
00:08:39.160 | Well, what I need to do is I need to pass a function to the IPY widgets framework to
00:08:47.120 | say please run this function when you click on this button.
00:08:52.360 | So IPY widget doc says that there's an onclick method which can register a function to be
00:08:59.040 | called when the button is clicked.
00:09:01.080 | So let's try running that method, passing it f, my function.
00:09:06.560 | OK, so now nothing happened, didn't run anything, but now if I click on here-- oh, hi, hi.
00:09:16.640 | So what happened is I told W that when a click occurs, you should call back to my f function
00:09:29.080 | and run it.
00:09:30.080 | So anybody who's done GUI programming will be extremely comfortable with this idea, and
00:09:34.280 | if you haven't, this will be kind of mind bending.
00:09:37.800 | So f is a callback.
00:09:40.620 | It's not a particular class.
00:09:42.400 | It doesn't have a particular signature.
00:09:44.120 | It's not a particular library.
00:09:45.760 | It's a concept.
00:09:47.680 | It's a function that we treat as an object.
00:09:51.200 | So look, we're not calling the function.
00:09:53.480 | We don't have any parentheses after f.
00:09:55.120 | We're passing the function itself to this method, and it says please call back to me
00:10:03.120 | when something happens, and in this case, it's when I click.
00:10:07.960 | So there's our starting point.
00:10:09.400 | And these kinds of functions, these kinds of callbacks that are used in a GUI in particular
00:10:16.520 | framework when some event happens are often called events.
00:10:19.840 | So if you've heard of events, they're a kind of callback, and then callbacks are a kind
00:10:24.840 | of what we would call a function pointer.
00:10:26.800 | I mean, they can be much more general than that, as you'll see, but it's basically a
00:10:30.520 | way of passing in something to say call back to this when something happens.
00:10:34.800 | Now, by the way, these widgets are really worth looking at if you're interested in building
00:10:41.720 | some analytical GUIs.
00:10:44.160 | Here's a great example from the Plotly documentation of the kinds of things you can create with
00:10:50.960 | widgets, and it's not just for creating applications for others to use, but if you want to experiment
00:10:56.000 | with different types of function or hyperparameters or explore some data you've collected, widgets
00:11:04.720 | are a great way to do that, and as you can see, they're very, very easy to use.
00:11:11.600 | In part one, you saw the image labeling stuff that was built with widgets like this.
00:11:21.040 | So that's how you can use somebody else's callback.
00:11:24.000 | Now to create our own callback.
00:11:26.680 | So let's create a callback, and the event that it's going to call back on is after a
00:11:32.200 | calculation is complete.
00:11:34.840 | So let's create a function called slow_calculation, and it's going to do five calculations.
00:11:41.120 | It's going to add i squared to a result, and then it's going to take a second to do it
00:11:46.040 | because we're going to add a sleep there.
00:11:47.440 | So this is kind of something like an epoch of deep learning.
00:11:51.160 | It's some calculation that takes a while.
00:11:53.360 | So if we call slow_calculation, then it's going to take five seconds to calculate the
00:11:59.880 | sum of i squared, and there it's done it.
00:12:02.200 | So I'd really like to know how's it going, get some progress.
00:12:06.740 | So we could take that and we could add something that you pass in a callback, and we just add
00:12:14.080 | one line of code that says if there's a callback, then call it and pass in the epoch number.
00:12:21.880 | So then we could create a function called show_progress that prints out awesome.
00:12:25.560 | We finished epoch number, epoch, and look, it takes a parameter and we're passing a parameter,
00:12:31.840 | so therefore we could now call slow_calculation and pass in show_progress, and it will call
00:12:39.920 | back to our function after each epoch.
00:12:43.680 | So there's our starting point for our callback.
00:12:46.640 | Now, what will tend to happen, you'll notice with stuff that we do in fast.ai, we'll start
00:12:52.840 | somewhere like this that's, for many of you, is trivially easy, and at some point during
00:12:58.440 | the next hour or two, you might reach a point where you're feeling totally lost.
00:13:04.200 | And the trick is to go back, if you're watching the video, to the point where it was trivially
00:13:08.520 | easy and figure out the bit where you suddenly noticed you were totally lost, and find the
00:13:12.200 | bit in the middle where you kind of missed a bit, because we're going to just keep building
00:13:15.760 | up from trivially easy stuff, just like we did with that matrix multiplication, right?
00:13:20.440 | So we're going to gradually build up from here and look at more and more interesting
00:13:23.200 | callbacks, but we're starting with this wonderfully short and simple line of code.
00:13:32.560 | So rather than defining a function just for the purpose of using it once, we can actually
00:13:38.560 | define the function at the point we use it using lambda notation.
00:13:43.000 | So lambda notation is just another way of creating a function.
00:13:45.960 | So rather than saying def, we say lambda, and then rather than putting in parentheses
00:13:50.640 | the arguments, we put them before a colon, and then we list the thing you want to do.
00:13:55.820 | So this is identical to the previous one.
00:13:59.280 | It's just a convenience for times where you want to define the callback at the same time
00:14:05.080 | that you use it, can make your code a little bit more concise.
00:14:10.540 | What if you wanted to have something where you could define what exclamation to use in
00:14:16.520 | the string as well?
00:14:17.520 | So we've now got two things.
00:14:20.060 | We can't pass this show progress to slow calculation.
00:14:26.000 | Let's try it.
00:14:28.700 | Right?
00:14:31.680 | It tries to call back, and it calls, remember CB is now show progress, so it's passing show
00:14:36.560 | progress and it's passing epoch as exclamation, and then epoch is missing.
00:14:41.480 | So that's an error.
00:14:42.480 | We've called a function with two arguments with only one.
00:14:45.760 | So we have to convert this into a function with only one argument.
00:14:49.960 | So lambda O is a function with only one argument, and this function calls show progress with
00:14:56.500 | a particular exclamation.
00:14:59.720 | So we've converted something with two arguments into something with one argument.
00:15:04.880 | We might want to make it really easy to allow people to create different progress indicators
00:15:08.920 | with different exclamations.
00:15:10.840 | So we could create a function called make show progress that returns that lambda.
00:15:17.240 | So now we could say make show progress, so we could do that here.
00:15:26.360 | Make show progress.
00:15:33.640 | And that's the same thing.
00:15:37.880 | This is a little bit awkward, so generally you might see it done like this instead.
00:15:42.200 | You see this in fast AI all the time.
00:15:44.220 | We define the function inside it, but this is basically just the same as our lambda,
00:15:51.080 | and then we return that function.
00:15:53.000 | So this is kind of interesting, because you might think of defining a function as being
00:15:57.680 | like a declarative thing, that as soon as you define it, that now it's part of the thing
00:16:02.200 | that's compiled, if you see your C++, that's how they work.
00:16:06.200 | In Python, that's not how they work.
00:16:08.020 | When you define a function, you're actually saying basically the same as this, which is
00:16:13.800 | there's a variable with this name, which is a function.
00:16:19.640 | And that's how come then we can actually take something that's passed to this function and
00:16:25.440 | use it inside here.
00:16:27.480 | So this is actually, every time we call make show progress, it's going to create a new
00:16:32.160 | function, underscore inner internally, with a different exclamation.
00:16:37.440 | And so it'll work the same as before.
00:16:44.300 | So this thing where you create a function that actually stores some information from
00:16:48.200 | the external context, and like it can be different every time, that's called a closure.
00:16:52.240 | So it's a concept you'll come across a lot, particularly if you're a JavaScript programmer.
00:16:57.300 | So we could say, f2 equals make show progress, terrific.
00:17:09.600 | And so that now contains that closure.
00:17:13.640 | So it actually remembers what exclamation you passed it.
00:17:24.760 | Because it's so often that you want to take a function that takes two parameters and turn
00:17:30.360 | it into a function that takes one parameter, Python and most languages have a way to do
00:17:35.400 | that, which is called partial function application.
00:17:39.160 | So the standard library functools has this thing called partial.
00:17:42.480 | So if you take run call partial and you pass it a function, and then you pass in some arguments
00:17:48.120 | for that function, it returns a new function, which that parameter is always a given.
00:17:56.080 | So let's check it out.
00:17:57.720 | So we could run it like that, or we could say f2 equals this partial function application.
00:18:06.400 | And so if I say f2 shift tab, then you can see this is now a function that just takes
00:18:11.920 | epoch.
00:18:12.920 | It just takes epoch because show progress took two parameters.
00:18:17.480 | We've already passed it one.
00:18:18.960 | So this now takes one parameter, which is what we need.
00:18:21.760 | So that's why we could pass that to as our callback.
00:18:26.760 | So we've seen a lot of those techniques already last week.
00:18:33.640 | Most of what we saw last week, though, did not use a function as a callback, but used
00:18:38.000 | a class as a callback.
00:18:40.440 | So we could do exactly the same thing, but pretty much any place you can use a closure,
00:18:45.340 | you can also use a class.
00:18:47.320 | Instead of storing it away inside the closure, some state, we can store our state, in this
00:18:52.520 | case the exclamation, inside self, passing it into init.
00:18:57.720 | So here's exactly the same thing as we saw before, but as a class.
00:19:02.960 | Dundacall is a special magic name which will be called if you take an object, so in this
00:19:12.600 | case a progress showing callback object, and call it with parentheses.
00:19:17.160 | So if I go cb high, you see I'm taking that object and I'm treating it as if it's a function.
00:19:23.760 | And that will call dundacall.
00:19:27.600 | If you've used other languages like in C++, this is called a functor.
00:19:34.400 | More generally it's called a callable in Python, so it's kind of something that a lot of languages
00:19:40.640 | have.
00:19:42.680 | All right, so now we can use that as a callback, just like before.
00:19:50.920 | All right, next thing to look at is, for our callback, is we're going to use star args and
00:19:59.640 | star star kwargs, or otherwise known as quargs.
00:20:03.920 | For those of you that don't know what these mean, let's create a function that takes star
00:20:08.200 | args and star star kwargs and prints out args and kwargs.
00:20:15.320 | So if I call that function, I could pass it 3a, thing1 equals hello, and you'll see that
00:20:22.080 | all the things that are passed as positional arguments end up in a tuple called args, and
00:20:28.280 | all the things passed as keyword arguments end up as a dictionary called quargs.
00:20:34.200 | That's literally all these things do.
00:20:38.000 | And so PyTorch uses that, for example, when you create an nn.sequential, it takes what
00:20:44.360 | you pass in as a star args, right, you just pass them directly and it turns it into a tuple.
00:20:50.560 | So why do we use this?
00:20:53.080 | There's a few reasons we use it, but one of the common ways to use it is if you kind of
00:20:58.200 | want to wrap some other class or object, then you can take a bunch of stuff as star star
00:21:04.760 | kwargs and pass it off to some other functional object.
00:21:10.480 | We're getting better at this, and we're removing a lot of the usages, but in the early days
00:21:13.980 | of fast AI version 1, we actually were overusing quargs.
00:21:18.240 | So quite often, we would kind of -- there would be a lot of stuff that wasn't obviously
00:21:24.640 | in the parameter list of a function that ended up in quargs, and then we would pass it down
00:21:29.600 | to, I don't know, the PyTorch data loader initializer or something.
00:21:33.880 | And so we've been gradually removing those usages, because, like, it's mainly most helpful
00:21:39.440 | for kind of quick and dirty throwing things together.
00:21:45.120 | In R, they actually use an ellipsis for the same thing.
00:21:47.780 | They kind of overuse it.
00:21:49.040 | Quite often, it's hard to see what's going on.
00:21:51.000 | You might have noticed in Matplotlib, a lot of times the thing you're trying to pass to
00:21:56.080 | Matplotlib isn't there in the shift tab when you hit shift tab.
00:21:59.040 | It's the same thing.
00:22:00.040 | They're using quargs.
00:22:01.120 | So there are some downsides to using it, but there are some places you really want to use
00:22:06.320 | For example, take a look at this.
00:22:08.480 | Let's take rewrite slow calculation, but this time we're going to allow the user to create
00:22:13.680 | a callback that is called before the calculation occurs and after the calculation that occurs.
00:22:21.840 | And the after calculation one's a bit tricky, because it's going to take two parameters
00:22:25.960 | It's going to take both the epoch number and also what have we calculated so far.
00:22:32.920 | So we can't just call CB parentheses I. We actually now have to assume that it's got
00:22:39.480 | some particular methods.
00:22:42.020 | So here is, for example, a print step callback, which before calculation just says I'm about
00:22:49.240 | to start and after calculation it says I'm done and there it's running.
00:22:55.800 | So in this case, this callback didn't actually care about the epoch number or about the value.
00:23:02.120 | And so, it just has star, star, star, quargs in both places.
00:23:07.880 | It doesn't have to worry about exactly what's being passed in, because it's not using them.
00:23:11.880 | So this is quite a good kind of use of this, is to basically create a function that's going
00:23:19.280 | to be used somewhere else and you don't care about one or more of the parameters or you
00:23:23.080 | want to make things more flexible.
00:23:25.360 | So in this case, we don't get an error saying, because if we remove this, which looks like
00:23:32.400 | we should be able to do because we don't use anything, but here's a problem.
00:23:37.600 | It tried to call before calc I, and before calc doesn't take an I.
00:23:42.960 | So if you put in both positional and keyword arguments, it'll always work everywhere.
00:23:51.600 | And so here we can actually use them.
00:23:54.400 | So let's actually use epoch and value to print out those details.
00:24:00.180 | So now you can see there it is printing them out.
00:24:02.000 | And in this case, I've put star, star, quargs at the end, because maybe in the future, there'll
00:24:06.560 | be some other things that are passed in and we want to make sure this doesn't break.
00:24:09.800 | So it kind of makes it more resilient.
00:24:14.680 | The next thing we might want to do with callbacks is to actually change something.
00:24:21.080 | So a couple of things that we did last week, one was we wanted to be able to cancel out
00:24:26.440 | of a loop to stop early.
00:24:29.000 | The other thing we might want to do is actually change the value of something.
00:24:33.320 | So in order to stop early, we could check.
00:24:38.280 | And also the other thing we might want to do is say, well, what if you don't want to
00:24:41.720 | define before calc or after calc?
00:24:44.320 | We wouldn't want everything to break.
00:24:46.020 | So we can actually check whether a callback is defined and only call it if it is.
00:24:52.120 | And we could actually check the return value and then do something based on the return
00:24:56.480 | value.
00:24:57.480 | So here's something which will cancel out of our loop if the value that's been calculated
00:25:04.240 | so far is over 10.
00:25:06.960 | So here we stop.
00:25:10.720 | Okay?
00:25:13.480 | What if you actually want to change the way the calculation is being done?
00:25:18.960 | So we could even change the way the calculation is being done by taking our calculation function,
00:25:27.200 | putting it into a class.
00:25:29.560 | And so now the value that it's calculated is an attribute of the class.
00:25:36.260 | And so now we could actually do something, a callback that reaches back inside the calculator
00:25:42.880 | and changes it, right?
00:25:46.600 | So this is going to double the result if it's less than three.
00:25:50.160 | So if we run this, right, we now actually have to call this because it's a class, but
00:25:56.720 | you can see it's giving a different value.
00:25:59.480 | And so we're also taking advantage of this in the callbacks that we're using.
00:26:03.520 | So this is kind of the ultimately flexible callback system.
00:26:10.040 | And so you'll see in this case, we actually have to pass the calculator object to the
00:26:20.340 | callback.
00:26:21.340 | So the way we do that is we've defined a callback method here, which checks to see whether it's
00:26:30.280 | defined and if it is, it grabs it and then it calls it passing in the calculator object
00:26:36.280 | itself so it's now available.
00:26:40.520 | And so what we actually did last week is we didn't call this callback.
00:26:45.720 | We called this dunder call, which means we were able to do it like this, okay?
00:26:58.200 | Now you know, which do you prefer?
00:27:01.080 | It's kind of up to you, right?
00:27:02.080 | I mean, we had so many callbacks being called that I felt the extra noise of giving it a
00:27:07.640 | name was a bit messy.
00:27:09.880 | On the other hand, you might feel that calling a callback isn't something you expect dunder
00:27:13.800 | call to do, in which case you can do it that way.
00:27:17.480 | So there's pros and cons, neither is right or wrong.
00:27:22.080 | Okay, so that's callbacks.
00:27:29.760 | We've been using dunder thingys a lot.
00:27:33.800 | Dunder thingys look like this.
00:27:36.440 | And in Python, a dunder thingy is special somehow.
00:27:42.480 | Most languages kind of let you define special behaviors.
00:27:46.520 | For example, in C++, there's an operator keyword where if you define a function that says operator
00:27:52.920 | something like plus, you're defining the plus operator.
00:27:58.400 | So most languages tend to have special magic names you can give things that make something
00:28:04.480 | a constructor or a destructor or an operator.
00:28:08.000 | I like in Python that all of the magic names actually look magic.
00:28:13.440 | They all look like that, which I think is actually a really good way to do it.
00:28:19.080 | So the Python docs have a data model reference where they tell you about all these special
00:28:28.760 | method names.
00:28:30.000 | And you can go through and you can see what are all the special things you can get your
00:28:33.920 | method to do.
00:28:34.920 | Like you can override how it behaves with less than or equal to or et cetera, et cetera.
00:28:41.740 | There's a particular list I suggest you know, and this is the list.
00:28:45.440 | So you can go to those docs and see what these things do because we use all of these in this
00:28:50.600 | course.
00:28:51.600 | So here's an example.
00:28:54.300 | Here's a sloppy adder plus.
00:28:57.640 | You pass in some number that you're going to add up.
00:29:00.320 | And then when you add two things together, it will give you the result of adding them
00:29:03.720 | up, but it will be wrong by 0.01.
00:29:06.960 | And that is called dunder add because that's what happens when you see plus.
00:29:11.480 | This is called dunder init because this is what happens when an object gets constructed.
00:29:15.720 | And this is called dunder repre because this is what gets called when you print it out.
00:29:19.320 | So now I can create a one adder and a two adder and I can plus them together and I can
00:29:24.200 | see the result.
00:29:26.120 | So that's kind of an example of how these special dunder methods work.
00:29:36.120 | So that's a bit of that Python stuff.
00:29:42.920 | There's another bit of code stuff that I wanted to show you, which you'll need to be doing
00:29:47.280 | a lot of, which is you need to be really good at browsing source code.
00:29:53.120 | If you're going to be contributing stuff to fast AI or to the fast AI for Swift for TensorFlow
00:30:01.640 | or just building your own more complex projects, you need to be able to jump around source
00:30:07.560 | code.
00:30:08.560 | Or even just to find out how PyTorch does something, if you're doing some research,
00:30:12.120 | you need to really understand what's going on under the hood.
00:30:16.920 | This is a list of things you should know how to do in your editor of choice.
00:30:22.080 | Any editor that can't do all of these things is worth replacing with one that can.
00:30:26.900 | Most editors can do these things, emacs can, visual studio code can, sublime can, and the
00:30:35.080 | editor I use most of the time, vim, can as well.
00:30:38.000 | I'll show you what these things are in vim.
00:30:42.160 | On the forums there are already some topics saying how to do these things in other editors.
00:30:47.440 | If you don't find one that seems any good, feel free to create your own topic if you've
00:30:52.000 | got some tips about how to do these things or other useful things in your editor of choice.
00:30:56.400 | I'm going to show you in vim for no particular reason, just because I use vim.
00:31:06.360 | My editor, it's called vim.
00:31:08.640 | One of the things I like about vim is I can use it in a terminal, which I find super helpful
00:31:12.860 | because I'm working on remote machines all the time and I like to be at least as productive
00:31:17.900 | in a terminal as I am on my local computer.
00:31:24.760 | The first thing you should be able to do is to jump to a symbol.
00:31:28.320 | A symbol would be like a class or a function or something like that.
00:31:33.720 | For example, I might want to be able to jump straight to the definition of createCNN, but
00:31:45.960 | I can't quite remember the name of the function, createCNN.
00:31:49.680 | I would go colon tag, create, I'm pretty sure it's create underscore something, and then
00:31:55.360 | I'd press tab a few times and it would loop through, there it is, createCNN, and then
00:32:00.920 | I'd hit enter.
00:32:02.480 | That's the first thing that your editor should do, is it should make it easy to jump to a
00:32:06.680 | tag even if you can't remember exactly what it is.
00:32:09.720 | The second thing it should do is that you should be able to click on something like
00:32:13.880 | CNN learner and hit a button, which in vim's case is control right square bracket, and
00:32:19.320 | it should take you to the definition of that thing.
00:32:22.060 | Okay, let's create this CNN learner, what's this thing called a data bunch, right square
00:32:27.560 | bracket, okay, there's data bunch.
00:32:30.320 | You'll also see that my vim is folding things, classes, and functions to make it easier for
00:32:38.000 | me to see exactly what's in this file.
00:32:40.840 | In some editors, this is called outlining, in some it's called folding, most editors
00:32:45.600 | should do this.
00:32:48.600 | Then there should be a way to go back to where you were before, in vim that's control T for
00:32:54.280 | going back up the tag stack.
00:32:56.280 | So here's my CNN learner, here's my createCNN, and so you can see in this way it makes it
00:33:02.360 | nice and easy to kind of jump around a little bit.
00:33:09.200 | Something I find super helpful is to also be able to jump into the source code of libraries
00:33:13.520 | I'm using.
00:33:14.520 | So for example, here's chiming normal, so I've got my vim configured, so if I hit control
00:33:20.040 | right square bracket on that, it takes me to the definition of chiming normal in the PyTorch
00:33:25.400 | source code.
00:33:26.720 | And I find docstrings kind of annoying, so I have mine folded up by default, but I can
00:33:30.360 | always open them up.
00:33:32.640 | If you use vim, the way to do that is to add additional tags for any packages that you
00:33:43.760 | want to be able to jump to.
00:33:45.120 | I'm sure most editors will do something pretty similar.
00:33:48.600 | Now that I've seen how chiming normal works, I can use the same control T to jump back
00:33:52.960 | to where I was in my fast AI source code.
00:33:59.800 | Then the only other thing that's particularly important to know how to do is to just do
00:34:03.560 | more general searches.
00:34:05.360 | So let's say I wanted to find all the places that I've used Lambda, since we talked about
00:34:10.600 | Lambda today, I have a particular thing I use called ACK, I can say ACK Lambda, and here
00:34:19.360 | is a list of all of the places I've used Lambda, and I could click on one, and it will jump
00:34:27.080 | to the code where it's used.
00:34:30.920 | Again most editors should do something like that for you.
00:34:33.860 | So I find with that basic set of stuff, you should be able to get around pretty well.
00:34:38.380 | If you're a professional software engineer, I know you know all this.
00:34:41.980 | If you're not, hopefully you're feeling pretty excited right now to discover that editors
00:34:45.380 | can do more than you realized.
00:34:47.440 | And so sometimes people will jump on our GitHub and say, I don't know how to find out what
00:34:54.000 | a function is that you're calling because you don't list all your imports at the top
00:34:57.960 | of the screen, but this is a great place where you should be using your editor to tell you.
00:35:03.840 | And in fact, one place that GUI editors can be pretty good is often if you actually just
00:35:08.560 | point at something, they will pop up something saying exactly where is that symbol coming
00:35:13.080 | from.
00:35:14.080 | I don't have that set up in VIM, so I just have to hit the right square bracket to see
00:35:18.880 | where something's coming from.
00:35:21.280 | Okay, so that's some tips about stuff that you should be able to do when you're browsing
00:35:27.680 | source code, and if you don't know how to do it yet, please Google or look at the forums
00:35:32.240 | and practice.
00:35:35.720 | Something else we were looking at a lot last week and you need to know pretty well is variance.
00:35:40.880 | So just a quick refresher on what variance is, or for those of you who haven't studied
00:35:44.960 | it before, here's what variance is.
00:35:48.720 | Variance is the average of how far away each data point is from the mean.
00:35:53.140 | So here's some data, right, and here's the mean of that data.
00:35:57.880 | And so the average distance for each data point from the mean is T, the data points,
00:36:04.840 | minus M, top mean.
00:36:06.840 | Oh, that's zero.
00:36:08.880 | That didn't work.
00:36:09.880 | Oh, well of course it didn't work.
00:36:11.640 | The mean is defined as the thing which is in the middle, right?
00:36:16.320 | So of course that's always zero.
00:36:17.800 | So we need to do something else that doesn't have the positives and negatives cancel out.
00:36:23.380 | So there's two main ways we fix it.
00:36:25.680 | One is by squaring each thing before we take the mean, like so.
00:36:32.500 | The other is taking the absolute value of each thing.
00:36:36.600 | So turning all the negatives and positives before we take the mean.
00:36:40.300 | So they're both common fixes for this problem.
00:36:43.780 | You can see though the first is now on a totally different scale, right?
00:36:47.240 | The numbers were like 1, 2, 4, 8, 8, and this is 47.
00:36:50.640 | So we need to undo that squaring.
00:36:52.920 | So after we've squared, we then take the square root at the end.
00:36:57.000 | So here are two numbers that represent how far things are away from the mean.
00:37:02.260 | Or in other words, how much do they vary?
00:37:04.480 | If everything's pretty close to similar to each other, those two numbers will be small.
00:37:08.960 | If they're wildly different to each other, those two numbers will be big.
00:37:14.240 | This one here is called the standard deviation.
00:37:17.880 | And it's defined as the square root of this one here which is called the variance.
00:37:23.000 | And this one here is called the mean absolute deviation.
00:37:27.200 | You could replace this M made with various other things like median, for example.
00:37:37.960 | So we have one outlier here, 18.
00:37:42.400 | So in the case of the one where we took a square in the middle of it, this number is
00:37:47.120 | higher because the square takes that 18 and makes it much bigger.
00:37:51.920 | So in other words, standard deviation is more sensitive to outliers than mean absolute deviation.
00:37:58.880 | So for that reason, the mean absolute deviation is very often the thing you want to be using
00:38:07.320 | because in machine learning outliers are more of a problem than to help a lot of the time.
00:38:13.840 | But mathematicians and statisticians tend to work with standard deviation rather than
00:38:18.200 | mean absolute deviation because it makes their math proofs easier and that's the only reason.
00:38:22.800 | They'll tell you otherwise, but that's the only reason.
00:38:25.440 | So the mean absolute deviation is really underused and it actually is a really great measure
00:38:32.720 | to use and you should definitely get used to it.
00:38:40.360 | There's a lot of places where I kind of notice that replacing things involving squares with
00:38:45.360 | things involving absolute values, the absolute value things just often work better.
00:38:49.920 | It's a good tip to remember that there's this kind of long-held assumption.
00:38:54.920 | We have to use a squared thing everywhere, but it actually often doesn't work as well.
00:39:02.760 | This is our definition of variance.
00:39:06.280 | Notice that this is the same.
00:39:10.400 | So this is written in math.
00:39:15.920 | This written in math looks like this and it's another way of writing the variance.
00:39:21.960 | It's important because it's super handy and it's super handy because in this one here
00:39:28.280 | we have to go through the whole data set once.
00:39:30.680 | To calculate the mean of the data and then a second time to get the squares of the differences.
00:39:39.480 | This is really nice because in this case, we only have to keep track of two numbers,
00:39:45.760 | the squares of the data and the sum of the data and as you'll see shortly, this kind
00:39:53.800 | of way of doing things is generally therefore just easier to work with.
00:39:59.360 | So even though this is kind of the definition of the variance that makes intuitive sense,
00:40:08.280 | this is the definition of variance that you normally want to implement.
00:40:13.580 | And so there it is in math.
00:40:16.360 | The other thing we see quite a bit is covariance and correlation.
00:40:21.480 | So if we take our same data set, let's now create a second data set which is double t
00:40:31.800 | times a little bit of random noise, so here's that plotted.
00:40:39.640 | Let's now look at the difference between each item of t and its mean and multiply it by
00:40:47.520 | each item of u and its mean, so there's those values and let's look at the mean of that.
00:40:58.520 | So what's this number?
00:41:00.440 | So it's the average of the difference of how far away the x value is from the mean of the
00:41:05.920 | x value is, the x's, multiplied by each difference between the y value and how far away from
00:41:13.760 | the y mean it is.
00:41:16.940 | Let's compare this number to the same number calculated with this data set, where this
00:41:22.320 | data set is just some random numbers compared to v.
00:41:28.840 | And let's now calculate the exact same product, the exact same mean.
00:41:32.680 | This number's much smaller than this number.
00:41:36.720 | Why is this number much smaller?
00:41:38.880 | So if you think about it, if these are kind of all lined up nicely, then every time it's
00:41:45.880 | higher than the average on the x-axis, it's also higher than the average on the y-axis.
00:41:52.520 | So you have two big positive numbers and vice versa, two big negative numbers.
00:41:58.180 | So in either case, you end up with, when you multiply them together, a big positive number.
00:42:03.280 | So this is adding up a whole bunch of big positive numbers.
00:42:07.660 | So in other words, this number tells you how much these two things vary in the same way,
00:42:17.120 | kind of how lined up are they on this graph.
00:42:20.240 | And so this one, when one is big, the other's not necessarily big.
00:42:24.400 | When one is small, the other's not necessarily very small.
00:42:29.120 | So this is the covariance, and you can also calculate it in this way, which might look
00:42:39.960 | somewhat similar to what we saw before with our different variance calculation.
00:42:44.840 | And again, this is kind of the easier way to use it.
00:42:51.160 | So as I say here, from now on, I don't want you to ever look at an equation or type in
00:42:59.040 | an equation in LaTeX without typing it in Python, calculating some values and plotting them.
00:43:05.880 | Because this is the only way we get a sense in here of what these things mean.
00:43:12.800 | And so in this case, we're going to take our covariance and we're going to divide it by
00:43:18.040 | the product of the standard deviations.
00:43:21.720 | And this gives us a new number, and this is called correlation, or more specifically Pearson
00:43:27.280 | correlation coefficient.
00:43:30.000 | So we don't cover covariance and Pearson correlation coefficient too much in the course, but it
00:43:34.720 | is one of these things which it's often nice to see how things vary.
00:43:38.480 | But remember, it's telling you really about how things vary linearly, right?
00:43:42.800 | So if you want to know how things vary non-linearly, you have to create something called a neural
00:43:46.040 | network and check the loss and the metrics.
00:43:50.520 | But it's kind of interesting to see also how variance and covariance, you can see they're
00:43:55.200 | much the same thing, you know, where else one of them, in fact, you basically you can think
00:44:00.120 | of it this way, right?
00:44:01.120 | One of them is E of X squared, in other words, X and X are kind of the same thing, it's E
00:44:07.320 | of X times X, where else this is two different things, E of X times Y, right?
00:44:11.960 | And so rather than having here, we had E of X squared here and E of X squared here.
00:44:19.960 | If you replace the second X with a Y, you get that and you get that.
00:44:24.920 | So they're like literally the same thing.
00:44:28.760 | And then again here, if X and X are the same, then this is just sigma squared, right?
00:44:36.040 | So the last thing I want to quickly talk about a little bit more is Softmax.
00:44:43.540 | This was our final log Softmax definition from the other day, and this is the formula,
00:44:52.160 | the same thing as an equation.
00:44:55.680 | And this is our cross entropy loss, remember.
00:45:00.040 | So these are all important concepts we're going to be using a lot.
00:45:02.040 | So I just wanted to kind of clarify something that a lot of researchers that are published
00:45:06.960 | in big name conferences get wrong, which is when should you and shouldn't you use Softmax.
00:45:13.040 | So this is our Softmax page from our entropy example spreadsheet, where we were looking
00:45:18.580 | at cat, dog, plane, fish, building.
00:45:21.620 | And so we had various outputs.
00:45:23.200 | This is just the activations that we might have gotten out of the last layer of our model.
00:45:27.720 | And this is just E to the power of each of those activations.
00:45:31.040 | And this is just the sum of all of those E to the power ofs.
00:45:34.360 | And then this is E to the power of divided by the sum, which is Softmax.
00:45:38.600 | And of course, they all add up to one.
00:45:43.040 | This is like some image number one that gave these activations.
00:45:46.760 | Here's some other image number two, which gave these activations, which are very different
00:45:51.640 | to these.
00:45:53.880 | But the Softmaxes are identical.
00:45:57.440 | That and that are identical.
00:45:59.600 | So that's weird.
00:46:00.680 | How has that happened?
00:46:01.680 | Well, it's happened because in every case, the E to the power of this divided by the
00:46:09.060 | sum of the E to the power ofs ended up in the same ratio.
00:46:13.920 | So in other words, even though fish is only 0.63 here, but it's true here, once you take
00:46:20.640 | E to the power of, it's the same percentage of the sum, right?
00:46:25.160 | And so we end up with the same Softmax.
00:46:28.120 | Why does that matter?
00:46:30.120 | Well, in this model, it seems like being a fish is associated with having an activation
00:46:38.680 | of maybe like 2ish, right?
00:46:41.920 | And this is only like 0.6ish.
00:46:43.920 | So maybe there's no fish in this.
00:46:47.240 | But what's actually happened is there's no cats or dogs or planes or fishes or buildings.
00:46:54.800 | So in the end then, because Softmax has to add to 1, it has to pick something.
00:47:07.600 | So it's fish that comes through.
00:47:09.280 | And what's more is because we do this E to the power of, the thing that's a little bit
00:47:13.960 | higher, it pushes much higher because it's exponential, right?
00:47:17.760 | So Softmax likes to pick one thing and make it big.
00:47:22.000 | And they have to add up to 1.
00:47:24.480 | So the problem here is that I would guess that maybe image 2 doesn't have any of these
00:47:28.720 | things in it.
00:47:30.280 | And we had to pick something.
00:47:31.560 | So it said, oh, I'm pretty sure there's a fish.
00:47:34.040 | Or maybe the problem actually is that this image had a cat and a fish and a building.
00:47:43.840 | But again, because Softmax, they have to add to 1, and one of them is going to be much
00:47:48.440 | bigger than the others.
00:47:49.960 | So I don't know exactly which of these happened, but it's definitely not true that they both
00:47:55.880 | have an equal probability of having a fish in them.
00:48:00.060 | So to put this another way, Softmax is a terrible idea unless you know that every one of your,
00:48:06.840 | if you're doing image recognition, every one of your images, or if you're doing audio or
00:48:11.600 | tabular or whatever, every one of your items has one, no more than one, and definitely
00:48:17.920 | at least one example of the thing you care about in it.
00:48:22.900 | Because if it doesn't have any of cat, dog, plane, fish, or building, it's still going
00:48:27.160 | to tell you with high probability that it has one of those things.
00:48:31.600 | Even if it has more than just one of cat, dog, plane, fish, or building, it'll pick
00:48:35.080 | one of them until you're pretty sure it's got that one.
00:48:38.160 | So what do you do if there could be no things or there could be more than one of these things?
00:48:45.840 | Well, instead you use binomial, regular old binomial, which is e to the x divided by one
00:48:54.400 | plus e to the x.
00:48:56.000 | It's exactly the same as Softmax if your two categories are, has the thing and doesn't
00:49:01.960 | have the thing because they're like p and one minus p.
00:49:05.000 | So you can convince yourself of that during the week.
00:49:07.840 | So in this case, let's take image one and let's go 1.02 divided by one plus 1.02.
00:49:17.880 | And ditto for each of our different ones.
00:49:20.600 | And then let's do the same thing for image two.
00:49:23.360 | And you can see now the numbers are different, as we would hope.
00:49:28.000 | And so for image one, it's kind of saying, oh, it looks like there might be a cat in
00:49:33.960 | it if we assume 0.5 is a cut off, there's probably a fish in it, and it seems likely
00:49:41.160 | that there's a building in it, right?
00:49:44.080 | Whereas for image two, it's saying, I don't think there's anything in there, but maybe
00:49:48.040 | a fish.
00:49:49.040 | And this is what we want, right?
00:49:51.800 | And so when you think about it, like for image recognition, probably most of the time you
00:49:57.960 | don't want Softmax.
00:50:00.080 | So why do we always use Softmax?
00:50:02.840 | Because we all grew up with ImageNet.
00:50:05.040 | And ImageNet was specifically curated, so it only has one of the classes in ImageNet in
00:50:11.400 | it, and it always has one of those classes in it.
00:50:17.200 | An alternative, if you want to be able to handle the what if none of these classes are
00:50:22.480 | in it case, is you could create another category called background or doesn't exist or null
00:50:29.160 | or missing.
00:50:31.240 | So let's say you created this missing category.
00:50:33.120 | So there's six, cat, dog, plane, fish, building or missing, nothing.
00:50:41.540 | A lot of researchers have tried that, but it's actually a terrible idea and it doesn't
00:50:47.080 | work.
00:50:48.080 | And the reason it doesn't work is because to be able to successfully predict missing,
00:50:54.320 | the penultimate layer activations have to have the features in it that is what a not
00:51:00.800 | cat, dog, plane, fish, fish or building looks like.
00:51:04.200 | So how do you describe a not cat, dog, plane, fish or building?
00:51:08.480 | What are the things that would activate high?
00:51:10.080 | Is it shininess?
00:51:11.080 | Is it fur?
00:51:12.080 | Is it sunshine?
00:51:13.080 | Is it edges?
00:51:15.080 | It's none of those things.
00:51:17.360 | There is no set of features that when they're all high is clearly a not cat, dog, plane,
00:51:22.240 | fish or building.
00:51:23.400 | So that's just not a kind of object.
00:51:25.840 | So a neural net can kind of try to hack its way around it by creating a negative model
00:51:33.060 | of every other single type and create a kind of not one of any of those other things.
00:51:37.680 | But that's very hard for it.
00:51:39.720 | Whereas creating simply a binomial does it or doesn't it have this for every one of the
00:51:46.480 | classes is really easy for it, right?
00:51:49.320 | Because it just doesn't have a cat.
00:51:50.960 | Yes or no.
00:51:51.960 | It doesn't have a dog.
00:51:52.960 | Yes or no.
00:51:53.960 | And so forth.
00:51:54.960 | So lots and lots of well-regarded academic papers make this mistake.
00:52:04.520 | So look out for it.
00:52:06.640 | And if you do come across an academic paper that's using softmax and you think, does that
00:52:11.280 | actually work with softmax?
00:52:14.140 | And you think maybe the answer is no.
00:52:15.800 | Try replicating it without softmax and you may just find you get a better result.
00:52:20.320 | An example of somewhere where softmax is obviously a good idea or something like softmax is obviously
00:52:25.440 | a good idea, language modeling.
00:52:28.400 | What's the next word?
00:52:30.280 | It's definitely at least one word.
00:52:32.600 | It's definitely not more than one word, right?
00:52:35.040 | So you want softmax.
00:52:36.040 | So I'm not saying softmax is always a dumb idea, but it's often a dumb idea.
00:52:41.920 | So that's something to look out for.
00:52:45.240 | Okay. Next thing I want to do is I want to build a learning rate finder.
00:52:53.000 | And to build a learning rate finder, we need to use this test callback kind of idea, this
00:53:01.320 | ability to stop somewhere.
00:53:05.120 | Problem is, as you may have noticed, this I want to stop somewhere callback wasn't working
00:53:12.560 | in our new refactoring where we created this runner class.
00:53:16.080 | And the reason it wasn't working is because we were turning true to mean cancel.
00:53:23.000 | But even after we do that, it still goes on to do the next batch.
00:53:27.440 | And even if we set self.stop, even after we do that, it'll go on to the next epoch.
00:53:32.080 | So to like actually stop it, you would have to return false from every single callback
00:53:36.840 | that's checked to make sure it like really stops or you would have to add something that
00:53:41.720 | checks for self.stop in lots of places.
00:53:44.680 | But it would be a real pain.
00:53:46.920 | Right?
00:53:47.920 | And it's also not as flexible as we would like.
00:53:50.840 | So what I want to show you today is something which I think is really interesting, which
00:53:54.480 | is using the idea of exceptions as a kind of control flow statement.
00:54:03.800 | You may have think of exceptions as just being a way of handling errors.
00:54:07.600 | But actually exceptions are a very versatile way of writing very neat code that will be
00:54:13.120 | very helpful for your users.
00:54:15.080 | Let me show you what I mean.
00:54:18.960 | So let's start by just grabbing our MNIST data set as before and creating our data bunch
00:54:22.800 | as before.
00:54:23.960 | And here's our callback as before and our train eval callback as before.
00:54:29.080 | But there's a couple of things I'm going to do differently.
00:54:32.800 | The first is, and this is a bit unrelated, but I think it's a useful refactoring, is
00:54:36.760 | previously inside runner end under call, we went through each callback in order, and we
00:54:45.840 | checked to see whether that particular method exists in that callback.
00:54:51.480 | And if it was, we called it and checked whether it returns true or false.
00:54:58.720 | It actually makes more sense for this to be inside the callback class.
00:55:07.960 | Because by putting it into the callback class, the callback class is now taking a -- has a
00:55:13.320 | dunder call which takes a callback name, and it can do this stuff.
00:55:18.440 | And what it means is that now your users who want to create their own callbacks, let's
00:55:25.400 | say they wanted to create a callback that printed out the callback name for every callback every
00:55:31.360 | time it was run.
00:55:32.600 | Or let's say they wanted to add a break point, like a set trace that happened every time
00:55:38.700 | the callback was run.
00:55:40.200 | They could now create their own inherit from callback and actually replace dunder call itself
00:55:48.400 | with something that added this behavior they want.
00:55:52.000 | Or they could add something that looks at three or four different callback names and
00:55:56.200 | attaches to all of them.
00:55:57.660 | So this is like a nice little extra piece of flexibility.
00:56:02.280 | It's not the key thing I wanted to show you, but it's an example of a nice little refactoring.
00:56:07.440 | The key thing I wanted to show you is that I've created three new types of exception.
00:56:13.460 | So an exception in Python is just a class that inherits from exception.
00:56:18.640 | And most of the time you don't have to give it any other behavior.
00:56:21.880 | So to create a class that's just like its parent, but it just has a new name and no
00:56:25.920 | more behavior, you just say pass.
00:56:28.600 | So pass means this has all the same attributes and everything as the parent.
00:56:34.940 | But it's got a different name.
00:56:37.220 | So why do we do that?
00:56:38.220 | Well, you might get a sense from the names.
00:56:40.680 | Cancel train exception, cancel epoch exception, cancel batch exception.
00:56:45.360 | The idea is that we're going to let people's callbacks cancel anything, you know, cancel
00:56:50.280 | at one of these levels.
00:56:51.280 | So if they cancel a batch, it will keep going with the next batch, but not finish this one.
00:56:56.560 | If they cancel an epoch, it will keep going with the next epoch that will cancel this
00:57:02.080 | Cancel train will stop the training altogether.
00:57:06.480 | So how would cancel train exception work?
00:57:08.560 | Well, here's the same runner we had before.
00:57:12.000 | But now fit, we already had try finally to make sure that our after fit and remove learner
00:57:20.120 | happened, even if there's an exception, I've added one line of code.
00:57:24.680 | Accept cancel train exception.
00:57:27.320 | And if that happens, then optionally it could call some after cancel train callback.
00:57:32.240 | But most importantly, no error occurs.
00:57:35.920 | It just keeps on going to the finally block and will elegantly and happily finish up.
00:57:45.280 | So we can cancel training.
00:57:46.680 | So now our test callback can after step, we'll just print out what step we're up to.
00:57:53.840 | And if it's greater than or equal to 10, we will raise cancel train exception.
00:57:58.640 | And so now when we say run dot fit, it just prints out up to 10 and stops.
00:58:04.760 | There's no stack trace, there's no error.
00:58:07.180 | This is using exception as a control flow technique, not as an error handling technique.
00:58:13.280 | So another example, inside all batches, I go through all my batches in a try block,
00:58:23.000 | except if there's a cancel epoch exception, in which case I optionally call an after cancel
00:58:28.040 | epoch callback and then continue to the next epoch.
00:58:34.080 | Or inside one batch, I try to do all this stuff for a batch, except if there's a cancel
00:58:39.360 | batch exception, I will optionally call the after cancel batch callback, and then continue
00:58:45.440 | to the next batch.
00:58:47.520 | So this is like a super neat way that we've allowed any callback writer to stop any one
00:58:54.640 | of these three levels of things happening.
00:58:59.800 | So in this case, we're using cancel train exception to stop training.
00:59:05.580 | So we can now use that to create a learning rate finder.
00:59:10.320 | So the basic approach of the learning rate finder is that there's something in begin
00:59:14.500 | batch which, just like our parameter scheduler, is using exponential curve to set the learning
00:59:23.880 | rate.
00:59:24.880 | So this is identical to parameter scheduler.
00:59:28.000 | And then after each step, it checks to see whether we've done more than the maximum number
00:59:32.920 | of iterations, which is defaulting to 100, or whether the loss is much worse than the
00:59:40.200 | best we've had so far.
00:59:42.240 | And if either of those happens, we will raise cancel train exception.
00:59:46.600 | So to be clear, this neat exception-based approach to control flow isn't being used in the fast
00:59:52.160 | AI version one at the moment, but it's very likely that fast AI 1.1 or 2 will switch to
00:59:58.440 | this approach because it's just so much more convenient and flexible.
01:00:02.880 | And then assuming we haven't canceled, just see if the loss is better than our best loss,
01:00:06.520 | and if it is, then set best loss to the loss.
01:00:13.280 | So now we can create a learner, we can add the LR find, we can fit, and you can see that
01:00:21.600 | it only does less than 100 epochs before it stops because the loss got a lot worse.
01:00:29.720 | And so now we know that we want something about there for our learning rate.
01:00:34.880 | Okay, so now we have a learning rate finder.
01:00:42.080 | So let's go ahead and create a CNN, and specifically a scooter CNN.
01:00:50.900 | So we'll keep doing the same stuff we've been doing, get our MNIST data, normalize it.
01:00:57.680 | Here's a nice little refactoring because we very often want to normalize with this dataset
01:01:04.280 | and normalize both datasets using this dataset's mean and standard deviation.
01:01:08.180 | Let's create a function called normalize2, which does that and returns the normalized
01:01:12.440 | training set and the normalized validation set.
01:01:15.640 | So we can now use that, make sure that it's behaved properly, that looks good.
01:01:22.800 | Create our data bunch, and so now we're going to create a CNN model, and the CNN is just
01:01:29.280 | a sequential model that contains a bunch of stride two convolutions.
01:01:35.920 | And remember the input's 28 by 28, so after the first it'll be 14 by 14, then 7 by 7,
01:01:41.520 | then 4 by 4, then 2 by 2, then we'll do our average pooling, flatten it, and a linear
01:01:47.520 | layer, and then we're done.
01:01:51.540 | Now remember our original data is vectors of length 768, they're not 28 by 28, so we
01:01:58.000 | need to do a Bax.view one channel by 28 by 28 because that's what nn.com2d expects, and
01:02:08.720 | then minus one, the batch size remains whatever it was before.
01:02:12.720 | So we need to somehow include this function in our nn.sequential, PyTorch doesn't support
01:02:19.280 | that by default, we could write our own class with a forward function, but nn.sequential
01:02:25.080 | is convenient for lots of ways, it has a nice representation, you can do all kinds of customizations
01:02:30.100 | with it, so instead we create a layer called lambda, an nn.module called lambda, it just
01:02:38.400 | pasted a function, and the forward is simply to call that function.
01:02:44.720 | And so now we can say lambda, MNIST resize, and that will cause that function to be called.
01:02:51.840 | And here lambda flatten simply calls this function to be called, which removes that one
01:02:58.400 | comma one axis at the end after the adaptive average pooling.
01:03:03.020 | So now we've got a CNN model, we can grab our callback functions and our optimizer and
01:03:10.120 | our runner and we can run it, and six seconds later we get back one epochs result.
01:03:18.880 | So that's at this point now getting a bit slow, so let's make it faster.
01:03:24.300 | So let's use CUDA, let's pop it on the GPU.
01:03:27.940 | So we need to do two things, we need to put the model on the GPU, which specifically means
01:03:33.800 | the model's parameters on the GPU.
01:03:38.840 | So remember a model contains two kinds of numbers, parameters, they're the things that
01:03:42.240 | you're updating, the things that it stores, and there's the activations, there's the things
01:03:46.720 | that it's calculating.
01:03:48.000 | So it's the parameters that we need to actually put on the GPU.
01:03:51.900 | And the inputs to the model and the loss function, so in other words the things that come out
01:03:56.240 | of the data loader we need to put those on the GPU.
01:03:58.960 | How do we do that?
01:03:59.960 | With a callback of course.
01:04:02.800 | So here's a CUDA callback, when you initialize it you pass it a device and then when you
01:04:09.240 | begin fitting you move the model to that device.
01:04:13.920 | So model.2.2 is part of PyTorch, it moves something with parameters or a tensor to a
01:04:22.760 | device and you can create a device by calling torch.device, pass it the string CUDA, and
01:04:29.240 | whatever GPU number you want to use, if you only have one GPU, it's device0.
01:04:35.600 | Then when we begin a batch, let's go back and look at our runner.
01:04:43.160 | When we begin a batch, we've put xbatch and ybatch inside self.xb and self.yb.
01:04:52.080 | So that means we can change them.
01:04:54.520 | So let's set the runner's xb and the runner's yb to whatever they were before, but move
01:05:00.400 | to the device.
01:05:03.580 | So that's it.
01:05:05.120 | That's going to run everything on CUDA.
01:05:07.400 | That's all we need.
01:05:09.360 | This is kind of flexible because we can put things on any device we want.
01:05:14.000 | Maybe more easily is just to call this once, which is torch.cuda.setdevice, and you don't
01:05:19.560 | even need to do this if you've only got one GPU.
01:05:22.200 | And then everything by default will now be sent to that device.
01:05:25.280 | And then instead of saying .2device, we can just say .cuda.
01:05:29.800 | And so since we're doing pretty much everything with just one GPU for this course, this is
01:05:33.880 | the one we're going to export.
01:05:35.720 | So just model.cuda, xb.cuda, yb.cuda.
01:05:40.580 | So that's our CUDA callback.
01:05:42.240 | So let's add that to our callback functions, grab our model and our runner and fit, and
01:05:47.480 | now we can do three epochs in five seconds versus one epoch in six seconds.
01:05:52.000 | So that's a lot better.
01:05:53.600 | And for a much deeper model, it'll be dozens of times faster.
01:05:57.920 | So this is literally all we need to use CUDA.
01:06:01.360 | So that was nice and easy.
01:06:04.520 | Now we want to make it easier to create different kinds of architectures, make things a bit
01:06:09.280 | easier.
01:06:10.280 | So the first thing we should do is recognize that we go conv.value a lot.
01:06:14.800 | So let's pop that into a function called conv2d that just goes conv.value.
01:06:19.220 | Since we use a kernel size of three and a stride of two in this MNIST model a lot, let's make
01:06:23.560 | those defaults.
01:06:26.080 | Also this model we can't reuse for anything except MNIST because it has a MNIST resize
01:06:34.840 | at the start.
01:06:36.100 | So we need to remove that.
01:06:37.960 | So if we're going to remove that, something else is going to have to do the resizing.
01:06:41.880 | And of course the answer to that is a callback.
01:06:45.600 | So here's a callback which transforms the independent variable, the x, for a batch.
01:06:55.000 | And so you pass it some transformation function, which it stores away.
01:07:00.000 | And then begin batch simply replaces the batch with the result of that transformation function.
01:07:04.960 | So now we can simply append another callback, which is the partial function application
01:07:09.840 | of that callback with this function.
01:07:16.160 | And this function is just to view something at one by 28 by 28.
01:07:23.040 | And you can see here we've used the trick we saw earlier of using underscore inner to
01:07:29.520 | define a function and then return it.
01:07:31.560 | So this is something which creates a new view function that views it in this size.
01:07:37.620 | So for those of you that aren't that comfortable with closures and partial function application,
01:07:44.200 | this is a great piece of code to study, experiment, make sure that you feel comfortable with it.
01:07:52.160 | So using this approach, we now have the MNIST view resizing as a callback, which means we
01:08:00.600 | can remove it from the model.
01:08:02.580 | So now we can create a generic getCNN model function that returns a sequential model containing
01:08:10.160 | some arbitrary set of layers, containing some arbitrary set of filters.
01:08:16.500 | So we're going to say, OK, this is the number of filters I have per layer, 8, 16, 32, 32.
01:08:24.200 | And so here is my getCNN layers.
01:08:28.080 | And the last few layers is the average pooling, flattening, and the linear layer.
01:08:34.680 | The first few layers is for every one of those filters, length of the filters.
01:08:38.800 | It's a conv2D from that filter to the next one.
01:08:43.800 | And then what's the kernel size?
01:08:48.440 | The kernel size depends.
01:08:51.160 | It's a kernel size of 5 for the first layer, or 3 otherwise.
01:08:58.600 | Why is that?
01:09:00.800 | Well, the number of filters we had for the first layer was 8.
01:09:09.480 | And that's a pretty reasonable starting point for a small model to start with 8 filters.
01:09:13.680 | And remember, our image had a single channel.
01:09:18.560 | And imagine if we had a single channel, and we were using 3 by 3 filters.
01:09:24.040 | So as that convolution kernel scrolls through the image, at each point in time, it's looking
01:09:29.760 | at a 3 by 3 window.
01:09:33.280 | And it's just one channel.
01:09:37.240 | So in total, there's 9 input activations that it's looking at.
01:09:45.100 | And then it spits those into a dot product with-- sorry, it spits those into 8 dot products,
01:09:55.480 | so a matrix multiplication, I should say, 8 by 9.
01:10:04.840 | And out of that will come a vector of length 8.
01:10:13.360 | Because we said we wanted 8 filters.
01:10:15.800 | So that's what a convolution does.
01:10:19.480 | And this seems pretty pointless.
01:10:21.680 | Because we started with 9 numbers, and we ended it with 8 numbers.
01:10:26.800 | So all we're really doing is just reordering them.
01:10:29.980 | It's not really doing any useful computation.
01:10:32.560 | So there's no point making your first layer basically just shuffle the numbers into a
01:10:39.760 | different order.
01:10:42.120 | So what you'll find happens in, for example, most ImageNet models.
01:10:46.760 | Most ImageNet models are a little bit different, because they have three channels.
01:10:50.640 | So it's actually 3 by 3 by 3, which is 27.
01:10:59.080 | But it's still kind of like-- quite often with ImageNet models, the first layer will
01:11:04.240 | be like 32 channels.
01:11:07.240 | So going from 27 to 32 is literally losing information.
01:11:11.320 | So most ImageNet models, they actually make the first layer 7 by 7, not 3 by 3.
01:11:19.560 | And so for a similar reason, we're going to make our first layer 5 by 5.
01:11:24.400 | So we'll have 25 inputs for our 8 outputs.
01:11:28.840 | So this is the kind of things that you want to be thinking about when you're designing
01:11:33.400 | or reviewing an architecture, is like how many numbers are actually going into that
01:11:37.760 | little dot product that happens inside your CNN kernel.
01:11:44.600 | So that's something which can give us a CNN model.
01:11:48.560 | So let's pop it all together into a little function that just grabs an optimization function,
01:11:52.760 | grabs an optimizer, grabs a learner, grabs a runner.
01:11:56.560 | And at this point, if you can't remember what any of these things does, remember, we've
01:11:59.840 | built them all by hand from scratch.
01:12:02.160 | So go back and see what we wrote.
01:12:05.400 | There's no magic here.
01:12:08.400 | And so let's look if we say getCNNModel, passing in 8, 16, 32, 32.
01:12:14.680 | Here you can see 8, 16, 32, 32.
01:12:18.240 | Here's our 5 by 5, the rest are 3 by 3.
01:12:21.120 | They all have a stride of 2, and then a linear layer, and then train.
01:12:26.200 | So at this point, we've got a fairly general simple CNN creator that we can fit.
01:12:33.800 | And so let's try to find out what's going on inside.
01:12:37.120 | How do we make this number higher?
01:12:39.200 | How do we make it train more stably?
01:12:41.720 | How do we make it train more quickly?
01:12:43.800 | Well, we really want to see what's going on inside.
01:12:46.880 | We know already that different ways of initializing changes the variance of different layers.
01:12:53.600 | How do we find out if it's saturating somewhere, if it's too small, if it's too big, what's
01:12:58.680 | going on?
01:13:00.580 | So what if we replace nn.Sequential with our own sequential model class?
01:13:07.960 | And if you remember back, we've already built our own sequential model class before, and
01:13:11.800 | it just had these two lines of code, plus return.
01:13:14.860 | So let's keep the same two lines of code, but also add two more lines of code that grabs
01:13:19.880 | the mean of the outputs and the standard deviation of the outputs, and saves them away inside
01:13:27.000 | a bunch of lists.
01:13:28.920 | So here's a list for every layer, for means, and a list for every layer for standard deviations.
01:13:37.280 | So let's calculate the mean of standard deviations, pop them inside those two lists.
01:13:42.580 | And so now it's a sequential model that also keeps track of what's going on, the telemetry
01:13:48.160 | of the model.
01:13:50.100 | So we can now create it in the same way as usual, fit it in the same way as usual, but
01:13:54.080 | now our model has two extra things in it.
01:13:56.060 | It has an act means, and it acts standard deviations.
01:13:59.900 | So let's plot the act means for every one of those lists that we had.
01:14:06.840 | And here it is, right?
01:14:08.080 | Here's all of the different means.
01:14:10.920 | And you can see it looks absolutely awful.
01:14:15.520 | What happens early in training is every layer, the means get exponentially bigger until they
01:14:23.520 | suddenly collapse, and then it happens again, and it suddenly collapses, and it happens
01:14:28.260 | again, and it suddenly collapses until eventually, it kind of starts training.
01:14:36.440 | So you might think, well, it's eventually training, so isn't this okay?
01:14:44.320 | But my concern would be this thing where it kind of falls off a cliff.
01:14:50.040 | There's lots of parameters in our model.
01:14:54.520 | Are we sure that all of them are getting back into reasonable places, or is it just that
01:14:59.480 | a few of them have got back into a reasonable place?
01:15:01.960 | Maybe the vast majority of them have zero gradients at this point.
01:15:05.840 | I don't know.
01:15:07.060 | It seems very likely that this awful training profile early in training is leaving our model
01:15:14.000 | in a really sad state.
01:15:16.120 | That's my guess, and we're going to check it to see later.
01:15:18.520 | But for now, we're just going to say, let's try to make this not happen.
01:15:23.680 | And let's also look at the standard deviations, and you see exactly the same thing.
01:15:29.060 | This just looks really bad.
01:15:32.440 | So let's look at just the first 10 means, and they all look okay.
01:15:37.640 | They're all pretty close-ish to zero, which is about what we want.
01:15:41.840 | But more importantly, let's look at the standard deviations for the first 10 batches, and this
01:15:46.560 | is a problem.
01:15:47.960 | The first layer has a standard deviation not too far away from one, but then, not surprisingly,
01:15:56.680 | the next layer is lower.
01:15:59.160 | The next layer is lower.
01:16:00.760 | As we would expect, because the first layer is less than one, the following layers are
01:16:05.240 | getting exponentially further away from one, until the last layer is really close to zero.
01:16:14.600 | So now we can kind of see what was going on here, is that our final layers were getting
01:16:21.720 | basically no activations, they were basically getting no gradients.
01:16:25.240 | So gradually, it was moving into spaces where they actually at least had some gradient.
01:16:30.340 | But by the time they kind of got there, the gradient was so fast that they were kind of
01:16:36.760 | falling off a cliff and having to start again.
01:16:40.760 | So this is the thing we're going to try and fix.
01:16:43.360 | And we think we already know how to, we can use some initialization.
01:16:46.800 | Yes, Rachel?
01:16:51.120 | Did you say that if we went from 27 numbers to 32, that we were losing information?
01:16:55.280 | And could you say more about what that means?
01:16:58.320 | Yeah, I guess we're not losing information where, that was poorly said, where we're
01:17:07.280 | wasting information, I guess, where like that, if you start with 27 numbers and you do some
01:17:12.640 | matrix multiplication and end up with 32 numbers, you are now taking more space for the same
01:17:20.400 | information you started with.
01:17:22.280 | And the whole point of a neural network layer is to pull out some interesting features.
01:17:28.360 | So you would expect to have less total activations going on because you're trying to say, oh,
01:17:37.480 | in this area, I've kind of pulled this set of pixels down into something that says how
01:17:43.920 | fairy this is or how much of a diagonal line does this have or whatever.
01:17:48.800 | So increasing the actual number of activations we have for a particular position is a total
01:18:00.200 | waste of time.
01:18:01.200 | We're not doing any useful, we're wasting a lot of calculation.
01:18:09.400 | We can talk more about that on the forum if that's still not clear.
01:18:19.900 | So this idea of creating telemetry for your model is really vital.
01:18:24.560 | This approach to doing it, where you actually write a whole new class that only can do one
01:18:28.960 | kind of telemetry is clearly stupid.
01:18:33.100 | And so we clearly need a better way to do it.
01:18:35.160 | And what's the better way to do it?
01:18:37.040 | It's callbacks, of course.
01:18:39.680 | Except we can't use our callbacks because we don't have a callback that says when you
01:18:44.040 | calculate this layer, call back to our code.
01:18:48.080 | We have no way to do that, right?
01:18:50.620 | So we actually need to use a feature inside PyTorch that can call back into our code when
01:18:58.880 | a layer is calculated, either the forward pass or the backward pass.
01:19:03.280 | And for reasons I can't begin to imagine PyTorch doesn't call them callbacks, they're called
01:19:08.440 | hooks, right?
01:19:09.440 | But it's the same thing.
01:19:10.520 | It's a callback, okay?
01:19:12.120 | And so we can say for any module, we can say register forward hook and pass in a function.
01:19:21.680 | This is a callback.
01:19:22.680 | It's a callback that will be called when this module's forward pass is calculated.
01:19:28.960 | Or you could say register backward hook and that will call this function when this module's
01:19:34.680 | backward pass is calculated.
01:19:37.600 | So to replace that previous thing with hooks, we can simply create a couple of global variables
01:19:44.200 | to store our means and standard deviations for every layer.
01:19:47.480 | We can create a function to call back to to calculate the mean and standard deviation.
01:19:53.960 | And if you Google for the documentation for register forward hook, you will find that
01:20:00.120 | it will tell you that the callback will be called with three things.
01:20:06.040 | The module that's doing the callback, the input to the module, and the output of that
01:20:11.680 | module, either the forward or the backward pass is appropriate.
01:20:15.680 | In our case, it's the output we want.
01:20:18.120 | And then we've got a fourth thing here because this is the layer number we're looking at,
01:20:22.280 | and we used partial to connect the appropriate closure with each layer.
01:20:28.880 | So once we've done that, we can call fit, and we can do exactly the same thing.
01:20:33.080 | So this is the same thing, just much more convenient.
01:20:36.280 | And because this is such a handy thing to be able to do, fast.ai has a hook class.
01:20:41.600 | So we can create our own hook class now, which allows us to, rather than having this kind
01:20:46.560 | of messy global state, we can instead put the state inside the hook.
01:20:52.720 | So let's create a class called hook that when you initialize it, it registers a forward
01:20:56.560 | hook on some function.
01:21:01.320 | And what it's going to do is it's going to recall back to this object.
01:21:07.480 | So we pass in self with the partial.
01:21:10.220 | And so that way, we can get access to the hook.
01:21:13.660 | We can pop inside it our two empty lists when we first call this to store away our means
01:21:20.160 | and standard deviations.
01:21:22.040 | And then we can just append our means and standard deviations.
01:21:24.360 | So now, we just go hooks equals hook for layer in children of my model.
01:21:32.680 | And we'll just grab the first two layers because I don't care so much about the linear layers.
01:21:36.640 | It's really the conf layers that are most interesting.
01:21:39.340 | And so now, this does exactly the same thing.
01:21:44.380 | Since we do this a lot, let's put that into a class two called hooks.
01:21:51.000 | So here's our hooks class, which simply calls hook for every module in some list of modules.
01:22:01.480 | Now something to notice is that when you're done using a hooked module, you should call
01:22:08.440 | hook.remove.
01:22:11.200 | Because otherwise, if you keep registering more hooks on the same module, they're all
01:22:15.640 | going to get called.
01:22:16.640 | And eventually, you're going to run out of memory.
01:22:19.040 | So one thing I did in our hook class was I created a dunder dell.
01:22:23.720 | This is called automatically when Python cleans up some memory.
01:22:28.400 | So when it's done with your hook, it will automatically call remove, which in turn will
01:22:35.080 | remove the hook.
01:22:37.080 | So I then have a similar thing in hooks.
01:22:41.620 | So when hooks is done, it calls self.remove, which in turn goes through every one of my
01:22:50.000 | registered hooks and removes them.
01:22:57.640 | You'll see that somehow I'm able to go for H in self, but I haven't registered any kind
01:23:06.880 | of iterator here.
01:23:08.360 | And the trick is I've created something called a list container just above, which is super
01:23:13.880 | handy.
01:23:14.880 | It basically defines all the things you would expect to see in a list using all of the various
01:23:22.600 | special dunder methods and then some.
01:23:25.880 | It actually has some of the behavior of numpy as well.
01:23:30.000 | We're not allowed to use numpy in our foundations, so we use this instead.
01:23:33.440 | And this actually also works a bit better than numpy for this stuff because numpy does
01:23:37.200 | some weird casting and weird edge cases.
01:23:41.200 | So for example, with this list container, it's got dunder get item.
01:23:45.960 | So that's the thing that gets called when you call something with square brackets.
01:23:51.000 | So if you index into it with an int, then we just pass it off to the enclosed list because
01:23:59.680 | we gave it a list to enclose.
01:24:02.220 | If you send it a list of balls, like false, false, false, false, false, false, false,
01:24:10.200 | then it will return all of the things where that's true, or you can index it into it with
01:24:16.760 | a list, in which case it will return all of the index, the things that are indexed by
01:24:22.360 | that list.
01:24:23.360 | For instance, it's got a length which just passes off to length and an iterator that
01:24:26.760 | passes off to iterator and so forth.
01:24:31.400 | And then we've also defined the representation for it such that if you print it out, it just
01:24:36.960 | prints out the contents unless there's more than 10 things in it, in which case it shows
01:24:41.480 | dot, dot, dot.
01:24:43.680 | So with a nice little base class like this, so you can create really useful little base
01:24:49.160 | classes in much less than a screen full of code.
01:24:53.120 | And then we can use them, and we will use them everywhere from now on.
01:24:56.280 | So now we've created our own listy class that has hooks in it.
01:25:04.160 | And so now we can just use it like this.
01:25:05.480 | We can just say hooks equals hooks, everything in our model with that function we had before,
01:25:11.960 | to pen stats, we can print it out to see all the hooks.
01:25:15.920 | We can grab a batch of data.
01:25:19.400 | So now we've got one batch of data.
01:25:22.040 | And check its mean and standard deviation is about zero one, as you would expect.
01:25:26.060 | We can pass it through the first layer of our model.
01:25:29.000 | Model zero is the first layer of our model, which is the first convolution.
01:25:32.880 | And our mean is not quite zero, and our standard deviation is quite a lot less than one, as
01:25:38.240 | we kind of know what's going to happen.
01:25:41.080 | So now we'll just go ahead and initialize it with timing.
01:25:44.280 | And after that, our variance is quite close to one.
01:25:48.120 | And our mean, as expected, is quite close to 0.5 because of the value.
01:25:57.280 | So now we can go ahead and create our hooks and do a fit.
01:26:03.440 | And we can plot the first 10 means and standard deviations, and then we can plot all the means
01:26:08.080 | and standard deviations, and there it all is.
01:26:11.800 | And this time we're doing it after we've initialized all the layers of our model.
01:26:18.000 | And as you can see, we don't have that awful exponential crash, exponential crash, exponential
01:26:24.520 | crash.
01:26:25.520 | So this is looking much better.
01:26:27.440 | And you can see early on in training, our variances all look, our standard deviations
01:26:32.400 | all look much closer to one.
01:26:34.800 | So this is looking super hopeful.
01:26:38.360 | I've used a with block.
01:26:40.440 | A with block is something that will create this object, give it this name, and when it's
01:26:48.040 | finished, it will do something.
01:26:51.040 | The something it does is to call your dunder exit method here, which will not remove.
01:26:59.080 | So here's a nice way to ensure that things are cleaned up.
01:27:02.680 | For example, your hooks are removed.
01:27:06.880 | So that's why we have a dunder enter.
01:27:08.520 | That's what happens when you start the with block, dunder exit when you finish the with
01:27:11.800 | block.
01:27:15.520 | So this is looking very hopeful, but it's not quite what we wanted to know.
01:27:20.720 | Really the concern was, does this actually do something bad?
01:27:26.120 | Is it actually, or does it just train fine afterwards?
01:27:29.520 | So something bad really is more about how many of the activations are really, really
01:27:35.080 | small.
01:27:36.080 | How well is it actually getting everything activated nicely?
01:27:40.280 | So what we could do is we could adjust our append stats.
01:27:43.400 | So not only does it have a mean and a standard deviation, but it's also got a histogram.
01:27:49.600 | So we could create a histogram of the activations, pop them into 40 bins between 0 and 10.
01:27:58.040 | We don't need to go underneath 0 because we have a value.
01:28:00.840 | So we know that there's none underneath 0.
01:28:03.600 | So let's again run this.
01:28:07.040 | We will use our timing initialization.
01:28:14.040 | And what we find is that even with that, if we make our learning rate really high, 0.9,
01:28:26.640 | we can still get this same behavior.
01:28:28.600 | And so here's plotting the entire histogram.
01:28:30.560 | And I should say thank you to Stefano for the original code here from our San Francisco
01:28:35.840 | study group to plot these nicely.
01:28:38.760 | So you can see this kind of grow, collapse, grow, collapse, grow, collapse thing.
01:28:45.520 | The biggest concern for me though is this yellow line at the bottom.
01:28:51.000 | The yellow line, yellow is where most of the histogram is.
01:28:54.640 | Actually, what I really care about is how much yellow is there.
01:28:59.440 | So let's say the first two histogram bins are 0 or nearly 0.
01:29:07.180 | So let's get the sum of how much is in those two bins and divide by the sum of all of the
01:29:12.480 | bins.
01:29:13.600 | And so that's going to tell us what percentage of the activations are 0 or nearly 0.
01:29:21.800 | Let's plot that for each of the first four layers.
01:29:25.320 | And you can see that in the last layer, it's just as we suspected, over 90% of the activations
01:29:33.800 | are actually 0.
01:29:35.560 | So if you were training your model like this, it could eventually look like it's training
01:29:40.120 | nicely without you realizing that 90% of your activations were totally wasted.
01:29:45.760 | And so you're never going to get great results by wasting 90% of your activations.
01:29:50.600 | So let's try and fix it.
01:29:52.800 | Let's try and be able to train at a nice high learning rate and not have this happen.
01:29:57.340 | And so the trick is, is we're going to try a few things, but the main one is we're going
01:30:01.160 | to use our better ReLU.
01:30:04.840 | And so we've created a generalized ReLU class where now we can pass in things like an amount
01:30:11.200 | to subtract from the ReLU because remember we thought subtracting half from the ReLU
01:30:14.920 | might be a good idea.
01:30:16.280 | We can also use leaky ReLU and maybe things that are too big are also a problem.
01:30:20.840 | So let's also optionally have a maximum value.
01:30:23.960 | So in this generalized ReLU, if you passed a leakiness, then we'll use leaky ReLU.
01:30:30.560 | Otherwise we'll use normal ReLU.
01:30:32.560 | You could very easily write these leaky ReLU by hand, but I'm just trying to make it run
01:30:38.680 | a little faster by taking advantage of PyTorch.
01:30:41.560 | If you said I want to subtract something from it, then go ahead and subtract that from it.
01:30:46.620 | If I said there's some maximum value, go ahead and clamp it at that maximum value.
01:30:51.440 | So here's our generalized ReLU.
01:30:54.040 | And so now let's have our conv layer and getCNN layers both take a * * quags and just pass
01:31:00.800 | them on through so that eventually they end up passed to our generalized ReLU.
01:31:07.360 | And so that way we're going to be able to create a CNN and say what ReLU characteristics
01:31:12.600 | do we want nice and easily.
01:31:16.280 | And even getCNN model will pass down quags as well.
01:31:21.440 | So now that our ReLU can go negative because it's leaky and because it's subtracting stuff,
01:31:26.120 | we'll need to change our histogram so it goes from -7 to 7 rather than from 0 to 10.
01:31:32.900 | So we'll also need to change our definition of getMin so that the middle few bits of the
01:31:44.720 | histogram are 0 rather than the first two.
01:31:49.060 | And now we can just go ahead and train this model just like before and plot just like
01:31:52.840 | before.
01:31:55.320 | And this is looking pretty hopeful.
01:31:58.940 | Let's keep looking at the rest.
01:32:00.280 | So here's the first one, two, three, four layers.
01:32:02.880 | So compared to that, which was expand, does die, expand, die, expand, die, we're now seeing
01:32:13.880 | this is looking much better.
01:32:14.880 | It's straight away.
01:32:16.200 | It's using the full richness of the possible activations.
01:32:19.520 | There's no death going on.
01:32:21.080 | But our real question is how much is in this yellow line?
01:32:25.880 | There's a question.
01:32:28.800 | And let's see, in the final layer, look at that, less than 20%.
01:32:35.560 | So we're now using nearly all of our activations by being careful about our initialization
01:32:44.440 | and our ReLU, but we're still training at a nice high learning rate.
01:32:50.720 | So this is looking great.
01:32:54.360 | Could you explain again how to read the histograms?
01:32:58.680 | Sure.
01:33:00.120 | So the four histogram, let's go back to the earlier one.
01:33:03.180 | So the four histograms are simply the four layers.
01:33:05.440 | So after the first, second, third, fourth.
01:33:10.920 | And the x-axis is the iteration.
01:33:16.100 | So each one is just one more iteration as most of our plots show.
01:33:21.240 | The y-axis is how many activations are the highest they can be or the lowest they can
01:33:33.700 | So what this one here is showing us, for example, is that there are some activations that are
01:33:40.120 | at the max and some activations are in the middle and some activations at the bottom,
01:33:44.680 | whereas this one here is showing us that all of the activations are basically zero.
01:33:51.920 | So what this shows us in this histogram is that now we're going all the way from plus
01:33:58.800 | seven to minus seven because we can have negatives.
01:34:01.800 | This is zero.
01:34:02.800 | It's showing us that most of them are zero because yellow is the most energy.
01:34:11.160 | There are activations throughout everything from the bottom to the top.
01:34:16.440 | And a few less than zero, as we would expect, because we have a leaky value and we also
01:34:21.960 | have that minus, we're not doing minus 0.5, we're doing minus 0.4, because leaky value
01:34:27.840 | means that we don't need to subtract half anymore, we subtract a bit less than half.
01:34:34.560 | And so then this line is telling us what percentage of them are zero or nearly zero.
01:34:45.520 | And so this is one of those things which is good to run lots of experiments in the notebook
01:34:50.680 | yourself to get a sense of what's actually in these histograms.
01:34:54.640 | So you can just go ahead and have a look at each hook's stats.
01:34:58.720 | And the third thing in it will be the histograms, so you can see what shape is it and how is
01:35:03.320 | it calculated and so forth.
01:35:07.880 | So now that we've done that, this is looking really good.
01:35:12.760 | So what actually happens if we train like this?
01:35:16.440 | So let's do a one cycle training.
01:35:20.320 | So use that combined sheds we built last week, 50/50 two phases, cosine scheduling, cosine
01:35:27.560 | annealing.
01:35:28.560 | So gradual warm up, gradual cool down, and then run it for eight epochs.
01:35:33.940 | And there we go, we're doing really well, we're getting up to 98%.
01:35:37.760 | So this kind of, we hardly were really training in a thing, we were just trying to get something
01:35:43.080 | that looked good.
01:35:46.680 | And once we had something that looked good in terms of the telemetry, it's really training
01:35:51.320 | really well.
01:35:54.320 | One option I added by the way in at CNN was I added a uniform Boolean, which will set
01:36:02.920 | the initialization function to chiming normal, if it's false, which is what we've been using
01:36:09.200 | so far, or chiming uniform, if it's true.
01:36:13.160 | Timing uniform, so now I've just trained the same model with uniform equals true.
01:36:19.000 | A lot of people think that uniform is better than normal, because a uniform random number
01:36:25.640 | is less often close to zero.
01:36:27.960 | And so the thinking is that maybe uniform random, uniform initialization might cause
01:36:34.520 | it to kind of have a better richness of activations.
01:36:37.800 | I haven't studied this closely, I'm not sure I've seen a careful analysis in a paper.
01:36:43.160 | In this case, 9822 versus 9826, they're looking pretty similar, but that's just something
01:36:48.800 | else that it's there to play with.
01:36:51.380 | So at this point, we've got a pretty nice bunch of things you can look at now, and so
01:36:57.800 | you can see as your problem to play with during the week is how accurate can you make a model?
01:37:06.640 | Just using the layers we've created so far.
01:37:10.320 | And for the ones that are great accuracy, what does the telemetry look like?
01:37:14.720 | How can you tell whether it's going to be good?
01:37:17.040 | And then what insights can you gain from that to make it even better?
01:37:20.300 | So in the end, try to beat me, try to beat 98%.
01:37:26.200 | You'll find you can beat it pretty easily with some playing around, but do some experiments.
01:37:33.320 | All right, so that's kind of about what we can do with initialization.
01:37:47.480 | You can go further, as we discussed with Selu or with Fixup, like there are these really
01:37:54.880 | finely tuned initialization methods that you can do 1,000 layers deep, but they're super
01:37:59.440 | fiddly.
01:38:00.820 | So generally, I would use something like the layer-wise sequential unit variance, LSUV
01:38:07.760 | thing that we saw earlier in ... Oh, sorry, we haven't done that one yet.
01:38:15.120 | Okay, we're going to do that next.
01:38:17.520 | Okay, so, forget I said that.
01:38:22.080 | So that's kind of about as far as we can get with basic initialization.
01:38:29.560 | To go further, we really need to use normalization, of which the most commonly known approach
01:38:36.000 | to normalization in the model is batch normalization.
01:38:41.040 | So let's look at batch normalization.
01:38:43.280 | So batch normalization has been around since, I think, about 2005.
01:38:50.520 | This is the paper.
01:38:53.080 | And they first of all describe a bit about why they thought batch normalization was a
01:38:57.320 | good idea, and by about page 3, they provide the algorithm.
01:39:05.560 | So it's one of those things that if you don't read a lot of math, it might look a bit scary.
01:39:10.760 | But then when you look at it for a little bit longer, you suddenly notice that this
01:39:14.760 | is literally just the mean, sum divided by the count.
01:39:18.840 | And this is the mean of the difference to the mean squared, and it's the mean of that.
01:39:26.760 | Oh, that's just what we looked at, that's variance.
01:39:29.440 | And this is just subtract the mean, divide by the standard deviation.
01:39:32.720 | Oh, that's just normalization.
01:39:34.720 | So once you look at it a second time, you realize we've done all this.
01:39:37.360 | We've just done it with code, not with math.
01:39:41.520 | And so then, the only thing they do is after they've normalized it in the usual way, is
01:39:46.360 | that they then multiply it by gamma, and they add beta.
01:39:50.400 | What are gamma and beta?
01:39:51.960 | They are parameters to be learned.
01:39:55.320 | What does that mean?
01:39:56.320 | That's the most important line here.
01:39:59.280 | Remember that there are two types of numbers in a neural network, parameters and activations.
01:40:06.200 | Activations are things we calculate, parameters are things we learn.
01:40:09.960 | So these are just numbers that we learn.
01:40:13.640 | So that's all the information we need to implement batch norm.
01:40:17.300 | So let's go ahead and do it.
01:40:19.360 | So first of all, we'll grab our data as before, create our callbacks as before.
01:40:24.200 | Here's our pre-batch norm version, 96.5%.
01:40:29.480 | And the highest I could get was a 0.4 learning rate this way.
01:40:39.680 | And so now let's try batch norm.
01:40:40.920 | So here's batch norm.
01:40:43.160 | So let's look at the forward first.
01:40:47.240 | We're going to get the mean and the variance.
01:40:51.400 | And the way we do that is we call update stats, and the mean is just the mean.
01:40:57.880 | And the variance is just the variance.
01:41:00.960 | And then we subtract the mean, and we divide by the square root of the variance.
01:41:06.320 | And then we multiply by, and then I didn't call them gamma and beta, because why use
01:41:11.320 | Greek letters when, because who remembers which one's gamma and which one's beta?
01:41:14.660 | Let's use English.
01:41:15.660 | The thing we multiply, we'll call the malts, and the things we add, we'll call the ads.
01:41:21.280 | And so malts and ads are parameters.
01:41:26.080 | We multiply by a parameter that initially is just a bunch of ones, so it does nothing.
01:41:31.420 | And we add a parameter which is initially just a bunch of zeros, so it does nothing.
01:41:37.000 | But they're parameters, so they can learn.
01:41:39.400 | Just like our, remember our original linear layer we created by hand just looked like
01:41:44.480 | this.
01:41:45.480 | In fact, if you think about it, ads is just bias.
01:41:49.760 | It's identical to the bias we created earlier.
01:41:54.600 | So then there's a few extra little things we have to think about.
01:41:58.120 | One is what happens at inference time, right?
01:42:02.040 | So during training, we normalize.
01:42:05.360 | But the problem is that if we normalize in the same way at inference time, if we get
01:42:09.320 | like a totally different kind of image, we might kind of remove all of the things that
01:42:17.560 | are interesting about it.
01:42:19.320 | So what we do is while we're training, we keep an exponentially weighted moving average
01:42:27.080 | of the means and the variances.
01:42:29.360 | I'll talk more about what that means in a moment.
01:42:31.400 | But basically we've got a running average of the last few batches means and a running
01:42:36.560 | average of the last few batches variances.
01:42:39.120 | And so then when we're not training, in other words at inference time, we don't use the
01:42:44.800 | mean and variance of this mini-batch, we use that running average mean and variance that
01:42:50.600 | we've been keeping track of.
01:42:55.120 | So how do we calculate that running average?
01:42:58.480 | Well, we don't just create something called self.vars.
01:43:04.520 | We go self.register buffer vars.
01:43:07.400 | Now that creates something called self.vars.
01:43:11.960 | So why didn't we just say self.vars=torch.ones?
01:43:15.960 | Why do we say self.register buffer?
01:43:18.840 | It's almost exactly the same as saying self.vars=torch.ones, but it does a couple of nice things.
01:43:25.840 | The first is that if we move the model to the GPU, anything that's registered as a buffer
01:43:32.560 | will be moved to the GPU as well.
01:43:35.400 | And if we didn't do that, then it's going to try and do this calculation down here.
01:43:39.680 | And if the vars and means aren't on the GPU, but everything else is on the GPU, we'll get
01:43:43.840 | an error.
01:43:44.840 | It'll say, "Oh, you're trying to add this thing on the CPU to this thing on the GPU,
01:43:48.000 | and it'll fail."
01:43:49.280 | So that's one nice thing about register buffer.
01:43:51.380 | The other nice thing is that the variances and the means, these running averages, they're
01:43:56.400 | part of the model, right?
01:43:59.080 | When we do inference, in order to calculate our predictions, we actually need to know
01:44:04.600 | what those numbers are.
01:44:06.000 | So if we save the model, we have to save those variances and means.
01:44:11.240 | So register buffer also causes them to be saved along with everything else in the model.
01:44:17.040 | So that's what register buffer does.
01:44:19.600 | So the variances, we start them out at ones, the means we start them out at zeros.
01:44:25.160 | We then calculate the mean and variance of the minibatch, and we average out the axes
01:44:31.960 | zero, two, and three. So in other words, we average over all the batches, and we average
01:44:37.080 | over all of the x and y coordinates.
01:44:41.940 | So all we're left with is a mean for each channel, or a mean for each filter.
01:44:50.880 | Keepgame equals true means that it's going to leave an empty unit access in positions
01:44:57.640 | zero, two, and three, so it'll still broadcast nicely.
01:45:02.560 | So now, we want to take a running average.
01:45:08.140 | So normally, if we want to take a moving average, if we've got a bunch of data points, we want
01:45:16.680 | a moving average.
01:45:17.680 | We would grab five at a time, and we would take the average of those five, and they would
01:45:24.880 | take the next five, and we'd take their average, and we keep doing that a few at a time.
01:45:32.200 | We don't want to do that here, though, because these batch norm statistics, every single
01:45:39.040 | activation has one.
01:45:40.920 | So it's giant.
01:45:43.600 | Models can have hundreds of millions of activations.
01:45:47.800 | We don't want to have to save a whole history of every single one of those, just so that
01:45:52.160 | we can calculate an average.
01:45:54.840 | So there's a handy trick for this, which is instead to use an exponentially weighted moving
01:46:01.380 | average.
01:46:02.380 | And basically, what we do is we start out with this first point, and we say, okay, our
01:46:09.920 | first average is just the first point.
01:46:12.280 | So let's say, I don't know, that's three.
01:46:16.080 | And then the second point is five.
01:46:19.160 | And what we do is to take an exponentially weighted moving average, we first of all need
01:46:24.320 | some number, which we call momentum, let's say it's 0.9.
01:46:30.200 | So for the second value, so for the first value, our exponentially weighted moving average,
01:46:35.680 | which we'll call mu, equals three.
01:46:38.840 | And then for the second one, we take mu one, we multiply it by our momentum, and then we
01:46:50.480 | add our second value, five, and we multiply it by one minus our momentum.
01:46:59.320 | So in other words, it's mainly whatever it used to be before, plus a little bit of the
01:47:04.040 | new thing.
01:47:05.780 | And then mu two, sorry, mu three equals mu two times 0.9 plus, and maybe this one here
01:47:15.440 | is four, the new one times 0.1.
01:47:21.120 | So we're basically continuing to say, oh, it's mainly the thing before plus a little
01:47:26.500 | bit of the new one.
01:47:28.000 | And so what you end up with is something where, like by the time we get to here, the amount
01:47:36.460 | of influence of each of the previous data points, once you calculate it out, it turns
01:47:43.340 | out to be exponentially decayed.
01:47:46.060 | So it's a moving average with an exponential decay, with the benefit that we only ever
01:47:51.040 | have to keep track of one value.
01:47:55.660 | So that's what an exponentially weighted moving average is.
01:48:02.140 | This thing we do here, where we basically say we've got some function where we say it's
01:48:08.980 | some previous value times 0.9, say, plus some other value times one minus that thing.
01:48:25.320 | This is called a linear interpolation.
01:48:28.080 | It's a bit of this and a bit of this other thing, and the two together make one.
01:48:32.760 | Linear interpolation in PyTorch is spelt lerp.
01:48:38.120 | So we take the means, and then we lerp with our new mean using this amount of momentum.
01:48:49.320 | Unfortunately, lerp uses the exact opposite of the normal sense of momentum.
01:48:58.080 | So momentum of 0.1 in batch norm actually means momentum of 0.9 in normal person speak.
01:49:07.440 | So this is actually how nn.batchnorm works as well.
01:49:14.560 | So batch norm momentum is the opposite of what you would expect.
01:49:18.920 | I wish they'd given it a different name.
01:49:20.680 | They didn't, sadly.
01:49:21.840 | So this is what we're stuck with.
01:49:24.120 | So this is the running average means instead of deviations.
01:49:28.280 | So now we can go ahead and use that.
01:49:32.200 | So now we can create a new conv layer, which you can optionally say whether you want batch
01:49:36.480 | norm.
01:49:37.640 | If you do, we append a batch norm layer.
01:49:40.760 | If we do append a batch norm layer, we remove the bias layer.
01:49:44.400 | Because remember I said that the ads in batch norm just is a bias.
01:49:49.160 | So there's no point having a bias layer anymore.
01:49:53.160 | So we'll remove the unnecessary bias layer.
01:49:56.360 | And so now we can go ahead and initialize our CNN.
01:50:01.240 | This is a slightly more convenient initialization now that's actually going to go in and recursively
01:50:08.480 | initialize every module inside our module, the weights in the standard deviations.
01:50:20.680 | And then we will train it with our hooks.
01:50:25.280 | And you can see our mean starts at 0 exactly.
01:50:29.400 | And our standard deviation starts at 1 exactly.
01:50:35.400 | So our training has entirely gotten rid of all of the exponential growth and sudden crash
01:50:44.200 | stuff that we had before.
01:50:46.840 | There's something interesting going on at the very end of training, which I don't quite
01:50:51.040 | know what that is.
01:50:52.040 | I mean, when I say the end of training, we've only done one epoch.
01:50:56.240 | But this is looking a lot better than anything we've seen before.
01:51:01.680 | I mean, that's just a very nice-looking curve.
01:51:04.720 | And so we're now able to get up to learning rates up to 1.
01:51:12.680 | We've got 97% accuracy after just three epochs.
01:51:17.320 | This is looking very encouraging.
01:51:19.720 | So now that we've built our own batch norm, we're allowed to use PyTorch's batch norm.
01:51:24.840 | And we get pretty much the same results.
01:51:27.200 | Sometimes it's 97, sometimes it's 98.
01:51:28.880 | This is just random variation.
01:51:32.320 | So now that we've got that, let's try going crazy.
01:51:34.800 | Let's try using our little one-cycle learning scheduler we had.
01:51:39.880 | And let's try and go all the way up to a learning rate of 2.
01:51:47.760 | And look at that.
01:51:48.760 | We totally can, right?
01:51:50.560 | And we're now up towards nearly 99% accuracy.
01:51:55.080 | So batch norm really is quite fantastic.
01:52:00.360 | Batch norm has a bit of a problem, though, which is that you can't apply it to what we
01:52:08.520 | call online learning tasks.
01:52:10.360 | In other words, if you have a batch size of 1, right, so you're getting a single item
01:52:16.400 | at a time and learning from that item, what's the variance of that batch?
01:52:21.760 | The variance of a batch of 1 is infinite, right?
01:52:25.780 | So we can't use batch norm in that case.
01:52:28.280 | Well, what if we're doing like a segmentation task where we can only have a batch size of
01:52:32.800 | 2 or 4, which we've seen plenty of times in part 1?
01:52:36.840 | That's going to be a problem, right?
01:52:37.960 | Because across all of our layers, across all of our training, across all of the channels,
01:52:42.480 | the batch size of 2, at some point, those two values are going to be the same or nearly
01:52:48.760 | the same.
01:52:49.840 | And so we then divide by that variance, which is about 0.
01:52:53.480 | We have infinity, right?
01:52:55.800 | So we have this problem where any time you have a small batch size, you're going to get
01:53:01.680 | unstable or impossible training.
01:53:04.880 | It's also going to be really hard for RNNs.
01:53:07.720 | Because for RNNs, remember, it looks something like this, right?
01:53:12.440 | We have this hidden state, and we use the same weight matrix again and again and again.
01:53:18.360 | Right?
01:53:19.360 | Remember, we can unroll it, and it looks like this.
01:53:23.760 | If you've forgotten, go back to lesson 7.
01:53:26.720 | And then we can even stack them together into two RNNs.
01:53:30.040 | One RNN fits to another RNN.
01:53:31.920 | And if we unroll that, it looks like this.
01:53:35.360 | And remember, these state, you know, time step to time step transitions, if we're doing IMDB
01:53:42.400 | with a movie review with 2,000 words, there's 2,000 of these.
01:53:48.160 | And this is the same weight matrix each time, and the number of these circles will vary.
01:53:53.560 | It's the number of time steps will vary from document to document.
01:53:57.280 | So how would you do batch norm, right?
01:54:00.200 | How would you say what's the running average of means and variances?
01:54:07.440 | Because you can't put a different one between each of these unrolled layers, because, like,
01:54:12.360 | this is a for loop, remember?
01:54:13.820 | So we can't have different values every time.
01:54:17.240 | So it's not at all clear how you would insert batch norm into an RNN.
01:54:22.880 | So batch norm has these two deficiencies.
01:54:25.680 | How do we handle very small batch sizes all the way down to a batch size of one?
01:54:29.260 | How do we handle RNNs?
01:54:32.520 | So this paper called layer normalization suggests a solution to this.
01:54:40.920 | And the layer normalization paper from Jimmy Barr and Kyros and Jeffrey Hinton, who just
01:54:49.880 | won the Turing Award with Yoshua, Benjio, and Yann LeCun, which is kind of the Nobel
01:54:57.240 | prize of computer science, they created this paper, which, like many papers, when you read
01:55:03.800 | it, it looks reasonably terrifying, particularly once you start looking at all this stuff.
01:55:11.920 | But actually, when we take this paper and we convert it to code, it's this.
01:55:18.440 | Now, which is not to say the paper's garbage, it's just that the paper has lots of explanation
01:55:24.080 | about what's going on and what do we find out and what does that mean.
01:55:27.840 | But the actual, what's layer norm?
01:55:29.760 | It's the same as batch norm, but rather than saying x dot means 0, 2, 3, you say x dot
01:55:35.680 | mean 1, 2, 3, and you remove all the running averages.
01:55:41.620 | So this is layer norm with none of that running average stuff.
01:55:48.680 | And the reason we don't need the running averages anymore is because we're not taking the mean
01:55:54.040 | across all the items in the batch.
01:55:57.460 | Every image has its own mean, every image has its own standard deviation.
01:56:02.520 | So there's no concept of having to average across things in a batch.
01:56:07.720 | And so that's all layer norm is.
01:56:12.240 | We also average over the channels.
01:56:14.580 | So we average over the channels and the x and the y for each image individually.
01:56:20.080 | So we don't have to keep track of any running averages.
01:56:23.280 | The problem is that when we do that and we train, even at a lower learning rate of 0.8,
01:56:28.680 | it doesn't work.
01:56:30.760 | Layer norm's not as good.
01:56:33.560 | So it's a workaround we can use, but because we don't have the running averages at inference
01:56:40.000 | time and more importantly, because we don't have a different normalization for each channel,
01:56:48.640 | we're just throwing them all together and pretending they're the same and they're not.
01:56:52.800 | So layer norm helps, but it's nowhere near as good as batch norm.
01:56:58.000 | But for RNNs, what you have to use is something like this.
01:57:04.360 | So here's a thought experiment.
01:57:06.880 | What if you're using layer norm on the actual input data and you're trying to distinguish
01:57:13.640 | between foggy days and sunny days?
01:57:16.920 | So foggy days will have less activations on average because they're less bright and they
01:57:25.080 | will have less contrast.
01:57:26.880 | In other words, they have lower variance.
01:57:29.320 | So layer norm would cause the variances to be normalized to be the same and the means
01:57:36.800 | to be normalized to be the same.
01:57:39.580 | So now the sunny day picture and the hazy day picture would have the same overall kind
01:57:46.560 | of activations and amount of contrast.
01:57:48.560 | And so the answer to this question is, no, you couldn't.
01:57:52.160 | With layer norm, you would literally not be able to tell the difference between pictures
01:57:55.520 | of sunny days and pictures of foggy days.
01:57:57.840 | Now, it's not only if you put the layer norm on the input data, which you wouldn't do,
01:58:03.400 | but everywhere in the middle layers, it's the same, right?
01:58:07.200 | Anywhere where the overall level of activation or the amount of difference of activation
01:58:11.580 | is something that is part of what you care about, it throws it away.
01:58:17.320 | It's designed to throw it away.
01:58:19.480 | Furthermore, if your inference time is using things from kind of a different distribution
01:58:24.820 | where that different distribution is important, it throws that away.
01:58:28.960 | So layer norm's a partial hacky workaround for some very genuine problems.
01:58:35.480 | There's also something called instance norm, and instance norm is basically the same thing
01:58:42.960 | as layer norm.
01:58:43.960 | It's a bit easier to read in the paper because they actually lay out all the indexes.
01:58:48.140 | So a particular output for a particular batch for a particular channel for a particular
01:58:51.960 | x for a particular y is equal to the input for that batch and channel in x and y minus
01:58:57.240 | the mean for the batch and the channel.
01:59:01.960 | So in other words, it's the same as layer norm, but now it's mean 2,3 rather than mean
01:59:06.800 | 1,2,3.
01:59:07.800 | So you can see how all these different papers, when you turn them into code, they're tiny
01:59:12.600 | variations, right?
01:59:16.400 | Instance norm, even at a learning rate of 0.1, doesn't learn anything at all.
01:59:21.280 | Why can't it classify anything?
01:59:24.280 | Because we're now taking the mean, removing the difference in means and the difference
01:59:29.960 | in activations for every channel and for every image, which means we've literally thrown
01:59:35.920 | away all the things that allow us to classify.
01:59:38.840 | Does that mean that instance norm is stupid?
01:59:41.040 | No, certainly not.
01:59:42.680 | It wasn't designed for classification.
01:59:45.800 | It was designed for style transfer, where the authors guessed that these differences
01:59:52.360 | in contrast and overall amount were not important, or something they should remove from trying
01:59:58.600 | to create things that looked like different types of pictures.
02:00:02.280 | It turned out to work really well.
02:00:03.640 | But you've got to be careful, right?
02:00:05.440 | You can't just go in and say, "Oh, here's another normalization thing, I'll try it."
02:00:09.200 | You've got to actually know what it's for to know whether it's going to work.
02:00:13.400 | So then finally, there's a paper called Group Norm, which has this wonderful picture, and
02:00:18.600 | it shows the differences.
02:00:20.080 | Batch Norm is averaging over the batch, and the height, and the width, and is different
02:00:31.040 | for each channel.
02:00:33.320 | Layer Norm is averaging for each channel, for each height, for each width, and is different
02:00:38.600 | for each element of the batch.
02:00:40.960 | Instance Norm is averaging over height and width, and is different for each channel and
02:00:49.400 | each batch.
02:00:51.000 | And then Group Norm is the same as Instance Norm, but they arbitrarily group a few channels
02:00:56.560 | together and do that.
02:00:59.320 | So Group Norm is a more general way to do it.
02:01:05.280 | In the PyTorch docs, they point out that you can actually turn Group Norm into Instance
02:01:09.580 | Norm, or Group Norm into Layer Norm, depending on how you group things up.
02:01:17.920 | So there's all kinds of attempts to work around the problem that we can't use small batch
02:01:26.080 | sizes, and we can't use RNNs with Batch Norm.
02:01:32.360 | But none of them are as good as Batch Norm.
02:01:35.920 | So what do we do?
02:01:37.360 | Well, I don't know how to fix the RNN problem, but I think I know how to fix the batch size
02:01:43.880 | problem.
02:01:46.320 | So let's start by taking a look at the batch size problem in practice.
02:01:49.240 | Let's create a new data bunch with a batch size of 2.
02:01:55.120 | And so here's our Conf layer, as before, with our Batch Norm.
02:01:59.580 | And let's use a learning rate of 0.4 and fit that.
02:02:04.960 | And the first thing you'll notice is that it takes a long time.
02:02:09.720 | Small batch sizes take a long time, because it's just lots and lots of kernel launches
02:02:14.640 | on the GPU, it's just a lot of overhead.
02:02:17.400 | Something like this might even run faster on the CPU.
02:02:20.760 | And then you'll notice that it's only 26% accurate, which is awful.
02:02:25.920 | Why is it awful?
02:02:27.520 | Because of what I said, the small batch size is causing a huge problem.
02:02:33.660 | Because quite often, there's one channel in one layer where the variance is really small,
02:02:39.840 | because those two numbers just happen to be really close, and so it blows out the activations
02:02:43.800 | out to a billion, and everything falls apart.
02:02:49.400 | There is one thing we could try to do to fix this really easily, which is to use Epsilon.
02:02:57.120 | What's Epsilon?
02:02:58.520 | Let's go take a look at our code.
02:03:04.840 | Here's our Batch Norm.
02:03:06.480 | Look, we don't divide by the square root of variance.
02:03:12.240 | We divide by the square root of variance plus Epsilon, where Epsilon is 1e neg 5.
02:03:19.660 | Epsilon's a number that computer scientists and mathematicians, they use this Greek letter
02:03:25.600 | very frequently to mean some very small number.
02:03:29.560 | And in computer science, it's normally a small number that you add to avoid floating point
02:03:36.400 | rounding problems and stuff like that.
02:03:39.100 | So it's very common to see it on the bottom of a division to avoid dividing by such small
02:03:46.320 | numbers that you can't calculate things in floating point properly.
02:03:51.880 | But our view is that Epsilon is actually a fantastic hyperparameter that you should be
02:03:59.700 | using to train things better.
02:04:02.280 | And here's a great example.
02:04:03.720 | With Batch Norm, what if we didn't set Epsilon to 1e neg 5?
02:04:08.160 | But what if we set it to 0.1?
02:04:11.340 | If we set Epsilon to 0.1, then that basically would cause this to never make the overall
02:04:20.240 | activations be multiplied by anything more than 10.
02:04:23.760 | Sorry, that would be 0.01 because we're taking the square root.
02:04:27.800 | So if you set it to 0.01, let's say the variance was 0, it would be 0 plus 0.01 square root.
02:04:37.400 | So it ends up dividing by 0.1, which ends up multiplying by 10.
02:04:41.300 | So even in the worst case, it's not going to blow out.
02:04:45.280 | I mean, it's still not great because there actually are huge differences in variance
02:04:52.520 | between different channels and different layers, but at least this would cause it to not fall
02:04:56.240 | apart.
02:04:57.240 | So option number one would be use a much higher Epsilon value.
02:05:02.920 | And we'll keep coming back to this idea that Epsilon appears in lots of places in deep learning
02:05:06.640 | and we should use it as a hyper parameter we control and take advantage of.
02:05:15.040 | But we have a better idea.
02:05:17.900 | We think we have a better idea, which is we've built a new algorithm called running batch
02:05:22.720 | norm.
02:05:24.000 | And running batch norm, I think, is the first true solution to the small batch size batch
02:05:33.780 | norm problem.
02:05:36.120 | And like everything we do at fast AI, it's ridiculously simple.
02:05:40.240 | And I don't know why no one's done it before.
02:05:42.300 | Maybe they have and I've missed it.
02:05:45.160 | And the ridiculously simple thing is this.
02:05:49.500 | In the forward function for running batch norm, don't divide by the batch standard deviation.
02:05:59.320 | Don't subtract the batch mean, but instead use the moving average statistics at training
02:06:05.840 | time as well.
02:06:08.920 | Not just at inference time.
02:06:12.920 | Why does this help?
02:06:14.840 | Because let's say you're using a batch size of two.
02:06:20.200 | Then from time to time, in this particular layer, in this particular channel, you happen
02:06:24.760 | to get two values that are really close together and they have a variance really close to zero.
02:06:30.000 | But that's fine because you're only taking point one of that and point nine of whatever
02:06:35.240 | you had before.
02:06:36.240 | Like that's how running averages work.
02:06:38.720 | So if previously the variance was one, now it's not 1e neg five, it's just point nine.
02:06:47.400 | So in this way, as long as you don't get really unlucky and have the very first batch be dreadful,
02:06:56.440 | because you're using this moving average, you never have this problem.
02:07:01.960 | So let's take a look.
02:07:02.960 | We'll look at the code in a moment, but let's do the same thing, 0.4.
02:07:07.720 | We're going to use our running batch norm.
02:07:10.600 | We train it for one epoch, and instead of 26% accuracy, it's 91% accuracy.
02:07:19.380 | So it totally nails it.
02:07:22.680 | In one epoch, just a two batch size and a pretty high learning rate.
02:07:28.920 | There's quite a few details we have to get right to make this work.
02:07:36.380 | But they're all details that we're going to see in lots of other places in this course.
02:07:41.080 | We're just kind of seeing them here for the first time.
02:07:44.720 | So I'm going to show you all of the details, but don't get overwhelmed.
02:07:47.520 | We'll keep coming back to them.
02:07:50.760 | The first detail is something very simple, which is in normal batch norm, we take the
02:07:57.600 | running average of variance, but you can't take the running average of variance.
02:08:03.140 | It doesn't make sense to take the running average of variance.
02:08:05.400 | It's a variance.
02:08:06.800 | You can't just average a bunch of variances, particularly because they might even be different
02:08:12.760 | batch sizes, because batch size isn't necessarily constant.
02:08:16.600 | Instead, as we learned earlier in the class, the way that we want to calculate variance
02:08:25.600 | is like this, sum of expected value of mean of X squared minus mean of X squared.
02:08:33.960 | So let's do that.
02:08:34.960 | Let's just, as I mentioned, we can do, let's keep track of the squares and the sums.
02:08:42.920 | So we register a buffer called sums and we register a buffer called squares and we just
02:08:50.000 | go X dot sum over 023 dimensions and X times X dot sum, so squared.
02:09:01.040 | And then we'll take the lerp, the exponentially weighted moving average of the sums and the
02:09:07.080 | squares.
02:09:08.080 | And then for the variance, we will do squares divided by count minus squared mean.
02:09:16.840 | So it's that formula.
02:09:23.260 | So that's detail number one that we have to be careful of.
02:09:29.400 | Detail number two is that the batch size could vary from many batch to many batch.
02:09:36.080 | So we should also register a buffer for count and take an exponentially weighted moving
02:09:42.400 | average of the counts, of the batch sizes.
02:09:47.180 | So that basically tells us, so what do we need to divide by each time?
02:09:53.000 | The amount we need to divide by each time is the total number of elements in the mini-batch
02:09:59.920 | divided by the number of channels.
02:10:01.960 | That's basically grid X times grid Y times batch size.
02:10:06.940 | So let's take an exponentially weighted moving average of the count and then that's what
02:10:12.840 | we will divide by for both our means and variances.
02:10:18.200 | That's detail number two.
02:10:22.040 | Detail number three is that we need to do something called debiasing.
02:10:27.760 | So debiasing is this.
02:10:32.280 | We want to make sure that at every point, and we're going to look at this in more detail
02:10:38.400 | when we look at optimizers, we want to make sure that every point that no observation
02:10:44.740 | is weighted too highly.
02:10:47.160 | And the problem is that the normal way of doing moving averages, the very first point
02:10:54.020 | gets far too much weight because it appears in the first moving average and the second
02:10:58.020 | and the third and the fourth.
02:11:00.380 | So there's a really simple way to fix this, which is that you initialize both sums and
02:11:05.440 | squares to zeros and then you do a lerp in the usual way and let's see what happens when
02:11:17.060 | we do this.
02:11:21.080 | So let's say our values are 10 and then 20.
02:11:32.180 | These are the first two values we get.
02:11:35.000 | So actually we only need to look at the first value.
02:11:38.540 | So the value, so actually let's say the value is 10.
02:11:42.940 | So we initialize our mean to zero at the very start of training.
02:11:50.660 | And then the value that comes in is 10.
02:11:53.740 | So we would expect the moving average to be 10.
02:11:56.460 | But our lerp formula says it's equal to our previous value, which is 0, times 0.9 plus
02:12:05.820 | our new value times 0.1 equals 0 plus 1, equals 1, it's 10 times too small.
02:12:17.820 | So that's very easy to correct for because we know it's always going to be wrong by that
02:12:23.540 | amount.
02:12:24.620 | So we then divide it by 0.1 and that fixes it.
02:12:31.540 | And then the second value has exactly the same problem.
02:12:35.140 | It's got too much zero in it.
02:12:37.780 | But this time it's actually going to be divided by, let's not call it 0.1.
02:12:44.620 | Let's call it 1 minus 0.9.
02:12:48.240 | Because when you work through the math, you'll see the second one, it's going to be divided
02:12:51.620 | by 1 minus 0.9 squared and so forth.
02:12:59.980 | So this thing here where we divide by that, that's called debiasing.
02:13:03.620 | It's going to appear again when we look at optimization.
02:13:08.700 | So you can see what we do is we have a exponentially weighted debiasing amount where we simply
02:13:16.380 | keep multiplying momentum times the previous debiasing amount.
02:13:22.560 | So initially it's just equal to momentum and then momentum squared and then momentum cubed
02:13:28.220 | and so forth.
02:13:32.300 | So then we do what I just said, we divide by the debiasing amount.
02:13:42.540 | And then there's just one more thing we do, which is remember how I said you might get
02:13:48.060 | really unlucky that your first mini-batch is just really close to zero and we don't
02:13:53.460 | want that to destroy everything.
02:13:55.420 | So I just say if you haven't seen more than a total of 20 items yet, just clamp the variance
02:14:01.780 | to be no smaller than 0.01, just to avoid blowing out of the water.
02:14:08.080 | And then the last two lines are the same.
02:14:11.700 | So that's it, right?
02:14:15.140 | It's all pretty straightforward arithmetic.
02:14:18.940 | It's a very straightforward idea, but when we put it all together, it's shockingly effective.
02:14:29.120 | And so then we can try an interesting thought experiment, so here's another thing to try
02:14:33.060 | during the week.
02:14:35.460 | What's the best accuracy you can get in a single epoch?
02:14:41.300 | So say run.fit 1.
02:14:45.380 | And with this convolutional with running batch norm layer and a batch size of 32 and a linear
02:14:56.620 | schedule from one to 0.2, I got 97.5%.
02:15:01.380 | I only tried a couple of things, so this is definitely something that I hope you can beat
02:15:06.240 | me at.
02:15:07.500 | But it's really good to create interesting little games to play.
02:15:14.060 | In research, we call them toy problems.
02:15:16.980 | Almost everything in research is basically toy problems.
02:15:19.260 | Come up with toy problems and try to find good solutions to them.
02:15:23.540 | So another toy problem for this week is what's the best you can get using whatever kind of
02:15:31.300 | normalization you like, whatever kind of architecture you like, as long as it only uses concepts
02:15:36.980 | we've used up to lesson 7 to get the best accuracy you can in one epoch.
02:15:44.340 | So yeah, that's basically it.
02:15:50.140 | So what's the future of running batch norm?
02:15:53.100 | I mean, it's kind of early days.
02:15:56.620 | We haven't published this research yet.
02:15:58.620 | We haven't done all the kind of ablation studies and stuff we need to do yet.
02:16:02.940 | At this stage, though, I'm really excited about this.
02:16:05.020 | Every time I've tried it on something, it's been working really well.
02:16:10.220 | The last time that we had something in a lesson that we said, this is unpublished research
02:16:15.100 | that we're excited about, it turned into ULM fit, which is now a really widely used algorithm
02:16:22.380 | and was published at the ACL.
02:16:28.700 | So fingers crossed that this turns out to be something really terrific as well.
02:16:32.420 | But either way, you've kind of got to see the process, because literally building these
02:16:38.220 | notebooks was the process I used to create this algorithm.
02:16:41.500 | So you've seen the exact process that I used to build up this idea and do some initial
02:16:49.180 | testing of it.
02:16:50.380 | So hopefully that's been fun for you, and see you next week.
02:16:53.660 | (audience applauds)