back to indexLesson 19: Deep Learning Foundations to Stable Diffusion
Chapters
0:0 Introduction and quick update from last lesson
2:8 Dropout
12:7 DDPM from scratch - Paper and math
40:17 DDPM - The code
41:16 U-Net Neural Network
43:41 Training process
56:7 Inheriting from miniai TrainCB
60:22 Using the trained model: denoising with “sample” method
69:9 Inference: generating some images
74:56 Notebook 17: Jeremy’s exploration of Tanishq’s notebook
84:9 Make it faster: Initialization
87:41 Make it faster: Mixed Precision
89:40 Change of plans: Mixed Precision goes to Lesson 20
00:00:00.000 |
Okay, hi everybody, and this is Lesson 19 with extremely special guests Tanish and Jono. 00:00:13.640 |
And it's New Year's Eve 2022, finishing off 2022 with a bang, or at least a really cool 00:00:25.920 |
And most of this lesson's going to be Tanish and Jono, but I'm going to start with a quick 00:00:39.840 |
What I wanted to show you is that Christopher Thomas on the forum, what I want to show you 00:00:49.360 |
is that Christopher Thomas on the forum came up with a better winning result for our challenge, 00:00:58.960 |
the Fashion MNIST Challenge, which we are tracking here. 00:01:04.520 |
And be sure to check out this forum thread for the latest results. 00:01:11.240 |
And he found that he was able to get better results with Dropout. 00:01:17.920 |
Then Peter on the forum noticed I had a bug in my code, and the bug in my code for ResNets, 00:01:27.360 |
actually I won't show you, I'll just tell you, is that in the res block I was not passing 00:01:32.400 |
along the batch norm parameter, and as a result, all the results I had were without batch norm. 00:01:38.120 |
So then when I fixed batch norm and added Dropout at Christopher's suggestion, I got 00:01:45.040 |
better results still, and then Christopher came up with a better Dropout and got better 00:01:52.400 |
So let me show you the 93.2 for 5 epochs improvement. 00:01:59.640 |
I won't show the change to batch norm because that's actually, that'll just be in the repo 00:02:08.880 |
So I'm going to tell you about what Dropout is and then show that to you. 00:02:13.740 |
So Dropout is a simple but powerful idea where what we do with some particular probability, 00:02:22.480 |
so here that's a probability of 0.1, we randomly delete some activations. 00:02:30.100 |
And when I say delete, what I actually mean is we change them to zero. 00:02:36.080 |
So one easy way to do this is to create a binomial distribution object where the probabilities 00:03:00.180 |
Of course, randomly, that's not always going to be the case. 00:03:03.240 |
But since I asked for 10 samples and 0.1 of the time, it should be 0, I so happened to 00:03:11.020 |
And so if we took a tensor like this and multiplied it by our activations, that will set about 00:03:21.360 |
a tenth of them to 0 because multiplying by 0 gives you 0. 00:03:28.400 |
So you pass it and you say what probability of Dropout there is, store it away. 00:03:33.440 |
Now we're only going to do this during training time. 00:03:37.900 |
So at evaluation time, we're not going to randomly delete activations. 00:03:43.440 |
But during training time, we will create our binomial distribution object. 00:03:52.200 |
And then you say, how many binomial trials do you want to run, so how many coin tosses 00:03:58.760 |
or dice rolls or whatever each time, and so it's just one. 00:04:04.000 |
If you put that one onto your accelerator, GPU or MPS or whatever, it's actually going 00:04:10.520 |
to create a binomial distribution that runs on the GPU. 00:04:13.720 |
That's a really cool trick that not many people know about. 00:04:18.800 |
And so then if I sample and I make a sample exactly the same size as my input, then that's 00:04:25.680 |
going to give me a bunch of ones and zeros and a tensor the same size as my activations. 00:04:32.240 |
And then another cool trick is this is going to result in activations that are on average 00:04:41.400 |
So if I multiply by 1 over 1 minus 0.9, so I multiply in this case by that, then that's 00:05:02.280 |
And Jeremy, in the line above where you have props equals 1 minus p, should that be 1 minus 00:05:15.940 |
Not that it matters too much because, yeah, you can always just use nn.dropout at this 00:05:21.080 |
point and use point one, which is why I didn't even see that. 00:05:24.440 |
So as you can see, I'm not even bothering to export this because I'm just showing how 00:05:27.720 |
to repeat what's already available in PyTorch. 00:05:40.640 |
So if we're in evaluation mode, it's just going to return the original. 00:05:45.880 |
If p equals 0, then these are all going to be just ones anyway. 00:05:53.440 |
So we'll be multiplying by 1 divided by 1, so there's nothing to change. 00:06:01.080 |
Yeah, and otherwise it's going to kind of zero out some of our activations. 00:06:05.240 |
So a pretty common place to add dropout is before your last linear layer. 00:06:15.720 |
So yeah, if I run the exact same epochs, I get 93.2, which is a very slight improvement. 00:06:24.200 |
And so the reason for that is that it's not going to be able to memorize the data or the 00:06:34.880 |
actions, because there's a little bit of randomness. 00:06:39.400 |
So it's going to force it to try to identify just the actual underlying differences. 00:06:45.500 |
There's a lot of different ways of thinking about this. 00:06:47.480 |
You can almost think of it as a bagging thing, a bit like a random forest. 00:06:50.880 |
Each time it's giving a slightly different kind of random subset. 00:07:03.120 |
I also added a drop2d layer right at the start, which is not particularly common. 00:07:10.320 |
This is also how Christopher Thomas' idea tried it as well, although he didn't use dropout2d. 00:07:14.920 |
What's the difference between dropout2d and dropout? 00:07:17.960 |
So this is actually something I'd like you to do to implement yourself is an exercise 00:07:24.480 |
The difference is that with dropout2d, rather than using x.size as our tensor of ones and 00:07:36.400 |
So in other words, potentially dropping out every single batch, every single channel, 00:07:40.400 |
every single x, y independently, instead, we want to drop out an entire kind of grid 00:07:53.080 |
So if any of them are zero, then they're all zero. 00:07:57.640 |
So you can look up the docs for dropout2d for more details about exactly what that looks 00:08:04.280 |
But yeah, so the exercise is to try and implement that from scratch and come up with a way to 00:08:12.640 |
So like actually check that it's working correctly, because it's a very easy thing to think that 00:08:23.480 |
So then yeah, Christopher Thomas actually found that if you remove this entirely and 00:08:31.120 |
only keep this, then you end up with a better results for 50 epochs. 00:08:39.000 |
So I feel like we should insert some kind of animation or trumpet sounds or something 00:08:48.160 |
I'm not sure if I'm clever enough to do that in the video editor, but I'll see how I go. 00:09:00.400 |
Did you guys have any other things to add about dropout, how to understand it or what 00:09:06.080 |
Oh, I did have one more thing before, but you go ahead if you've got anything to mention. 00:09:12.120 |
And I was going to ask just because I think the standard is to set it like remove the 00:09:18.520 |
But I was wondering if there's anyone you know of, or if it works to use it for some 00:09:28.800 |
Did you see this or were you just like, okay, test time dropout callback. 00:09:36.440 |
So yeah, before Epoch, if you remember in learner, we put it into training mode, which 00:09:48.200 |
actually what it does is it puts every individual layer into training mode. 00:09:53.160 |
So that's why for the module itself, we can check whether that module is in training mode. 00:09:58.800 |
So what we can actually do is after that's happened, we can then go back in this callback 00:10:04.000 |
and apply a lambda that says if this is a dropout, then put it in training mode all 00:10:26.280 |
And so then you can run it multiple times just like did for TTA with this callback. 00:10:33.720 |
Now that's very unlikely to give you a better result because it's not kind of showing at 00:10:45.720 |
different versions or anything like that, like TTA does that are kind of meant to be 00:10:53.120 |
But what it does do is it gives you some a sense of how confident it is. 00:11:04.400 |
If it kind of has no idea, then that little bit of dropouts quite often going to lead 00:11:12.920 |
So this is a way of kind of doing some kind of confidence measure. 00:11:16.920 |
You'd have to calibrate it by kind of looking at things that it should be confident about 00:11:22.120 |
and not confident about and seeing how that test time dropout changes. 00:11:27.020 |
But the basic idea, it's been used in medical models before. 00:11:35.040 |
I wouldn't say it's totally popular, which is why I didn't even bother to show it being 00:11:41.840 |
used, but I just want to add it here because I think it's an interesting idea and maybe 00:11:47.160 |
could be more used than it is, or at least more studied than it has been. 00:11:54.900 |
A lot of stuff that gets used in the medical world is less well-known in the rest of the 00:12:07.840 |
So I will stop my sharing and we're going to switch to Tanish, who's going to do something 00:12:13.880 |
much more exciting, which is to show that we are now at a point where we can do DDPM from 00:12:21.680 |
scratch or at least everything except the model. 00:12:26.120 |
And so to remind you, DDPM doesn't have the latent VAE thing and we're not going to do 00:12:32.720 |
conditional, so we're not going to get to tell it what to draw. 00:12:40.000 |
And the UNET model itself is the one bit we're not going to do today. 00:12:46.040 |
We're going to do that next lesson, but other than the UNET, it's going to be unconditional 00:13:05.620 |
You may notice people look a little bit different. 00:13:09.880 |
So we have a couple of days have passed and we're back again. 00:13:14.600 |
And then Chavado recorded his bit before we do Tanishk's bit. 00:13:18.520 |
And then we're going to post them in backwards. 00:13:20.360 |
So hopefully there's not too many confusing continuity problems as a result. 00:13:24.640 |
And it all goes smoothly, but it's time to turn it over to Tanishk to talk about DDPM. 00:13:30.800 |
So we've reached the point where we have this mini AI framework and I guess it's time to 00:13:42.760 |
now start using it to build more, I guess, sophisticated models. 00:13:48.440 |
And as we'll see here, we can start putting together a diffusion model from scratch using 00:13:54.600 |
the mini AI library, and we'll see how it makes our life a lot easier. 00:13:59.040 |
And also, it'd be very nice to see how, you know, the equations in the papers correspond 00:14:05.420 |
So I have here, of course, the notebook that we'll be watching from. 00:14:13.800 |
The paper, which, you know, we have the diffusion model paper, the Doising Diffusion Probabilistic 00:14:20.720 |
Models, which is the paper that was published in 2020, it was one of the original diffusion 00:14:27.200 |
model papers that kind of set off the entire trend of diffusion models and is a good starting 00:14:38.200 |
And also, I have some diagrams and drawings that I will also show later on. 00:14:45.880 |
But yeah, basically, let's just get started with the code here, and of course, the paper. 00:14:53.240 |
So just to provide some context with this paper, you know, this paper that was published 00:14:58.600 |
from this group in UC Berkeley, I think a few of them have gone on now to work at Google. 00:15:10.500 |
And so diffusion models were actually originally introduced in 2015. 00:15:15.520 |
But this paper in 2020 greatly simplified the diffusion models and made it a lot easier 00:15:20.040 |
to work with and, you know, got these amazing results, as you can see here, when they trained 00:15:25.000 |
on faces and in this case, CIFAR-10, and, you know, this really was very kind of a big leap 00:15:34.160 |
in terms of the progress of diffusion models. 00:15:37.960 |
And so just to kind of briefly provide, I guess, kind of an overview. 00:15:42.560 |
If I could just quickly just mention something, which is, you know, when we started this course, 00:15:52.040 |
we talked a bit about how perhaps the diffusion part of diffusion models is not actually all 00:16:01.480 |
Everybody's been talking about diffusion models, particularly because that's the open source 00:16:07.920 |
But this week, actually a model that appears to be quite a lot better than stable diffusion 00:16:15.040 |
was released that doesn't use diffusion at all. 00:16:20.320 |
Having said that, the basic ideas, like most of the stuff that Tanishk talks about today, 00:16:28.440 |
will still appear in some kind of form, you know, but a lot of the details will be different. 00:16:35.160 |
But strictly speaking, actually, I don't even know if I've got a word anymore for the kind 00:16:41.900 |
of like modern generative model things we're doing. 00:16:46.760 |
So in some ways when we're talking about diffusion models, you should maybe replace it in your 00:16:51.960 |
head with some other word, which is more general and includes this paper that Tanishk is looking 00:17:01.760 |
I'm sure by the time people watch this video, probably, you know, somebody will have decided 00:17:11.920 |
Yeah, yeah, this is the paper that first Jeremy was talking about. 00:17:16.000 |
And yeah, every week there seems to be another state of the art model. 00:17:21.260 |
But yeah, like Jeremy said, a lot of the principles are the same, but you know, the details can 00:17:30.040 |
And just to, I just want to again, also, like Jeremy was saying, kind of zoom back a little 00:17:34.360 |
bit and kind of talk about a little bit about what, you know, I just kind of provide a review 00:17:42.680 |
So let me just, yeah, so with this task, we were trying to, in this case, I would try to 00:17:54.880 |
do image generation, of course, it could be other forms of generation like text generation 00:18:01.080 |
And the general idea is that, of course, we have some, you know, data points, you know, 00:18:07.360 |
in this case, we have some images of dogs, and we want to produce more like these data 00:18:12.400 |
So in this case, maybe the dog image generation or something like this. 00:18:17.120 |
And so the overall idea that a lot of these approaches take for image, you know, for some 00:18:24.000 |
sort of generative modeling task is they have, they tried to, not over there, here, they tried 00:18:30.960 |
to, yeah, so let me use it in a bit, P of X, which is our, which is basically the sort 00:18:45.160 |
of likelihood of data point X, of X. So let's say X is some image, then P of X tells us 00:19:01.560 |
like, what is the probability that you would see that image in real life. 00:19:07.320 |
And like, if we can take like a simpler example, which may be easier to think about of like 00:19:12.760 |
a one dimensional data point, like height, for example. 00:19:17.080 |
And if we were to look at height, of course, we know like we have a data distribution that's 00:19:22.240 |
And, you know, you have maybe some, you know, mean height, which is like something like 00:19:27.160 |
5'9", 5'10", yeah, I guess 5'9", 10 inches or something like that, or 5'9", whatever. 00:19:34.920 |
And of course, we have some is like, you have some on more unlikely points, but that is 00:19:39.920 |
Like, for example, we have a set of feet, or where you have something that maybe not 00:19:44.320 |
as likely, it was just like, you know, like three feet or something like this. 00:19:46.960 |
So here's my X axis is height, and the Y axis is the probability of some random person you 00:19:57.760 |
So you know, yeah, this is basically the probability. 00:20:02.160 |
And so of course, you have this sort of peak, which is where, you know, you have higher 00:20:05.760 |
probability. And so those are the sorts of, you know, values that you would see more often. 00:20:11.200 |
So this is this is our what we do call our P of X. And like, the important part about 00:20:19.360 |
P of X is that you can use this now to sample new values, if you know what P of X is, or 00:20:25.320 |
even if you have some sort of information about P of X. 00:20:27.480 |
So for example, here, you can think of like, if you were to like, say, maybe have some, 00:20:33.040 |
let's say you have some game, and you have some human characters in the game, and you 00:20:36.600 |
just want to randomly generate a height for this human character, you know, you could, 00:20:42.780 |
you wouldn't want to, of course, select a random height between three and seven, that's 00:20:45.800 |
kind of uniformly distributed, you would instead maybe want to, you would want to have the 00:20:52.960 |
height dependent on this sort of function where you would more likely sample values, 00:20:58.760 |
you know, in the middle, and less likely sample the source of extreme points. So it's dependent 00:21:03.760 |
on this function to P of X. So having some information about P of X will allow you to 00:21:09.760 |
sample more data points. And so that's kind of the overall goal of generative modeling 00:21:15.260 |
is to get some information about P of X, that then allows us to sample new points and, you 00:21:22.040 |
know, create new generations. So that's kind of a high level kind of description of what 00:21:28.520 |
we're trying to do when we're doing generative modeling. And of course, there are many different 00:21:34.120 |
approaches. We, you know, we have our famous scans, which, you know, used to be the common 00:21:39.800 |
method back in the day before diffusion models, you know, we have VAEs, which I think we'll 00:21:45.760 |
probably talk a little bit more about that later as well. 00:21:48.120 |
We'll be talking about both of those techniques later here. 00:21:51.520 |
Yeah. And so there are many different other techniques. There are also some niche techniques 00:21:55.320 |
that are out there as well. But of course, now the popular one is, are these diffusion 00:22:00.320 |
models? Or, you know, as we talked about, maybe a better term might be inhibitor of 00:22:04.200 |
refinement or whenever, you know, whatever the term ends to be. But yeah, so there are 00:22:12.160 |
many different techniques. And yeah, so let's just, so this is kind of the general diagram 00:22:20.160 |
that shows what diffusion models are. And if we can look at the paper here, which let's 00:22:26.440 |
pull up the paper. Yeah, you see here, this is the sort of, they call it a direct directed 00:22:31.960 |
graphical model. It's a very complicated term. It's just kind of showing what's going on 00:22:36.080 |
in this, you know, in this process. And, you know, there's a lot of complicated math here, 00:22:42.680 |
but we'll highlight some of the key variables and equations here. 00:22:48.520 |
So basically the idea is that, okay, so let's see here. So this is like, so this is an image 00:22:54.720 |
that we want to generate, right? And so X0 is basically, you know, these are actually 00:23:03.440 |
the samples that we want. So we want to, X0 is what we want to generate. And, you know, 00:23:11.280 |
these would be, yeah, these are images. And we start out with pure noise. So that's the, 00:23:19.440 |
that's what, X uppercase T, pure noise. And the whole idea is that we have two processes. 00:23:30.680 |
We have this process where we're going from pure noise to our image. And we have this 00:23:38.240 |
process where we're going from image to pure noise. So the process where we're going from 00:23:41.880 |
our image to pure noise, this is called the forward process. Well, word, sorry, my typing 00:23:50.820 |
instead of my handwriting is not so good in it. So hopefully it's clear enough. Let me 00:23:56.640 |
know if it's not. So we have the forward process, which is mostly just used for our training. 00:24:02.460 |
Then we also have our reverse process. So this is the reverse process, right up here. Reverse 00:24:13.320 |
process. So this is a bit of a summary, I guess, of what you and Waseem talked about 00:24:20.840 |
in lesson 9b. And just, it's just mostly to highlight now what, what are the different 00:24:29.520 |
variables as we look at the, the code and see, you know, the different variables in 00:24:35.120 |
the code. Okay. So we'll be focusing today on the code, but the code we'll be referring 00:24:40.360 |
to things by name and those names won't make sense very much unless we see the, what they're 00:24:47.680 |
used for in the math. Okay. And I will dive too much into the math. I just want to focus 00:24:53.520 |
on the sorts of variables and equations that we see in the code. So basically the general 00:25:01.200 |
idea is that, you know, we do these in multiple different steps. You know, we have here from 00:25:06.680 |
time step zero all the way to time steps, time steps, uppercase T. And so there's some 00:25:13.120 |
fixed number of steps, but then we have this intermediate process where we're going, you 00:25:17.200 |
know, from some particular, yeah, some particular time step. Yeah. We have this time step lower 00:25:26.200 |
T steep, which is some noisy image. And, and, and yes, we're transitioning between these 00:25:35.120 |
two different noisy images. So we have this, what is sometimes called the transition. We 00:25:39.880 |
have this one here. This is like sometimes called the transition kernel or yeah, whatever 00:25:44.440 |
it is, it basically is just telling us, you know, how do we go from, you know, one in 00:25:50.120 |
this case, we're going from a less noisy image to a more noisy image. And then going backwards 00:25:53.840 |
is going from a more noisy image to a less noisy image. So let's look at the equations. 00:25:59.040 |
So the forward direction is trivially easily to make it something more noisy. You just 00:26:02.640 |
add a bit more noise to it. And the reverse direction is incredibly difficult, which is 00:26:07.480 |
to particularly to go from the far left to the far right is strictly speaking impossible 00:26:11.840 |
because none of that person's face exists anymore. But somewhere in between, you could 00:26:18.880 |
certainly go from something that's partially noisy to less noisy by a learned model. 00:26:26.160 |
Exactly. And that's like what I'm going to write down right now in terms of, you know, 00:26:31.000 |
in terms of, I guess, the symbols in the map. So yeah, basically, I'm just trying to pull 00:26:35.880 |
out the, just to write down the equations here. So we have, let me zoom in a bit. Yeah, 00:26:45.040 |
so we have our two, oops, let's see here. Two of x t, x t minus one. Or actually, you 00:26:58.880 |
know what, maybe it's just better if I just snip. Yeah, it's a snippet from here. So the 00:27:08.980 |
one that is going from our, the one that is going from our forward process is this, this, 00:27:19.400 |
this equation here. So I'll just make that a little smaller for you guys. So right there. 00:27:26.040 |
So that is going, and basically, to explain, we have this, we have this sort of script, 00:27:36.880 |
a little bit of a, maybe a little bit confusing notation here. But basically, this is referring 00:27:41.120 |
to a normal distribution or Gaussian distribution. And instead of saying, okay, this is a Gaussian 00:27:48.080 |
distribution that's describing this particular variable. So it's just saying, okay, you know, 00:27:54.400 |
n is our normal or Gaussian distribution, and it's representing this variable x of t, 00:28:01.360 |
or x, sorry, x t. And then we have here is, is the mean. And this is the variance. So 00:28:15.280 |
just to again, clarify, I think we've talked about this before as well. But like, you know, 00:28:20.280 |
this is a, you know, this is, of course, a bad drawing of a Gaussian. But, you know, 00:28:24.040 |
our mean is just, our mean is just, you know, this, you know, the middle point here is the 00:28:30.680 |
mean, and the variance just kind of describes the sort of spread of, of the Gaussian distribution. 00:28:37.760 |
So if you think about this a little further, you have this beta, which is one of the important 00:28:43.800 |
variables that kind of describe the diffusion process, beta dot t. So you'll see the beta 00:28:51.040 |
t in the code. And basically, beta t increases as t increases. So basically, your beta t 00:29:02.000 |
will be greater than your beta t minus one. So if you think about that a little bit more 00:29:07.960 |
carefully, you can see that, okay, so at, you know, t minus one, at this time point here, 00:29:16.560 |
and then you're going to the next time point, you're going to increase your beta t, so increasing 00:29:20.880 |
the variance, but then you have this one minus beta t and take the square root of that on 00:29:27.160 |
and multiply it by x t minus one. So as your t is increasing, this term actually decreases. 00:29:35.300 |
So your mean is actually decreasing, and you're getting less of the original image, because 00:29:39.880 |
the original image is going to be part of x t minus one. So as you, 00:29:44.540 |
And just to let you know, just like, you know, we, we can't see your pointer. So if you want 00:29:52.160 |
to point at things, you would need to highlight them or something. 00:29:56.080 |
Yeah, so yeah, I'll just, let's see. Yeah. Or yeah, basically, I mean, I don't particularly 00:30:05.080 |
play anything in specific, I was just saying that, yeah, basically, if we have our, our, 00:30:11.360 |
our x of t here, as, as the time step increases, you know, you're getting less contribution 00:30:17.800 |
from your x of x t minus one. And so that means your mean is going towards zero. And 00:30:24.240 |
so you've got to have a mean of zero, and, you know, the variance keeps increasing, and 00:30:28.240 |
basically, you just have a Gaussian distribution and you lose any contribution from the original 00:30:32.520 |
image as your time step increases. So that's why when we start out from x of zero and go 00:30:38.160 |
all the way to our x of t here, this becomes pure noise, it's because we're doing this 00:30:43.600 |
iterative process where we keep adding noise, we lose that contribution from the original 00:30:47.240 |
image, and, and, you know, that that leads to, that leads to the image having pure noise 00:30:54.960 |
at the end of the process. So just something I find useful here is to consider one extreme, 00:31:03.800 |
which is that is to consider x one. So at x one, the mean is going to be one minus beta 00:31:13.560 |
t times x naught. And the reason that's interesting is x naught is the original image. So we're 00:31:19.160 |
taking the original image. And at this point, one minus beta t will be pretty close to one. 00:31:28.240 |
So at x one, we're going to have something that's the mean is very close to the image. 00:31:34.280 |
And the variance will be very small. And so that's why we will have a image that just 00:31:39.600 |
has a tiny bit of noise. Right, right. And then another thing that sometimes is easier 00:31:46.120 |
to write out is sometimes you can write out, in this case, you can write out q of x t directly. 00:31:56.160 |
Because these are all independent in terms of like, q of x t is only dependent of x t 00:32:01.880 |
minus one, and then x t minus one is only dependent on x t minus two. And you can you 00:32:06.560 |
can use this independent, each of these steps are independent. So based on the different 00:32:11.920 |
laws of probability, you can get your q above x t in close form. So yeah, that's what's 00:32:19.000 |
shown here. q of x t did it the original image. So this is also another way of kind of seeing 00:32:24.360 |
this more clearly where you can see you can see that. Anyways, some going back here. Yeah, 00:32:34.760 |
so this is another way to see here more directly. So this is, of course, our clean image. And 00:32:45.280 |
this is our clear, our noisy image. And so you can also see again, now, alpha bar t is 00:32:59.160 |
dependent on beta t, basically, it's like one minus like the cumulative. This is I mean, 00:33:07.600 |
we'll see the card for it, I guess. So maybe, yes, yes. So it might be clear to see that 00:33:12.600 |
this is alpha bar t or something like this. But basically, basically, the idea is that 00:33:17.360 |
alpha bar t, alpha bar t is going to be, again, less, this is what is going to be less than 00:33:27.520 |
alpha bar t minus one. So basically, alpha, this keeps decreasing, right? This decreases 00:33:34.400 |
as as time step increases. And on the other hand, this is going to be increasing as time 00:33:39.380 |
step increases. So again, you can see the contribution from the original image decreases 00:33:45.080 |
as time step increases, while the noise, you know, as shown by the variance is increasing, 00:33:50.280 |
while, you know, the time step is increased. Anyway, so that hopefully clap eyes that the 00:33:55.800 |
forward process and then the reverse process is basically a neural network, as we as Jeremy 00:34:02.600 |
had mentioned. And yeah, this is a screenshot this, that's he's this. That's yes, this is 00:34:16.760 |
our this is our reverse process. And basically, the idea is, well, this is a neural network. 00:34:23.340 |
And this is also a neural network, a neural network. And we learned during the training 00:34:31.660 |
of the model. But the the nice thing about this particular diffusion model paper that 00:34:37.560 |
made it so simple was actually, we completely ignored this and actually said to our instance 00:34:47.400 |
We can't say what you're pointing at. So I think it's important to mention what this 00:34:51.560 |
This term here. So this one, we just kind of ignore and it's just a constant dependent 00:35:04.680 |
on beta t. So you only have one neural network that you need to train, which is basically 00:35:11.800 |
referring to this mean. And when the nice thing about this decision model process is 00:35:19.040 |
that it also meet power prices mean into this easier form, where you do a lot of complicated 00:35:26.800 |
math, which will not get into here. But basically, you get this kind of simplified training objective 00:35:34.600 |
where see here. Yeah, you see the simplified training objective, you instead have this 00:35:43.080 |
epsilon beta function. And let me just screenshot that again, screenshot. This is our loss function 00:35:53.920 |
that we train and we have this epsilon beta function. 00:36:00.200 |
You could see it's a very simple loss function, right? This is just a, you can just write 00:36:05.280 |
this down. This is just an MSC loss. And we have this epsilon beta function here. That 00:36:11.800 |
is our... I mean, to folks like me here are less mathy. 00:36:14.440 |
It might not be obvious that it's a simple thing because it looks quite complicated to 00:36:17.800 |
me, but once we see it in code, it'll be simple. Yes, yes. Basically, you're just doing like 00:36:25.040 |
an... Yeah, you'll see an encode how simple it is. But this is like just an MSC loss. 00:36:30.440 |
So we've seen MSC loss before, but you'll see how... Yeah, this is basically MSC. So 00:36:35.560 |
the nice... So just to kind of take a step back again, what is this epsilon theta? Because 00:36:41.160 |
this is like a new thing that seems a little bit confusing. Basically, epsilon... You can 00:36:48.280 |
see here, basically... Yeah, absolutely. This here is saying... This is actually equivalent 00:36:58.800 |
to this equation here. These two are equivalent. This is just another way of saying that because 00:37:05.200 |
basically it's saying that's X of t. So this is giving X of t just in a different way. 00:37:13.680 |
But epsilon is actually this normal distribution with a mean of zero and a variance of one. 00:37:24.520 |
And then you have all these scaling terms that changes the mean to be the same as this 00:37:30.040 |
equation that we have over here. So this is our X of t. And so what epsilon is, it's actually 00:37:38.480 |
the noise that we're adding to our image to make it into a noisy image. And what this 00:37:43.840 |
neural network is doing is trying to predict that noise. So what this is actually doing 00:37:49.600 |
is this is actually a noise predictor. And it is predicting the noise in the image. And 00:37:59.400 |
why is that important? Basically, the general idea is if we were to think about our distribution 00:38:11.460 |
of data, let's just think about it in a 2D space. Just here, each data point here represents 00:38:21.840 |
an image. And they're in this blob area, which represents a distribution. So this is in distribution. 00:38:36.080 |
And this is out of the distribution. Out of distribution. And basically, the idea is that, 00:38:47.640 |
OK, if we take an image and we want to generate some random image, if we were to take a random 00:38:56.040 |
data point, it would most likely be noisy images. So if we take some random data point, 00:39:03.240 |
the way to generate a random data point, it's going to be just noise. But we want to keep 00:39:08.880 |
adjusting this data point to make it look more like an image from your distribution. 00:39:14.040 |
That's the whole idea of this iterative process that we're doing in our diffusion model. So 00:39:18.800 |
the way to get that information is actually to take images from your data set and actually 00:39:26.660 |
add noise to it. So that's what we try to do in this process. So we have an image here, 00:39:32.280 |
and we add noise to it. And then what we do is we try to plan a neural network to predict 00:39:37.440 |
the noise. And by predicting the noise and subtracting it out, we're going back to the 00:39:42.240 |
distribution. So adding the noise takes you away from the distribution, and then predicting 00:39:47.020 |
the noise brings you back to distribution. So then if we know at any given point in this 00:39:54.440 |
space how much noise to remove, that tells you how to keep going towards the data distribution 00:40:04.160 |
and get a point that lies within the distribution. So that's why we have noise prediction, and 00:40:10.560 |
that's the importance of doing this noise prediction is to be able to then do this iterative 00:40:15.260 |
process where we can start out at a random point, which would be, for example, pure noise 00:40:20.080 |
and keep predicting and removing that noise and walking towards the data distribution. 00:40:26.000 |
Okay. Okay. So yeah, let's get started with the code. And so here, we of course have our 00:40:35.480 |
imports, and we're going to load our dataset. We're going to work with our fashion amnesty 00:40:43.400 |
set, which is what we've been working with for a while already. And yeah, this is just 00:40:51.160 |
basically similar code that we've seen from before in terms of loading the dataset. And 00:40:57.920 |
then we have our model. So we remove the noise from the image. So what our model is going 00:41:02.660 |
to take in is it's going to take in the previous image, the noisy image, and predict the noise. 00:41:10.520 |
So the shapes of the input and the output are the same. They're going to be in the shape 00:41:14.800 |
of an image. So what we use is we use a unit neural network, which takes in kind of an input 00:41:22.840 |
image. And we do see your pointer now, by the way, so feel free to point at things. 00:41:27.680 |
Yeah. So yeah, it takes an input image. And in this case, a unit is for purpose, but they 00:41:38.640 |
can also be used for any sort of image-to-image path, where we're going from an input image 00:41:43.920 |
and then outputting some other image of some sort. 00:41:48.000 |
And we'll talk about a new architecture, which we haven't learned about yet, and we will 00:41:51.760 |
be learning about in the next lesson. But broadly speaking, those gray arrows going 00:41:58.880 |
from left to right are very much like ResNet skip connections, but they're being used in 00:42:07.840 |
a different way. Everything else is stuff that we've seen before. So basically, we can 00:42:16.360 |
pretend that those don't exist for now. It's a neural network that the output is the same 00:42:25.280 |
size or a similar size to the input, and therefore you can use it to learn how to go from one 00:42:34.880 |
Yeah. So that's where the unit is. And yeah, like Jefferson said, we'll talk about it more. 00:42:45.280 |
The sort of units that are used for diffusion models also tend to have some additional tricks, 00:42:50.520 |
which again, we'll talk about them later on as well. But yeah, for the time being, we 00:42:57.080 |
will just import a unit from the Diffusers library, which is the Hagen-Feiss library for 00:43:05.740 |
diffusion models. So they have a unit implementation, and we'll just be using that for now. And 00:43:13.120 |
so, yeah, of course, strictly speaking, we're cheating at this point because we're using 00:43:17.560 |
something we haven't written from scratch, but we're only cheating temporarily because 00:43:24.360 |
Yeah. And yeah, so and then of course, we're working with one channel images or fashion 00:43:30.720 |
MNIST images or one channel images. So we just have to specify that. And then of course, 00:43:35.880 |
the channels of the different blocks within the unit are also specified. And then let's 00:43:42.240 |
go into the training process. So basically, the general idea of course, is we want to 00:43:52.200 |
train with this MSE loss. What we do is we select a random time step, and then we add 00:44:00.880 |
noise to our image based on that time step. So of course, if we have a very high time 00:44:06.480 |
step, we're adding a lot of noise. If we have a lower time step, then we're adding very 00:44:14.840 |
little noise. So we're going to randomly choose a time step. And then yeah, we add the noise 00:44:22.120 |
accordingly to the image. And then we pass the noisy image to a model as well as the 00:44:28.400 |
time step. And we are trying to predict the amount of noise that was in the image. And 00:44:34.960 |
we predicted with the MSE loss. So we can see all the-- 00:44:37.440 |
I have some pictures of some of these variables I could share if that would be useful. So 00:44:46.440 |
I have a version. So I think Tanishka is sharing notebook number 15. Is that right? And I've 00:44:52.320 |
got here notebook number 17. And so I talked to Tanishka's notebook and just as I was starting 00:44:58.000 |
to understand it, I'd like to draw pictures for myself to understand what's going on. 00:45:02.720 |
So I talked to things which are in Tanishka's class and just put them into a cell. So I 00:45:10.960 |
just copied and pasted them, although I replaced the Greek letters with English written out 00:45:17.680 |
versions. And then I just plotted them to see what they look like. So in Tanishka's class, 00:45:24.680 |
he has this thing called beta, which is just lin space. So that's just literally a line. 00:45:33.120 |
So beta, there's going to be 1,000 of them. And they're just going to be equally spaced 00:45:37.960 |
from 0.001 to 0.02. And then there's something called sigma, which is the square root of 00:45:50.360 |
that. So that's what sigma is going to look like. And then he's also got alpha bar, which 00:46:00.640 |
is a cumulative product of 1 minus this. And there's what alpha bar looks like. So you 00:46:10.160 |
can see here, as Tanishka was describing earlier, that when T is higher, this is T on the x 00:46:19.600 |
axis, beta is higher. And when T is higher, alpha bar is lower. So yeah, so if you want 00:46:30.840 |
to remind yourself, so each of these things, beta, sigma, alpha bar, they're each got 1,000 00:46:40.400 |
things in them. And this is the shape of those 1,000 things. So this is the amount of variance, 00:46:51.120 |
I guess, added at each step. This is the square root of that. So it's the standard deviation 00:46:57.760 |
added at each step. And then if we do 1 minus that, it's just the exact opposite. And then 00:47:07.880 |
this is what happens if you multiply them all together up to that point. And the reason 00:47:10.840 |
you do that is because if you add noise to something, you add noise to something that 00:47:15.120 |
you add noise to something that you add noise to something, then you have to multiply together 00:47:18.120 |
all that amount of noise to say how much noise you would get. So yeah, those are my pictures, 00:47:26.080 |
Yep, good to see the diagram or see how the actual values and how it changes over time. 00:47:35.600 |
So yeah, let's see here, sorry. Yeah, so like Jeremy was showing, we have our lint space 00:47:44.280 |
for our beta. In this case, we're using kind of more of the Greek letters. So you can see 00:47:49.720 |
the Greek letters that we see in the paper as well as now we have it here in the code 00:47:54.720 |
as well. And we have our lint space from our minimum value to our maximum value. And we 00:48:01.080 |
have some number of steps. So this is the number of time steps. So here, we use a thousand 00:48:06.880 |
time steps, but that can depend on the type of model that you're training. And that's 00:48:11.720 |
one of the parameters of your model or high parameters of your model. 00:48:16.280 |
And this is the callback you've got here. So this callback is going to be used to set 00:48:24.720 |
up the data, I guess, so that you're going to be using this to add the noise so that 00:48:30.760 |
the models then got the data that we're trying to get it to learn to then denoise. 00:48:37.600 |
Yeah, so the callback, of course, makes life a lot easier in terms of, yeah, setting up 00:48:45.560 |
everything and still being able to use, I guess, the mini AI learner with maybe some 00:48:51.560 |
of these more complicated and maybe a little bit more unique training loops. So yeah, in 00:48:58.240 |
this case, we're just able to use the callback in order to set up the data, the data, I guess, 00:49:08.600 |
the batch that we are passing into our learner. 00:49:11.520 |
I just wanted to mention, when you first did this, you wrote out the Greek letters in English, 00:49:19.600 |
alpha and beta and so forth. And at least for my brain, I was finding it difficult to 00:49:25.760 |
read because they were literally going off the edge of the page and I couldn't see it 00:49:29.840 |
all at once. And so we did a search in a place to replace it with the actual Greek letters. 00:49:36.520 |
I still don't know how I feel about it. I'm finding it easier to read because I can see 00:49:43.600 |
it all at once and it was a scroll and I don't get it as overwhelmed. But when I need to 00:49:49.960 |
edit the code, I kind of just tend to copy and paste the Greek letters, which is why 00:49:56.300 |
we use the actual word beta in the init parameter list so that somebody using this never has 00:50:03.520 |
to type a Greek letter. But I don't know, Johnno or Tanishka, if you had any thoughts 00:50:09.080 |
over the last week or two since we made that change about whether you guys like having 00:50:18.160 |
I like it for this demo in particular. I don't know that I do this in my code, but because 00:50:22.400 |
we're looking back and forth between the paper and the implementation year, I think it works 00:50:29.520 |
Yeah, I agree. I think it's good for like, yeah, when you're trying to study something 00:50:35.800 |
or trying to implement something, having the Greek letters is very useful to be able to, 00:50:41.560 |
I guess, match the math more closely and it's just easy just to do it. Take the equation 00:50:47.640 |
and put it into code or, you know, white-swear style looking at the code and try to match 00:50:52.000 |
to the equation. So I think for like, yeah, educational purpose, I tend to like, I guess, 00:50:58.480 |
the Greek letters. So yeah, so, you know, we have our initialization, but we're just 00:51:08.640 |
defining all these variables. We'll get to the predict in just a moment. But first, I 00:51:15.200 |
just want to go over the before batch, where we're ever setting up our batch to pass into 00:51:23.200 |
the model. So remember that the model is taking in our noisy image and the time step. And 00:51:32.840 |
of course the target is the actual amount of noise that we are adding to the image. 00:51:39.640 |
So basically, we generate that noise. So that's what... epsilon is that target. So epsilon 00:51:45.920 |
is the amount of noise, not the amount of, is the actual noise. Yes, epsilon is the actual 00:51:51.560 |
noise that we're adding. And that's the target as well, because our model is a noise-predicting 00:51:56.880 |
model. It's predicting the noise in the image. And so our target should be the noise that 00:52:03.320 |
we're adding to the image during training. So we have our epsilon on and we're generating 00:52:08.700 |
it with this random function. It's the random normal distribution with a mean of zero, variance 00:52:15.320 |
of one. So that's what that's doing and adding the appropriate shape and device. 00:52:21.200 |
Then the batch that we get originally will contain the clean images. These are the original 00:52:30.200 |
images from our dataset. So that's x0. And then what we want to do is we want to add 00:52:36.400 |
noise. So we have our alpha bar and we have a random time step that we select. And then 00:52:43.280 |
we just simply follow that equation, which again, I'll just show in just a moment. 00:52:47.880 |
That equation, you can make a tiny bit easier to read, I think. If you were to double-click 00:52:51.480 |
on that first alpha bar underscore t, cut it and then paste it, sorry, in the xt equals 00:52:57.840 |
torch dot square root, take the thing inside the square root, double-click it and paste 00:53:03.200 |
it over the top of the word torch. That would be a little bit easier to read, I think. And 00:53:13.560 |
then you'll do the same for the next one. There we go. 00:53:24.560 |
Yeah, so basically, yeah, so yeah, I guess let's just pull up the equation. So let's 00:53:37.400 |
see. There's a section in the paper that has the nice algorithm. Let's see if I can find 00:53:44.520 |
it. No, here. I think earlier. Yes, string. Right, so we're just following the same sort 00:53:54.360 |
of training steps here, right? We select a clean image that we take from our data set. 00:54:01.560 |
This fancy kind of equation here is just saying take an image from your data set, take a random 00:54:08.380 |
time step between this range. Then this is our epsilon that we're getting, just to say 00:54:15.120 |
get some epsilon value. And then we have our equation for x of t, right? This is the equation 00:54:22.360 |
here. You can see that square root of alpha bar t x0 plus square root of 1 minus alpha 00:54:27.840 |
bar t times epsilon. So that's the same equation that we have right here, right? And then what 00:54:34.520 |
we need to do is we need to pass this into our model. So we have x t and t. So we set 00:54:40.200 |
up our batch accordingly. So this is the two things that we pass into our model. And of 00:54:45.200 |
course, we also have our target, which is our epsilon. And so that's what this is showing 00:54:49.160 |
here. We pass in our x of t as well as our t here, right? And we pass that into a model. 00:54:56.280 |
The model is represented here as epsilon theta. And theta is often used to represent like 00:55:01.120 |
this is a neural network with some parameters and the parameters are represented by theta. 00:55:05.520 |
So epsilon theta is just representing our noise predicting model. So this is our neural 00:55:09.400 |
network. So we have passed in our x of t and our t into a neural network. And we are comparing 00:55:15.040 |
it to our target here, which is the actual epsilon. And so that's what we're doing here. 00:55:20.820 |
We have our batch where we have our x of t and t and epsilon. And then here we have our 00:55:28.100 |
prediction function. And because we actually have, I guess, in this case, we have two things 00:55:33.640 |
that are in a tuple that we need to pass into our model. So we just kind of get those elements 00:55:40.620 |
from our tuple with this. Yeah, we get the elements from the tuple, pass it into the 00:55:45.720 |
model, and then hugging face has its own API in terms of getting the output. So you need 00:55:51.360 |
to call dot sample in order to get the predictions from your model. So we just do that. And then 00:55:57.360 |
we do, yeah, we have learned dot preds. And that's what's going to be used later than 00:56:02.920 |
when we're trying to do our loss function calculation. 00:56:05.880 |
So, I mean, it's just worth looking at that a little bit more since we haven't quite seen 00:56:11.440 |
something like this before. And it's something which I'm not aware of any other framework 00:56:15.760 |
that would let you do this. Literally replace how prediction works. And many AIs kind of 00:56:21.800 |
really fun for this. So because you're inherited from TrainedCB, TrainedCB has predict, it 00:56:28.320 |
ought to find, and you have to find a new version. So it's not going to use the TrainedCB version 00:56:33.240 |
anymore. It's going to use your version. And what you're doing is instead of passing learned 00:56:40.320 |
dot batch zero to the model, you've got a star in front of it. So the key thing is that 00:56:46.720 |
star is going to, and we know actually learned dot batch zero has two things in it because 00:56:52.680 |
that learned dot batch that you showed at the end of the before batch method has two 00:56:56.920 |
things in learned dot zero. So that star will unpack them and send each one of those as 00:57:02.680 |
a separate argument. So our model needs to take two things, which the diffusers unit 00:57:08.820 |
does take two things. So that's the main interesting point. And then something I find a bit awkward 00:57:15.280 |
honestly about a lot of HackingFace stuff including diffusers is that generally their 00:57:22.280 |
models don't just return the result, but they put it inside some name. And so that's what 00:57:27.960 |
happens here. They put it inside something called sample. So that's why Tanishk added 00:57:33.360 |
dot sample at the end of the predict because of this somewhat awkward thing, which HackingFace 00:57:40.320 |
like to do for some reason. But yeah, now that you know, I mean, this is something that 00:57:44.920 |
people often get stuck on. I see on Kaggle and stuff like that. It's like, how on earth 00:57:50.080 |
do I use these models? Because they take things in weird forms and they give back things with 00:57:54.360 |
weird forms. Well, this is hell, you know, if you inherit from TranCB, you can change 00:58:01.000 |
predict to do whatever you want, which I think is quite sweet. 00:58:07.120 |
Yep. So yeah, that's the training loop. And then of course you have your regular training 00:58:14.840 |
loop that's implemented in many AI where you are going to have, yeah, you have your loss 00:58:21.400 |
function calculation. I mean, at the predictions, learn.preds, and of course the target is our 00:58:31.400 |
learn.batch1, which is our epsilon. So, you know, we have those and we pass it into the 00:58:37.880 |
loss function. It calculates the loss function and does the back propagation. So I'll just 00:58:42.480 |
go over that. We'll get back to the sampling in just a moment. But just to show the training 00:58:49.520 |
loop. So most of this is copied from our, I think it's 14 augment notebook. The way you've 00:58:58.880 |
got the Tmax and the shed. The only thing I think you've added here is the DDPM callback, 00:59:09.200 |
And the transient loss function. Yes. So basically we have to initialize our 00:59:14.560 |
DDPM callback with the appropriate arguments. So like the number of time steps and the minimum 00:59:21.840 |
beta and maximum beta. And then of course we're using an MSC loss as we talked about. It just 00:59:32.480 |
becomes a regular training loop. And everything else is from before. Yeah. We have your scheduler, 00:59:39.840 |
your progress bar, all of that we've seen before. 00:59:42.240 |
I think that's really cool that we're using basically the same code to train a diffusion 00:59:46.880 |
model as we've used to train a classifier just with, yeah, an extra callback. 00:59:51.840 |
Yeah. Yeah. Yeah. So I think callbacks are very powerful for being, you know, for allowing 00:59:57.720 |
us to do such things. It's like pretty, you can take all this code and now we have a diffusion 01:00:05.080 |
training loop and we can just call it lower.fit. And yeah, you can see got a nice training 01:00:11.720 |
loop, nice loss curve. We can save our model. Saving functionality to be able to save our 01:00:19.880 |
model and we could load it in. But now that we have our trained model, then the question 01:00:25.800 |
is what can we do to use it to sample, you know, the dataset? 01:00:32.080 |
So the basic idea, of course, was that, you know, we have like basically we're here, right? 01:00:44.160 |
We have, let's see here. Okay. So we have a basic idea is that we start out with a random 01:00:50.200 |
data point. And of course that's not going to be within the distribution at first, but 01:00:55.160 |
now we've learned how to move from, you know, at one point towards the data distribution. 01:01:03.520 |
That's what our noise prediction predicting function does. It basically tells you in what 01:01:09.220 |
direction and how much to... So the basic idea is that, yeah, I guess I'll start from 01:01:17.740 |
a new drawing here. Again, we have distribution is and we have a random point. And we use 01:01:28.480 |
our noise predicting model that we have trained to tell us which direction to move. So it 01:01:33.400 |
tells us some direction. Or I guess, let's... It tells us some direction to move. At first 01:01:46.460 |
that direction is not going to be like you cannot follow that direction all the way to 01:01:50.160 |
get the correct data point. Because basically what we were doing is we're trying to reverse 01:01:55.560 |
the path that we were following when we were adding noise. So like, because we had originally 01:01:59.640 |
a data point and we kept adding noise to the data point and maybe, you know, it followed 01:02:03.440 |
some path like this. And we want to reverse that path to get to... So our noise predicting 01:02:12.760 |
function will give us an original direction which would be some kind of... It's going 01:02:18.720 |
to be kind of tangential to the actual path at that location. So what we would do is we 01:02:25.040 |
would maybe follow that data point all the way towards... We're just going to keep following 01:02:30.920 |
that data point. We're going to try to predict the fully denoised image by following this 01:02:39.040 |
noise prediction. But our fully denoised image is also not going to be a real image. So what 01:02:45.540 |
we... So let me... I'll show you an example of that over here in the paper on where they 01:02:51.080 |
show this a little bit more carefully. Let's see here. So X zero... Yeah. So basically, 01:03:00.600 |
you can see the different... You can see the different data points here. It's not going 01:03:08.520 |
to look anything like a real image. So you can see all these points. You know, it doesn't 01:03:12.080 |
look anything. That we would do is we actually had a little bit of noise back to it and we 01:03:23.320 |
start... We have a new point where then we could maybe estimate a better... Get a better 01:03:27.480 |
estimate of which direction to move. Follow that all the way again. We follow a new point. 01:03:33.920 |
I get add back a little bit of noise. You get a new estimate. You make a new estimate 01:03:38.760 |
of this noise prediction and removing the noise. Follow that all again completely and 01:03:44.840 |
add a little bit of noise again to the image and burst onto a dimension. So that's kind 01:03:52.640 |
of what we're showing here as well. That's a lot like SGD. With SGD, we don't take the 01:03:57.080 |
gradient and jump all the way. We use a learning rate to go some of the way because each of 01:04:01.280 |
those estimates of where we want to go are not that great, but we just do it slowly. 01:04:08.680 |
Exactly. And at the end of the day, that's what we're doing with this noise prediction. 01:04:12.800 |
We are predicting the gradient of this p of x, but of course, we need to keep making estimates 01:04:21.320 |
of that gradient as we're progressing. So we have to keep evaluating our noise prediction 01:04:27.040 |
function to get updated and better estimates of our gradient in order to finally converge 01:04:33.920 |
onto our image. And then you can see that here where we have maybe this fully predicted 01:04:40.520 |
denoised image, which at the beginning doesn't look anything like a real image, but then 01:04:45.740 |
as we continue throughout the sound like process, we finally converge on something that looks 01:04:50.880 |
like an actual image. Again, these are CIFAR-10 images and still a little bit maybe unclear 01:04:55.880 |
about how realistic these images, these very small images look, but that's kind of the 01:05:00.880 |
general principle I would say. And so that's what I can show in the code. This idea of 01:05:12.400 |
we're going to start out basically with a random image, right? And this random image 01:05:17.480 |
is going to be like a pure noise image and it's not going to be part of the data distribution. 01:05:23.440 |
It's not anything like a real image, it's just a rounded image. And so this is going 01:05:27.040 |
to be our x, I guess, x uppercase T, right? That's what we start out with. And we want 01:05:32.960 |
to go from x uppercase T all the way to x0. So what we do is we go through each of the 01:05:39.520 |
time steps and we have to put it in this sort of batch format because that's what our neural 01:05:48.600 |
network expects. So we just have to format it appropriately. And we'll get to Z in just 01:05:55.760 |
a moment. I'll explain that in just a moment. But of course, we just again have similar 01:06:00.000 |
alpha bar, beta bar, which is getting those variables that we -- 01:06:06.840 |
And we faked beta bar because we couldn't figure out how to type it, so we used b bar 01:06:11.320 |
Yeah, we were unable to get beta bar to work, I guess. But anyway, at each step, what we're 01:06:20.260 |
trying to do is to try to predict what direction we need to go. And that direction is given 01:06:24.320 |
by our noise predicting model, right? So what we do is we pass in x of t and our current 01:06:30.480 |
time step into our model and we get this noise prediction and that's the direction that we 01:06:35.280 |
need to move it. So basically, we take x of t, we first attempt to completely remove the 01:06:41.600 |
noise, right? That's what this is doing. That's what x0 cap is. That's completely removing 01:06:45.880 |
the noise. And of course, as we said, that estimate at the beginning won't be very accurate. 01:06:53.160 |
And so now what we do is we have some coefficients here where we have a coefficient of how much 01:06:58.420 |
that we keep of this estimate of our denoise image and how much of the originally noisy 01:07:06.800 |
image we keep. And on top of that, we're going to add in some additional noise. So that's 01:07:13.160 |
what we do here. We have x0 cap and we multiply by its coefficient and we have x of t, we 01:07:24.320 |
multiply by some coefficient and we also add some additional noise. That's what the z is. 01:07:29.480 |
It's basically a weighted average of the two plus the natural noise. And then the whole 01:07:37.940 |
idea is that as we get closer and closer to a time step equals to zero, our estimate of 01:07:48.120 |
x0 will be more and more accurate. So our x0 coefficient will get closer as we're going 01:07:58.520 |
through the process and our x t coefficient will get closer and closer to zero. So basically 01:08:05.200 |
we're going to be weighting more and more of the x0 hat estimate and less and less of 01:08:10.600 |
the x t as we're getting closer and closer to our final time step. And so at the end 01:08:15.620 |
of the day, we will have our estimated generated image. So that's kind of an overview of the 01:08:23.540 |
sampling process. So yeah, basically the way I implemented it here was I had the sample 01:08:34.720 |
function that's part of our callback and it will take in the model and the kind of shape 01:08:43.800 |
that you want for your images that you're producing. So if you want to specify how many 01:08:48.600 |
images you produce, that's going to be part of your back size or whatever. And you'll 01:08:52.400 |
just see that in a moment. But yeah, it's just part of the callback. So then we basically 01:08:57.600 |
have our DDPM callback and then we could just call the sample method of our DDPM callback 01:09:07.720 |
and we pass in our model. And then here you can see we're going to produce, for example, 01:09:11.960 |
16 images and it just has to be a one channel image of shape 32 by 32. And we get our samples. 01:09:20.720 |
And one thing I forgot to note was that I am collecting each of the time step effects 01:09:27.560 |
of t. So the predictions here, you can see that there are a thousand of them. We want 01:09:34.680 |
the last one because that is our final generation. So we want the last one and that's what we 01:09:40.120 |
have. They're no sad actually. Yeah. And this is a long way since DDPM. So this is like 01:09:46.860 |
slower and less great than it could be. But considering that except for unit, we've done 01:09:52.400 |
this from scratch, you know, literally from matrix multiplication. I think those are pretty 01:09:58.160 |
decent. Yeah. And we're only trained for about five epochs. It took like, you know, maybe 01:10:05.360 |
like four minutes to train this model, something like that. It's pretty quick. And this is 01:10:09.880 |
what we could get with very little training. And it's yeah, pretty decent. You can see 01:10:15.200 |
it or some clear shirts and shoes and pants and whatever else. So yeah. And you can see 01:10:21.480 |
fabric and it's got texture and things have buckles and yeah. You know, something to compare, 01:10:29.480 |
like we did generative modeling in the first time we did part two back in the days when 01:10:37.320 |
something called Vassus Guy and Gan was just new, which is actually created by the same 01:10:41.760 |
guy that created PyTorch or one of the two guys, Sumith. And we trained for hours and 01:10:47.160 |
hours and hours and got things that I'm not sure were any better than this. So things 01:10:54.320 |
that come a long way. Yeah. Yeah. And of course, then yeah, so we can see then like how this 01:11:06.800 |
sampling progresses over time, over the multiple time steps. So that's what I'm showing here 01:11:12.440 |
because I collected during the sampling process, we are collecting at each time step what that 01:11:17.080 |
estimate looks like. And you can kind of see here. And so this is an estimate out of like 01:11:23.360 |
the noisy image over the time steps. Oops. And I guess I had to pause for a minute. Yeah, 01:11:28.480 |
you can kind of see. But you'll notice that actually, so we actually what we did is like, 01:11:33.200 |
okay, so we selected an image, which is like the ninth image. So that's this image here. 01:11:38.560 |
So we're looking at this image, particularly here. And we're going over, yeah, we have 01:11:43.640 |
a function here that's showing the i time step during the sampling process of that image. 01:11:51.280 |
And we're just getting the images. And what we are doing is we're only showing basically 01:11:57.560 |
from time step 800 to 1000. And here, we're just, we're just having it like where it's 01:12:04.040 |
like, okay, we're looking at like, maybe every five steps and we're going from 800 to nine. 01:12:09.160 |
And this kind of make it a little bit visually easier to see the transition. But what you'll 01:12:14.400 |
notice is I didn't start all the way from zero, I started from 800. And the reason we 01:12:19.320 |
do that is because actually, between zero and 800, there's very little change in terms 01:12:25.560 |
of like, it's just mostly a noisy image. And it turns out, yeah, I didn't see as I make 01:12:32.040 |
a note of this year, it's actually a limitation of the noise schedule that is used in the 01:12:37.000 |
original DDP on paper. And especially when applied to some of these smaller images when 01:12:42.200 |
we're working with images of like size 32 by 32 or whatever. And so there are some other 01:12:48.880 |
papers like the improved DDP on paper that propose other sorts of noise schedules. And 01:12:54.320 |
what I mean by noise schedule is basically how beta is defined, basically. So, you know, 01:13:00.520 |
we had this definition of torch.lenspace for our beta, but people have different ways of 01:13:05.840 |
defining beta that lead to different properties. So, you know, things like that, people have 01:13:11.800 |
come up with different improvements, and those sorts of improvements work well when we're 01:13:15.120 |
working with these smaller images. And basically, the point is like, if we are working from 01:13:19.840 |
0 to 800, and it's just mostly just noise that entire time, you know, we're not actually 01:13:24.680 |
making full use of all those time steps. So, it would be nice if we could actually make 01:13:28.680 |
full use of those time steps and actually have it do something during that time period. 01:13:32.960 |
So, all these, there are some papers that examine this a little bit more carefully. 01:13:37.160 |
And it would be kind of interesting for maybe some of you folks to also look at these papers 01:13:41.600 |
and see if you can try to implement, you know, those sorts of models with this notebook as 01:13:47.400 |
a starting point. And it should be a fairly simple change in terms of like noise schedule 01:13:52.680 |
So I actually think, you know, this is the start of our next journey, you know, which 01:13:57.200 |
is our previous journey was, you know, going from being totally rubbish at FashionMnist 01:14:04.280 |
classification to being really good at it. I would say now we're like a little bit rubbish 01:14:10.640 |
at doing FashionMnist generation. And yeah, I think, you know, we should all now work 01:14:20.080 |
from here over the next few lessons and so forth and people, you know, trying things 01:14:26.020 |
at home and all of us trying to make better and better generative models, you know, initially 01:14:34.520 |
a FashionMnist and hopefully we'll get to the point where we're so good at that, that 01:14:37.800 |
we're like, oh, this is too easy. And then we'll pick something harder. 01:14:44.840 |
And eventually that'll take us to stable diffusion and beyond, I imagine. 01:14:55.800 |
That's cool. I got some stuff to show you guys. If you're interested, I tried to, you 01:15:07.600 |
know, better understand what was going on in Tanishk's notebook and tried doing it in 01:15:13.080 |
a thousand different ways and also see if I could just start to make it a bit faster. 01:15:18.680 |
So that's what's in notebook 17, which I will share. So we've already seen the start of 01:15:31.120 |
notebook 17. Well, one thing I did just do is just drew a picture for myself, partly just 01:15:36.800 |
to remind myself what the real ones look like, and they definitely have more detail than 01:15:43.480 |
the samples that Tanishk was showing. But they're not, you know, they're just 28 by 28. I mean, 01:15:51.040 |
they're not super amazing images and they're just black and white. So even if we're fantastic 01:15:55.360 |
at this, they're never going to look great because we're using a small, simple dataset. 01:16:01.200 |
As you always should, when you're doing any kind of R and D or experiments, you should 01:16:06.480 |
always use a small and simple dataset up until you're so good at it that it's not challenging 01:16:11.800 |
anymore. And even then, when you're exploring new ideas, you should explore them on small, 01:16:19.480 |
Yeah. So after I drew the various things, what I like to do is one thing I found challenging 01:16:27.200 |
about working with your class to Tanishk is I find when stuff is inside a class, it's 01:16:32.160 |
harder for me to explore. So I copied and pasted it before batch contents and called it Noisify. 01:16:43.800 |
And so one of the things that's fun to do that is it forces you to figure out what are 01:16:48.040 |
the actual parameters to it. And so now that I, rather than putting that fast, now that 01:16:52.360 |
I've got all of my various things to do with, so these are the three parameters to the DDPM 01:17:01.560 |
callbacks in it. So then these things we can calculate from that. So with those, then actually 01:17:07.440 |
all we need is, yeah, what's the image that we're going to Noisify and then what's the 01:17:16.760 |
alpha bar, which I mean, we can get from here, but it sort of would be more general if you 01:17:20.560 |
can pass in your alpha bar. So yeah, this is just copying and pasting from the class. 01:17:27.240 |
But the nice thing is then I could experiment with it. So I can call Noisify on my first 01:17:32.320 |
25 images and with a random T, each one's got a different random T. And so I can print 01:17:41.160 |
out the T and then I could actually use those as titles. And so this lets me, I thought 01:17:45.960 |
this was quite nice. I might actually rerun this because actually none of these look like 01:17:51.440 |
anything because as it turns out in this particular case, all of the Ts are over 200. And as Tanishk 01:17:58.280 |
mentioned, once you're over 200, it's almost impossible to see anything. So let me just 01:18:03.240 |
rerun this and see if we get a better, there we go. There's a better one. So with a T of 01:18:12.200 |
0, right? So remember T equals 0 is the pure image. So T equals 7, it's just a slightly 01:18:19.560 |
speckled image. And by 67, it's a pretty bad image. And by 94, it's very hard to see what 01:18:26.480 |
it is at all. And by 293, maybe I can see a pair of pants. I'm not sure I can see anything. 01:18:36.680 |
So yeah, by the way, there's a handy little, so I think we've looked at map before in the 01:18:45.520 |
course. There's an extended version of map in fast core. And one of the nice things is 01:18:49.640 |
you can pass it a string and it basically just calls this format string if you pass 01:18:55.200 |
it a string rather than a function. And so this is going to stringify everything using 01:18:59.560 |
its representations. This is how I got the titles out of it, just by the way. So yeah, 01:19:06.720 |
I found this useful to be able to draw a picture of everything. And then I wanted to, yeah, 01:19:14.280 |
look at what else can I do. So then I took, you won't be surprised to see, I took the 01:19:19.640 |
sample method and turned that into a function. And I actually decided to pass everything 01:19:24.560 |
that it needs even, I mean, you could actually calculate pretty much all of these. But I 01:19:29.840 |
thought since I've calculated them before, it was passed them in. So this is all copied 01:19:32.800 |
and pasted from Janiszk's version. And so that means the callback now is tiny, right? 01:19:39.420 |
Because before batch is just noisify and the sample method just calls the sample function. 01:19:47.640 |
Now what I did do is I decided just to, yeah, I wanted to try like as many different ways 01:19:52.240 |
of doing this as possible, partly as an exercise to help everybody like see all the different 01:19:59.800 |
ways we can work with our framework, you know. So I decided not to inherit from train_cb, 01:20:05.160 |
but instead I inherited from callback. So that means I can't use Janiszk's nifty trick 01:20:14.080 |
of replacing predict. So instead I now need some way to pass in the two parts of the first 01:20:21.800 |
element of the tuple, add separate things to the model and return the sample. So how 01:20:28.280 |
else could we do that? Well, what we could do is we could actually inherit from unit 01:20:33.040 |
2D model, which is what Janiszk used directly, unit 2D model, and we could replace the model. 01:20:40.400 |
And so we could replace specifically the forward function. That's the thing that gets called. 01:20:45.080 |
And we could just call the original forward function, but rather than passing an x or 01:20:50.320 |
passing star x, and rather than returning that, we'll return that dot sample. Okay. So 01:20:56.840 |
if we do that, then we don't need the train_cb anymore and we don't need the predict. And 01:21:03.120 |
so if you're not working with something as beautifully flexible as mini AI, you can always 01:21:09.440 |
do this, you know, to make, to replace your model so that it has the interface that you 01:21:16.600 |
need it to have. So now, again, we did the same as Janiszk had of create the callback. 01:21:23.160 |
And now when we create the model, we'll use our unit class, which we just created. I wanted 01:21:28.800 |
to see if I can make things faster. I tried dividing all of Janiszk's channels by two, 01:21:35.460 |
and I found it worked just as well. One thing I noticed is that it uses group_norm in the 01:21:41.720 |
unit, which we have briefly learned about before. And in group_norm, it splits the channels 01:21:47.400 |
up into a certain number of groups. And I needed to make sure that those groups had more than 01:21:56.320 |
one thing in. So you can actually pass in how many groups do you want to use in the 01:22:00.800 |
normalization. So that's what this is for. If you're going to be a little bit careful 01:22:06.280 |
of these things, I didn't think of it at first. And I ended up, I think the NAM groups might 01:22:10.720 |
have been 32, and I got an error saying you can't split 16 things into 32 groups. But 01:22:18.200 |
it also made me realize, actually, even in Janiszk's, maybe you probably had 32 in the 01:22:22.400 |
first with 32 groups, and so maybe the group_norm wouldn't have been working as well. So they're 01:22:27.120 |
little subtle things to look out for. So now that we're not using anything inherited from 01:22:34.840 |
train_cb, that means we either need to use train_cb itself or just use our train_liner, 01:22:40.200 |
and that everything else is the same as what Janiszk had. 01:22:44.280 |
So then I wanted to look at the results of Noisify here, and we've seen this trick before, 01:22:51.640 |
which is we call fit, but don't call the training part of the fit, and use the single batch_cb 01:22:58.320 |
callback that we created way back when we first created learner. And now learn.batch 01:23:03.560 |
will contain the tuple of tuples, which we can then use that trick to show. So, I mean, 01:23:12.520 |
obviously we'd expect it to look the same as before, but it's nice. I always like to 01:23:16.480 |
draw pictures of everything all along the way, because it's very, very often. I mean, 01:23:21.000 |
the first six to seven times I do anything, I do it wrong. So given that I know that, 01:23:26.600 |
I might as well draw a picture to try and see how it's wrong until it's fixed. It also 01:23:30.280 |
tells me when it's not wrong. Isn't there a show_batch function now that does something 01:23:37.320 |
similar? Yes, you wrote that show_image_batch, didn't you? I can't quite remember. Yeah, 01:23:46.240 |
we should remind ourselves how that worked. That's a good point. Thanks for reminding 01:23:55.480 |
it. Okay, so then I'll just go ahead and do the same thing that Nish did. But then the 01:24:02.160 |
next thing I looked at was I looked at the, how am I going to make this train faster? 01:24:07.160 |
I want a higher learning rate. And I realized, oddly enough, the diffuser's code does not 01:24:16.960 |
initialize anything at all. They use the defaults, which just goes to show, like even the experts 01:24:24.440 |
at Hugging Face that don't necessarily think like, oh, maybe the PyTorch defaults aren't 01:24:33.680 |
perfect for my model. Of course they're not because they depend on what activation function 01:24:37.960 |
do you have and what resbox do you have and so forth. So I wasn't exactly sure how to 01:24:45.840 |
initialize it. Partly by chatting to Kat Crowley, who's the author of K. Diffusion, and partly 01:24:55.040 |
by looking at papers and partly by thinking about my own experience, I ended up doing 01:25:00.240 |
a few things. One is I did do the thing that we talked about a while ago, which is to take 01:25:06.640 |
every second convolutional layer and zero it out. You could do the same thing with using 01:25:12.080 |
batch norm, which is what we tried. And since we've got quite a deep network, that seemed 01:25:16.200 |
like it might, it helps basically by having the non-ID path in the resnets do nothing 01:25:25.560 |
at first. So they can't cause problems. We haven't talked about orthogonalized weights 01:25:33.280 |
before and we probably won't because you would need to take our computational linear algebra 01:25:40.580 |
course to learn about that, which is a great course. Rachel Thomas did a fantastic job 01:25:45.040 |
of it. I highly recommend it, but I don't want to make it a prerequisite. But Kat mentioned 01:25:49.320 |
she thought that using orthogonal weights for the downsamplers was a good idea. And 01:25:57.000 |
then for the app blocks, they also set the second comms to zero. And something Kat mentioned 01:26:03.840 |
she found useful, which is also from, I think it's from the Darrow while Google paper is 01:26:10.660 |
to also zero out the weights of basically the very last layer. And so it's going to 01:26:15.960 |
start by predicting zero as the noise, which is something that can't hurt. So that's how 01:26:24.160 |
I initialized the weights. So call in at DDPM on my model. Something that I found made a 01:26:31.400 |
huge difference is I replaced the normal atom optimizer with one that has an epsilon of 01:26:36.200 |
one in Nick five. The default I think is one in egg eight. And so to remind you, this is 01:26:44.000 |
when we divide by the kind of exponentially weighted moving average of the squared gradients, 01:26:51.400 |
when we divide by that, if that's a very, very small number, then it makes the effective 01:26:58.880 |
learning rate huge. And so we add this to it to make it not too huge. And it's nearly 01:27:04.680 |
always a good idea to make this bigger than the default. I don't know why the default 01:27:08.400 |
is so small. And I found until I did this, anytime I tried to use a reasonably large 01:27:13.320 |
learning rate somewhere around the middle of the one cycle training, it would explode. 01:27:20.120 |
So that makes a big difference. So this way, yeah, I could train, I could get 0.016 after 01:27:31.000 |
five epochs and then sampling. So it looks all pretty similar. We've got some pretty 01:27:38.480 |
nice textures, I think. So then I was thinking, how do I get faster? So one way we can make 01:27:44.640 |
it faster is we can take advantage of something called mixed precision. So currently we're 01:27:53.720 |
using 32 bit floating point values. That's the defaults and also known as single precision. 01:28:02.040 |
And GPUs are pretty fast at doing 32 bit floating point values, but they're much, much, much, 01:28:09.240 |
much faster at doing 16 bit floating point values. So 16 bit floating point values aren't 01:28:17.120 |
able to represent a very wide range of numbers or much precision at the difference between 01:28:23.000 |
numbers. And so they're quite difficult to use, but if you can, you'll get a huge benefit 01:28:28.820 |
because modern GPUs, modern Nvidia GPUs specifically have special units that do matrix multiplies 01:28:37.920 |
of 16 bit values extremely quickly. You can't just cast everything to 16 bit because then 01:28:46.280 |
you, there's not enough precision to calculate gradients and stuff properly. So we have to 01:28:51.280 |
use something called mixed precision. Depending on how enthusiastic I'm feeling, I guess we 01:29:01.000 |
ought to do this from scratch as well. We'll see. We do have an implementation from scratch 01:29:09.160 |
because we actually implemented this before Nvidia implemented it in an earlier version 01:29:14.360 |
of fast AI. Anyway, we'll see. So basically the idea is that we use 32 bit for things 01:29:23.400 |
where we need 32 bit and we use 16 bit for things we use 16 bit. So that's what we're 01:29:27.800 |
going to do is we're going to use this mixed precision. But for now, we're going to use 01:29:31.720 |
Nvidia's, you know, semi-automatic or fairly automatic code to do that for us. Actually, 01:29:40.600 |
we had a slight change of plan at this point when we realized this lesson was going to 01:29:45.160 |
be over three hours in length and we should actually split it into two. So we're going 01:29:49.840 |
to wrap up this lesson here and we're going to come back and implement this mixed precision