back to indexLesson 23: Deep Learning Foundations to Stable Diffusion
00:00:00.640 |
Hi everybody, today we are covering lesson 23 and we're here with Jono and Tanishk. How are 00:00:07.360 |
you guys both doing? Doing well, excited for another lesson. Yeah likewise. Great, I 00:00:18.720 |
shamefully have to start with admitting to a bug which actually is rather, well I don't know, 00:00:28.320 |
it kind of messed up things in a sense but I kind of I think it's really interesting 00:00:32.240 |
actually what happened. The bug, it was in notebook 23, the Keras notebook and it's about 00:00:40.480 |
the measure measuring the FID. So to recall FID measures how similar a bunch of samples are 00:00:52.960 |
from a model to a bunch of samples of real images and that similarity is defined in this kind of 00:00:59.600 |
like some kind of distance between the distributions of the features in a classifier or some kind of 00:01:07.120 |
model. So that means that to get FID we have to load a model and we have to pass it some 00:01:18.960 |
data loaders so that it can calculate what the samples look like from real images. 00:01:24.240 |
Now the problem is that the data loaders I was passing actually had images that 00:01:31.920 |
the pixels were between negative 0.5 and positive 0.5 but you might recall this model that I trained 00:01:41.600 |
has pixels between negative 1 and 1. So what this image eval class would have seen and specifically 00:01:50.480 |
this this C model which we are putting which we are getting the features from is it would have 00:01:54.480 |
seen a whole bunch of unusually low contrast images so they wouldn't really have looked like 00:02:02.000 |
many things in the data set because in fact in the data set I think particularly for fashion 00:02:07.120 |
MNIST things are pretty consistently you know normalized in terms of going all the way from 00:02:14.320 |
0 to 1 or negative 1 to 1 I guess 0 to 255 in the original. And so as a result I think what would 00:02:22.560 |
have happened is that the features that came out of this would have been kind of weird and they 00:02:29.680 |
might not have necessarily consistently said oh these are t-shirt features and these are shoe 00:02:34.880 |
features but they would have said oh this is a weird low contrast low contrast image feature. 00:02:39.760 |
And so then the shame continues in that I added another bug on top of this bug which is when I 00:02:49.520 |
then did the sampling I didn't I didn't multiply by two and the data that I trained it on was 00:03:03.200 |
actually the same data loaders or that well the specifically the same transform the same Noisify 00:03:08.400 |
transform well where did it come from it's the same yeah the same transform I not Noisify the 00:03:18.160 |
same transform I which yeah previously was point from negative 0.5 to 0.5 so I trained the model 00:03:24.720 |
using this restricted input space as well and therefore it was spitting out things that were 00:03:31.200 |
between negative 0.5 and 0.5. And so the FID then said wow these are so similar the samples are 00:03:40.880 |
consistently spitting out features of low contrast things and all of the real samples are low 00:03:45.600 |
contrast things so those are really similar and that's how we got really low numbers. So those 00:03:51.280 |
low numbers are wrong so I was a bit surprised I guess that that the Keras model was doing so much 00:03:56.720 |
better and it certainly it made me a big believer in the Keras model but actually it's not doing so 00:04:02.960 |
much better so once we fix that the FID's are actually around five six five and the reals are 00:04:24.000 |
we were getting some pretty good results in cosine. So cosine yeah we were getting three 00:04:28.480 |
to four depending on how many steps we were doing DDIM. So the result of this is that this 00:04:37.840 |
somewhat odd situation where the cosine model where we 00:04:48.720 |
scaled it accidentally to be negative 0.5 to 0.5 and then post sampling multiplied by two so we're 00:04:58.160 |
not cheating like the Keras one used to be is working better than Keras which yeah it's a 00:05:04.000 |
surprise to me because I was thinking Keras was kind of like in theory optimally scaling things 00:05:12.160 |
but I guess the truth is it was scaling things to unit variance but there's nothing particularly 00:05:17.760 |
to say that's optimally scaling things and so empirically we've found kind of accidentally 00:05:23.040 |
a better way to scale things and also our dependent variable is different you know 00:05:32.320 |
our dependent variable is not that Keras you know C-mix combination but our dependent variable 00:05:38.400 |
is just the noise the zero one noise you know the noise before it's multiplied by alpha. 00:05:46.800 |
Okay so that's that's the bug. Anyway I promised last time we would stop looking at fashion 00:05:54.880 |
MNIST for a while so let's move on to tiny image net. So and the reason we're going to do this is 00:06:04.720 |
because we want to I want to show an example of we're going to try and create units today and I 00:06:11.360 |
wanted to show an example of a of a nice unit we can create that combines a lot of the ideas we've 00:06:16.640 |
been looking at it's going to be a super resolution unit and doing super resolution on fashion MNIST 00:06:23.120 |
isn't going to be very interesting because the maximum training size we have is 28 by 28. So 00:06:28.320 |
so I thought we'd go a little bit bigger than that to tiny image net which is 64 by 64. 00:06:38.080 |
I found it quite difficult actually to find tiny image net data but eventually I discovered that 00:06:44.480 |
it's still on the Stanford servers where it was originally created it's just not linked to 00:06:48.880 |
anywhere. So we'll try to if this disappears we will we will keep our forum and website up to 00:06:55.680 |
date with other places to find it. Anyway so for now we can grab the URL from there and unpack it. 00:07:08.560 |
So SHUtil is a very handy little library inside the Python standard library and one of the things 00:07:13.360 |
it has is a very handy unpack archives which can handle zip files and it's going to put it in our 00:07:18.640 |
data directory. So I yeah just you know there's a few different ways we could process this 00:07:27.120 |
and I thought we might experiment some things but I thought yeah it wouldn't be a bad idea to 00:07:32.160 |
try doing things the reasonably kind of manual way just to see you know what that looks like 00:07:39.680 |
and often this is the easiest way to do things because you know that's a very well-defined set 00:07:45.280 |
of steps right. So step one is to create a data set. So a data set is just literally something 00:07:51.200 |
that has a length and that you can index into it so it has to have these two things to find. 00:07:56.800 |
You don't have to inherit from anything you just have to define these two things. Broadly speaking 00:08:04.320 |
in Python you generally don't have to inherit from things you just have to provide the methods that 00:08:09.440 |
are expected. So our data set is in a directory called tiny image net 200 and then there's a 00:08:21.520 |
train directory and a val directory for the training and the validation set and then the 00:08:25.520 |
train directory this is pretty classic normal thing each category so this is a category 00:08:33.120 |
has images in a separate folder and specifically they're in images subfolder. So what I wanted to 00:08:41.200 |
do was to just grab start with grab all of the files in path slash train or the image files 00:08:48.640 |
so the python standard library has a glob function which searches recursively if asked to 00:08:58.240 |
for everything that matches this well this specification. So this specification is path 00:09:07.200 |
slash star dot jpeg and then this star star here I don't know why we need to do it twice it's a bit 00:09:13.200 |
weird it also that you also need that to be recursive. So to be recursive you both have 00:09:17.760 |
to say recursive true here and also put star star before the slash here. So that's going to give us 00:09:24.160 |
a list of all files inside path train and so then if we index into that training data set 00:09:34.480 |
with zero that will call get item passing an i of zero and so we will then return a tuple. One is 00:09:42.080 |
the thing in self dot files i which is this file and then the label for it and the label is 00:09:52.560 |
that so it's the parent's parent's name parent's parent's name and so that's the name. 00:10:02.160 |
Okay so there's a data set that returns two strings when you index into it a couple of 00:10:07.920 |
two strings the first is the name of the image file the so the path of the image file and the 00:10:13.200 |
second is the name of the category it's in. These weird names are called wordnet categories they're 00:10:21.040 |
like codes that indicate concepts basically in English. 00:10:25.920 |
So one of the reasons I actually used this particular data set is because it's going to 00:10:34.320 |
force us to do some more data processing which I think is good practice and that's because weirdly 00:10:42.000 |
in the validation set although it's in tiny image net 200 slash val which is the not weird part 00:10:48.480 |
the weird part is that they are not then in subdirectories organized by label instead there 00:10:58.080 |
is a separate val annotations dot text file which looks like this so it says for each file name 00:11:12.640 |
what category is it it's also got the like the bounding box of whereabouts that is but we're not 00:11:18.880 |
going to be using that today. So I decided to create a dictionary that would tell us for each 00:11:28.000 |
file what category is it in so that means that I want to create a in this case here I'm doing 00:11:37.360 |
something exactly like a list comprehension but because it's not in square brackets it's a generator 00:11:41.440 |
comprehension so it'll generate it kind of stream out the the results and we're going to go through 00:11:48.480 |
each line in this file and we're going to split on tab so that's going to give us this and then this 00:12:01.920 |
and then this and then we're going to grab the first two and if you basically pass a 00:12:09.760 |
list of lists or list of tuples or whatever to dict it will create a dictionary 00:12:15.280 |
using these pairs as key values so if we have a look 00:12:21.440 |
there it is that's quite a nice neat way to do it and if you're not sure you can just 00:12:32.720 |
click type dict type open brackets and then hit shift tab a couple of times and it'll show you 00:12:37.520 |
the various options and you can see here I'm doing dict iterable because my generator it is 00:12:43.360 |
iterable and it says oh that's exactly as if you created a dictionary and then gone 4 kv in iterable 00:12:49.120 |
dk equals v so there's a nice little trick okay now we need a data set that works just like tiny 00:13:04.800 |
data set but the get items are going to label things differently so I just inherited from tiny 00:13:10.720 |
data set so that means we don't need to do init or len again and then get item again it's going to 00:13:15.840 |
turn the i-th file this time the label will not be the parent parent name but we will look up in 00:13:22.400 |
the annotations dictionary the name of the file and so that works we can check the length works 00:13:33.600 |
so then um a fairly generally useful thing that I thought we'll then create is something 00:13:41.440 |
that lets us transform any data set so here's a class you can pass it a data set and you can pass 00:13:51.440 |
it a transformation for the x or the independent variable and you can pass the transformation from 00:13:56.880 |
the y and both of them default to no op that is no operation so it just doesn't change it at all 00:14:04.240 |
so a transform data set the length of it is just the length of the original data set 00:14:10.080 |
but when we call get item it'll grab the tuple from the data set we passed in and it will 00:14:18.640 |
return that tuple but with transform x and transform y applied to it does that make sense so far 00:14:29.520 |
great okay so I don't like working with these n 0 3 0 things but the data set luckily has a 00:14:44.320 |
so if I just open it up oh sorry this one actually is not quite going to help us this is just a list 00:14:57.120 |
of all of the word net ids that they have images for we could have actually got this by simply 00:15:04.080 |
grabbing um by listing this directory it would have told us all the ids but they've got they've 00:15:10.640 |
also got just a the text file containing all of them so we can see that there are 00:15:23.520 |
okay um and that's useful because we're going to want to change n 0 3 0 etc into an int 00:15:33.600 |
and the way we can change it into an int is by simply saying oh we'll call we'll call this one 00:15:39.360 |
zero and this one one and so forth right so the kind of the int to string or id to string version 00:15:47.920 |
of this is literally this list so zero will be there that but the string to int version where 00:15:54.320 |
you do this all the time is basically enumerate so that gives us the index and the value for 00:16:01.120 |
everything in the list so those are going to be our keys and values but actually we're going to 00:16:06.160 |
invert it to become value colon key and that's what's true to id will be so note here that we 00:16:14.480 |
have a dictionary comprehension you can tell because it's got curly brackets and a colon and 00:16:20.800 |
so here's our dictionary comprehension so we could have used that uh for this as well we could have 00:16:28.320 |
done a dictionary comprehension instead but um yeah so there's lots of ways of doing things 00:16:34.880 |
none of them's any better or worse than any other um okay so that's the uh the one those 00:16:42.320 |
word tags whatever do we have the the names for them or is that something yes the names i'm going 00:16:49.360 |
to get to yes shortly there's a word dot text so yeah all right i grabbed one batch of data 00:16:57.840 |
and grabbed its main and standard deviation and so then i've just copied and pasted them in here 00:17:02.480 |
for normalizing um so my my transform x is going to be i'm going to read the image 00:17:11.920 |
um if you read it as RGB that's going to force it to be three channels because actually some of them 00:17:17.680 |
are only one channel uh divided by 255 so it'll be between zero and one and then we will normalize 00:17:24.560 |
and then for our y's we will go through strata id to get the id and just use that as our tensor 00:17:34.640 |
so it's you know doing it manually is actually pretty straightforward right because now 00:17:41.920 |
we just pass those to our tufim ds our transformed data set 00:17:51.120 |
yi is a tensor but we can look it up to get its value and xi is an image tensor 00:18:09.920 |
with three channels so channel by height by width has is normal for pytorch 00:18:15.920 |
um so for showing images it's nice to denormalize them so that's just denormalizing 00:18:22.320 |
and so if we show the image that we just grabbed it's a water jug i guess 00:18:36.160 |
all right so now we can create a data loader for our training set so it's going to contain our 00:18:40.960 |
transformed training data set and pass in a batch size this one has to be shuffled 00:18:47.920 |
not sure why i put num workers equals zero there 00:18:53.440 |
generally eight's pretty good if you've got at least eight cores 00:19:00.800 |
yeah so we can now grab an x batch and a y batch and take a look at a 00:19:06.800 |
denormalized image from there so there we've got a nice little kitty cat 00:19:09.840 |
so i think this is already looking better than fashion emnest 00:19:15.200 |
yeah so there's this thing words.txt that they've also provided and this is actually a list of the 00:19:25.520 |
entire wordnet hierarchy so the top of the hierarchy is entity and one of the entity types 00:19:35.360 |
is a physical entity or an abstract entity entities can be things and so forth so this is how wordnet is 00:19:45.520 |
yeah handled so this is quite a big file actually so if we go through each item of that file and 00:19:56.240 |
again split on tabs because split on tabs that's what backslash t means is going to give us the 00:20:06.320 |
wordnet id and then the name of it so now we can go through all of those they call them sin sets 00:20:16.400 |
and if the key is in our list of the 200 that we want we'll keep it and we don't really want like 00:20:27.280 |
causal agent comma cause comma causal agency the first one generally seems to be the most normal 00:20:34.720 |
so i just split on comma and grab the first one 00:20:39.520 |
um all right so that's um so we could then go through our y batch and just turn each of those 00:20:51.200 |
numbers into strings and then look at each of those up in our sin sets and join them up 00:20:55.680 |
and then use those as titles to see our Egyptian cat and our cliff and our guacamole 00:21:04.400 |
it's a monarch butterfly and so forth and you can see that this is going to be quite tricky 00:21:10.480 |
because like a cliff this is a cliff dwelling for instance could be quite you know complicated um 00:21:17.200 |
i have a feeling for this they intentionally like a hundred of the 00:21:21.600 |
categories might have come from the normal image net and i think they might have then picked a 00:21:26.800 |
hundred that are designed to be particularly difficult or something if memory serves correctly 00:21:33.280 |
um all right so then we could define a transform batch function with the same basic idea 00:21:40.880 |
and that's just gonna yeah transform the x and the y in a batch um 00:21:48.960 |
oh yes we're about to use that i should move that down a bit because we're not quite there yet 00:21:57.280 |
okay so before that we can create our data loaders we created a get dls back in an earlier lesson 00:22:03.920 |
which simply turns that into a data loader and that into a data loader and this one gets 00:22:08.480 |
shuffled and that one doesn't and so forth um oh i see this is where we do our num workers cool 00:22:14.240 |
um all right so then oh yeah so then we want to add um our data augmentation so i i noticed that 00:22:27.120 |
um training a tiny image net model i mean it's it's a much harder thing to do than fashion feminist 00:22:39.040 |
overfitting was actually a real challenge um and i guess it's because 64 by 64 isn't that many pixels 00:22:50.640 |
um um so yeah so i found i really needed data augmentation to make much progress at all 00:22:57.520 |
now um very common data augmentation is called random resource crop which is basically to pick 00:23:06.160 |
like one area inside and then zoom into it and make that your image but for such low resolution 00:23:13.120 |
images that tends to work really poorly because it's going to introduce a lot of kind of blurring 00:23:18.960 |
artifacts so instead for small images i think it's better to add a bit of padding around them 00:23:25.360 |
and then randomly pick a 64 by 64 area from that padded area so it's just going to shift them 00:23:32.640 |
slightly it's not a lot of augmentation but it's something and then we do our random 00:23:37.520 |
horizontal flips and then we'll use that random arrays thing that we created earlier 00:23:43.200 |
um this is just something i was experimenting with so yeah so now we can use that batch transform 00:23:50.800 |
callback using transform batch passing in those transforms so um with um torch vision transforms 00:24:01.360 |
so this capital t is torch vision transforms um yeah because these are all um nn dot modules 00:24:09.760 |
you can pass them to nn dot sequential to just have each of them called one at a time in a row 00:24:14.880 |
there's nothing magic about this it's just doing function composition we could easily create our 00:24:22.480 |
own um in fact they're also the transforms.compose that does the same thing yeah i was going to say 00:24:32.640 |
so we've got a fast um uh fast core dot compose which uh as you can see basically it just says 00:24:41.920 |
for f in funcs x equals f of x um yeah i don't know is there is there's a yeah torch 00:24:50.160 |
torch vision compose i think might be the kind of the old way to do it is that right i'm not sure 00:24:57.120 |
i have a feeling maybe this is considered the better way now because it's kind of scriptable 00:25:02.000 |
i'm not promising that though um but yeah it does basically the same thing 00:25:07.360 |
okay so yeah we can now create um a model as usual um 00:25:16.400 |
okay so basically um i copied the get model with dropout get drop model from our earlier 00:25:28.400 |
tiny sorry our earlier fashion emnist um stuff um and i yeah started with uh kernel size five 00:25:39.120 |
convolution and then yeah a bunch of res blocks um um yeah so this is um oh what we've used to seeing 00:25:53.440 |
before um and so we can take a look in this case as it quite often seems to be the case 00:26:00.640 |
we accidentally end up with no random erasing let's just run it again 00:26:04.640 |
really doesn't want to do random erasing here we go so we can see it so um yeah there's this 00:26:13.120 |
very small border you can hardly see sometimes and a bit of random erasing and it's been done um 00:26:19.760 |
you know all of the batch is being transformed or augmented in the same way 00:26:23.680 |
which is kind of okay um it's certainly faster um it can be a bit of a problem if you have like 00:26:33.440 |
one batch that has lots and lots and lots of augmentation being done to it and it could be 00:26:38.560 |
like really hard to recognize and that could cause the loss to be a lot in that batch and 00:26:45.200 |
if you're like been training for ages that could kind of jump you out of the um 00:26:52.080 |
you know the smooth part of the of the lost surface um that's that's the one downside of this so i'm 00:26:59.360 |
not going to say it's always a good idea to do augmentation at batch level but it can certainly 00:27:03.280 |
speed things up a lot if you don't have heaps of cpus um all right so you can use that summary thing 00:27:13.760 |
we created there's our model um and yeah because we're increasing the doubling the number of 00:27:23.280 |
channels as we're decreasing the grid size our number of mega flops per layer is constant so 00:27:27.760 |
that's a pretty good sign that we're using compute throughout um so yeah then we can train it with 00:27:33.760 |
adam w mixed precision um and our um augmentations so i then did the learning rate finder 00:27:51.280 |
nearly 60 59 percent and um yeah this took quite a while actually to get close to 60 percent i got 00:28:01.680 |
to admit um it uh and you can see that the training sets already up to 91 so we're kind of on the verge 00:28:10.080 |
of overfitting um um okay so then i thought all right um how do we do better 00:28:26.560 |
and i wanted to have a sense of like how much better could we get and i kind of tend to like 00:28:32.160 |
to look at papers with code which is a site that shows papers with their code and also like how 00:28:38.960 |
good results did they get so this is the image classification on tiny image net um and at first 00:28:44.640 |
i was like pretty disheartened to see all these like 90 plus things um but as i looked at the 00:28:54.320 |
papers i realized something well the first thing is i noticed that these ticks here 00:28:57.600 |
represent extra training data so these are actually pre-trained models that are only fine 00:29:04.800 |
tuned on tiny image net so that's a total cheat and then i looked more closely at this one and 00:29:09.680 |
actually these are also using pre-trained data so papers with code is actually incorrect um 00:29:14.160 |
and so the first ones i could see which i could clearly kind of replicate and made sense of was 00:29:22.560 |
this one so the the highest one that i'm confident of is this 72 um and so then 00:29:30.560 |
i kind of wanted to get a sense of right how you know how how much work is there to get from like 00:29:40.880 |
60 to 70 and how good is this um so i opened up the paper and so here's tiny image net 00:29:54.320 |
um and they've got like they're basically this paper turns out to be about a new type of mix 00:30:01.040 |
up data augmentation this is the normal kind of mix up and this is their special kind of mix up 00:30:05.680 |
and on a resnet 18 yeah i see they're getting like 63 64 65 with various different types of mix up 00:30:12.880 |
uh and kind of 64 or 65 for their special one and then if they use much bigger models than we're 00:30:19.120 |
using um they can get up to 66 ish so that kind of made me think okay this classifier is 00:30:25.120 |
not not bad um but there's clearly room to improve it um and i can't help myself i always have to try 00:30:33.680 |
to do better so this is a good opportunity to learn about a trick that is used in um 00:30:41.040 |
real resnets which is in a real resnet we don't just say 00:30:46.800 |
how many filters or channels or activations per layer and then just go through and do a 00:30:59.520 |
you know try to conv each time um but instead um you can also say the number of res blocks 00:31:11.760 |
per her kind of down sampling layer so this would say do three res blocks 00:31:18.000 |
and you know then down sample or down sample and then do three res blocks or something like that 00:31:23.760 |
or do three res blocks the first of which or the last of which is a down sample and then two res 00:31:28.000 |
blocks uh with a down sample and then two res blocks with a down sample so this has got a total 00:31:33.040 |
of one two three four five down samples but it's got it's rather than having one two three four five 00:31:40.800 |
res blocks it's going to have three four five six seven eight nine res blocks so it's nearly twice 00:31:47.840 |
as deep and so the way we do that is we just replace the places it was saying res block with 00:31:54.080 |
res underscore blocks and that's just a sequential which goes through the number of 00:32:01.040 |
blocks and creates a res block and you can do it a couple of ways in this case um 00:32:08.400 |
I said if it's the last one then make it stride two otherwise stride one so it's going to be 00:32:18.720 |
down sampling at the end of each set of res blocks um so that's the only thing I changed 00:32:24.320 |
I changed res block to res blocks and passed in the number of blocks which is this okay so 00:32:34.000 |
um so the number of megaflops is now 7 10 ish which is more than double right so 00:32:46.320 |
should give should have more opportunity to learn stuff which also it could be more opportunity to 00:32:50.400 |
overfit um so again we do our lr find and uh yeah so let's do 25 epox and I didn't actually add more 00:33:02.240 |
augmentation um okay and that got up to nearly 62 so that was a good improvement um and you know 00:33:13.920 |
interestingly it's not overfitting more it's actually if anything less which you know there's 00:33:20.400 |
something about its ability to actually learn um this which is slowing it down or something 00:33:27.920 |
um so I thought yeah it'd be nice to train it for longer so I decided to add 00:33:39.920 |
more augmentation um and uh to do that um I decided to use something called trivial augment 00:33:51.120 |
which is not a very well known approach but it deserves to be um 00:33:59.040 |
and it comes from Frank Hutter's lab he's he's Frank Hutter is somebody who consistently creates 00:34:07.840 |
extremely practical useful improvements um with much less of the nonsense that we often see from 00:34:16.720 |
the some of the huge well-funded labs um and so this one's kind of a bit of a reaction to some 00:34:23.920 |
previous approaches such as one called auto augment one called rand augment they might have both come 00:34:31.920 |
from google brain I'm not quite sure where they kind of used lots of like you know many many 00:34:39.040 |
thousands of tpu hours um to like optimize how every image is you know or how how each set of 00:34:46.720 |
images is is augmented and um yeah what these guys did is they said well what if we don't 00:34:52.000 |
do that but we just randomly pick a different augmentation for each image um and that's what 00:35:00.000 |
they did they just uh they just said algorithm one is the procedure pick an augmentation pick an 00:35:08.080 |
amount do it um I feel like they're almost kind of like trying to make a point about writing this 00:35:17.200 |
algorithm here um um yeah and they basically find this is at least as good or often better 00:35:26.720 |
actually than the incredibly resource intensive ones the incredibly resource intensive ones also 00:35:31.680 |
kind of require a different version for every data set um which is why they describe this as a 00:35:37.600 |
tuning free um so rather nicely and surprisingly for me it's actually built into pytorch 00:35:44.080 |
so if we go to pytorch's website and go to trivial augment wide 00:35:54.000 |
um yeah they show you some examples of trivial augment wide 00:35:59.280 |
we can create our own as well now the thing is um I found um that doing this at a batch level 00:36:11.120 |
worked poorly and I think the reason is what I described earlier I think sometimes it will pick 00:36:18.400 |
a really challenging augmentation to see on you know and it all totally don't mess up the 00:36:24.080 |
loss function and if every single image in the batch is like that then it'll shoot it off into 00:36:29.520 |
the distant parts of the of the um weight area um which is a good excuse for me to show 00:36:40.320 |
how to do augmentations um on a per item level um now um these actually require or some of them 00:36:54.640 |
require um having a pil image the python imaging library image not a tensor 00:37:04.720 |
so I had to change things around so we have to import image from pil um 00:37:11.760 |
and we have to change our tofum x now and we're going to do the augmentations in there 00:37:20.720 |
instead um for the training set um so for the training set 00:37:26.720 |
we're going to set one fact for both so we're going to pass in something is just do you want 00:37:32.800 |
to do augmentations so for the training set we're going to pass org equals true 00:37:39.280 |
and for the validation set we won't um so yeah so we so image.open is how you create a pil image 00:37:47.680 |
object um and then if we wanted augmentations then do these augmentations and then convert 00:37:57.680 |
it into a tensor so a torch vision has a dot to tensor we can then call and then we can normalize 00:38:06.560 |
it and actually I decided just to use torch visions normalize um I mean either is fine 00:38:12.320 |
or this one works well and then again if you want augmentation then do your rand arrays 00:38:18.560 |
and if you remember our rand arrays was designed to kind of use um zero one distributed 00:38:26.960 |
gaussian noise so you want that to happen after normalization so that's why do this order so yeah 00:38:34.560 |
so now we don't need to use the batch tofum thing we're just doing it all directly in the data set 00:38:44.000 |
so you can see you know you can do data augmentation in very simple ways without almost any framework 00:38:53.920 |
help here in fact we're really not we're not doing any and nothing's coming from a framework really 00:38:58.800 |
it's just yeah it's just this little tofum ds we made um and so now yeah we just pass that 00:39:06.640 |
into our data loaders get deals um and we don't need any augmentation callback 00:39:13.360 |
um all right so now we can keep improving things by doing something called pre-activation 00:39:26.160 |
resnets so if we go back to our original resnet 00:39:38.560 |
we have this conv block which consists two convolutions in a row 00:39:48.880 |
the second one has no activation and to remind you what conv is 00:39:59.600 |
is that we first of all do a conv and then optionally we do a normalization 00:40:07.360 |
and then optionally we do our activation function 00:40:09.680 |
so we end up and then the second of those has act equals none so basically what this is saying 00:40:19.120 |
is go convolution norm activation convolution norm that's what self.com is and then this is 00:40:29.840 |
the identity path so this does nothing at all if there's no downsampling or no change of channels 00:40:35.120 |
and then we apply the activation function the final activation function to the whole thing 00:40:45.440 |
so that was how the um original res block was designed which is kind of a bit of an accident 00:40:51.520 |
because i to be honest when i wrote that i didn't bother looking at the paper i just did whatever 00:40:55.840 |
seemed reasonable in my head um but yeah then looking into it further i looked at this this 00:41:02.800 |
slightly later paper by the same author as of the resnet paper chiming her um and um um 00:41:11.520 |
timing her uh al drew um you know this uh this version here on the left as you can see it's 00:41:22.240 |
conv norm relu conv norm add relu and um yeah he basically pointed out yeah you know what maybe 00:41:33.920 |
that's not great because the relu is being applied to the addition so there isn't actually a really 00:41:39.360 |
an identity path at all so wouldn't it be nice if we could have a pure identity path 00:41:43.680 |
and so to do that he proposed reordering things to go norm relu conv norm relu conv 00:41:52.000 |
add and so this is called a pre-act or pre-activation res block 00:42:01.600 |
so that means i had to redefine conv to do norm then act and then conv 00:42:08.880 |
so my sequential now has the activation in both places 00:42:14.320 |
and so yeah other than that um oh and then of course there's no activation 00:42:25.760 |
happening in the res block because it's all happening in the cons 00:42:33.840 |
yeah makes sense yeah cool um so this is now the site this is exactly the same except we now need 00:42:45.920 |
to have an activation and a batch norm after all those blocks because previously it finished with 00:42:52.640 |
an activation norm and activation now it starts with them so we have to put these at the end it 00:42:58.960 |
also means we can't start with a res block anymore because if we started with a res block then it 00:43:02.480 |
would have an activation function at the start which would throw away half of our data which 00:43:06.400 |
would be a bad idea um so you've got to be a bit careful with some of the details um but yeah so 00:43:15.120 |
now you can see that each image is getting its own augmentation and so this one's been shared 00:43:22.320 |
looks like it's a door or something because it's really hard to tell what the hell it is it's been 00:43:25.680 |
shared this one's been moved uh it looks like this one's also been shared um and you can 00:43:33.680 |
also see they've got different amounts of random arrays on them um so yeah so i thought i'd try 00:43:39.280 |
change training that for 50 epochs and that got us to 65 percent which um is you know as good as 00:43:59.440 |
nearly as good as the you know normal mix up things that are getting even on a resonant 50s 00:44:05.360 |
this is looking really good um so i won't spend time on this but i'll just mention i was kind of 00:44:13.440 |
curious like i mean one of the things i should mention also is they trained all these for 400 epochs 00:44:17.760 |
so i was kind of curious what would happen if we trained it a bit longer i wasn't 00:44:21.520 |
patient enough to train it for 400 epochs but i thought i could do 00:44:28.000 |
200 epochs so i just duplicated that last one um um that made it 200 epochs 00:44:43.520 |
which yeah is better than any of their non-special mix ups so i think it just goes to show you can 00:44:57.120 |
get you know genuinely state-of-the-art results so if we use their special mix up that would be 00:45:02.400 |
interesting to try as well see if we can match their results there but you know we've we've built 00:45:07.520 |
all this from scratch um we didn't do the data augmentation from scratch because it's not very 00:45:13.280 |
interesting but uh yeah other than that um so i think that's really cool so i know that you did 00:45:21.440 |
some other experiments with the the pre-activation oh right yeah um right when i saw that when i saw 00:45:31.280 |
the pre-activation success i was quite enthusiastic about it so i actually thought like oh maybe you 00:45:38.080 |
should go back and actually use it everywhere um but for but weirdly enough i think it's weird 00:45:44.480 |
like it it was worse for fashion MNIST and worse for like less data augmentation um i mean maybe 00:45:52.720 |
it's not that weird but because the idea of when et al introduced it they said this is to train 00:45:59.520 |
deeper models you know there's a there's a more pure identity path um and so with that more pure 00:46:05.920 |
identity path um that that should kind of let the gradients flow through it more easily and so there 00:46:12.560 |
should be a smoother surface weight surface loss surface um so yeah i guess it makes sense that 00:46:21.120 |
you don't really see the benefits on less deep models um the bit i'm surprised you elaborate 00:46:27.440 |
because like it seems like that should be that that sort of uh justification should be true for 00:46:32.480 |
smaller models right or well yeah it does but smaller models 00:46:39.840 |
um are going to have a less bumpy surface anyway they've just got less dimensions to be bumpy on 00:46:46.400 |
and um there's less more importantly they're less deep so there's less room for gradients to explode 00:46:53.920 |
exponentially um so they're not as sensitive um but yeah i mean i can see why they don't 00:47:02.800 |
necessarily help as much but i don't have any idea why they were worse and they were quite 00:47:07.520 |
consistently worse yeah yeah i find it quite interesting too yeah yeah it's quite curious 00:47:16.240 |
and it's interesting that when we do these like experiments on things that nowadays are 00:47:23.360 |
considered pretty fundamental and foundational you kind of all the time discover things that 00:47:29.120 |
nobody seems to have noticed or written about or there's plenty of room to as a kind of a more 00:47:34.720 |
experimental researcher to do experiments and then go like oh that's interesting and then try and 00:47:40.080 |
figure out what's going on yeah um i think a lot of researchers go in the opposite direction and 00:47:46.720 |
they try to start with like theoretical assumptions and then test them um when i think about it i feel 00:47:52.480 |
like uh maybe a lot of the more successful folks in terms of people who build stuff that actually 00:47:57.920 |
get used a more experimental first maybe um okay so um 00:48:07.200 |
shall we have a five minute break since we're kind of on the hour sure all right so let's now look at 00:48:18.720 |
um notebook 25 super res uh i've just um copied a few things in the previous notebook 00:48:27.680 |
some transforms and our data sets and our dnorm and our trifim batch 00:48:34.880 |
and our trifim x let me show you we're using trifim batch here 00:48:41.520 |
we're not even using trifim batch let's get rid of that because that's just confusing 00:48:47.680 |
okay so it looks like we're doing the per uh let's figure this out so what are we doing here so 00:48:54.000 |
we've got um what our two data sets all right so the goal of this is we're going to do super 00:49:00.800 |
resolution not um classification so let's talk about what that means what we're going to do 00:49:08.240 |
is the independent variable will be scaled down to a 32 by 32 pixel 00:49:16.560 |
image and the dependent variable will be the original image 00:49:21.760 |
and so to do random crop within a padded image and random flips both the independent and the 00:49:34.880 |
dependent variable needs to have had exactly the same random cropping and exactly the same flipping 00:49:38.880 |
otherwise you can't say oh this is how you do super res to go from the 32 by 32 to the 64 by 64 00:49:46.240 |
because it might be like oh it has to be flipped around and moved around so yes so for this kind of 00:49:50.640 |
image reconstruction task um it's important to make sure that your um augmentation is done 00:50:00.000 |
the same way on the independent the dependent variable so that's why we've put it into our 00:50:06.640 |
data set um and so this is something people often get confused about and they don't know how to do 00:50:11.120 |
it but it's actually pretty straightforward if we do it this way we just put it straight in the data 00:50:14.560 |
set um and it doesn't require any framework fanciness um now then what i did do is i then 00:50:25.760 |
added random erasing just to the training set and the reason for that is i wanted to make the 00:50:35.360 |
super resolution task a bit more difficult which means sometimes it doesn't just do super 00:50:41.040 |
resolution but it also has to like replace some of the deleted pixels with proper pixels 00:50:45.520 |
and so it gives it a little bit more to do you know which um can be quite helpful it's kind of 00:50:51.680 |
it's it's a it's a data augmentation technique and also something to give it like 00:50:55.680 |
more of an opportunity to learn what the pictures really look like 00:50:59.840 |
um okay so with that in case that though these are going to do the padding random cropping and 00:51:07.760 |
flipping um the training set will also add random erasing and then we create data loaders from those 00:51:13.360 |
would it make sense to use the trivial augment here 00:51:24.480 |
maybe yeah i don't particularly see a reason not to if um if if well only if you found that uh 00:51:37.120 |
overfitting was a problem and if you did do it you would do it to both independent and dependent 00:51:45.760 |
variables um so yeah here you can see an example the independent variables some of the in this case 00:51:51.680 |
all of them actually have some random arrays the dependent doesn't so it has to figure out how to 00:51:55.680 |
replace that with that and you can also see that this is very blocky and this is less blocky that's 00:52:06.080 |
because this has been gone down to 32 by 32 pixels and this one's still at the 64 by 64 00:52:12.320 |
so in fact once you go down that far the cat's lost its eyes entirely so it's going to be quite 00:52:16.800 |
challenging it's lost its lines entirely um so super resolution is quite a good task to try to 00:52:23.680 |
get a model to learn what pictures look like because it has to yeah figure out like how to 00:52:29.280 |
draw an eye and how to draw cat's whiskers and things like that um were you going to say something 00:52:36.000 |
jon i'm sorry oh i was just going to point out that the um data sets are also simpler because 00:52:41.840 |
you don't have to load the labels um so there's no difference between the train and the validation 00:52:46.160 |
now it's just finding the images good point yeah because the the label you know is actually a 00:52:51.200 |
dependent variable is just the picture um and so okay so because um turfum ds turfum ds has a 00:53:03.840 |
turfum x which is only applied to the independent variable um the independent variable has applied 00:53:11.360 |
to it this pair of resize to 32 by 32 and then interpolate and what that actually does is it 00:53:20.000 |
ends up still with a 64 by 64 image but the the pixels in that image are all like doubled up 00:53:28.240 |
and so that means that it's still doing super resolution but it's not actually 00:53:32.720 |
going from 32 by 32 to by 64 by 64 but it's just going from the 64 by 64 where all of the pixels 00:53:38.880 |
are like two by two pixels and it's just a little bit easier because that way um we could certainly 00:53:44.720 |
create a unit that goes from 32 to 64 but if you have the input and output image the same size it 00:53:51.840 |
can make code a little bit simpler um i originally started doing it by not doing this interpolate 00:53:58.000 |
thing and then i decided i was just getting a little bit confusing and there's no reason not 00:54:02.080 |
to do it this way frankly um okay so that's our task um and the idea is that then 00:54:09.920 |
if it does a good job of this you know you could pass 64 by 64 images into it and hopefully it 00:54:16.800 |
might turn them into 128 by 128 images um particularly if you trained it on a few different 00:54:21.760 |
resolutions you'd expect it to get pretty good at you know resizing things to a bunch of different 00:54:27.600 |
resolutions you could even call it multiple times um uh but anyway for this i was just kind of doing 00:54:34.000 |
it to to demonstrate um but we have in previous courses trained you know bigger ones for longer 00:54:41.200 |
with larger images and they actually do one of the interesting things is they tend to not only 00:54:45.680 |
do super resolution but they often make the images look better because the kind of the 00:54:52.160 |
pixels it fills in it kind of fills in with like what that image looks like on average which tends 00:55:00.480 |
to kind of like average out imperfections so often these super resolution models actually improve 00:55:06.320 |
image quality as well funnily enough okay so let's consider the dumb way to do things we've 00:55:13.200 |
seen a kind of a dumb way to do things before which is an autoencoder so go in with low expectations 00:55:18.400 |
here because we've done an autoencoder before it was so bad it actually inspired us to create the 00:55:22.400 |
learner if you remember so that was back in notebook eight um and so basically what we're 00:55:28.480 |
going to do is we're going to have a model which looks a lot like previous models it starts with 00:55:34.880 |
a res block kernel size five and then it's got a bunch of res blocks of stride two um but then 00:55:44.640 |
we're going to have an equal number of up blocks and what an up block is going to do is it's 00:55:52.480 |
going to sequentially first of all it's going to do an up sampling nearest 2d which is actually 00:55:57.120 |
identical to this right so it's going to just double all the pixels and then we're going to 00:56:07.040 |
pass that through a res block so it's basically a res block with like a stride of a half if you like 00:56:16.080 |
you know it's it's it's it's undoing a stride to it's up sampling rather than down sampling um 00:56:22.880 |
okay so and then we'll have an extra res block at the end to get it down to three channels which 00:56:29.520 |
is what we need um okay so we can do our learning learning uh learning rate finder on that 00:56:38.560 |
and i just train it pretty briefly for five epochs um so so this model is basically um trying to take 00:56:48.240 |
the image that we start up then kind of really squeeze it into i guess a small representation 00:56:53.600 |
and then try to bring that small representation back up to then the full super res yeah exactly 00:56:59.440 |
right tanish can and we could have done it without any of the stride too you know i guess we could 00:57:05.440 |
have just had a whole bunch of stride one layers there's a few reasons not to do it that way though 00:57:11.280 |
one is obviously just the computation requirements are very high because the convolution has to 00:57:15.440 |
scan over the image and so when you keep it at 64 by 64 that's a lot of scanning um another is that 00:57:24.400 |
um you're never kind of forcing it to learn higher level abstractions by recognizing how to kind of 00:57:31.520 |
like you know use more channels on a smaller grid size to represent it um so yeah it's like the 00:57:40.400 |
same reason that we in in classifiers we don't leave it it's tried one the whole time you know 00:57:46.160 |
you end up with something that's inefficient and generally not as good um exactly yeah thanks for 00:57:52.000 |
clarifying tanish um okay so the loss goes down and the loss function i'm using is just msc here 00:57:58.880 |
right so it's how similar is each pixel to the pixel it's meant to be 00:58:02.480 |
and so then i can call capture preds um to get their predictions and the targets and the inputs 00:58:11.200 |
or probabilities targets and inputs i can't quite remember now uh so here's our input images 00:58:20.720 |
and oh dear here's our predicted images so pretty terrible um so why is that 00:58:32.320 |
well basically it's kind of like the problem we had with our earlier auto encoder it's really 00:58:39.280 |
difficult to go from a like a two by two or four by four or whatever image into a 64 by 64 image 00:58:49.360 |
you know um we're asking it to do something that's just really challenging and so that would require 00:58:54.240 |
a much bigger model trained for a much longer amount of time i'm sure it's possible 00:59:00.160 |
um and in fact you know latent diffusion as we've talked about has a model that kind of 00:59:08.640 |
does exactly that um um but in our case there's no need to make it so complicated we can actually 00:59:15.600 |
do something dramatically easier um which is um we can um create a a unit 00:59:26.080 |
so units were originally developed in 2015 and they originally developed for medical imaging 00:59:38.000 |
um but they've been used very very widely since um and i was involved in medical 00:59:45.280 |
imaging at the time they came out and certainly they quite quickly got recognized in medical 00:59:49.200 |
imaging they took a little bit longer to get recognized elsewhere but nowadays they're pretty 00:59:53.280 |
universal and they are used in stable diffusion and basically um some of the details don't matter 01:00:01.600 |
here this is like the original paper um so let's focus on the kind of the broad idea this thing 01:00:08.000 |
here is called that we're going to call it the downsampling path so in this case they started 01:00:12.080 |
with 572 by 572 images it looks like they started with one channel images and then they you know 01:00:21.200 |
as we've seen then they took them down to 284 by 284 by 128 and then down to 140 by 140 by 256 01:00:28.400 |
and then down to 68 by 68 by 512 32 by 32 by 1024 so here's this downsampling path right 01:00:36.000 |
and then the upsampling path is exactly what we've seen before right so we up sample and have some 01:00:42.800 |
i mean in the original thing they didn't use res nets or res blocks um they just use comms 01:00:48.800 |
but the idea is the same um but the trick is these extra things across here these arrows 01:00:58.400 |
um which is copy and crop what we can do is we can take so during the upsampling 01:01:06.960 |
we've got a 512 by 512 here sorry a 512 channel thing here we can up sample to a 512 channel 01:01:17.440 |
thing we can then put it through a conf to make it into a 256 channel thing and then what we can do 01:01:31.200 |
is we can copy across the activations from here now they actually do things in a slightly weird 01:01:39.040 |
way where they're downsampling they had 136 pixels by 136 and over here they have 104 by 104 so they 01:01:46.400 |
crop out the center bit that's because of just kind of like the slightly weird way they did uh 01:01:51.600 |
they basically weren't padding things nowadays we don't have to worry about that that cropping 01:01:57.280 |
so what we do is we literally copy over this these activations and we then either concatenate or add 01:02:06.400 |
and you can see in this case they're concatenating see how there's the white bit and the blue bit 01:02:09.840 |
so they have concatenated the two lots together so actually i think what they did here is 01:02:16.320 |
they went from a 52 by 52 by 512 to a 104 by 104 by 256 and i think that's what this little blue 01:02:25.840 |
rectangle here is and then they had another uh copied copied out the 104 by 104 by 256 and 01:02:34.000 |
then put the two together to get a 104 by 104 by 512 and so this these activations half are from 01:02:47.120 |
the upsampling and half are from the downsampling from earlier in this whole process and it might be 01:02:56.880 |
easiest to understand why that's interesting when we get all the way back up to the top 01:03:02.000 |
where we've got this uh 392 by 392 thing the thing we're copying across now 01:03:09.040 |
is just two convolutions away from the original image so like for super resolution for example 01:03:17.920 |
we want it to look a lot like the original image so in this case we're actually going to have an 01:03:24.240 |
entire copy of almost something very much like the original image that we can include in these final 01:03:30.960 |
convolutions and so ditto here we have you know something that's kind of like the somewhat 01:03:37.120 |
downsampled version we can use here and the more downsampled version we can use here so 01:03:41.520 |
yeah that's that's how the u-net works do either of you guys have anything to add like things 01:03:47.920 |
that you found this helpful to understand or anything surprising i guess it's a fascinating 01:03:55.440 |
thing these days a lot of people tend to just add so you've got the you know the outputs from the 01:04:01.440 |
down layer are the same shape the inputs fully corresponding like up block and then they just 01:04:05.680 |
kind of add the yeah particularly for super resolution adding might make more sense than 01:04:10.320 |
concatenating because you're like literally saying like oh this little two by two bit is 01:04:15.200 |
basically the right pixel but it just have to be slightly modified on the edges yeah it also makes 01:04:22.320 |
me think of like a boosting sort of thing where if you think about like the fact that a lot of 01:04:28.560 |
information from the original image is being passed all the way across at that highest skip 01:04:32.240 |
connection then the rest of the network can be effectively producing an update to that 01:04:38.240 |
rather than having to recreate the whole image or to put it another way it's like a resnet but 01:04:45.280 |
there's a skip connections right but the skip connections are like jumping from the start to 01:04:50.400 |
the end and a bit after the start to a bit before the end and i guess a resonance a bit like boosting 01:04:56.000 |
too hmm yeah yeah i mean i was kind of going to say the same thing so yeah but basically like 01:05:05.520 |
i think uh in compared to like the the noising on encoder where like we saw like the results from 01:05:10.080 |
like even worse than i guess the original image here i guess the the worst it could be is basically 01:05:16.160 |
the original image so you know i guess it's it's just like a similar sort of uh kind of intuition 01:05:21.760 |
behind the the the result the resnet uh and how that works so yeah i mean it could be worse if 01:05:29.120 |
these comms at the end are incapable of undoing what these comms did um which is like one argument 01:05:36.240 |
for maybe why there should also be a connection from here over to here and maybe a few more comms 01:05:41.920 |
after that which is something i'm kind of interested in and not enough people do 01:05:46.480 |
in my opinion um another thing to consider is that they've only got two comms down here but at this 01:05:53.280 |
point you have the benefit of only being a 28 by 28 you know why not do more computation at this 01:06:01.120 |
point you know um so there's a couple of things that sometimes people consider but maybe not enough 01:06:09.600 |
um uh so let me try to remember what i did um so in my unit here 01:06:23.360 |
which is a list of res blocks now a module list is just like a sequential except it doesn't 01:06:33.520 |
actually do anything so then in the forward we have to go through the down path and x equals lx 01:06:41.760 |
each time so it's basically yeah a sequential that doesn't actually do anything um and so the up 01:06:49.200 |
path is exactly the same as we saw before it's a bunch of up blocks um and then like we saw before 01:06:55.680 |
the final one's going to have to go to three channel um but now for our forward what we're 01:07:06.480 |
since we're going to be copying this over here and copying this over here we have to save it 01:07:14.880 |
during the down sampling path so we're going to save it in a something called layers 01:07:24.800 |
so i actually decided to do the little trick i mentioned which is to save the very first input 01:07:30.000 |
um so i saved the very first input i then put it through the very first res block 01:07:36.880 |
and then we go through each in the downward path 01:07:43.200 |
there's actually no need at all for there to be an i l here doesn't have to be enumerated because 01:07:51.840 |
we don't use i okay so we go through the downward path so for this l for layer so for each layer 01:07:58.160 |
in the downward path append the activations so that again as we go through each one we're going 01:08:04.880 |
to be able to copy them over by saving them for later and then call the layer okay so how many 01:08:12.560 |
layers have we got there's n layers that we stored away so now we're going to go through the up 01:08:18.480 |
sampling path and again we're going to call call each one but before we do we're going to actually 01:08:23.760 |
do the thing that john i mentioned which is rather than concatenating unless we're back um at unless 01:08:29.280 |
with this is the very first layer because the very first up sampling layer there's nothing to copy 01:08:34.160 |
right so this is the very first up sampling layer let's just add the saved activations 01:08:43.440 |
and then call the layer um and then right at the very end we'll add back the very first layer 01:08:51.760 |
and then pass it through the very fine last res block 01:08:57.760 |
all right maybe that last one should be concatenated i'm not sure anyhow um this is what i did um 01:09:12.960 |
now the next thing that i wondered about was like how to um initialize this and basically what i 01:09:22.480 |
wanted to do is i wanted to initialize this so that when it's when it's untrained it would um 01:09:28.240 |
the output of the model would be identical to the input because like a reasonable starting point 01:09:34.480 |
for like what does this look like so yeah what does this look like following super resolution 01:09:40.000 |
would be this you know that's a reasonable starting point so um i just created this little zero weights 01:09:46.960 |
thing which zeros out the weights and biases of a layer right so i created the model and then i said 01:09:54.720 |
okay um let's look at the very end of the up sampling path and we'll call that the last res net 01:10:06.240 |
and so let's zero out the very last convolutions 01:10:11.840 |
and also the id connection and so that means that whatever it does for all this at the very end 01:10:22.720 |
it's going to have um nothing in there this will be zero so that means that this will be 01:10:32.240 |
equal to layer zero um and then that means we also want to make sure that this doesn't change 01:10:38.400 |
anything so then we can just zero out the weights there um that's probably not quite right is it 01:10:47.760 |
i guess i should have actually set those to like an identity matrix 01:10:53.200 |
maybe i'll try to do that later um but at least it's something that would be very easy for it to 01:10:59.920 |
i have a question germane yeah this this zero weights i see a lot of people do a thing where they 01:11:05.360 |
instead like multiply by one e minus three or one e minus four to make the weights really small 01:11:11.360 |
but not completely zero and i don't have a good intuition whether it's like you know in some sense 01:11:17.600 |
having everything set to zero fires off some warnings that maybe this is going to be like 01:11:22.320 |
perfectly balanced on some saddle point or it's not going to have any signal to work with 01:11:26.400 |
yeah it's very small but not quite zero random weights might be better yeah do you have an 01:11:30.560 |
individual that i think so or not so much intuition but more empirical like or both um i don't i don't 01:11:38.320 |
think it's an issue um and i think it comes from like a lot of people's phd supervisors and stuff 01:11:43.200 |
you know come from back in an era when they were doing like linear regression with one layer or 01:11:47.440 |
whatever and in those cases yeah if all the weights are the same then no learning can happen because 01:11:54.160 |
every weight update is identical but in this case all the previous weights are different 01:11:59.040 |
so there's they all have different gradients and there's definitely nothing to worry about 01:12:04.160 |
i mean multiplying it by a small number would work too like it's not a problem but um yeah 01:12:12.400 |
setting it to zeros i and honestly i um i have to stop myself from i mean not that's a problem but 01:12:20.000 |
i just i always have this natural incarnation to not want to set them to zeros because of years 01:12:26.320 |
of being told not to but there's no reason that should be a problem um all right so i just would 01:12:37.680 |
i was just like again like that unit code is very concise and it's very very interesting to see 01:12:45.920 |
the basic ideas you know very simple and oh yeah to see that i guess yeah yeah it's helpful i think 01:12:52.800 |
to just get it into a little bit of code isn't it yeah thanks um that's very simple code too 01:12:59.760 |
okay so we do a lot of find and then we train and you can see but previously our loss even after 01:13:11.440 |
five epochs there's 207 and in this case our loss after one epoch is oh wait six so it's obviously 01:13:21.840 |
much easier and we end up at 073 okay so we can take a look there's our inputs 01:13:32.880 |
and there's our outputs so it's actually better rather than dramatically worse now so that's good 01:13:39.760 |
um yeah some of it's actually not bad at all i would say 01:13:45.600 |
this car definitely looks like i think it's like a little over smoothed you know i think you could 01:13:56.240 |
say so if we look at the other guy's eyes kids eyes still aren't great like in the original he's 01:14:02.160 |
actually got proper pupils um so yeah it's definitely not recreated the original but 01:14:09.520 |
you know given limited compute and limited data like the basic idea is not bad um 01:14:18.880 |
i do worry that the poor koala like it it didn't have eyes here but like it ought to have known 01:14:26.240 |
there should be eyes in a sense and it didn't create any and maybe it should have done a better 01:14:30.800 |
job on the eyes so um my feeling is um and this is pretty common way of thinking about this is 01:14:38.320 |
that when you use mean squared error msc as your loss function on these kinds of models 01:14:42.960 |
you tend to get rather blurry results because if the model's not sure what to do it's just 01:14:49.200 |
going to predict kind of the average you know um so one good way to fix that is to use perceptual 01:14:56.880 |
loss and um i think it was johnno who taught us about perceptual loss wasn't it when we did 01:15:03.440 |
the style transfer stuff um so perceptual loss is this idea that we could look it's kind of similar 01:15:11.120 |
as well to the the fit idea um we could look at the some intermediate layer of a pre-trained model 01:15:20.080 |
and try to make sure that our output images have the same features as the real images and in this 01:15:31.680 |
case it ought to be saying like the real image you know if we went to kind of midway through a resnet 01:15:37.440 |
it should be saying like there should be an eye here you know and in this case this would not 01:15:42.720 |
represent an eye very well so that would should give it some useful feedback to improve how it 01:15:49.440 |
draws an eye here um so that's the basic idea um so to do perceptual loss we need to classify a model 01:15:57.600 |
so i just used the little i don't know why i use little 25 epoch one i guess maybe that's 01:16:03.200 |
all i had trained when at that time um so let's use little 25 epoch model um 01:16:13.520 |
so then um yeah just grab a batch of validation set and then we can just try it out by 01:16:20.320 |
calling the classifier model um and here i'm doing it 01:16:27.680 |
in fp16 just keeping my memory use down um um i don't think this dot half would be necessary 01:16:40.000 |
since i've got autocast anyway never mind um okay this is the same code we had before for the sin 01:16:46.480 |
sets um so here is our images um so what we've got here 01:17:07.120 |
huh i'm just looking at some of them they're a bit weird aren't they i mean koalas are 01:17:11.280 |
fine you know i wouldn't have picked this as a parking meter i wouldn't have picked this as 01:17:15.200 |
a bow tie um so yeah so basically what this is doing here is it's um 01:17:20.960 |
showing us the predictions so the predictions are not amazing um trolley bus that looks right 01:17:35.760 |
um this is weird it's called this one a neck brace and this one a basketball that looks 01:17:40.720 |
more like a neck brace the labrador retriever it's got right the tractor it's got right 01:17:44.240 |
centerpiece right mushrooms right those probably aren't bunching bags okay so you know you can see 01:17:49.520 |
our classifier it's okay but it's not amazing i think this was one with like a 60 accuracy 01:17:54.240 |
um but the important thing is it's like it's got enough features to be able to like 01:18:00.640 |
do an okay job i have no idea what this is so i'm pretty sure it's not a goose 01:18:09.440 |
the model is a very simple just a bunch of res blocks 01:18:18.800 |
um three four five and then at the end we've got our pooling flatten dropout linear batch note 01:18:30.720 |
um so we don't need yeah so what we're going to do is just to keep things simple we're just 01:18:40.960 |
going to grab um i think the end of the three res block and so a simple way to do that is 01:18:52.080 |
we'll just go from range four to the end of the model and delete those layers so if we do that 01:18:59.600 |
and then look at the model again you can now see i've got zero one two three and that's it 01:19:09.920 |
so this model um is going to yeah return the kind of the activations after the fourth res block 01:19:20.720 |
um so for perceptual losses i think we talked about you could like pick a couple of different 01:19:27.680 |
places like there's various ways to do it this is just the simplest i didn't even have to use 01:19:32.960 |
hooks or anything we can just call c model and um in fact if we do it 01:19:38.960 |
um so just to take a look at this looks like and again we're going to use 01:19:44.800 |
mixed precision here um we can grab our y batch as before put it through our classifier model 01:19:56.320 |
and so now that we've done this this is now going to give us those intermediate level features 01:20:02.320 |
so the features what's the shape of them it's batch size 1024 by the number of channels of that layer 01:20:10.720 |
by the height and width of that layer so these are 8 by 8 by 256 features we're going to be using for 01:20:17.040 |
the perceptual loss um and so when i was doing this i kind of wanted to like check with things 01:20:23.600 |
were vaguely looking reasonable um so i would have expected that these features um from the 01:20:33.440 |
actual y should be similar to if i um use our model 01:20:48.240 |
so something then i did i thought okay if we if we took that model that we trained then we would 01:20:55.280 |
hope that the features were at least of the same sign um from you know from the result of the model 01:21:04.720 |
than they are in the real images um so this is just me comparing that and it's like oh yeah 01:21:10.320 |
they are generally the same sign so this is just little checks that i was doing along the way 01:21:14.640 |
and then i also thought i kind of look at the msc loss along the way um yeah so there's no need to 01:21:22.800 |
keep all those in there it was just stuff i was kind of doing to like debug as i went well not 01:21:27.600 |
even debug to like identify ahead of time as of any problems um so now we can calculate create 01:21:33.440 |
our loss function so our loss function is going to be the um the msc loss just like before between 01:21:41.920 |
the input and the target which is just all that's being passed in here plus the msc loss 01:21:47.200 |
between the features we get out of c model and the features we get from the actual 01:21:54.880 |
and the features we get from the actual target image and so the features um we can calculate 01:22:07.840 |
for the target image now the target image we're not going to be modifying that at all so we do 01:22:15.360 |
that bit with no gradient um but we do want to be able to modify the thing that's generating our 01:22:22.080 |
input that's the model we're trying to actually optimize so we do have gradient for that so in 01:22:26.400 |
each case we're calling the classifier model one on the target and one on the input and so those 01:22:32.720 |
are giving us our features now then we add them together but they're not particularly similar 01:22:40.960 |
numerically like they're very different scales and we wouldn't want it to focus entirely on one or 01:22:47.200 |
the other so i just ran it um for epoch or two checked what the losses were looked like and i 01:22:53.200 |
noticed that the feature loss was about 10 times bigger so my very hacky way was just to divide 01:22:58.080 |
it by 10 um but honestly like that detail doesn't tend to matter very much in my opinion which 01:23:04.960 |
there's nothing wrong with doing it in a rather hacky way um there are papers which suggest more 01:23:13.120 |
elegant ways to handle it um which isn't a bad idea to save you a bit of time if you're doing 01:23:18.560 |
a lot of messing around with this jimmy i don't know if you know it but the um the new vae decoder 01:23:28.160 |
from stability ai for the stable diffusion auto encoder they trained it some with just 01:23:34.320 |
mean squared error and some with mean squared error combined with the perceptual loss 01:23:37.680 |
and they had a scaling factor of you know times 0.1 so exactly there you go so the answer is 0.1 01:23:44.960 |
that's that's the official and and drake apathy says that the correct learning rate to use is 01:23:50.560 |
always 4e neg 3 so we're getting all this sorted out now that's good all right so for my unit we're 01:23:57.440 |
going to do the same stuff as before in terms of initializing it do our lr find train it for 20 epochs 01:24:07.280 |
and obviously the loss is not comparable because this is lost now incorporates the perceptual loss 01:24:11.760 |
as well and so this is one of the challenges with these things it's like is it better or worse well 01:24:16.400 |
we just tend to have to take a look and compare i guess and maybe i should copy over our previous 01:24:22.000 |
models images so we can compare okay there's our inputs there's our outputs and yeah look he's got 01:24:31.760 |
pupils now which he didn't used to have koala still doesn't quite have eyeballs but i'd like 01:24:39.680 |
it's definitely less you know out of focus looking um so yeah flipping that's going on 01:24:50.720 |
yeah so there's some of them are going to be flipped because this is copied from earlier 01:24:54.480 |
so yeah there's clipping and cropping going on so they won't be identical 01:25:05.520 |
yeah you can also see like the background like was all just blurred before where else now it's 01:25:11.200 |
got texture which if we look at the real the real has texture you know so yeah clearly the 01:25:20.320 |
perceptual loss has improved matters quite significantly there's an interesting thing 01:25:27.680 |
here which is that there's not really any metric we can use now right because if we did mean square 01:25:31.040 |
error the one that's trained means good error would probably do better but visually it looks worse 01:25:35.920 |
yeah and if we use like an fid well that's based on the features of the pre-trained network so 01:25:40.400 |
that would probably be biased by the one that's trained using those features the perceptual loss 01:25:44.560 |
and so you get back to this very old school thing of like well actually how we're choosing is just 01:25:48.320 |
looking and evaluating right um and when you speak to someone like jason antich who's made a whole 01:25:53.040 |
career out of you know image restoration and super resolution and colorization that is like 01:25:58.160 |
a big part of his process even now is still like looking at a bunch of images to decide whether 01:26:04.720 |
something is better rather than relying on these yeah some phd student yelled at me on twitter a 01:26:10.560 |
few weeks ago for like saying like look at this cool thing our student made look how don't they 01:26:14.480 |
look better and he was like don't you know there's rigorous ways to measure these things this is not 01:26:19.120 |
a rigorous approach at all it's like phd students man they got all the answers can't have a human 01:26:27.200 |
looking at a picture and deciding if they like it or not that's insane well i'm a pc student i agree 01:26:33.120 |
though that we should be with me so yeah okay some phd students are better than others that's that's 01:26:38.080 |
fair enough um what's this oh right okay so talking of cheating let's do that um 01:26:57.360 |
so we're going to do something which is kind of fast ai's favorite trick and has been since we 01:27:07.920 |
first launched which is gradually unfreezing pre-trained networks um so in a sense it seems 01:27:16.960 |
a bit funny to initialize all of this downpath randomly because we already have a model that's 01:27:28.240 |
perfectly capable of doing something useful on tiny image net images which is this 01:27:36.160 |
so yeah what if we um took our unit right and for the model dot start which to remind you 01:27:51.360 |
is the res block right at the front why don't we use the actual 01:28:02.320 |
weights of the pre-trained model and then for each of the bits in the down sampling path 01:28:09.280 |
why don't we use the actual weights that we used from that as well and so this is a useful 01:28:16.000 |
way to understand how we can um copy over weights which is that any part of a module 01:28:28.160 |
an nn dot module is itself an nn dot module an nn dot module has a state dict which is a thing 01:28:35.520 |
you can then call load state dict to put it somewhere else so this is going to fill in the 01:28:40.560 |
whole res block called model dot start with the whole res block which is p model zero 01:28:46.400 |
so here's how we can copy across yeah that starting one and then all the down blocks are 01:28:52.320 |
going to have the rest of it so this is basically going to copy into our model rather than having 01:28:57.280 |
random weights we're going to have all the weights from our pre-trained model 01:29:01.680 |
um and then since they're they're good at doing something they're not good at doing super 01:29:11.040 |
resolution but they're good at doing something why don't we assume that they're good at doing 01:29:16.400 |
super resolution so turn off requires grad and so what that means if we now train it's not going to 01:29:23.600 |
update any of the parameters in the down block i guess i should have actually done model dot 01:29:28.880 |
start requires grad as false two now think about it um and so this is uh the the classic uh fine 01:29:36.320 |
tune approach from fastai the library um we're going to do one epoch of just the up sampling path 01:29:49.120 |
and that gets us to a loss of 255 now our um loss function hasn't changed that's totally comparable 01:29:55.760 |
so previously our one epoch was 385 and in fact after one epoch with frozen weights for the down 01:30:04.960 |
path we've beaten this now this is in a sense totally cheating but in a sense it's totally 01:30:13.840 |
not it's totally cheating because the thing we're trying to do is to generate for the perceptual 01:30:22.240 |
loss intermediate layer activations which are the same as this and so we're literally using that to 01:30:32.880 |
create intermediate layer activations so obviously that's going to work but why is it okay to be 01:30:41.200 |
cheating well because that's actually what we want like to be able to do super resolution we need 01:30:47.360 |
something that can like recognize there's an eye here so we already has something that know that 01:30:53.520 |
there's an eye there and in fact interestingly this thing trained a lot more quickly than this 01:31:00.800 |
thing and it turns out it's better at super resolution than that thing even though it wasn't 01:31:07.040 |
trained to do super resolution and I think that's because that the signal which is just like 01:31:11.840 |
what is this is a really simple signal to use so yeah so we do that and then we can basically 01:31:22.080 |
go through and set requires grad equals true again and so the basic idea being here that 01:31:26.960 |
yeah when you've got a bunch of random weights which is the whole up sampling path and a bunch 01:31:34.160 |
of pre-trained weights the down sampling path don't start then fine-tuning the whole thing 01:31:39.280 |
because at the start it's going to be crap you know so and so just train the random weights 01:31:44.960 |
for at least an epoch and then set everything to unfrozen and then we'll do our 20 epochs on the 01:31:52.160 |
whole thing and so we go from 255 to 249 207 198 so it's improved a lot so to verify with the 01:32:08.240 |
with using these weights and comparing that to the perceptual loss the perceptual loss is looking 01:32:15.600 |
at the up sample data the super resolution images as well as incorporating the weights 01:32:22.240 |
that's for the down sampling path and so that's looking at I guess the original 01:32:26.560 |
downgraded right although we are just adding them so if you have zeros in the up sampling path that 01:32:34.160 |
it's going to be the same so it is very easy for it to get the correct activations in the up sampling 01:32:42.000 |
path and then yeah I mean then it's kind of a bit weird because it goes all the way back to the top 01:32:49.440 |
creates the image and then goes into the class of C model the classifier again 01:32:53.440 |
but I think it's going to create basically the same activations 01:32:59.520 |
it's a bit confusing and weird so yeah I mean it's not totally cheating but it's 01:33:07.520 |
it's certainly a much easier problem to solve yeah okay so let's get our 01:33:18.640 |
yeah so that's looking pretty impressive so the kid has a you know yeah definitely looks 01:33:30.400 |
pretty reasonable now car looks pretty reasonable we still don't have eyes for 01:33:38.880 |
the koala such as life but definitely the background textures look way better 01:33:43.120 |
the candy store looks less much better than it did 01:33:48.640 |
medicine looks a lot better than it did so yeah it's really I think it looks great 01:33:59.360 |
okay so then we can get better still this is not part of the original unet but 01:34:04.800 |
you know making better models is often about like where can we squeeze in more computation 01:34:12.720 |
give it opportunities to do things and like there's nothing particularly that says that this 01:34:17.760 |
down sampling thing is exactly the right thing you need here right it's being used for two things 01:34:23.040 |
one is this conv and one is this conv but those are two different things and so it's kind of having 01:34:29.280 |
to like learn to squeeze both purposes into one thing so I had this idea probably I'm sure lots 01:34:35.520 |
of people have had this idea but whatever I had this idea which is why don't we put some res blocks 01:34:40.560 |
in here which I called cross connections or cross cons so I decided that a cross conv 01:34:49.040 |
is going to be just a res block followed by a conv and so the unit I just copied and pasted 01:34:58.800 |
but now as well as the downs I've also got crosses and so the crosses are cross cons 01:35:04.400 |
so now rather than just adding the layer I add the cross conv applied to the layer 01:35:14.640 |
yeah I really should have added a cross con for this one as well now I think about it this 01:35:22.320 |
is the probably the one that wants it the most oh well never mind another time um okay so now 01:35:29.840 |
yeah again we can definitely compare loss functions so this is one nine eight so everything else was 01:35:36.960 |
the same so I did the same thing of because you know the down sampling is the same so we can still 01:35:43.040 |
copy in the state dict requires grad and it's better one eight nine quite a lot better really 01:35:53.200 |
because you know this is these are hard to get improvements uh let's see if we can notice 01:35:57.840 |
anything hey look it's got an i just yeah so how about that um at this point it's almost 01:36:10.720 |
quite difficult to see whether it's an improvement or not but I think there's a bit of an eye on the 01:36:15.280 |
koala I think is encouraging yeah so that's uh super res uh oh man the bad news is we're out of time 01:36:37.040 |
okay we didn't promise to do diffusion unit this lesson so 01:36:42.960 |
we built a unit we built a unit yes we did and it's and we did super resolution with it and it 01:36:52.960 |
looks pretty good so um I gotta admit I haven't thought about like exercises for people to do 01:37:00.800 |
what would be useful things for people to try with like maybe they could create a unit they 01:37:06.960 |
could learn about segmentation create a unit for segmentation or oh you know there are a couple of 01:37:13.760 |
points where you well I was just gonna say there were a couple of ways we said oh I should have 01:37:18.880 |
tried this and should have tried that I think that's obviously yeah basically yeah I think 01:37:24.560 |
that's obviously a good next step I was gonna say um style transfer is a good idea to do I 01:37:30.560 |
think with a unit so style transfer you can actually set up a loss function so that you 01:37:36.400 |
can create a unit that learns to create images that look like van Gogh you know for example 01:37:41.840 |
um it's a totally different approach it's a it's a tricky one I think I think when I was playing 01:37:49.280 |
with that it almost helped to not have the skip connections at the highest resolutions otherwise 01:37:57.040 |
it just really wants to copy the input and modify it slightly interesting um maybe doing um whereas 01:38:03.760 |
which one would be better there too oh yes that's a good point yeah 01:38:10.400 |
oh well we'll put some stuff up on the website about yeah you know ideas and I'm sure some 01:38:17.920 |
students hopefully by the time you watch this will have some ideas on the forum or things they've 01:38:22.480 |
tried to yeah all right yeah the colorization is nice because it's um so colorization right 01:38:30.480 |
the transform is just to grayscale and back oh yes and then that's yeah that's already actually 01:38:36.640 |
okay so there's all kinds of decrapification you could do isn't there so if you want to keep it a 01:38:40.880 |
bit more simple yes rather than doing these two lines of code you could um um yeah just turn it 01:38:53.200 |
into black and white that's a great point um or um you could delete the center every time you know 01:39:04.720 |
to create like a something that learns how to fill in or maybe delete the left hand side 01:39:10.880 |
and that way that would lay that something that you can give it a photo in it all 01:39:13.920 |
invent a little bit more to the left yeah and then you could keep running it to generator panorama 01:39:21.120 |
another one you could do would be to like um in memory or something save it as a really uh 01:39:30.400 |
highly compressed jpeg and so then you would it would be something that would learn to remove jpeg 01:39:36.160 |
artifacts which then for your like old photos that you saved with crappy jpeg compression you could 01:39:43.760 |
bring them back to life uh you could probably do like yeah you could do like i guess drawing to 01:39:51.840 |
painting or something like this by taking some paintings and then like passing it through such 01:39:55.600 |
an edge detection and using that as your starting point sounds interesting oh uh what about 01:40:01.760 |
watermark removal you could um you know use pil or whatever to draw watermarks text whatever over 01:40:09.760 |
the top which is quite useful for like you know radiology images and stuff sometimes have 01:40:15.440 |
personally identifiable information written on them and you can just like learn to delete it 01:40:20.720 |
yeah okay so lots of things people can do that's awesome thanks for your ideas basically any image 01:40:27.680 |
to image nets super all right um or just make the super res better um try it on full image net 01:40:38.640 |
if you like um if you've got lots of hard drive space thanks jonno thanks to nish