back to index

Lesson 16: Deep Learning Foundations to Stable Diffusion


Chapters

0:0 The Learner
2:22 Basic Callback Learner
7:57 Exceptions
8:55 Train with Callbacks
12:15 Metrics class: accuracy and loss
14:54 Device Callback
17:28 Metrics Callback
24:41 Flexible Learner: @contextmanager
31:3 Flexible Learner: Train Callback
32:34 Flexible Learner: Progress Callback (fastprogress)
37:31 TrainingLearner subclass, adding momentum
43:37 Learning Rate Finder Callback
49:12 Learning Rate scheduler
53:56 Notebook 10
54:58 set_seed function
55:36 Fashion-MNIST Baseline
57:37 1 - Look inside the Model
62:50 PyThorch hooks
68:52 Hooks class / context managers
72:17 Dummy context manager, Dummy list
74:45 Colorful Dimension: histogram

Whisper Transcript | Transcript Only Page

00:00:00.000 | Hi there and welcome to lesson 16 where we are working on building our first
00:00:07.840 | flexible training framework the learner and I've got some very good news which
00:00:14.640 | is that I have thought of a way of doing it a little bit more gradually and
00:00:19.400 | simply actually than last time so that should that should make things a bit
00:00:25.680 | easier so we're going to take it a bit more step by step so we're working in
00:00:30.080 | the 09 learner notebook today and we've seen already this this basic callbacks
00:00:43.480 | learner and so the idea is that we've seen so far this learner which wasn't
00:00:56.040 | flexible at all but it had all the basic pieces which is we've got a fit method
00:01:02.160 | we hard coding that we can only calculate accuracy and average loss we're
00:01:08.320 | hard coding we're putting things on a default device hard coding a single
00:01:14.760 | learning rate but the basic idea is here we go through each epoch and call one
00:01:20.280 | epoch to train or evaluate depending on this flag and then we loop through each
00:01:27.680 | batch in the data loader and one batch is going to grab the X and Y parts of
00:01:33.640 | the batch call the model call the loss function and if we're training do the
00:01:39.880 | backward pass and then print out we'll calculate the statistics for our
00:01:47.880 | accuracy and then at the end of an epoch print that out so it wasn't very
00:01:52.320 | flexible but it did do something so that's good so what we're going to do
00:01:58.760 | now is we're going to do is an intermediate step we're going to look at
00:02:01.400 | a but I'm calling a basic callbacks learner and it actually has nearly all
00:02:06.040 | the functionality of the full thing the way we're going to after we look at this
00:02:09.600 | basic callbacks learner we're then going to after creating some callbacks and
00:02:15.840 | metrics we're going to look at something called the flexible learner so let's go
00:02:20.000 | step by step so the basic callbacks learner looks very similar to the
00:02:26.680 | previous learner it it's got a fit function which is going to go through
00:02:36.880 | each epoch calling one epoch with training on and then training off and
00:02:43.520 | then one epoch will go through each batch and call one batch and one batch
00:02:52.480 | will call the model the loss function and if we're training it will do the
00:02:58.680 | backward step so that's all pretty similar but there's a few more things
00:03:03.160 | going on here for example if we have a look at fit you'll see that after
00:03:10.040 | creating the optimizer so we call self dot optfunk so optfunk here defaults to
00:03:18.840 | SGD so we instantiate an SGD object passing in our models parameters and
00:03:24.440 | the requested learning rate and then before we start looping through one
00:03:29.560 | epoch at a time now we've set epochs here we first of all call self dot
00:03:36.240 | callback and passing in before fit now what does that do self dot callback is
00:03:43.040 | here and it takes a method names in this case it's before fit and it calls a
00:03:49.440 | function called run callbacks it passes in a list of our callbacks and the
00:03:54.240 | method name in this case before fit so run callbacks is something that's going
00:04:01.120 | to go for each callback and it's going to sort them in order of their order
00:04:09.080 | attribute and so there's a base class through our callbacks which has an order
00:04:13.600 | of zero so our callbacks all going to have the same order of zero and which
00:04:16.720 | you will ask otherwise so here's an example of a callback so before we look
00:04:24.160 | at how callbacks work let's just let's just run a callback so we can create a
00:04:29.760 | ridiculously simple callback called completion callback which before we start
00:04:34.580 | fitting a new model it will set its count attribute to zero after each batch it
00:04:40.000 | will increment that and after completing the fitting process it will print out
00:04:44.960 | how many batches we've done so before we even train a model we could just run
00:04:49.200 | manually before fit after batch and after fit using this run cbs and you can
00:05:00.400 | see it's ended up saying completed one batches so what did that do so it went
00:05:06.880 | through each of the cbs in this list there's only one so it's going to look at
00:05:12.560 | the one cb and it's going to try to use get atra to find an attribute with this
00:05:20.320 | name which is before fit so if we try that manually so this is the kind of
00:05:26.360 | thing I want you to do if you find anything difficult to understand is do
00:05:29.520 | it all manually so create a callback set it to cbs zero just like you're doing in
00:05:34.880 | a loop right and then find out what happens if we call this and pass in this
00:05:47.240 | and you'll see it's returned a method and then what happens to that method it
00:05:55.680 | gets called so let's try calling it there yeah so that's what happened when
00:06:03.080 | we call the before fit which doesn't do anything very interesting but if we then
00:06:06.400 | call after batch and then we call after fit there it is right so yeah make sure
00:06:19.480 | you don't just run code really nearly but understand it by experimenting with
00:06:26.720 | it and I don't always experiment with it myself in these classes often I'm
00:06:30.400 | leaving that to you but sometimes I'm trying to give you a sense of how I
00:06:33.840 | would experiment with code if I was learning it so then having done that I
00:06:37.520 | would then go ahead and delete those cells but you can see I'm using this
00:06:40.520 | interactive notebook environment to to explore and learn and understand and so
00:06:48.160 | now we've got and if I haven't created a simple example of something to make it
00:06:53.260 | really easy to understand you should do that right don't just use what I've
00:06:56.440 | already created or what somebody else has already created so we've now got
00:07:00.160 | something that works totally independently we can see how it works
00:07:02.720 | this is what a callback does so a callback is something which will look at
00:07:07.320 | a class a callback is a class where you can define one or more of before after
00:07:14.720 | fit before after batch and before after epoch so it's going to go through and
00:07:23.040 | run all the callbacks that have a before fit method before we start fitting then
00:07:30.200 | it'll go through each epoch and call one epoch with training and one epoch with
00:07:34.780 | evaluation and then when that's all done it will call after fit callbacks and one
00:07:43.560 | epoch will before it starts on enumerating through the batches it will
00:07:51.120 | call before epoch and when it's done it will call after epoch the other thing
00:07:57.400 | you'll notice is that there's a try except immediately before every before
00:08:04.020 | method and immediately after every after method there's a try and there's an
00:08:09.320 | accept and each one has a different thing to look for cancel fit exception
00:08:13.320 | cancel epoch exception and cancel batch exception so here's the bit which goes
00:08:19.280 | through each batch calls before batch processes the batch calls after batch
00:08:24.480 | and if there's an exception that's of type cancel batch exception it gets
00:08:30.000 | ignored so what's that for so the reason we have this is that any of our
00:08:37.240 | callbacks could call could raise any one of these three exceptions to say I don't
00:08:48.200 | want to do this batch please so maybe you'll look an example of that in a
00:08:53.760 | moment so we can now train with this so let's call create a little get model
00:09:01.280 | function that creates a sequential model with just some linear layers and then
00:09:07.440 | we'll call fit and it's not telling us anything interesting because the only
00:09:11.900 | callback we added was the completion callback that's fine it's it's training
00:09:18.840 | it's doing something and we now have a trained model just didn't print out any
00:09:22.600 | metrics or anything because we don't have any callbacks for that that's the
00:09:25.720 | basic idea so we could create a maybe we could call it a single batch callback
00:09:38.760 | which after batch after a single batch it raises a cancel cancel fit exception
00:09:51.720 | so that's a pretty I mean that could be kind of useful actually if you want to
00:09:56.760 | just run one battery a model to make sure it works so we could try that so
00:10:07.680 | now we're going to add to our list of callbacks the single batch callback
00:10:14.320 | let's try it and in fact you know we probably want this let's just have a
00:10:22.920 | think here oh that's fine let's run it there we go so it ran and nothing
00:10:31.720 | happened and the reason nothing happened is because this canceled before this ran
00:10:39.280 | so we could make this run second by setting its order to be higher and we
00:10:45.800 | could say just order equals 1 because the default order is 0 and we sort in
00:10:52.920 | order of the order attribute actually let's use cancel epoch exception there
00:11:04.680 | we go that way it'll run the final fit there we are so it did one batch for the
00:11:15.240 | it did one batch for the training and one batch for the evaluation so it's a
00:11:21.000 | total of two batches so remember callbacks are not a special magic part
00:11:28.080 | of like the Python language or anything it's just a name we used refer to these
00:11:33.000 | functions or classes or callables more accurately that we that we pass into
00:11:39.880 | something that will then call back to that callable at particular times and I
00:11:45.540 | think these are kind of interesting kinds of callbacks because these
00:11:48.120 | callbacks have multiple methods in them so is each method a callback is each
00:11:54.840 | class with all those methods of callback I don't know I tend to think of the
00:11:58.020 | class with all the methods in as a single callback I'm not sure if we have
00:12:02.400 | great nomenclature for this all right so let's actually try to get this doing
00:12:08.680 | something more interesting by not modifying the learner at all but just by
00:12:11.680 | adding callbacks because that's the great hope of callbacks right so it
00:12:17.700 | would be very nice if it told us the accuracy and the loss so to do that it
00:12:24.520 | would be great to have a class that can keep track of a metric so I've created
00:12:29.440 | here a metric class and maybe before we look at it we'll see how it works you
00:12:37.600 | could create for example an accuracy metric by defining the calculation
00:12:42.600 | necessary to calculate the accuracy metric which is the mean of how often do
00:12:49.200 | the input sequence targets and the idea is you could then create an accuracy
00:12:53.960 | metric object you could add a batch of inputs and targets and add another batch
00:13:00.000 | of inputs and targets and get the value and there you would get the point four
00:13:06.960 | five accuracy or another way you could do it would be just to create a metric
00:13:13.080 | which simply takes gets the weighted average for example of your loss so you
00:13:16.640 | could add point six as the loss with a batch size of 32 point nine as a loss in
00:13:22.280 | a batch size of two and then that's going to give us a weighted average loss
00:13:28.000 | of point six two which is equal to this weighted average calculation so that's
00:13:34.240 | like one way we could kind of make it easy to calculate metrics so here's the
00:13:38.880 | class basically we're going to keep track of all of the actual values that
00:13:44.940 | we're averaging and the number in each mini batch and so when you add a mini
00:13:49.960 | batch we call calculate which for example for accuracy remember this is
00:13:56.800 | going to override the parent classes calculate so it does the calculation
00:14:01.520 | here and then we'll add that to our list of values we will add to our list of
00:14:08.680 | batch sizes the current batch size and then when you calculate the value we
00:14:16.480 | will calculate the weighted sum sorry the weighted mean weighted average now
00:14:24.320 | notice that here value I didn't have to put parentheses after it and that's
00:14:28.640 | because it's a property I think we've seen this before so just remind you
00:14:32.040 | property just means you don't have to put parentheses after it to get it's to
00:14:36.760 | get the calculation to happen all right so just let me know if anybody's got any
00:14:42.440 | questions up to here of course so we now need some way to use this metric in a
00:14:52.300 | callback to actually print out the first thing I'm going to do those are going to
00:14:56.160 | create one more one useful metric first a very simple one just two lines of code
00:15:01.080 | called the device callback and that is something which is going to allow us to
00:15:05.600 | use CUDA or for the Apple GPU or whatever without the complications we
00:15:14.880 | had before of you know how do we have multiple processes and our data loader
00:15:19.760 | and also use our device and not have everything fall over so the way we could
00:15:23.960 | do it is we could say before fit put the model onto the default device and before
00:15:34.600 | each batch is run put that batch onto the device because look what happened in
00:15:43.000 | in the this is really really important in the learner absolutely everything is
00:15:48.680 | put inside self dot which means it's all modifiable so we go for self dot
00:15:54.280 | iteration number comma self dot the batch itself enumerating the data loader
00:15:59.400 | and then we call one batch but before it we call the callback so we can modify
00:16:05.800 | this now how does the callback get access to the learner well what actually
00:16:11.520 | happens is we go through each of our call backs and put set an attribute
00:16:16.920 | called learn equal to the learner and so that means in the callback itself we can
00:16:23.720 | say self dot learn dot model and actually we could make this a bit better
00:16:29.000 | I think so make it like maybe you don't want to use a default device so this is
00:16:32.800 | where I would be inclined to add a constructor and set device and we could
00:16:38.840 | default it to the default device of course and then we could use that
00:16:47.560 | instead and that would give us a bit more flexibility so if you wanted to
00:16:51.480 | trade on some different device then you could I think that might be a slight
00:16:59.560 | improvement okay so there's a callback we can use to put things included and we
00:17:04.640 | could check that it works by just quickly going back to our old learner
00:17:09.120 | here remove the single batch CB and replace it with device CB yep still
00:17:23.560 | works so that's a good sign okay so now let's do our metrics now of course we
00:17:33.680 | couldn't use metrics until we built them by hand the good news is we don't have to
00:17:38.920 | write every single metric now by hand because they already exist in a fairly
00:17:43.480 | new project called torch eval which is an official PyTorch project and so torch
00:17:49.360 | eval is something that gives us actually I came across it after I had created my
00:17:56.200 | own metric class but it actually looks pretty similar to the one that I built
00:18:01.280 | earlier so you can install it with PEP I'm not sure if it's on Conde yet but
00:18:08.280 | it probably will be soon by the time you see the video I think it's pure of
00:18:12.080 | Python anyway so it doesn't matter how you install it and yeah it has a pretty
00:18:17.680 | similar pretty similar approach where you call dot update and you call dot
00:18:27.720 | compute so they're slightly different names but they're basically super
00:18:30.320 | similar to the thing that we just built but there's a nice good list of metrics
00:18:36.280 | to pick from so because we've already built our own now that means we're
00:18:41.280 | allowed to use theirs so we can import the multi-class accuracy metric and the
00:18:47.220 | main metric and just to show you they look very very similar if we call
00:18:53.200 | multi-class accuracy and we can pass in a mini batch of inputs and targets and
00:18:58.880 | compute and that all works nicely now these in fact it's exactly the same as
00:19:05.000 | what I wrote we both added this thing called reset which basically well resets
00:19:10.840 | it and it's obviously we're going to be wanting to do that probably at the start
00:19:15.880 | of each epoch and so if you reset it and then try to compute you'll get NAN
00:19:22.800 | because you can't get accuracy accuracies meaningless when you don't
00:19:26.560 | have any data yet okay so let's create a metrics callback so we can print out our
00:19:33.000 | metrics I've got some ideas to improve this which maybe I do this week but
00:19:37.980 | here's a basic working version slightly hacky but it's not too bad so generally
00:19:45.560 | speaking one thing I noticed actually is I don't know if this is considered a
00:19:49.540 | bug but a lot of the metrics didn't seem to work correctly and torch eval when I
00:19:54.600 | had tensors that were on the GPU and had requires grad so I created a little to
00:20:07.120 | CPU function which I think is very useful and that's just going to detach
00:20:12.720 | the so detach takes the tensor and removes all the gradient history the
00:20:17.800 | computation history used to calculate a gradient and puts it on the CPU that'll
00:20:22.760 | do the same for dictionaries of tensors lists of tensors and tuples of tensors
00:20:28.600 | so our metrics callback basically here's how we're going to use it so let's run
00:20:34.480 | it so here we're creating a metrics callback object and saying we want to
00:20:40.800 | create a metric called accuracy that's what's going to print out and this is
00:20:46.560 | the metrics object we're going to use to calculate accuracy and so then we just
00:20:51.400 | pass that in as one of our callbacks and so you can see what it's going to do is
00:20:57.880 | it's going to print out the epoch number whether it's training or evaluating so
00:21:03.560 | training set or validation set and it'll print out our metrics and our current
00:21:10.780 | status actually we can simplify that we don't need to print those bits because
00:21:18.000 | it's all in the dictionary now let's do that there we go um so let's take a look
00:21:30.640 | at how this works so we are going to be creating with for the callback we're
00:21:37.760 | going to be passing in the names and object metric objects for the metrics to
00:21:44.360 | track and print so here it is here star star metrics so he's seen star star
00:21:51.400 | before and as a little shortcut I decided that it might be nice if you
00:21:57.720 | didn't want to write accuracy equals you could just remove that and run it and if
00:22:04.440 | you do that then it will give it a name and I'll just use the same name as the
00:22:08.000 | class and so that's why you can either pass in so star ms will be a tuple well
00:22:15.280 | I mean it's got to be pulled out so it's just passing a list of positional
00:22:18.800 | arguments which we turn into a tuple or you can pass in named arguments that'll
00:22:23.560 | be turned into a dictionary if you pass in positional arguments then I'm going
00:22:27.560 | to turn them into named arguments in the dictionary by just grabbing the name
00:22:32.200 | from their type so that's where this comes from that's all that's going on
00:22:36.800 | here just a little shortcut bit of convenience so we store that away and
00:22:42.600 | this is yeah this is a bit I think I can simplify a little bit but I'm just
00:22:46.440 | adding manually an additional metric which is I'm going to call the loss and
00:22:51.760 | that's just going to be the weighted the weighted average of the losses so
00:22:57.480 | before we start fitting we we're going to actually tell the learner that we are
00:23:04.480 | the metrics callback and so you'll see later where we're going to actually use
00:23:08.320 | this before each epoch we will reset all of our metrics after each epoch we will
00:23:17.480 | create a dictionary of the keys and values which are the actual strings that
00:23:22.560 | we want to print out and we will call log which for now we'll just print them
00:23:28.840 | and then after each batch this is the key thing we're going to actually grab
00:23:34.880 | the input and target we're going to put them on the CPU and then we're going to
00:23:43.840 | go through each of our metrics and call that update so remember the update in
00:23:51.520 | the metric is the thing that actually says here's a batch of data right so
00:23:58.320 | we're passing in the batch of data which is the predictions and the targets and
00:24:04.200 | then we'll do the same thing for our special loss metric passing in the
00:24:09.320 | actual loss and the size of any batch and so that's how we're able to get this
00:24:17.720 | yeah this actual running on the Nvidia GPU and showing our metrics and obviously
00:24:26.720 | there's a lot of room to improve how this is displayed but all the
00:24:30.080 | informations we needed here and it's just a case of changing that function
00:24:35.560 | okay so that's our kind of like intermediate complexity learner we can
00:24:42.040 | make it more sophisticated but it's still exactly it's still going to fit in
00:24:48.800 | a single screen of codes this is kind of my goal here was to keep everything in a
00:24:51.920 | single screen of code this first bit is exactly the same as before but you'll see
00:24:59.360 | that the one epoch and fit and batch has gone from let's see but what it was
00:25:08.000 | before it's gone from quite a lot of code all this to much less code and the
00:25:19.680 | trick to doing that is I decided to use a context manager we're going to learn
00:25:23.240 | more about context managers in the next notebook but basically I originally last
00:25:27.600 | week I was saying I was going to do this as a decorator but I realized a context
00:25:30.080 | manager is better basically what we're going to do is we're going to call our
00:25:33.680 | before and after callbacks in a try accept block and to say that we want to
00:25:43.960 | use the callbacks in the try accept block we're going to use a with statement so
00:25:49.760 | in Python a with statement says everything in that block call our
00:25:56.680 | context manager before and after it now there's a few ways to do that but one
00:26:02.200 | really easy one is using this context manager decorator and everything up and
00:26:08.000 | to the up to the yield statement is called before your code where it says
00:26:13.440 | yield it then calls your code and then everything after the yield is called
00:26:18.760 | after your code so in this case it's going to be try self dot callback before
00:26:24.520 | name where name is fit and then it will call for self dot epoch etc because
00:26:34.280 | that's where the yield is and then it'll call self dot callback after fit
00:26:39.920 | accept okay and now we need to grab the cancel fit exception so all of the
00:26:48.320 | variables that you have in Python all live inside a special dictionary called
00:26:52.880 | globals so this dictionary contains all of your variables so I can just look up
00:26:57.320 | in that dictionary the variable called cancel fit with a capital F exception
00:27:04.480 | so this is except cancel fit exception so this is exactly the same then as this
00:27:12.160 | code except the nice thing is now I only have to write it once rather than at
00:27:18.080 | least three times and I'm probably going to want more of them so you know I tend
00:27:21.880 | to think it's worth yeah I tend to think it's worth refactoring a code when you
00:27:29.320 | have duplicate code particularly here we had the same code three times so that's
00:27:34.440 | going to be more of a maintenance headache we're probably going to want to
00:27:36.960 | add callbacks to more things later so by putting it into a context manager just
00:27:43.200 | once I think we're going to reduce our maintenance burden I know we do because
00:27:47.800 | I've had a similar thing in fast AI for some years now and it's been quite
00:27:51.760 | convenient so that's what this context managers about yeah other than that the
00:28:02.080 | code's exactly the same so we create our optimizer and then with our callback
00:28:06.960 | context manager for fit go through each epoch call one epoch set it to training
00:28:14.320 | or non training mode based on the argument we pass in grab the training or
00:28:19.600 | validation set based on the argument we trust in and then using the context
00:28:24.120 | manager for epoch go through each batch in the data loader and then for each
00:28:30.360 | batch in the data loader using the batch context now this is where something gets
00:28:37.500 | quite interesting we call predict get lost and if we're training backward step
00:28:43.000 | and zero grad but previously we actually called self dot model etc self dot loss
00:28:50.880 | function etc so we go through each batch and call before batch do the batch oh
00:29:05.680 | they say that's our that's our slow version wait what are we doing oh yes
00:29:11.720 | we're gonna be over here okay I'm back where we are yes so previously we were
00:29:25.480 | calling yeah calling calling the model calling the loss function calling loss
00:29:29.880 | dot backward opt dot step opt dot zero grad but now we are calling instead
00:29:40.280 | self dot predict self dot get lost self dot backward and how on earth is that
00:29:45.280 | working because they are not defined here at all and so the reason I've
00:29:49.800 | decided to do this is it gives us a lot of flexibility we can now actually create
00:29:55.240 | our own way of doing predict get lost backward step and zero grad in different
00:30:01.280 | situations and we're going to see some of those situations so what happens if
00:30:09.400 | we call self dot predict and it doesn't exist well it doesn't necessarily cause
00:30:14.240 | an error what actually happens is it calls a special magic method in Python
00:30:19.840 | called dunder get atra as we've seen before and what I'm doing here is I'm
00:30:24.840 | saying okay well if it's one of these special five things don't raise an
00:30:29.680 | attribute error which is this is the default thing it does but instead create
00:30:36.800 | a call back or actually I should say call self dot call back passing in that
00:30:43.840 | name so it's actually going to call self dot call back quote predict and self dot
00:30:51.480 | call back is exactly the same as before and so what that means now is to make
00:30:55.200 | this work exactly the same as it did before I need a call back which does
00:30:59.260 | these five things and here it is I'm going to call it train call back so here
00:31:05.640 | are the five things predict get lost backwards step and zero grad so there
00:31:10.760 | are here predict get lost backwards step and zero grad okay so they're almost
00:31:23.760 | exactly the same as what they looked like in our intermediate learner except
00:31:27.560 | now I just need to have self dot learn in front of everything because we
00:31:31.320 | remember this is a callback it's not the learner and so for a callback the
00:31:35.000 | callback can access the learner using self dot learn so self dot learn dot
00:31:38.040 | preds is self dot learn dot model passing in self dot learn dot batch and
00:31:42.040 | just the independent variables ditto for the loss calls the loss function
00:31:48.180 | backward step zero grad so that's at this point this isn't doing anything that
00:31:58.160 | wasn't doing before but the nice thing is now if you want to use hugging face
00:32:03.400 | accelerate or you want something that works on hugging face data styles
00:32:07.320 | dictionary things or whatever you can actually change exactly how it behaves
00:32:14.400 | by just call passing by creating a callback for training and if you want
00:32:20.260 | everything except one thing to be the same you can inherit from train CB so
00:32:23.760 | this is I've I've not tried this before I haven't seen this done anywhere else
00:32:28.280 | so it's a bit of an experiment so I would sit here how you go with it and then
00:32:34.520 | finally I thought it'd be nice to have a progress bar so let's create a progress
00:32:38.720 | callback and the progress bar is going to show on it our current loss and going to
00:32:44.600 | put create a plot of it so I'm going to use a project that we created called
00:32:52.160 | fast progress mainly created by the wonderful Sylvain and basically fast
00:33:02.400 | progress is yeah very nice way to create a very flexible progress bars and so let
00:33:10.760 | me show you what it looks like first so let's get the model and train and as you
00:33:15.520 | can see it actually in real time updates the graph and everything there you go
00:33:21.240 | that's pretty cool so that's the that's the progress bar the metrics callback the
00:33:27.720 | device callback and the training callback all in action so before we fit we
00:33:33.720 | actually have to set self dot learn dot epochs now that might look a little bit
00:33:39.960 | weird but self dot learn dot epochs is the thing that we loop through for self
00:33:46.800 | dot epoch in so we can change that so it's not just a normal range but instead
00:33:52.280 | it is a progress bar around a range we can then check remember I told you that
00:34:01.080 | the learner is going to have the metrics attribute applied we can then say oh if
00:34:04.400 | the learner has a metrics attribute then let's replace the underscore log method
00:34:11.000 | there with ours and our one instead will write to the progress bar now this is
00:34:17.460 | pretty simple it looks very similar to before but we could easily replace this
00:34:21.000 | for example with something that creates an HTML table which is another thing
00:34:24.840 | fast progress does or other stuff like that so you can see we can modify the
00:34:29.520 | nice thing is we can modify how our metrics are displayed so that's a very
00:34:35.080 | powerful thing that Python lets us do is actually replace one piece of code with
00:34:38.960 | another and that's the whole purpose of why the metrics callback had this
00:34:46.760 | underscore log separately so why didn't I just say print here that's because this
00:34:51.800 | way classes can replace how the metrics are displayed so we could change that to
00:34:58.340 | like send them over to weights and biases for example or you know create
00:35:03.800 | visualizations or so forth so before epoch we do a very similar thing the
00:35:12.480 | self dot learn dot DL iterator we change it to have a progress bar wrapped around
00:35:17.400 | it and then after each bar we set the progress bars comment to be the to be the
00:35:26.240 | loss it's going to print just going to show the loss on the progress bar as it
00:35:29.320 | goes and if we've asked for a plot then we will append the losses to a list of
00:35:36.040 | losses and we will update the graph with the losses and the batch numbers so
00:35:51.200 | there we have it we have a yeah nice working learner which is I think the most
00:36:00.640 | flexible learner that training loop probably that's I hope has ever been
00:36:05.080 | written because I think the fast AI 2 one was the most flexible that had ever
00:36:08.880 | been written before and this is more flexible and the nice thing is you can
00:36:14.400 | make this your own you know you can you know fully understand this training loop
00:36:20.760 | so it's kind of like you can use a framework but it's a framework in which
00:36:25.720 | you're totally in control of it and you can make it work exactly how you want to
00:36:29.280 | ideally not by changing the change in the learner itself ideally by creating
00:36:33.880 | callbacks but if you want to you could certainly like look at that the whole
00:36:38.000 | learner fits on a single screen so you could certainly change that we haven't
00:36:44.320 | added inference yet although that shouldn't be too much to add I guess we
00:36:47.880 | have to do that at some point okay now interestingly I love this about Python
00:36:55.440 | it's so flexible when when we said self dot predict self dot get lost I said if
00:37:03.500 | they don't exist then it's going to use get atcha and it's going to try to find
00:37:08.440 | those in the callbacks and in fact you could have multiple callbacks that define
00:37:14.420 | these things and then they would chain them together which would be kind of
00:37:17.040 | interesting but there's another way we could make these exist which is which is
00:37:24.520 | that we could subclass this so let's not use train CB just to just to show us how
00:37:30.840 | this would work and instead we're going to use a subclass so here in a subclass
00:37:38.040 | learner and I'm going to override the five well it's not exactly overriding I
00:37:43.680 | didn't have any definition of them before so I'm going to define the five
00:37:46.200 | directly in the learner subclass so that way it's never going to end up going to
00:37:50.640 | get atcha because get atcha is only called if something doesn't exist so
00:38:02.160 | here it's basically all these five are exactly the same as in our train
00:38:07.440 | callback except we don't need self dot learn anymore we can just use self
00:38:10.560 | because we're now in the learner but I've changed zero grad to do something a
00:38:15.120 | bit crazy I'm not sure if this has been done before I haven't seen it but maybe
00:38:20.120 | it's an old trick that I just haven't come across but it occurred to me zero
00:38:24.680 | grad which remember is the thing that we call after we take the optimizer step
00:38:30.240 | doesn't actually have to zero the gradients at all what if instead of
00:38:35.240 | zeroing the gradients we multiplied them by some number like say 0.85 well what
00:38:47.360 | would that do well what it would do is it would mean that your previous
00:38:53.080 | gradients would still be there but they would be reduced a bit and remember what
00:39:00.680 | happens in PyTorch is PyTorch always adds the gradients to the existing
00:39:07.100 | gradients and that's why we normally have to call zero grad but if instead we
00:39:12.160 | multiply the gradients by some number I mean we should really make this a
00:39:15.600 | parameter let's do that shall we so let's create a parameter so probably the
00:39:25.840 | few ways we could do this well let's do it properly we've got a little bit of
00:39:34.960 | time so we could say well maybe just copy and paste all those over here and
00:39:51.000 | we'll add momentum momentum equals 0.85 self momentum equals momentum and then
00:40:08.200 | super so make sure you call the super classes passing in all the stuff we
00:40:18.480 | could use delegates for this and quags that would be possibly another great way
00:40:22.400 | of doing it but let's just do this for now okay and then so there we wouldn't
00:40:27.480 | make it 0.85 we would make it self dot momentum so you'll see now still trains
00:40:35.560 | but there's no train CB callback anymore in my list I don't need one because I
00:40:41.880 | have to find the five methods in the subclass now this training at the same
00:40:47.800 | learning rate for the same time the accuracy it improves by more let's run
00:40:54.100 | them all yeah this is a lot like gradient accumulation callback they're
00:41:02.080 | kind of cooler I think okay so the let's see the loss has gone from 0.8 to 0.55
00:41:16.480 | and the accuracy is gone from about 0.7 to about 0.8 so they've improved why is
00:41:24.400 | that well we're going to be learning a lot more about this pretty shortly but
00:41:32.220 | basically what's happening here but basically what's happening here is we
00:41:40.600 | have just implemented in a very interesting way which I haven't seen done
00:41:44.160 | before something called momentum and basically what momentum does is it say
00:41:49.200 | like imagine you are you know you're trying you've got some kind of complex
00:41:55.240 | contour lost surface right and you know so imagine these are hills with a marble
00:42:05.880 | very similar right and your marbles up here what would normally happen with
00:42:10.920 | gradient descent is it would go you know in the direction downhill which is this
00:42:16.960 | way so we'll go whoa over here and then whoa over here right very slow what
00:42:23.880 | momentum does is it's is the first steps the same and then the second step says oh
00:42:29.160 | I wanted to go this way but I'm going to add together the previous direction plus
00:42:35.280 | the new direction but reduce the previous direction a bit so that would
00:42:38.800 | actually make me end up about here and then the second one does the same thing
00:42:44.560 | and so momentum basically makes you much more quickly go to your destination so
00:42:52.600 | normally momentum is done it the reason I did it this way partly to show you is
00:42:57.100 | just a bit of fun a bit of interest but it's very it's very useful because
00:43:00.940 | normally momentum you have to store a complete copy basically of all the
00:43:05.800 | gradients the momentum version of the gradients so that you can kind of keep
00:43:10.080 | track of that that that running exponentially weighted moving average but
00:43:14.320 | using this trick you're actually using the dot grad themselves to store the
00:43:20.680 | exponentially weighted moving average so anyway there's a little bit of fun which
00:43:25.440 | hopefully particularly those of you who are interested in accelerated optimizers
00:43:31.360 | and memory saving might find a bit inspiring all right there's one more
00:43:40.220 | call back I'm going to show before the break which is the wonderful learning
00:43:46.200 | ratefinder I'm assuming that anybody who's watching this already is familiar
00:43:50.760 | with the learning ratefinder from fast AI if you're not there's lots of videos
00:43:55.760 | and tutorials around about it it's an idea that comes from a paper by Leslie
00:44:01.300 | Smith from a few years ago and the basic idea is that we will increase the
00:44:08.640 | learning rate I should have put titles on this the the x-axis here is learning
00:44:12.320 | rate the y-axis here is loss we increase the learning rate gradually over time
00:44:19.320 | and we plot the loss against the learning rate and we find how high can
00:44:24.940 | we bring the learning rate up before the loss starts getting worse you kind of
00:44:30.160 | want roughly where about the steepest slope is so probably here it would be
00:44:34.060 | about 0.1 so it'd be nice to create a learning ratefinder so here's a
00:44:41.200 | learning ratefinder callback so what a learning ratefinder needs to do well you
00:44:46.720 | have to tell it how much to multiply the learning rate by each batch let's say we
00:44:50.440 | add 30% to the learning rate each batch and so we'll store that so before we fit
00:44:55.920 | we obviously need to keep track of the learning rates and we need to keep track
00:45:00.040 | of the losses because those are the things that we put on a plot the other
00:45:05.800 | thing we have to do is decide when do we stop training so when is it clearly gone
00:45:10.120 | off the rails and I decided that if the loss is three times higher than the
00:45:16.200 | minimum loss we've seen then we should stop so we're going to keep track of
00:45:20.520 | their minimum loss and so let's just initially set that to infinity it's a
00:45:25.340 | nice big number well not quite a number but a number ish like thing so then
00:45:31.360 | after every batch first of all let's check that we're training okay if we're
00:45:35.200 | not training then we don't want to do anything we don't use the learning rate
00:45:39.480 | finder during validation so here's a really handy thing just raise cancel
00:45:43.680 | epoch exception and that stops it from doing that epoch entirely so just to see
00:45:49.220 | how that works you can see here one epoch does with the call back context
00:45:58.120 | manager epoch and that will say oh it's got cancelled so it goes straight to the
00:46:04.000 | accept which is going to go all the way to the end of that code and it's going
00:46:09.640 | to skip it so it's you can see that we're using exceptions as control
00:46:14.960 | structures which is actually a really powerful programming technique that is
00:46:24.000 | really underutilized in my opinion like a lot of things I do it's actually
00:46:29.280 | somewhat controversial some people think it's a bad idea but I find it actually
00:46:35.480 | makes my code more concise and more maintainable and more powerful so I like
00:46:40.920 | it so let's see yeah so that's we've got our cancel epoch exception so then we're
00:46:49.980 | just going to keep track of our learning rates the learning rates we're going to
00:46:54.200 | learn a lot more about optimizers shortly so I won't worry too much about
00:46:57.280 | this but basically the learning rates are stored by PyTorch inside the
00:47:01.020 | optimizer and they're actually stored in things called param groups parameter
00:47:04.400 | groups so don't worry too much about the details but we can grab the learning
00:47:08.240 | rate from that dictionary and we'll learn more about that shortly we've got
00:47:12.480 | to keep track of the loss appended to our list of losses and if it's less than
00:47:17.320 | the minimum we've seen then recorded as the minimum and if it's greater than a
00:47:22.000 | three times the minimum then look at this it's really cool cancel fit
00:47:25.240 | exception so this will stop everything in a very nice clean way no need for lots
00:47:35.520 | of returns and conditionals or and stuff like that just raise the cancel fit
00:47:39.760 | exception and yeah and then finally we've got to actually update our learning
00:47:47.920 | rate to 1.3 times the previous one and so basically the way you do it in PyTorch
00:47:54.080 | is you have to go through each parameter group and grab the learning rate in the
00:47:58.160 | dictionary and multiply it by lrmult so yeah you've already seen it run and we
00:48:05.480 | can at the end of running you will find that there is now a the callback will
00:48:12.280 | now contain an LR's and a losses so for this callback I can't just add it
00:48:17.600 | directly to the callback list I need to instantiate it first and the reason I
00:48:22.780 | need to instantiate it first is because I need to be able to grab its learning
00:48:25.800 | rates and its losses and in fact you know we could grab that whole thing and
00:48:30.280 | move it in here there's no reason callbacks only have to have the callback
00:48:34.120 | things right so we could do this and now that's just going to become self there
00:48:50.280 | we go and so then we can train it again and we could just call LR find dot plot
00:48:57.440 | so callbacks can really be you know quite self-contained nice things as you
00:49:03.640 | can see so there's a more sophisticated callback and I think it's doing a lot of
00:49:08.080 | really nice stuff here you might have come across something in PyTorch called
00:49:14.880 | learning rate schedulers and in fact we could implement this whole thing with a
00:49:20.360 | learning rate scheduler it won't actually save that much time but I just
00:49:24.160 | want to show you when you use stuff in PyTorch like learning rate schedulers
00:49:28.400 | you're actually using things that are extremely simple the learning rate
00:49:31.720 | scheduler basically does this one line of code for us so I'm going to now create a
00:49:35.960 | new LR find a CB and this time I'm going to use the PyTorch's exponential LR
00:49:42.880 | scheduler which is here so this is now it's interesting that actually the
00:49:51.560 | documentation of this is kind of actually wrong it claims that it decays
00:49:56.600 | the learning rate of each parameter group by gamma so gamma is just some
00:50:00.320 | number you pass in I don't know why this has to be a Greek letter but it sounds
00:50:04.600 | more fancy than multiplying by an LR multiplier it says every epoch but it's
00:50:11.080 | not actually done every epoch at all what actually happens is in PyTorch the
00:50:17.560 | schedulers have a step method and the decay happens each time you call step
00:50:23.240 | and if you set gamma which is actually LR mult to a number bigger than one it's
00:50:29.300 | not a decay it's an increase so the difference now I guess I'll copy and
00:50:34.400 | paste the previous version okay so the previous versions on the top so the main
00:50:43.520 | difference here is that before fit we're going to create something called a self
00:50:47.360 | dot shed equal to the scheduler and the scheduler because it's going to be
00:50:53.300 | adjusting the learning rates it actually needs access to the optimizer so we pass
00:50:56.760 | in the optimizer and the learning rate model player and so then in after batch
00:51:04.200 | rather than having this line of code we replace it with this line of code self
00:51:09.320 | dot shed dot step so that's the only difference and you know I mean we're not
00:51:14.040 | gaining much as I said by using the PyTorch exponential LR scheduler but I
00:51:20.040 | mainly wanted to do it so you can see that these things like PyTorch
00:51:23.440 | schedulers are not doing anything magic they're just doing that one line of code
00:51:27.960 | for us and so I run it again using this new version oopsy dozy oh I forgot to
00:51:36.880 | run this line of code there we go and I guess I should also add the nice little
00:51:44.920 | plot method maybe we'll just move it to the bottom there
00:51:54.720 | lrfind dot plot there we go and put that one back to how it was all right perfect
00:52:09.520 | timing so we added a few very important things in here so make sure we export
00:52:14.820 | and we'll be able to use them shortly all right let's have an eight minute break
00:52:23.120 | it's just have a 10 minute break so I see you back here at um eight past all
00:52:33.400 | right welcome back um one suggestion which I lately like is we could rename
00:52:37.320 | plot to after fit which I really like because that means we should be able to
00:52:47.520 | then just call learn dot fit and delete the next one and let's see it didn't
00:52:54.520 | work why not oh no that doesn't work does it because the hmm you know what I
00:53:07.200 | think the callback here could go into a finally block actually that would
00:53:19.600 | actually allow us to always call the callback even if we've cancelled I think
00:53:26.400 | that's reasonable that might have its own confusions anyway we could try it
00:53:31.640 | for now because that would let us put this after fit in there we go so that
00:53:41.760 | automatically runs then so that's an interesting idea I think I quite like it
00:53:56.040 | cool so let's now look at notebook 10 so I feel like this is the the next big
00:54:06.640 | piece we need so we've got a pretty good system now for training models what I
00:54:14.600 | think we're really missing though is a way to identify how our models are
00:54:20.080 | training and so try identify how our models are training we need to be able
00:54:24.880 | to look inside them and see what's going on while they train we don't currently
00:54:29.080 | have any way to do that and therefore it's very hard for us to diagnose and
00:54:34.360 | fix problems most people have no way of looking inside their models and so most
00:54:39.280 | people have no way to properly diagnose and fix models and that's why most
00:54:42.920 | people when they have a problem with training their model randomly try things
00:54:46.200 | until something starts hopefully working we're not going to do that we're going
00:54:50.920 | to do it properly so we can import the stuff that we just created in the
00:54:56.320 | learner and the first thing I'm going to do introduce now is a set seed function
00:55:02.440 | we've been using torch manual seed before we know all about RNGs random
00:55:08.360 | number generators we've actually got three of them pie torches numpies and
00:55:14.640 | pythons let's see all of them and also in python pytorch you can use a flag to
00:55:22.520 | ask it to use deterministic algorithms so things should be reproducible as we
00:55:26.840 | discussed before you shouldn't always just make things reproducible but for
00:55:30.720 | lessons I think this is useful so here's a function that lets you set a
00:55:35.000 | reproducible seed all right let's use the same data set as before a fashion
00:55:39.440 | MNIST data set will load it up in the same way and let's create a model that
00:55:45.440 | looks very similar to our previous models this one might be a bit bigger
00:55:49.400 | might not I didn't actually check okay so let's use multiclass accuracy again
00:55:59.000 | same callbacks that we used before we'll use the train CB version for no
00:56:04.200 | particular reason and generally speaking we want to train as fast as possible not
00:56:12.480 | just because we don't like wasting time but actually more importantly because
00:56:16.240 | the fact the higher the learning rate you train at the more the more you're
00:56:20.800 | able to find a often more generalizable set of weights and also oh training
00:56:35.080 | quickly also means that we can look at each batch let each item in the data
00:56:40.880 | less often so we're going to have less issues with overfitting and generally
00:56:44.880 | speaking being if we can train at a high learning rate then that means that we're
00:56:49.200 | learning to train in a stable way and stable training is is very good so let's
00:56:56.840 | try setting up a high learning rate of point six and see what happens so here's
00:57:02.600 | a function that's just going to create a learner with our callbacks and fit it
00:57:07.600 | and return the learner in case we want to use it and it's training oh and then
00:57:14.320 | it suddenly fell apart so it's going well for a while and then it stopped
00:57:19.000 | training nicely so one nice thing about this graph is that we can immediately
00:57:22.640 | see when it stops training well which is very useful so what happened there why
00:57:30.960 | did it go badly I mean we can guess that it might have been because of our high
00:57:33.760 | learning rate but what's really going on so let's try to look inside it so one
00:57:39.440 | way to look inside it would be we could create our own sequential model which
00:57:46.040 | just like the sequential model we've built before with do you remember we
00:57:49.920 | created one using nn.module list in a previous lesson if you've forgotten go
00:57:54.160 | back and check that out and when we call that model we go through each layer and
00:58:01.840 | just call the layer and what we could do is though we could add something in
00:58:07.120 | addition which is at each layer we could also get the mean of that layer and the
00:58:17.280 | standard deviation of that layer and append them to a couple of different
00:58:25.000 | lists and activation means and activation standard deviations this is
00:58:29.240 | going to contain after we call this model it's going to contain the means
00:58:34.560 | and standard deviations for each layer and then we could define dunder iter
00:58:42.280 | which makes this into an iterator as being let's say just oh just it when you
00:58:46.400 | iterate through this model you can iterate through the layers so we can
00:58:52.400 | then train this model in the usual way and this is going to give us exactly the
00:58:56.080 | same outcome as before because I'm using the same seed so you can see it looks
00:58:59.440 | identical but the difference is instead of using nn.sequential we've now
00:59:05.080 | used something that's actually saved the means and standard deviations of each
00:59:08.600 | layer and so therefore we can plot them okay so here we've plotted the
00:59:20.240 | activation means and notice that we've done it for every batch so that's why
00:59:27.880 | along the x-axis here we have batch number and on the y-axis we have the
00:59:32.520 | activation means and then we have it for each layer so rather than starting at one
00:59:37.160 | because we play from starting at zero so this is the first layer is blue second
00:59:40.960 | layer is orange third layer green fourth layer red and fifth layer watch for that
00:59:48.600 | like movie kind of color and look what's happened the activations have started
00:59:57.200 | pretty small close to zero and have increased at an exponentially increasing
01:00:01.640 | rate and then have crashed and then have increased again an exponential rate and
01:00:07.280 | crashed again it increased again crashed again and each time they've not gone up
01:00:12.160 | they've gone up even higher and they've crashed in this case even lower and what
01:00:18.040 | happens well wait what's happening here when our activations are really close to
01:00:22.040 | zero well when your activations are really close to zero that means that the
01:00:26.200 | inputs to each layer are numbers very close to zero as a result of which of
01:00:31.680 | course the outputs are very close to zero because we're doing just matrix
01:00:36.000 | multiplies and so this is a disaster when activations are very close to zero
01:00:41.920 | you're there there they're dead units they're not able to do anything and you
01:00:47.880 | can see for ages here it's not training at all and this is so this is the this
01:00:52.680 | is the activation means the standard deviations tell an even stronger story
01:00:55.880 | so you want generally speaking you want the means of the activations to be about
01:01:03.400 | zero and the standard deviations to be about one mean of zero is fine as long
01:01:09.720 | as they're spread around zero but a standard deviation of close to zero is
01:01:14.120 | terrible because that means all of the activations are about the same so here
01:01:18.760 | after batch 30 all all of the activations are close to zero and all of
01:01:24.600 | their standard deviations are close to zero so all the numbers are about the
01:01:27.680 | same and they're about zero so nothing's going on and you can see the same things
01:01:33.560 | happening with standard deviations we start with not very much variety in the
01:01:37.220 | weights it exponentially increases how much variety there is and then it crashes
01:01:41.200 | again exponentially increases crashes again this is a classic shape of bad
01:01:48.080 | behavior and with these two plots you can really understand what's going on in
01:01:55.200 | your model and if you train a model and at the end of it you kind of think well
01:02:00.200 | I wonder if this is any good if you haven't looked at this plot you don't
01:02:03.800 | know because you haven't checked to see whether it's training nicely maybe it
01:02:07.920 | could it could be a lot better if you can get something we'll see some nicer
01:02:12.240 | training pictures later but generally speaking you want something where your
01:02:16.560 | main is always about zero and your variance is always about one standard
01:02:21.880 | deviation is always about one and if you see that then it's a pretty good chance
01:02:26.840 | you're training properly if you don't see that you're most certainly not
01:02:30.160 | training properly okay so what I'm going to do in the rest of this part of the
01:02:34.800 | lesson is explain how to do this in a more elegant way because as I say being
01:02:40.800 | able to look inside your models is such a critically important thing to building
01:02:45.240 | and debugging models we don't have to do it manually we don't have to create our
01:02:49.440 | own sequential model we can actually use a pytorch thing called hooks so as it
01:02:55.680 | says here a hook is called when a layer that it's registered to is executed
01:03:00.920 | during the forward pass that's quite a forward hook well the backward pass and
01:03:04.880 | that's called a backward hook and so the key thing about hooks is we don't have
01:03:08.280 | to rewrite the model we can add them to any existing model so we can just use
01:03:13.320 | standard nn.sequential passing in our layers which were these ones here and so
01:03:23.840 | we're still going to have something to keep track of the activation means and
01:03:26.720 | standard deviation so just create an empty list for now for each layer in the
01:03:31.940 | model and let's create a little function it's going to be called because a hawk
01:03:38.040 | is going to call a function when when during the forward pass for a forward
01:03:43.040 | hook or the backward pass or a backward hook so it could have function called a
01:03:46.160 | pen stats it's going to be passed the hook number sorry the layer number the
01:03:51.440 | module and the input and the output so we're going to be grabbing the outputs
01:03:58.240 | mean and putting in in activation means and the output standard deviation and
01:04:02.600 | putting it in activation standard deviations so here's how you do it we've
01:04:07.440 | got a model you go through each layer of the model and you call on it register
01:04:13.120 | forward hook that's part of pytorch and we don't need to write it ourselves
01:04:16.960 | because we already did right it's just doing the same thing as this basically
01:04:20.240 | and what function is always going to be called the function that's going to be
01:04:28.240 | called is the append stats function passing in remember partial is the
01:04:33.760 | equivalent of saying a pen stats passing in I as the first element the first
01:04:38.940 | argument so if we now fit that model it trains in the usual way but after each
01:04:46.600 | after each layer it's going to call this and so you can see we get exactly the
01:04:54.120 | same thing as before so one question we get here is what's the difference
01:04:58.680 | between a hook and a callback nothing at all hooks and call backs are the same
01:05:03.740 | thing it's just that pytorch defines hooks and they call them hooks instead
01:05:09.800 | of callbacks they are less flexible than the callbacks that we used in in the
01:05:18.920 | learner because you don't have access to all the available states you can't
01:05:22.400 | change things but there are you know there are particular kind of callback
01:05:26.000 | it's just setting a piece of code that's going to be run for us when we when
01:05:33.620 | something happens and in this case there's something that happens is that
01:05:36.400 | either layer in the forward pass is called or a layer in the backward pass
01:05:39.800 | is called I guess you could describe the function that's being called back as the
01:05:47.180 | callback and the thing that's doing the callback has the hook I'm not sure if
01:05:52.560 | that level of distinction is important but maybe that's you could do that okay
01:05:57.320 | so anyway this is a little bit fussy of kind of like creating globals and
01:06:02.080 | depending to them and stuff like that so let's try to simplify this a little bit
01:06:06.160 | so what I did here was I created a class called hook so this class when we create
01:06:15.760 | it we're going to pass in the module that we're hooking so we call M dot
01:06:22.160 | register forward hook and we call the function we pass the function that we
01:06:26.280 | want to be given and so here's the pass the function and we're also going to
01:06:31.080 | pass in the hook class to the function let's also define a remove because this
01:06:41.540 | is actually the thing that this is actually the thing that removes the hook
01:06:46.120 | we don't want it sitting around forever this is called the Dell is called by
01:06:52.080 | Python when an object is freed so when that happens we should also make sure
01:06:56.720 | that we remove this okay so appends that's now we're going to replace it's
01:07:05.120 | going to instead get past the hook instead because that's what we asked to
01:07:13.700 | be passed and if there's no dot stats attribute in there yet then let's
01:07:21.980 | create one and then we're going to be past the activation so put that on the
01:07:29.780 | CPU and append the mean and at the standard deviation and now the nice thing
01:07:35.060 | is that the stats are actually inside this object which is convenient so now
01:07:40.480 | we can do exactly the same thing as before but we don't have to set any of
01:07:44.560 | that global stuff or whatever we can just say okay our hooks is a hook with
01:07:49.440 | that layer and that function for all those models layers and so we're just
01:07:58.640 | calling it has called register forward hook for us so now when we fit that it's
01:08:05.960 | going to run with the hooks there we go it trains actually they did do it too
01:08:22.800 | okay so then it trains and we get exactly the same shape as usual and we
01:08:33.120 | get back the same results as usual but as we can see we're gradually making
01:08:35.920 | this more convenient which is nice so we can make it nicer still because
01:08:44.760 | generally speaking we're going to be adding multiple hooks and this stuff of
01:08:48.880 | you know this list comprehension whatever it's a bit inconvenient so
01:08:52.800 | let's create a hooks class so first of all we'll see how the hooks class works
01:08:57.520 | in practice so in the hooks class the way we're going to use it is we're going
01:09:02.520 | to call with hooks pass in the model pass in the function to use as their hook
01:09:09.920 | and then we'll fit the model and that's it it's going to be literally just one
01:09:15.000 | extra line of code to set up the whole thing and then when we then you can then
01:09:19.360 | go through each hook and plot the main and standard deviation of each layer so
01:09:25.180 | that's how that's the hooks class is going to make things much easier so the
01:09:29.020 | hooks class as you can see we're using a making it a context manager and we want
01:09:38.640 | to be able to loop through it we want to be an index into it so it's quite a lot
01:09:43.200 | of behavior we want believe it or not all that behavior is in this tiny little
01:09:48.040 | thing and we're going to use the most flexible general way of creating context
01:09:54.360 | managers now context managers are things that we can say with the general way of
01:09:58.960 | creating a context manager is to create a class and to find two special things
01:10:03.560 | done to enter and done to exit done to enter is a function that's going to be
01:10:08.600 | called when it hits the with statement and if you add an as blah after it then
01:10:16.440 | the contents of this variable will be whatever is returned from done to enter
01:10:21.600 | and as you can see we just return the object itself so the the hooks object is
01:10:27.240 | going to be stored in hooks now interestingly the hooks class inherits
01:10:35.440 | from list you can do this you can actually inherit from stuff like list in
01:10:39.660 | python so a hooks the hooks object is a list and therefore we need to call the
01:10:45.360 | super classes constructor and we're going to pass in a that list comprehension we
01:10:51.120 | saw that list of hooks where it's going to hook into each module in the list of
01:10:56.480 | modules we asked to hook into now we're passing in a model here but because the
01:11:04.280 | model is an nn dot sequential you can actually loop through an nn dot
01:11:07.640 | sequential and it returns each of the layers so this is actually very very
01:11:12.080 | nice and concise and convenient so that's the constructor done to enter just
01:11:18.000 | returns it done to exit is what's called automatically at the end of the whole
01:11:24.360 | block so when this whole thing's finished it's going to remove the hooks
01:11:29.440 | and removing the hooks is just going to go through each hook and remove it the
01:11:34.760 | reason we can do for H and self is because remember this is a list and then
01:11:43.560 | finally we've got a dunder dell like before and I also added a dunder dell
01:11:48.840 | item this is the thing that lets you delete a single hook from the list which
01:11:53.280 | will remove that one hook and call the list still item so there's our whole
01:12:01.640 | thing so this is going to this this this one's optional this is the one that lets
01:12:05.000 | us remove a single hook rather than all of them so let's just understand some of
01:12:12.200 | what's going on there so here's a dummy context manager as you can see here it's
01:12:18.240 | got a dunder enter which is going to return itself and it's going to print
01:12:22.960 | something so you can see here I call with dummy context manager and so
01:12:28.400 | therefore it prints let's go first the second thing it's going to do is call
01:12:34.120 | this code inside the context manager so we've got as DCM so that's itself and
01:12:40.640 | so it's gonna call hello which prints hello so here it is and then finally
01:12:48.160 | it's going to automatically call exit dunder exit which is all done so here's
01:12:53.920 | all done so again if you haven't used context managers before you want to be
01:12:58.160 | creating little samples like this yourself and getting them to work so
01:13:02.120 | this is your key homework for this week is anything in the lesson where we're
01:13:09.160 | using a part of Python you're not a hundred percent familiar with is for you
01:13:13.280 | to from scratch to create some simple like kind of dummy version that fully
01:13:18.480 | explores what it's doing if you're familiar with all the Python pieces then
01:13:24.080 | it's to create your own you know that is to explore do the same thing with the
01:13:29.000 | PyTorch pieces like with with hooks and so forth and so I just wanted to show
01:13:35.640 | you also what it's like to inherit from list so here I'm here inheriting from a
01:13:40.120 | list and I could redefine how dunder Dell item works so now I can create a
01:13:46.000 | dummy list and it looks exactly the same as usual but now if I delete an item
01:13:54.040 | from the list it's going to call my overridden version and then it will
01:14:01.680 | call the original version and so the list is now got removed that item and did
01:14:06.840 | this at the same time so you can see you can actually yeah modify how Python
01:14:11.360 | works or create your own things that get all the behavior or the convenience of
01:14:16.720 | Python classes like this one and add stuff to them so that's what's happening
01:14:22.600 | there okay so that's our hooks class so the next bit was developed largely
01:14:34.600 | developed the last time I think it was that we did a part 2 course in San
01:14:39.520 | Francisco with Stefano so many thanks to him for helping get this next bit
01:14:44.000 | looking great we're going to create my favorite single image explanations of
01:14:52.040 | what's going on inside a model we call them the colorful dimension which they're
01:14:58.160 | histograms we're going to take our same append stats these are all the same as
01:15:04.080 | before we're going to add an extra line of code which is to get a histogram of
01:15:09.040 | the absolute values of the activations so a histogram a histogram to remind
01:15:15.600 | you is something that takes a collection of numbers and tells you how
01:15:21.680 | frequent each group of numbers are and we're going to create 50 bins for our
01:15:28.640 | histogram so we will use our hooks that we just created and we're going to use
01:15:40.000 | this new version of append stats so it's going to train us before but now we're
01:15:43.840 | going to in addition have this extra extra thing in stats we're just going to
01:15:48.640 | contain a histogram and so with that we're now going to create this amazing
01:15:56.240 | plot now what this plot is showing is for the first second third and fourth
01:16:02.960 | layers what does the training look like and you can immediately see the basic
01:16:07.120 | idea is that we're seeing this same pattern but what is this pattern showing
01:16:15.000 | what exactly is going on in these pictures so I think it might be best if
01:16:20.120 | we try and draw a picture of this so let's take a normal histogram okay so
01:16:35.320 | let's take a normal histogram where what will be where we basically have like
01:16:44.000 | have grouped all the data into bins and then we have counts of how much is in
01:16:51.280 | each bin so for example this will be like the value of the activations and it
01:16:59.600 | might be say from 0 to 10 and then from 10 to 20 and from 20 to 30 and these are
01:17:09.160 | generally equally spaced bins okay and then here is the count so that's the
01:17:22.360 | number of items with that range of values so this is called a histogram
01:17:28.440 | okay so what Stefano and I did was we actually turn that histogram that whole
01:17:42.720 | histogram into a single column of pixels so if I take one column of pixels with
01:17:50.240 | that's actually one histogram and the way we do it is we take these numbers so
01:17:58.040 | let's say let's say it's like 14 that one's like 2 7 9 11 3 2 4 2 say and so
01:18:10.880 | then what we do is we turn it into a single column and so in this case we've
01:18:18.960 | got 1 2 3 4 5 6 7 8 9 groups right so we would create our nine groups sorry they
01:18:29.680 | were meant to be evenly spaced but they were a good job got our nine groups and
01:18:34.200 | so we take the first group it's 14 and what we do is we color it with a
01:18:41.120 | gradient and a color according to how big that number is so 14 is a real big
01:18:46.360 | number so depending on you know what gradient we use maybe reds really really
01:18:50.160 | big and the next one's really small which might be like green and then the
01:18:55.120 | next one's quite big in the middle which is like blue the next one's getting quite
01:19:00.920 | quite bigger still so maybe it's just a little bit sorry we should go back to
01:19:04.480 | red go back to more red next one's bigger stills it's even more red and so
01:19:11.140 | forth so basically we're taking the histogram and taking it into a color
01:19:17.000 | coded single column plot if that makes sense and so what that means is that at
01:19:25.160 | the very so let's take layer number two here layer number two we can take the
01:19:32.080 | very first column and so in the color scheme that actually map plot lives
01:19:36.680 | picked here yellow is the most common and then light green is less common and
01:19:41.880 | then light blue is less common and then dark blue is zero so you can see the
01:19:46.160 | vast majority is zero and there's a few with slightly bigger numbers which is
01:19:50.600 | exactly the same that we saw for index one layer here it is right the average
01:19:57.480 | the average is pretty close to zero the standard deviation is pretty small this
01:20:03.720 | is giving us more information however so as we train at this point here the at
01:20:15.920 | this point here there is quite a few activations that are a lot larger as you
01:20:22.320 | can see and still the vast majority of them are very small there's a few big
01:20:26.680 | ones they still got a bright yellow bar at the bottom the other thing to notice
01:20:31.680 | here is what's happened is we've taken those those stats those histograms we've
01:20:36.960 | stacked them all up into a single tensor and then we've taken their log now log
01:20:41.620 | 1p is just log of the number plus 1 that's because we've got zeros here and
01:20:47.480 | so just taking the log is going to kind of let us see the full range more
01:20:54.560 | clearly so that's what the locks for so basically what we'd really ideally like
01:21:02.000 | to see here is that this whole thing should be a kind of more like a
01:21:09.240 | rectangle you know the maximum should be should be not changing very much there
01:21:14.480 | shouldn't be a thick yellow bar at the bottom but instead it should be a nice
01:21:17.520 | even gradient matching a normal distribution each single column of pixels
01:21:23.600 | wants to be kind of like a normal distribution so you know gradually
01:21:28.080 | decreasing the number of activations that's what we're aiming for there's a
01:21:36.600 | another really important and actually easier to read version of this which is
01:21:43.480 | what if we just took those first two bottom pixels so the the least common
01:21:48.360 | five percent and counted up how many were in what's not the foot sorry least
01:21:53.560 | common five percent the least cut the not least common either let's try again in
01:21:59.480 | the bottom two pixels we've got the smallest two equally sized groups of
01:22:08.280 | activations we don't want there to be too many of them because those are
01:22:14.320 | basically dead or nearly dead activations they're much much much smaller
01:22:18.360 | than the big ones and so taking the ratio between those bottom two groups
01:22:24.200 | and the total basically tells us what percentage have zero or near zero or
01:22:33.140 | extremely small magnitudes and remember that these are with absolute values so
01:22:44.520 | if we plot those you can see how bad this is and in particular for example
01:22:50.400 | at the final layer from the you know nearly from the very start really nearly
01:22:55.320 | all of the activations are they're entirely just about entirely disabled so
01:23:03.700 | this is this is bad news and if you've got a model where most of your model is
01:23:08.920 | close to zero then most of your models doing no work and so it's it's really
01:23:15.680 | it's really not working so it may look like at the very end things were
01:23:22.720 | improving but as you can see from this chart that's not true right there's still
01:23:28.360 | the vast majority are still inactive generally speaking I found that if early
01:23:33.840 | in training you see this rising crash rising crash at all you should stop and
01:23:39.920 | restart training because this your model will probably never recover too many of
01:23:48.600 | the activations have gone off the rails so we want it to look kind of like this
01:23:54.760 | the whole time but with less of this very thick yellow bar which is showing
01:24:00.280 | us most are inactive okay so that's our activations so we've got really now all
01:24:21.920 | of the kind of key pieces I think we need to be able to flexibly change how
01:24:28.800 | we train models and to understand what's going on inside our models and so from
01:24:35.040 | this point we've kind of like drilled down as deep as we need to go and we can
01:24:41.780 | now start to come back up again and and put together the pieces building up what
01:24:49.980 | are all of the things that are going to help us train models reliably and quickly
01:24:56.280 | and then hopefully we're going to be able to yeah successfully create from
01:25:00.400 | scratch some really high quality generative models and other models along
01:25:04.640 | the way okay I think that's everything for this class but next class we're
01:25:13.160 | going to start looking at things like initialization it's a really important
01:25:16.560 | topic if you want to do some revision before then just make sure that you're
01:25:22.680 | very comfortable with things like standard deviations and stuff like that
01:25:27.760 | because we're using that quite a lot for next time and yeah thanks for joining me
01:25:33.960 | look forward to the next lesson see you again