back to indexLesson 22: Deep Learning Foundations to Stable Diffusion
Chapters
0:0 Intro
0:30 Cosine Schedule (22_cosine)
6:5 Sampling
9:37 Summary / Notation
10:42 Predicting the noise level of noisy Fashion MNIST images
12:57 Why .logit() when predicting alpha bar t
14:50 Random baseline
16:40 mse_loss why .flatten()
17:30 Model & results
19:3 Why are we trying to predict the noise level?
20:10 Training diffusion without t - first attempt
22:58 Why it isn’t working?
27:2 Debugging (summary)
29:29 Bug in ddpm - paper that cast some light on the issue
38:40 Karras (Elucidating the Design Space of Diffusion - Based Generative Models)
49:47 Picture of target images
52:48 Scaling problem - (scalings)
59:42 Training and predictions of modified model
63:49 Sampling
66:5 Sampling: Problems of composition
67:40 Sampling: Rationale for rho selection
69:40 Sampling: Denosing
75:26 Sampling: Heun’s method fid: 0.972
79:0 Sampling: LMS sampler
80:0 Kerras Summary
83:0 Comparison of different approaches
85:0 Next lessons
00:00:00.000 |
All right. Hi, gang. And here we are in Lesson 21, joined by the legends themselves, Johnno 00:00:12.940 |
And today, you'll be shocked to hear that we are going to look at a Jupiter notebook. 00:00:25.160 |
We're going to look at notebook 22. This is pretty quick. Just, you know, improvement, 00:00:37.520 |
pretty simple improvement to our DDPM/DDIM implementation for fashion MNIST. And this 00:00:53.000 |
is all the same so far, but what I've done is I've made one quite significant change. 00:01:03.120 |
And some of the changes we'll be making today are all about making life simpler. And they're 00:01:08.320 |
kind of reflecting the way the papers have been taking things. And it's interesting to 00:01:14.000 |
see how the papers have not only made things better, they made things simpler. And so one 00:01:21.880 |
of the things that I've noticed in recent papers is that there's no longer a concept 00:01:29.200 |
of n steps, which is something we've always had before, and it always bothered me a bit, 00:01:36.000 |
this capital T thing. You know, this T over T, it's basically saying this is time step 00:01:44.480 |
number, say 500 out of 1000, so it's time step 0.5. Why not just call it 0.5? And the 00:01:55.920 |
answer is, well, we can. So we talked last time about the cosine scheduler. We didn't 00:02:01.400 |
end up using it because I came up with an idea which was, you know, simpler and nearly 00:02:08.760 |
the same, which is just to change our beta max. But in this next notebook, I decided let's 00:02:14.880 |
use the cosine scheduler, but let's try to get rid of the n steps thing and the capital 00:02:21.560 |
T thing. So here is A bar again. And now I've got rid of the capital T. So now I'm going 00:02:30.120 |
to assume that your time step is between 0 and 1, and it basically represents what percentage 00:02:37.320 |
of the way through the diffusion process are you? So 0 would be all noise, and 1 would be 00:02:42.840 |
- well, no, sorry, the other way around - 0 would be all clean, and 1 would be all noise. 00:02:47.400 |
So how far through the forward diffusion process? So other than that, this is exactly the same 00:02:53.540 |
equation we've already seen. And I realized something else, which is kind of fun, which 00:02:56.800 |
is you can take the inverse of that. So you can calculate T. So we would basically first 00:03:10.060 |
take the square root, and we would then take the inverse cos, and we would then divide 00:03:21.680 |
by 2 over pi, or times pi over 2. So we can both - so it's interesting now we don't - the 00:03:30.840 |
alpha bar is not something we look up in a list, it's something we calculate with a function 00:03:37.160 |
from a float. And so yeah, interestingly, that means we can also calculate T from an 00:03:42.280 |
alpha bar. So Noisify has changed a little. So now when we get the alpha bar through our 00:03:52.120 |
time step, we don't look it up, we just call the function. And now the time step is a random 00:04:00.920 |
float between 0 and 1, actually between 0 and 1.999, which actually I'm sure there's a function 00:04:09.920 |
I could have chosen to do a float in this range, but I just tapped it because I was lazy. Couldn't 00:04:14.960 |
be bothered hooking it up. Other than that, Noisify is exactly the same. So we're still 00:04:22.440 |
returning the xt, the time step, which is now a float, and the noise. That's the thing 00:04:30.960 |
we're going to try and predict, dependent variable, this tuple there as our inputs to 00:04:35.520 |
the model. All right, so here is what that looks like. So now when we look at our input 00:04:44.600 |
to our unit training process, you can see we've got a T of 0.05, so 5% of the way through 00:04:51.840 |
the forward diffusion process, it looks like this, and 65% through it looks like this. 00:04:58.000 |
So now the time step and basically the process is more of a kind of a continuous time step 00:05:08.200 |
and a continuous process. Rather before we were having these discrete time steps here, 00:05:13.280 |
we get just any random value that could be between 0 and 1. I think, yeah, that's also 00:05:19.720 |
Yeah, we should kind of get this more convenient, you know, to have... 00:05:23.240 |
To have a function to call. Yeah, I find this life a little bit easier. So the model's the 00:05:31.640 |
same, the callbacks are the same, the fitting process is the same. And so something which 00:05:39.400 |
is kind of fun is that we could now, we can, when we do now, create a little denoise function. 00:05:46.680 |
So we can take, you know, this batch of data that we generated, the noisified data, so 00:05:54.160 |
here it is again, and we can denoise it. So we know the T for each element, obviously. 00:06:03.880 |
So remember T is different for each element now. And we can therefore calculate the alpha 00:06:11.720 |
bar for each element. And then we can just undo the noisification to get the denoised 00:06:19.280 |
version. And so if we do that, there's what we get. And so this is great, right? It shows 00:06:25.360 |
you what actually happens when we run a single step of the model on variatingly, partially 00:06:33.960 |
noised images. And this is something you don't see very often because I guess not many people 00:06:39.120 |
are working in these kind of interactive notebook environments where it's really easy to do 00:06:42.640 |
this kind of thing. But I think this is really helpful to get a sense of like, okay, if you're 00:06:46.600 |
25% of the way through the forward diffusion process, this is what it looks like when you 00:06:51.800 |
undo that. If you're 95% of the way through it, this is what happens when you undo that. 00:06:59.800 |
So you can see here, it's basically like, oh, I don't really know what the hell's going 00:07:03.160 |
on, so at least a noisy mess. Yeah, I guess my feeling from looking at this is, I'm impressed, 00:07:15.720 |
you know, like this 45% noise thing, it looks all noise to me. It's found the long-sleeved 00:07:24.400 |
top. And yeah, it's actually pretty close to the real one. I looked it up, or you might 00:07:31.400 |
see it later, it's a little bit more of a pattern here, but it even gives a sense of 00:07:35.320 |
the pattern. So it shows you how impressive this is. So this is 35%. You can kind of see 00:07:41.040 |
there's a shoe there, but it's really picked up the shoe nicely. So these are very impressive 00:07:47.040 |
models in one step, in my opinion. So, okay, so sampling is basically the same, except 00:08:01.640 |
now rather than starting with using the range function to create our timesteps, we use 00:08:09.360 |
lin space to create our timesteps. So our timesteps start at, you know, if we did 1000, it would 00:08:17.560 |
be 0.999, and they end at 0, and then they're just linearly spaced with this number of steps. 00:08:24.520 |
So other than that, you know, a bar we now calculate, and the next a bar is going to 00:08:33.160 |
be whatever the current step is, minus 1 over steps. So if you're doing 100 steps, then 00:08:40.520 |
you'd be minus 0.01. So this is just stepping through linearly. And yeah, that's actually 00:08:51.680 |
it for changes. So if we just do DDIM for 100 steps, you know, that works really well. 00:09:04.000 |
We get a fit of 3, which is actually quite a bit better than we had on 100 steps for 00:09:13.160 |
our previous DDIM. So this definitely seems like a good sampling, sampling, sampling approach. 00:09:21.280 |
And I know Jono's going to talk a bit more shortly about, you know, some of the things 00:09:26.760 |
that can make better sampling approaches, but yeah, definitely we can see it making 00:09:31.960 |
a difference here. Did you guys have anything you wanted to say about this before we move 00:09:37.600 |
No, but it is a nice transition towards some of the other things we'll be looking at to 00:09:42.880 |
start thinking about how do we frame this. And it's also good, like the idea, so the 00:09:48.280 |
original DDPM paper has this 1,000 time steps, and a lot of people followed that. But the 00:09:53.680 |
idea that you don't have to be bound to that, and maybe it is worth breaking that convention. 00:09:57.520 |
I know Tanish made that meme about, you know, this 15 competing different standards connotation. 00:10:03.160 |
But yeah, sometimes it's helpful to reframe it, okay, time goes from 0 to 1. That can 00:10:07.400 |
simplify some things. It complicates others, but yeah, it's nice to think how you can reframe 00:10:13.480 |
In fact, where we will head today by the time we get to notebook 23, we will see, you know, 00:10:21.560 |
even simpler notation. And yeah, simpler notation generally comes. I think what happens is over 00:10:27.600 |
time people understand better what's the essence of the problem and the approach, and then 00:10:40.680 |
So okay, so the next one I wanted to share is something which is an idea we've been working 00:10:47.120 |
on for a while, and it's some new research. So partly, I guess this is an interesting 00:10:52.880 |
like insight into how we do research. So this is 22 noise pred. And the basic idea of this 00:11:02.440 |
was, well, actually, I've got to take you through it to see what the basic idea is. 00:11:06.480 |
So what I'm going to do is I'm going to create, okay, so fashion MNIST as before, but I'm 00:11:13.120 |
going to create a different kind of model. I'm not going to create a model that predicts 00:11:17.480 |
the noise given the noised image in t. Instead, I'm going to try to create a model which predicts 00:11:26.600 |
t given the noised image. So why did I want to do that? Well, partly, well, entirely because 00:11:34.000 |
I was curious. I felt like when I looked at something like this, I thought it was pretty 00:11:41.380 |
obvious roughly how much noise each image had. And so I thought, why are we passing 00:11:50.280 |
noise when we call the model? Why are we passing in the noised image and the amount of noise 00:11:55.640 |
or the t? Given that I would have thought the model could figure out how much noise 00:11:59.560 |
there is. So I wanted to check my contention, which is that the model could figure out how 00:12:04.520 |
much noise there is. So I thought, okay, well, let's create a model that would try and figure 00:12:08.200 |
out how much noise there is. So I created a different noisify now, and this noisify grabs 00:12:17.160 |
an alpha bar t randomly. And it's just a random number between 0 and 1. You don't want 1 per 00:12:32.600 |
item in the batch. And so then after just randomly grabbing an alpha bar t, we then 00:12:38.800 |
noisify in the usual way. But now our independent variable is the noised image and the dependent 00:12:45.080 |
variable is alpha bar t. And so we've got to try to create a model that can predict 00:12:48.920 |
alpha bar t given a noised image. Okay, so everything else is the same as usual. And 00:12:57.280 |
so we can see an example. You've got alpha bar t.squeeze.logit. Oh, yeah, that's true. 00:13:04.320 |
So the alpha bar t goes between 0 and 1. So we've got a choice. Like, I mean, we don't 00:13:12.080 |
have to do anything. But normally, if you've got something between 0 and 1, you might consider 00:13:16.400 |
putting a sigmoid at the end of your model. But I felt like the difference between 0.999 00:13:24.080 |
and 0.99 is very significant. So if we do log it, then we don't need the sigmoid at the 00:13:33.440 |
end anymore. It will naturally cover the full range of kind of-- it ought to be centered 00:13:39.280 |
at 0. It would have covered all the normal kind of range of numbers. And it also will 00:13:44.060 |
treat equal ratios as equally important at both ends of the spectrum. So that was my 00:13:51.800 |
hypothesis was that using logit would be better. I did test it and it was actually very dramatically 00:13:56.880 |
better. So without this logit here, my model didn't work well at all. And so this is like 00:14:01.760 |
an example of where thinking about these details is really important. Because if I hadn't 00:14:07.040 |
have done this, then I would have come away from this bit of research thinking like, oh, 00:14:10.720 |
I was wrong. We can't predict noise amount. Yeah. So thanks for pointing that out, China. 00:14:17.280 |
Yeah. So that's why in this example of a mini batch, you can see that the numbers can be 00:14:23.280 |
negative or positive. So 0 would represent alpha bar of 0.5. So here, 3.05 is not very 00:14:31.120 |
noise at all, or else negative 1 is pretty noisy. So the idea is that, yeah, given this 00:14:40.400 |
image, you would have to try to predict 3.05. So one thing I was kind of curious about is 00:14:49.680 |
like, it's always useful to know is like, what's the baseline? Like, what counts as good? Because 00:14:54.120 |
often people will say to me like, oh, I created a model and the MSE was 2.6. And I'll be like, 00:14:59.760 |
well, is that good? Well, it's the best I can do, but is it good? Like, or is it better 00:15:05.440 |
than random or is it better than predicting the average? So in this case, I was just like, 00:15:12.080 |
okay, well, what if we just predicted? Actually, this is slightly out of date. I should have 00:15:16.440 |
said 0 here rather than 0.5, but never mind close enough. So this is before I did the 00:15:24.040 |
logit thing. So I basically was looking at like, what's the loss if you just always predicted 00:15:34.760 |
a constant, which as I said, I should have put 0 here, haven't updated it. And so it's 00:15:41.760 |
like, oh, that would give you a loss of 3.5. Or another way to do it is you could just 00:15:49.560 |
put MSE here and then look at the MSE loss between 0.5 and your various, just a single 00:16:00.320 |
mini batch, mini batch of alphabets, logits. Yeah, so we wanted to get some, if we're getting 00:16:10.880 |
something that's about 3, then we basically haven't done any better than random. And so 00:16:16.760 |
in this case, this model, it doesn't actually have anything to learn. It always returns 00:16:23.260 |
the same thing. So we can just call fit with trade equals false just to find the loss. 00:16:28.480 |
So this is just a couple of ways of getting quickly finding a loss for a baseline naive 00:16:34.400 |
model. One thing that thankfully PyTorch will warn you about is if you try to use MSE and 00:16:48.800 |
your inputs and targets have different shapes, it will broadcast and give you probably not 00:16:53.720 |
the result you would expect, and it will give you a warning. So one way to avoid that is 00:16:58.480 |
just to use dot flatten on each. So this kind of flattened MSE is useful to avoid the warning 00:17:09.200 |
and also avoid getting weird errors or weird results. So we use that for our loss. So the 00:17:15.560 |
model's the model that we always use. So it's kind of nice. We just use our same old model. 00:17:22.320 |
Same changes. Even though we're doing something totally different. Oh, well, okay, that's 00:17:31.360 |
not quite true. One difference is that our output, we just have one output now, because 00:17:35.600 |
this is now a regression model. It's just trying to predict a single number. And so 00:17:39.880 |
our learner now uses MSE as a loss. Everything else is the same as usual. So we could go 00:17:48.240 |
ahead and trade it. And you can see, okay, the loss is already much better than 3, so 00:17:51.680 |
we're definitely learning something. And we end up with a 0.075 mean squared error. That's 00:18:01.080 |
pretty good considering, you know, there's a pretty wide range of numbers we're trying 00:18:05.800 |
to predict here. So I've got to save that as noise prediction on sigma. So save that 00:18:14.920 |
model. And so we can take a look at how it's doing by grabbing our one batch of noise images, 00:18:25.120 |
putting it through our T model. Actually, it's really an alpha bar model, but never 00:18:29.960 |
mind, call it a T model. And then we can take a look to see what it's predicted for each 00:18:35.120 |
one. And we can compare it to the actual for each one. And so you can see here it said, 00:18:42.240 |
oh, I think this is about 0.91. And actually, it is 0.91. I said, oh, here it looks like 00:18:47.760 |
about 0.36. And yeah, it is actually 0.36. So, you know, you can see overall 0.72. It's 00:18:54.120 |
actually 0.72. Well, those are exactly right. This one's 0.02 off. But yeah, my hypothesis 00:18:59.760 |
was correct, which is that we, you know, we can predict the thing that we were putting 00:19:06.480 |
in manually as input. So there's a couple of reasons I was interested in checking this 00:19:12.820 |
out. You know, the first was just like, well, yeah, wouldn't it be simpler if we weren't 00:19:19.040 |
passing in the T each time? You know, why not pass in the T each time? But it also felt 00:19:24.600 |
like it would open up a wider range of kind of how we can do sampling. The idea of doing 00:19:31.700 |
sampling by like precisely controlling the amount of noise that you try to remove each 00:19:37.720 |
time and then assuming you can remove exactly that amount of noise each time feels limited 00:19:45.000 |
to me. So I want to try to remove this constraint. So having, yeah, built this model, I thought, 00:19:56.560 |
okay, well, you know, which is basically like, okay, I think we don't need to pass T in. 00:20:01.180 |
Let's try it. So what I then did is I replicated the 22 cosine notebook. I just copied it, 00:20:06.720 |
pasted it in here. But I made a couple of changes. The first is that Noisify doesn't 00:20:16.440 |
return T anymore. So there's no way to cheat. We don't know what T is. And so that means 00:20:25.440 |
that the unit now doesn't have T, so it's actually going to pass zero every time. So 00:20:33.400 |
it has no ability to learn from T because it doesn't get T. So it doesn't really matter 00:20:38.840 |
what we pass in. We could have changed the unit to like remove the conditioning on T. 00:20:46.260 |
But for research, this is just as good, you know, for finding out. And it's good to be 00:20:51.600 |
lazy when doing research. There's no point doing something a fancy way where you can 00:20:55.080 |
do it a quick and easy way before you even know if it's going to work. So yeah, that's 00:21:00.700 |
the only change. So we can then train the model and we can check the loss. So the loss 00:21:06.880 |
here is 0.034. And previously it was 0.033. So interestingly, you know, maybe it's a tiny 00:21:21.080 |
bit worse at that, you know, but it's very close. Okay, so we'll save that model. And 00:21:31.600 |
then for sampling, I've got exactly the same DDIM step as usual. And my sampling is exactly 00:21:43.360 |
the same as usual, except now, when I call the model, I have no T to pass in. So we just 00:21:53.720 |
pass in this. I mean, I still know T because I'm still using the usual sampling approach, 00:22:02.560 |
but I'm not passing it to the model. And yeah, we can sample. And what happens is actually 00:22:11.360 |
pretty garbage. 22 is our fit. And as you can see here, you know, some of the images 00:22:24.720 |
are still really noisy. So I totally failed. And so that's always a little discouraging 00:22:34.720 |
when you think something's going to work and it doesn't. But my reaction to that is like, 00:22:39.280 |
if I think something's going to work and it doesn't is to think, well, I'm just going 00:22:43.040 |
to have to do a better job of it. You know, it ought to work. So I tried something different, 00:22:51.840 |
which is I thought like, okay, since we're not passing in the T, then we're basically 00:23:02.700 |
saying like, how much noise should you be removing? It doesn't know exactly. So it might 00:23:08.320 |
remove a little bit more noise that we want or a little bit less noise than we want. And 00:23:13.000 |
we know from the testing we did that sometimes it's out by like, in this case, 0.02. And 00:23:23.000 |
I guess if you're out consistently, sometimes it's got to end up not removing all the noise. 00:23:29.860 |
So the change I made was to the DDAM step, which is here. And let me just copy this and 00:23:39.880 |
get rid of the commented out sections just to make it a bit easier to read. Okay. So 00:23:50.320 |
the DDAM step, this is the normal DDAM step. Okay. And so step one is the same. So don't 00:23:56.320 |
worry about that because it's the same as we've seen before. But what I did was I actually 00:24:02.240 |
used my T model. So I passed the noised image into my T model, which is actually an alpha 00:24:08.900 |
bar model, to get the predicted alpha bar. And this is remember the predicted alpha bar 00:24:15.640 |
for each image, because we know from here that sometimes, so sometimes it did a pretty 00:24:20.860 |
good job, right? But sometimes it didn't. So I felt like, okay, we need a predicted alpha 00:24:26.480 |
bar for each image. What I then discovered is sometimes that could be really too low. 00:24:39.100 |
So what I wanted to make sure is it wasn't too crazy. So I then found the median for 00:24:43.760 |
a mini batch of all the predicted alpha bars, and I clamped it to not be too far away from 00:24:49.360 |
the median. And so then what I did when I did my X naught hat is rather than using alpha 00:24:58.400 |
bar T, I used the estimated alpha bar T for each image, clamped to be not too far away 00:25:06.720 |
from the median. And so this way it was updating it based on the amount of noise that actually 00:25:12.520 |
seems to be left behind, rather than the assumed amount of noise that should be left behind 00:25:19.240 |
you know, if we assume it's removed the correct amount. And then everything else is the same. 00:25:26.800 |
So when I did that, say, whoa, made all the difference. And here it is. They are beautiful 00:25:38.360 |
pieces of clothing. So 3.88 versus 3.2. That's possibly close enough, like I'd have to run 00:25:53.760 |
it a few times, you know, my guess is maybe it's a tiny bit worse, but it's pretty close. 00:26:00.360 |
But like this definitely gives me some encouragement that, you know, even though this is like something 00:26:10.400 |
I just did in a couple of days, where else the kind of the with T approaches have been 00:26:13.980 |
developed since 2015, and we're now in 2023. You know, I would expect it's quite likely 00:26:21.280 |
that these kind of like, no, no T approaches could eventually surpass the T based approaches. 00:26:35.080 |
And like one thing that definitely makes me think there's room to improve is if I plot 00:26:39.880 |
the fit or the kid, or each sample during the reverse diffusion process, it actually 00:26:45.080 |
gets worse for a while. I'm like, okay, well, that's, that's a bad sign. I have no idea 00:26:50.320 |
why that's happening. But it's a sign that, you know, if we could improve each step that 00:26:55.600 |
one would assume we could get better than 3.8. So yeah, Tanishko, do you have any thoughts 00:27:03.740 |
about that, or questions or comments or maybe to just like, to highlight that the research 00:27:12.440 |
process a little bit, it wasn't like this linear thing of like, Oh, here's this issue. 00:27:17.280 |
Not for me as well as we thought. Oh, here's the fix. We just scrapped this. You know, 00:27:22.300 |
this was like multiple days of like, discussing and like Jeremy saying, like, you know, I'm 00:27:26.760 |
taking my hair out. Do you guys have any ideas? And Oh, what about this? And Oh, and I just 00:27:30.360 |
in the team paper, they do this clamping, maybe that'll help. You know, so there's 00:27:33.240 |
a lot of back and forth. And also a lot of like, you saw that code that was commented 00:27:36.640 |
out there, prints, xt.min, xt.max, alphabar, pred, you know, just like seeing, oh, okay, 00:27:43.760 |
you know, my average prediction is about what I would expect. But sometimes the middle of 00:27:46.520 |
the max goes, you know, 2, 3, 8, 16, 150, 212 million, infinity, you know, maybe like 00:27:54.640 |
one or two little values that would just skyrocket out. Yeah, and so that kind of like, debugging 00:28:00.160 |
and exploring and printing things out. And actually, our initial discussions about this 00:28:05.800 |
idea, I kind of said to you guys, before lesson one of part two, I said, like, it feels to 00:28:12.600 |
me like we shouldn't need the t thing. And so it's actually been like, mumbling away 00:28:17.520 |
in the background for the months. Yeah, yeah. And I guess I mean, we should also mention 00:28:24.960 |
we have tried this, like a friend of ours trained a no T version of stable diffusion 00:28:28.960 |
for us. And we did the same sort of thing. I trained a pretty bad T predictor and it 00:28:34.160 |
sort of generates samples. So we're not like focusing on that large scale stuff yet. But 00:28:39.840 |
it is fun to like, every now and again, got this idea from fashion in this, we are trying 00:28:44.480 |
these out on some bigger models and seeing, okay, this does seem like maybe it'll work. 00:28:48.920 |
And to down the line that future plan is to say that's actually, you know, spend the time 00:28:52.400 |
train a proper model, and see, yeah, see how well that does. If it's interesting, you say 00:28:57.120 |
a friend of ours, we can be more specific. It's Robert, one of the two lead authors of 00:29:01.640 |
the stable diffusion paper who actually has been fine tuning a real stale stable diffusion 00:29:07.480 |
model, which is without T and it's looking super encouraging. So yeah, that'll be fun 00:29:16.600 |
to play with with this new, you know, we'll have to train a T predictor for that. See 00:29:22.160 |
how it looks. Yeah. All right. So I guess the other area we've been talking about kind 00:29:32.080 |
of doing some research on is this weird thing that came up over the last two weeks where 00:29:39.680 |
our bug in the DDPM implementation, where we accidentally weren't doing it from minus 00:29:45.600 |
one to one for the input range, it turned out that actually being from minus one to 00:29:50.760 |
one wasn't a very good idea anyway. And so we ended up centering it as being from minus 00:29:58.460 |
point five to point five, and John O and Tanishk have managed to actually find a paper. Well, 00:30:06.800 |
I say find a paper, a paper has come out in the last 24 hours, which has coincidentally 00:30:15.160 |
cast some light on this and has also cited a paper that we weren't aware of, which was 00:30:21.080 |
not released in the last 24 hours. So John O, are you going to tell us a bit about that? 00:30:25.120 |
Yeah, sure. I can do that. So it's funny, this was such perfect timing because I actually 00:30:31.600 |
got up early this morning planning to run with the different input scalings and the 00:30:36.720 |
cosine schedule that Jeremy was showing and some of the other schedulers we look at. I 00:30:40.000 |
thought it might be nice for the lesson to have a little plot of like, what is the fit 00:30:44.440 |
with these different solvers and input scalings, but it was going to be a lot of work. I'm 00:30:48.080 |
not looking forward to doing the groundwork. And then Tanishk sent me this paper, which 00:30:52.120 |
AK had just tweeted out because he reviews anything that comes up on archive every day 00:30:56.640 |
on the importance of noise scheduling for diffusion models. And this is by a researcher 00:31:01.000 |
at the Google Brain team, who's also done a really cool recent paper on something called 00:31:05.560 |
a recurrent interface network outside of the scope of this lesson, but also worth checking 00:31:09.680 |
out. Yeah, so this paper they're hoping to study this noise scheduling and the strategies 00:31:16.340 |
that you take for that. And they want to show that number one, those scheduling is crucial 00:31:20.320 |
for performance and the optimal one depends on the tasks. When increasing the image size, 00:31:25.320 |
the noise scheduling that you want changes and scaling the input data by some factor 00:31:31.960 |
is a good strategy for working with this. And that's the big thing we've been talking 00:31:35.800 |
about, right? Yeah, that's what we've been doing where we said, oh, do we scale from 00:31:39.320 |
minus 0.5 to 0.5 or minus 1 to 1 or do we normalize? And so they demonstrate the effectiveness 00:31:45.160 |
by training a really good high resolution model on image met, so class condition model. 00:31:51.320 |
That's correct. Yeah, amazing samples. They'll show one later. So I really like this paper. 00:31:56.760 |
It's very short and concise, and it just gets all the information across. And so they introduced 00:32:01.960 |
us here. We have this noising process on noiseifier function where we have square root of something 00:32:07.600 |
times x plus square root of 1 minus that something times the noise. And here they use gamma, 00:32:13.800 |
gamma of t, which is often used for the continuous time case. So instead of the alpha bar and 00:32:18.400 |
the beta bar scheduled for 1,000 times tapes, there'll be some function gamma of t that 00:32:22.720 |
tells you what your alpha bar should be. Okay, so that's our function is actually called 00:32:27.400 |
a bar, but it's the same thing. Yeah, same thing. It takes in a time set from 0 to 1, 00:32:32.720 |
and then that's used to noise the image. Interestingly, what they're showing here actually 00:32:37.040 |
is something that we had discovered, and I've been complaining about that my DTAMs with an 00:32:44.200 |
eater of less than one weren't working, which is to say when I added extra noise to the 00:32:49.200 |
image, it wasn't working. And what they're showing here is like, oh yeah, duh, if you 00:32:55.560 |
use a smaller image, then adding extra noise is probably not a good idea. 00:33:01.240 |
Yeah. And so they use a lot of reference in this paper to like information being destroyed 00:33:07.520 |
and signal to noise ratios. And that's really helpful for thinking about because it's not 00:33:11.920 |
something that's obvious, but at 64 by 64 pixels, adjacent pixels might have much less 00:33:17.600 |
in common versus the same amount of noise added at a much higher resolution, the noise 00:33:22.800 |
kind of averages out and you can still see a lot of the image. So yeah, that's one thing 00:33:26.800 |
they highlight is that the same noise level for different image sizes, it might be a harder 00:33:32.200 |
or easier task. And so they investigate some strategies for this. They look at the different 00:33:37.600 |
noise schedule functions. So we've seen the original version from the DTDM paper. We've 00:33:44.080 |
seen the cosine schedule and we've seen, I think we might look at, or the next thing 00:33:50.600 |
that Jamie's going to show us, a sigmoid based schedule. So they show the continuous 00:33:55.680 |
time versions of that and they plot how you can change various parameters to get these 00:33:59.660 |
different gamma functions or in our case, the alpha bar, where we starting at all image, 00:34:08.560 |
no noise at t equals zero, moving to all noise, no image at t equals one. But the path that 00:34:14.280 |
you take is going to be different for these different classes of functions and parameters. 00:34:19.760 |
And the signal to noise ratio, that's what this or the log signal to noise ratio is going 00:34:24.980 |
to change over that time as well. And so that's one of the knobs we can tweak. We're saying 00:34:30.040 |
our diffusion model isn't training that well, we think it might be related to the noise 00:34:33.780 |
schedule and so on. One of the things you could do is try different noise schedules, 00:34:37.080 |
either changing the parameters in one class of noise schedule or switching from a linear 00:34:42.240 |
to a cosine to a sigmoid. And then the second strategy is kind of what we were doing in 00:34:47.640 |
those experiments, which is just to add some scaling factor to exit error. 00:34:59.040 |
Exactly. And so that's the second dial that you can tweak is to say keeping your noise 00:35:03.960 |
schedule fixed, maybe you just scale x zero, which is going to change the ratio of signal 00:35:09.680 |
And that's what I think there's four in C there is what we were accidentally doing. 00:35:14.600 |
Yes. Yeah, exactly. And so see if we can get to Oh, yeah. So that again, changes the signal 00:35:23.560 |
to noise for different scalings you get. And so that's fine. So they have a compound, they 00:35:28.840 |
have a strategy that combines some of those things. And this is the important part, they 00:35:32.320 |
do their experiments. And so they have a nice table of investigating different schedules, 00:35:39.520 |
cosine schedules and sigmoid schedules. And in bold are the best results. And you can 00:35:43.640 |
see for 64 by 64 images versus 128 versus 256, the best schedule is not necessarily always 00:35:50.200 |
the same. And so that's like important finding number one, depending on what your data looks 00:35:56.640 |
like using a different noise schedule might be optimal. There's no one true best schedule. 00:36:01.840 |
There's no one bad value of, you know, beta min and beta max, that's just magically the 00:36:06.280 |
best. Likewise, for this input scaling at different sizes, with whatever schedules they tested, 00:36:16.520 |
and different values were kind of optimal. And so, yeah, it's just a really great illustration, 00:36:25.000 |
I guess that this is another design choice that's implicit or explicitly part of your 00:36:30.720 |
diffusion model training and sampling is how are you dealing with this, this noise schedule, 00:36:34.960 |
what schedule are you following, what scaling are you doing with your inputs. And by using 00:36:39.480 |
this thinking and doing these experiments, and they come up with a kind of rule of thumb 00:36:43.800 |
for how to scale the image based on image size, they show that they can, as they increase 00:36:49.400 |
the resolution, they can still maintain really good performance. Where previously it was 00:36:53.600 |
quite hard to train a really large resolution pixel space model, and they're able to do 00:37:00.280 |
that, they get some advantage from their fancy recurrent interface network, but still, it's 00:37:06.080 |
kind of cool that they can say, look, we get state of the art, high quality, 512 by 512 00:37:13.360 |
or 1024 by 1024 samples on class-conditioned image net. And using this approach to really 00:37:20.760 |
like consider how well do you train, how many steps do we need to take, one of the other 00:37:24.720 |
things in this table is that they compare to previous approaches. Oh, we used, you know, 00:37:29.440 |
a third of the training steps and for the same other settings, and we get better performance. 00:37:35.420 |
Just because we've chosen that input scaling better. And yeah, so that's the paper, really, 00:37:41.080 |
really nice, great work to the team. And that was very useful. 00:37:45.440 |
I love that you got up in the morning and thought, oh, it's going to be a hassle training 00:37:53.720 |
all these different models I need to train for different input scalings and different 00:37:58.600 |
sampling approaches. I just look at Twitter first, and then you looked at Twitter and 00:38:04.120 |
there was a paper saying like, hey, we just did a bunch of experiments for different noise 00:38:09.080 |
schedules and input scaling. Yeah, does your life always work that way each other? It seems 00:38:15.760 |
Yeah, it's very lucky like that. Yeah. You wait long enough, someone else will do it. 00:38:22.600 |
That's why it shows that the time when the UK starts posting on Twitter, it's like my 00:38:26.460 |
favourite hour of the day for all the papers to be posted. 00:38:32.040 |
Oh, well, thank you for that. So let me switch to notebook 23. Because this notebook is actually 00:38:54.200 |
an implementation of some ideas from this paper that everybody tends to just call it 00:39:01.120 |
Keras because there's other people. But I will do it anyway, Keras paper. And the reason 00:39:11.560 |
we're going to look at this is because in this paper, the authors actually take a much 00:39:22.840 |
more explicit look at the question of input scaling. Their approach was not apparently 00:39:32.000 |
to accidentally put a bug in their code, and then take it out, find it worked worse, and 00:39:36.800 |
then just put it back in again. Their approach was actually to think, how should things be? 00:39:42.920 |
So that's an interesting approach to doing things, and I guess it works for them. So 00:39:49.920 |
Yeah, exactly. Our approach is much more fun because you never quite know what's going 00:39:54.000 |
to happen. And so, yeah, in their approach, they actually tried to say, like, OK, given 00:40:01.960 |
all the things that are coming into our model, how can we have them all nicely balanced? 00:40:09.520 |
So we will skip back and forth between the notebook and the paper. So the start of this 00:40:16.760 |
is all the same, except now we are actually going to do it minus one to one because we're 00:40:23.080 |
not going to rely on accidental bugs anymore, but instead we're going to rely on the Keras 00:40:33.080 |
I say that, except that I put a bug in this notebook as well. One of the things that's 00:40:43.200 |
in the Keras paper is what is the standard deviation of the actual data, which I calculated 00:40:50.760 |
for a batch. However, this used to say minus 0.5. I used to do the minus 0.5 to 0.5 thing. 00:40:58.720 |
And so this is actually the standard deviation of the data before I, when it was still minus 00:41:03.480 |
0.5. So this is actually half the real standard deviation. For reasons I don't yet understand, 00:41:10.720 |
this is giving me better scaled results. So this actually should be 0.66. So there's still 00:41:18.040 |
a bug here and the bug still seems to work better. So we still got some mysteries involved. 00:41:22.720 |
So we're going to leave this. So it's actually, it's actually not 0.33, it's actually 0.66. 00:41:27.640 |
Okay, so the basic idea of this, actually I'll come back. Well, let me have a little 00:41:40.400 |
think. Yeah, okay. Now we'll start here. The basic idea of this paper is to say, you know 00:41:49.600 |
what, sometimes maybe predicting the noise is a bad idea. And so like you can either 00:42:01.880 |
try and predict the noise or you can try and predict the clean image and each of those 00:42:08.120 |
can be a better idea in different situations. If you're given something which is nearly 00:42:12.520 |
pure noise, you know, the model's given something which is nearly pure noise and is then asked 00:42:18.160 |
to predict the noise. That's basically a waste of time, because the whole thing's noise. 00:42:26.140 |
If you do the opposite, which is you try to get it predict the clean image. Well, then 00:42:30.640 |
if you give it a clean image that's nearly clean and try to predict the clean image, 00:42:33.680 |
that's nearly a waste of time as well. So you want something which is like, regardless 00:42:38.440 |
of how noisy the image is, you want it to be kind of like an equally difficult problem 00:42:42.640 |
to solve. And so what Keras do is they, they basically use this new thing called CSKIP, 00:42:57.440 |
which is a number, which is basically saying like, you know what we should do for the training 00:43:04.040 |
target is not just predict the noise all the time, not just predict the clean image all 00:43:10.000 |
the time, but predict kind of a looped version of one or the other depending on how noisy 00:43:16.840 |
it is. So here y is the clean image and n is the noise. So y plus n is the noised image. 00:43:35.000 |
And so if CSKIP was 0, then we would be predicting the clean image. And if CSKIP was 1, we would 00:43:44.800 |
be predicting y minus y, we would be predicting the noise. And so you can decide by picking 00:43:53.080 |
a different CSKIP whether you're predicting the clean image or the noise. And so, as you 00:43:58.440 |
can see from the way they've written it, they make this a function. They make it a function 00:44:02.080 |
of sigma. Now, this is where we got to a point now where we've kind of got a fairly, a much 00:44:08.760 |
simpler notation. There's no more alpha bars, no more alphas, no more betas, no more beta 00:44:13.600 |
bars. There's just a single thing called sigma. Unfortunately, sigma is the same thing as 00:44:19.980 |
alpha bar used to be, right? So we've simplified it, but we've also made things more confusing 00:44:24.920 |
by using an existing symbol for something totally different. So this is alpha bar. Okay. 00:44:30.680 |
So there's going to be a function that says, depending on how much noise there is, we'll 00:44:37.280 |
either predict the noise or we'll predict the clean image or we'll predict something 00:44:42.720 |
between the two. So in the paper, they showed this chart where they basically said like, 00:44:53.960 |
okay, let's look at the loss to see how good are we with a trained model at predicting 00:45:02.840 |
when sigma is really low. So when there's very small alpha bar, or when sigma is in 00:45:12.040 |
the middle or when sigma is really high. And they basically said, you know what, when it's 00:45:17.520 |
nearly all noise or nearly no noise, you know, we're basically not able to do anything at 00:45:24.240 |
all. You know, we're basically good at doing things when there's a medium amount of noise. 00:45:32.640 |
So when deciding, okay, what, what sigmas are we going to send to this thing? The first 00:45:38.040 |
thing we need to do is to, is to figure out some sigmas. And they said, okay, well, let's 00:45:44.440 |
pick a distribution of sigmas that matches this red curve here, as you can see. And so 00:45:51.840 |
this is a normally distributed curve where this is on a log scale. So this is actually 00:46:01.000 |
a log normal curve. So to get the sigmas that they're going to use, they picked a normally 00:46:06.980 |
distributed random number and then they expect it. And this is called a log normal distribution. 00:46:14.400 |
And so they used a mean of minus 1.2 and a standard deviation of 1.2. So that means that 00:46:24.120 |
about one third of the time, they're going to be getting a number that's bigger than 00:46:31.080 |
zero here. And e to the zero is one. So about one third of the time, they're going to be 00:46:38.620 |
picking sigmas that are bigger than one. And so here's a histogram I drew of the sigmas 00:46:47.640 |
that we're going to be using. And so it's nearly always less than five, but sometimes 00:46:55.960 |
it's way out here. And so it's quite hard to read these histograms. So this really nice 00:47:01.440 |
library called Seaborn, which is built on top of Matplotlib, has some more sophisticated 00:47:07.560 |
and often nicer looking plots. And one of them they have is called a KDE plot, which 00:47:12.200 |
is a kernel density plot. It's a histogram, but it's smooth. And so I clipped it at 10 00:47:20.040 |
so that you could see it better. So you can basically see that the vast majority of the 00:47:23.680 |
time it's going to be somewhere about 0.4 or 0.5, but sometimes it's going to be really 00:47:29.560 |
big. So our Noisify is going to pick a sigma using that log-normal distribution. And then 00:47:42.840 |
we're going to get the noise as usual, but now we're going to calculate C skip, right? 00:47:51.120 |
Because we're going to do that thing we just saw. We're going to find something between 00:47:56.520 |
the plain image and the noised input. So what do we use for C skip? We calculate it here. 00:48:10.560 |
And so what we do is we say what's the total amount of variance at some level of sigma? 00:48:17.640 |
Well it's going to be sigma squared, that's the definition of the variance of the noise, 00:48:23.000 |
but we also have the sigma of the data itself, right? So if we add those two together we'll 00:48:28.880 |
get the total variance. And so what the Keras paper said to do is to do the variance of 00:48:39.600 |
the data divided by the total variance and use that for C skip. So that means that if 00:48:48.520 |
your total variance is really big, so in other words it's got a lot of noise, then C skip's 00:48:54.300 |
going to be really small. So if you've got a lot of noise then this bit here will be 00:49:00.000 |
really small. So that means if there's a lot of noise try to predict the original image, 00:49:06.360 |
right? That makes sense because predicting the noise would be too easy. If there's hardly 00:49:10.900 |
any noise then this will be, total variance will be really small, right? So C skip will 00:49:17.520 |
be really big and so if there's hardly any noise then try to predict the noise. And so 00:49:25.040 |
that's basically what this C skip does. So it's a kind of slightly weird idea is that 00:49:31.760 |
our target, the thing we're trying to do actually is not the input image, sorry the original 00:49:38.960 |
image, it's not the noise but it's somewhere between the two. And I've found the easiest 00:49:43.080 |
way to understand that is to draw a picture of it. So here is some examples of noised 00:49:49.560 |
input, right? With various sigma's, remember sigma is alpha bar, right? So here's an example 00:49:58.480 |
with very little noise, 0.06. And so in this case the target is predict the noise, right? 00:50:07.560 |
So that's the hard thing to do is predict the noise. Whereas here's an example, 4.53 00:50:14.400 |
which is nearly all noise. So for nearly all noise the target is predict the image, right? 00:50:22.560 |
And then for something which is a little bit between the two like here, 0.64, the target 00:50:28.240 |
is predict some of the noise and some of the image. So that's the idea of Paris. And so 00:50:38.080 |
what this does is it's making the problem to be solved by the unit equally difficult 00:50:47.400 |
regardless of what sigma is. It doesn't solve our input scaling problem, it solves our kind 00:50:55.440 |
of difficulty scaling problem. To solve the input scaling problem they do it. 00:51:02.160 |
I just want to make one quick note. And so like this sort of idea of like is also interpolating 00:51:08.720 |
between the noise and the image is this similar to what's called the B-objectives as well. 00:51:15.440 |
So there's also a similar kind of it's yeah, it's very quite similar to what Keras and 00:51:19.120 |
Dell has, but that's also not been used in a lot of different models. Like for example 00:51:24.400 |
Stable Diffusion 2.0 was trained with this sort of B-objective. So people are using this 00:51:29.440 |
sort of methodology and getting good results. And yeah, so it's an actual practical thing 00:51:35.600 |
that people are doing. So yeah, I just want to make a note of that. 00:51:38.840 |
Yeah, as is the case of basically all papers created by Nvidia researchers, of which this 00:51:46.360 |
is one. It flies under the radar and everybody ignores it. The V-objective paper came from 00:51:54.280 |
the senior author was Jim Salomons, which is Google, right? Yeah. And so anything from 00:52:00.160 |
Google and OpenAI everybody listens to. So yeah, although Keras I think has done the 00:52:04.880 |
more complete version of this, and in fact the V-objective was almost like mentioned 00:52:12.680 |
in passing in the distillation paper. But yeah, that's the one that everybody has ended 00:52:17.520 |
up looking at. But I think this is the more... 00:52:19.960 |
Yeah, I think what happened with the V-objective is not many people get attention to it. I 00:52:24.960 |
think folks like Kat and Robin and these sorts of folks are actually paying attention to 00:52:29.280 |
that V-objective in that Google brain paper. But then also this paper did a much more principled 00:52:35.040 |
analysis of this sort of thing. So yeah, I think it's very interesting how sometimes 00:52:40.160 |
even these sort of side notes in papers that maybe people don't pay much attention to, 00:52:47.120 |
Yeah. Yeah. So, okay. So the noised input as usual is the input image plus the noise 00:52:54.000 |
times the sigma. But then, and then as we discussed, we decide how to kind of decide 00:53:01.760 |
what our target is. But then we actually take that noised input and we scale it up or down 00:53:09.520 |
by this number. And the target, we also scale up or down by this number. And those are both 00:53:18.440 |
calculated in this thing as well. So here's C out and here's C in. 00:53:28.080 |
Now I just wanted to show one example of where these numbers come from because for a while 00:53:32.200 |
they all seem pretty mysterious to me and I felt like I'd never be smart enough to understand 00:53:36.320 |
them, particularly because they were explained in the mathematical appendix of this paper, 00:53:41.760 |
which are always the bits I don't understand, until I actually try to and then it tends 00:53:45.960 |
to turn out they're not so bad after all, which was certainly the case here, which? 00:53:56.000 |
I think it was B something, I think. So B6, I think? Is that the one? 00:54:04.880 |
So in appendix B6, which does look pretty terrifying, but if you actually look at, for example, 00:54:14.600 |
what we're just looking at, C in, it's like how do they calculate? So C in is this. Now 00:54:21.360 |
this is the variance of the noise, this is the variance of the data, add them together 00:54:26.240 |
to get the total variance, square roots, the total standard deviation. So it's just the 00:54:31.240 |
inverse of the total standard deviation, which is what we have here. Where does that come 00:54:36.440 |
from? Well, they just said, you know what? The inputs for a model should have unit variance. 00:54:45.000 |
Now we know that. We've done that to dare in this course. So they said, right. So well, 00:54:53.760 |
the inputs to the model is the, the clean data plus the noise times some number we're 00:55:03.520 |
going to calculate and we want that to be one. Okay. So the variance of the plane images 00:55:14.660 |
plus the noise is equal to the variance of the clean images plus the variance of the 00:55:20.600 |
noise. Okay. So if we want that to be, if we want variance to be one, then divide both 00:55:31.040 |
sides by this and take the square root. And that tells us that our multiplier has to be 00:55:36.840 |
one over this. That's it. So it's like literally, you know, classical math. The only bit you 00:55:45.440 |
have to know is that the variance of three things added together is the variance of the 00:55:51.960 |
two things added together, which is not rocket science either. 00:55:57.320 |
And in this context, like why we want to do this, when we looked at those sigma's that 00:56:02.120 |
you're putting like the distribution, you've got some that are fairly low, but you've also 00:56:05.440 |
got some where the standard deviation sigma is like 40, right? So the variance is super 00:56:09.760 |
high. Yes. And so we don't want to feed something with standard deviation 40 into our model. 00:56:14.800 |
You would like it to be closer to unit variance. So we're thinking, okay, well, if you divide 00:56:18.360 |
by roughly 40, that would scare it down. But then we've also got some extra variance from 00:56:22.920 |
our data. It's just like 40 plus space of the data of a little bit. We want to scale 00:56:30.040 |
back down by that to get unit variance. Yeah. I mean, I love this paper because it's basically 00:56:34.440 |
just doing what we spent weeks doing of like, I feel like everything that we've done that's 00:56:41.920 |
improved every model has always been one thing, which is, can we get mean zero variance one 00:56:51.360 |
inputs to our model and for all of our activations? And then the only other thing is include enough 00:57:00.200 |
compute by adding enough layers and enough activations. Those two things seem to be all 00:57:05.960 |
that matters. Basically, well, I guess ResNet's added an extra cool little thing to that, 00:57:11.960 |
which is to make it even smoother by giving this kind of like identity path. So yeah, 00:57:21.600 |
basically trying to make things as smooth as possible and as equal everywhere as possible. 00:57:30.880 |
So yeah, this is what they've done. So they did that for the inputs, and then they've 00:57:34.360 |
also done it for the outputs and for the outputs, it's basically the same idea. They have basically 00:57:43.840 |
the same kind of analysis to show that. And so with this, so now, yeah, we've basically 00:57:51.920 |
we've got our noised input, we've got the linear version somewhere between X0 and the 00:58:00.120 |
noised input, we've got the scaling of the output and we've got the scaling of the input. 00:58:05.440 |
So now for the inputs to our model, we're going to have the scaled noise, we're going 00:58:11.680 |
to have the sigma and we're going to have the target, which is somewhere between the 00:58:17.920 |
image and the noise. And so, yeah, so I've never seen anybody draw a picture of this 00:58:26.200 |
before. So it was really cool when being in a notebook, being able to see like, oh, that's 00:58:30.800 |
what they're doing. So yeah, have a good look at this notebook to see exactly what's going 00:58:36.680 |
on because I think it gives you a really good intuition around what problem it's trying 00:58:41.400 |
to solve. So then I actually checked the noised input has a standard deviation of 1, the main's 00:58:48.640 |
not 0 and of course, why would it be? We didn't do anything. The only thing Keras cared about 00:58:53.440 |
was having the variance 1. We could easily adjust the input and output to have a mean 00:58:59.780 |
of 0 as well. And that's something I think we or somebody should try because I think 00:59:04.400 |
it does seem to help a bit as we saw with that generalised value stuff we did, but it's 00:59:09.120 |
less important than the variance. And so same with the target, it's got the 1. And yeah, 00:59:14.320 |
this is where if I change this to the correct value, which is 0.66, then actually it's slightly 00:59:21.440 |
further away from 1, both here and here, quite a lot further away. And maybe that's because 00:59:26.440 |
actually the data is, well, we know the data is not Gaussian distributed. Pixel data definitely 00:59:32.240 |
isn't Gaussian distributed. So this bug turned out better. Okay. So the unit's the same, 00:59:41.080 |
the initialisation's the same. This is all the same. Train it for a while. We can't compare 00:59:47.200 |
the losses because our target's different. But what we can do is we can create a denoise 00:59:55.440 |
that just takes the thing that, as per usual, the thing we had in noisify, right? And so 01:00:04.600 |
for x0, it's going to multiply by c out and then add c skip by noised input. Here it is, 01:00:10.440 |
multiply by c out, add noised input, c skip. Okay, so we can denoise. So let's grab our 01:00:21.360 |
sigmas from the actual batch we had. Let's calculate c skip c out and c in for the sigmas 01:00:27.840 |
in our mini batch. Let's use the model to predict the target given the noised input 01:00:36.920 |
and the sigmas, and then denoise it. And so here's our noised input, which we've already 01:00:44.120 |
seen, and here's our predictions. And these are absolutely remarkable, in my opinion. Yeah. 01:00:56.480 |
Like this one here, I can barely see it. You know, it's really found, look at the shirt. 01:01:01.720 |
There's a shirt here. It's actually really finding the little thing on the front. And 01:01:04.960 |
let me show you, here's what it should look like, right? And in cases where the sigma's 01:01:12.520 |
pretty high, like here, you can see it's really like saying, like, I don't know, maybe it's 01:01:18.080 |
shoes, but it could be something else. Is it shoes? Yeah, it wasn't shoes. But at least 01:01:23.600 |
it's kind of got the, you know, the bulk of the pixels in the right spot. Yeah, something 01:01:30.400 |
like this one is 4.5, has no idea what it is. It's like, oh, maybe it's shoes, maybe 01:01:35.240 |
it's pants. You know, it turns out it is shoes. Yeah. So I think that's fascinating how well 01:01:43.920 |
it can do. And then the other thing I did, which I thought was fun, was I just created, 01:01:53.080 |
so I just, you did a sigma of 80, which is actually what they do when they're doing sampling 01:01:57.520 |
from pure noise. That's what they consider the pure noise level. So I just created some 01:02:03.020 |
pure noise and denoised it just for one step. And so here's what happens when you denoise 01:02:11.400 |
it for one step. And you can see it's kind of overlaid all the possibilities. It's like, 01:02:16.560 |
I can see a pair of shoes here, a pair of pants here at top here. And sometimes it's 01:02:22.320 |
kind of like more confident that the noise is actually a pair of pants. And sometimes 01:02:27.440 |
it's more confident that it's actually shoes. But you can really get a sense of how like 01:02:32.520 |
from pure noise, it starts to make a call about like what this noise is actually covering 01:02:39.760 |
up. And this is also the bit which I feel is like, I'm the least convinced about when 01:02:48.160 |
it comes to diffusion models. This first step of going from like pure noise to something 01:02:56.960 |
and like trying to have a good mix of all the possible somethings, I'm, I don't know, 01:03:02.160 |
it feels a bit hand-waving to me. It clearly works quite well, but I'm not sure if it's 01:03:05.680 |
like we're getting the full range of possibilities. And I feel like some of the papers we're starting 01:03:12.520 |
to see are starting to say like, you know what, maybe this is not quite the right approach. 01:03:17.240 |
And maybe later in the course, we'll look at some of the ones that look at what we call 01:03:21.880 |
VQ models and tokenized stuff. Anyway, I thought this is pretty interesting to see these pictures, 01:03:30.560 |
which I don't think, yeah, I've never seen any pictures like this before. So I think 01:03:35.080 |
this is a fun result from doing all this stuff in notebooks step by step. 01:03:42.080 |
Okay, so sampling. So one of the nice things with this is the sampling becomes much, much, 01:03:52.160 |
much simpler. And so, and I should mention a lot of the code that I'm using, particularly 01:04:01.360 |
in the sampling section is heavily inspired by, and some of it's actually copied and pasted 01:04:06.400 |
from Kat's K-diffusion repo, which is, I think I mentioned before, some of the nicest generative 01:04:17.160 |
modeling code or maybe the nicest generative modeling code I've ever seen. It's really 01:04:22.480 |
great. So before we talk about the actual sampling, 01:04:27.600 |
the first thing we need to talk about is what sigma do we use at each reverse time step? 01:04:34.400 |
And in the past, we've always, well, nearly always done something, which I think has always 01:04:39.200 |
felt is sketchy as all hell, which is we've just linearly gone down the sigmas or the alpha 01:04:46.480 |
bars or the t's. So here, when we're sampling in the previous notebook, we used lin space. 01:04:52.340 |
So I always felt like that was questionable. And I felt like at the start, you probably, 01:04:59.040 |
like it was just noise anyway. So who cared? Who cares? So I, in DDPMv3, I experimented 01:05:07.880 |
with something that I thought intuitively made more sense. I don't know if you remember this 01:05:13.360 |
one, but I actually said, oh, let's, for the first 100 times steps, let's actually only 01:05:20.160 |
run the model every 10 times. And then for the next 100, let's run it nine times. The 01:05:24.520 |
next 100, let's run it every eight times. So basically at the start, be much less careful. 01:05:30.360 |
And so Keras actually ran a whole bunch of experiments and they said, yeah, you know what? 01:05:41.040 |
At the start of training, you know, you can start with a high sigma, but then like step 01:05:47.200 |
to a much lower sigma in the next step and then a much lower sigma in the next step. 01:05:51.680 |
And then the longer, the more you train step by smaller and smaller steps so that you spend 01:05:57.440 |
a lot more time fine-tuning carefully at the end and not very much time at the start. Now, 01:06:07.860 |
this has its own problems. And in fact, a paper just came out today, which we probably 01:06:12.060 |
won't talk about today, but maybe another time, which talked about the problems is that in 01:06:16.320 |
these very early steps, this is the bit where you're trying to create a composition that 01:06:21.600 |
makes sense. Now for fashion MNIST, we don't have much composing to do. It's just a piece 01:06:26.400 |
of clothing. But if you're trying to do an astronaut riding a horse, you know, you've 01:06:31.520 |
got to think about how all those pieces fit together. And this is where that happens. 01:06:36.040 |
And so I do worry that with the Keras approach is what's not giving that maybe enough time. 01:06:41.160 |
But as I've said, that's really the same as this step. That whole piece feels a bit wrong 01:06:47.240 |
to me. But aside from that, I think this makes a lot of sense, which is that, yeah, the sampling, 01:06:54.080 |
you should jump, you know, by big steps early on and small steps later on and make sure 01:06:59.720 |
that the fine details are just so. So that's what this function does, is it creates this 01:07:06.800 |
plot. Now it's this schedule of reverse diffusion sigma steps. It's a bit of a weird function 01:07:17.680 |
in that it's the rowth root of sigma, where row is seven. So the seventh root of sigma 01:07:26.580 |
is basically what it's scaling on. But the answer to why it's that is because they tried 01:07:33.280 |
it and it turned out to work pretty well. Do you guys remember where this was? 01:07:48.280 |
That's very. So this image here, so thanks for telling me where this is, shows fed as 01:08:03.640 |
a function of row. So it's basically what the whatth root are we taking? And they basically 01:08:10.680 |
said, like, if you take the fifth root up, it seems to work well, basically. So, yeah, 01:08:19.600 |
so that's a perfectly good way to do things is just to try things and see what works. 01:08:24.400 |
And you'll notice they tried things just like we love on small data sets. Not as small as 01:08:28.520 |
us because we're the king of small data sets, but small ish, so far 10, the image net 64. 01:08:35.200 |
That's the way to do things. I saw, like, I might have even been the CEO of Hugging Face 01:08:40.980 |
the other day, tweet something saying only people with huge amounts of GPUs can do research 01:08:45.360 |
now. And I think it totally misunderstands how research is done, which is research is 01:08:49.240 |
done on very small data sets. That's that's the actual research. And then when you're 01:08:55.400 |
all done, you scale it up at the end. I think we're kind of pushing the envelope in terms 01:09:02.220 |
of like, yeah, how how much can you do? And yeah, we've, like, we covered this kind of 01:09:12.360 |
main substantive path of diffusion models history, step by step, showing every improvement 01:09:19.340 |
and seeing clear improvements across all the papers using nothing but fashioned MNIST running 01:09:24.160 |
on a single GPU in like 15 minutes of training or something per model. So, yeah, definitely 01:09:30.080 |
you don't need lots of models. Anyway, OK, so this is the sigma we're going to jump to. 01:09:37.960 |
So the denoising is going to involve calculating the C skip, C out and C in and calling our 01:09:44.400 |
model with the C in scaled data and the sigma and then scaling it with C out and then doing 01:09:53.120 |
the C skip. OK, so that's just undoing the Noisify. So check this out. This is all that's 01:10:00.360 |
required to do one step of denoising for the simplest kind of scheduler, which sorry, the 01:10:07.240 |
simplest kind of sampler, which is called Euler. So we basically say, OK, what's the 01:10:11.400 |
sigma at time step I? What's the sigma 2 at time step I? And now when I'm talking about 01:10:19.120 |
at time step, I'm really talking about like the step from this function. Right. So this 01:10:24.560 |
is this is the sampling step. Yeah. OK, so then denoise using the function and then we 01:10:35.840 |
say, OK, well, just send back whatever you were given plus move a little bit in the direction 01:10:43.680 |
of the denoised image. So the direction is X minus denoised. So that's the noise. That's 01:10:49.880 |
the gradient as we discussed right back in the first lesson of this part. So we'll take 01:10:54.840 |
the noise. If we divide it by sigma, we get a slope. It's how much noise is there per 01:11:01.400 |
sigma. And then the amount that we're stepping is sigma 2 minus sigma 1. So take that slope 01:11:08.960 |
and multiply it by the change. Right. So that's the distance to travel towards the noise at 01:11:18.880 |
this fraction. You know, or you could also think of it this way. And I know this is a 01:11:23.080 |
very obvious algebraic change. But if we move this over here, you could also think of this 01:11:31.040 |
as being, oh, of the total amount of noise, the change in sigma we're doing, what percentage 01:11:37.840 |
is that? OK, well, that's the amount we should step. So there's two ways of thinking about 01:11:43.760 |
the same thing. So again, this is just, you know, high school math. Well, I mean, actually, 01:11:56.160 |
my seven-year-old daughter has done all these things. It's plus minus divided in times. 01:12:03.840 |
So we're going to need to do this once per sampling step. So here's a thing called sample, 01:12:13.000 |
which does that. It's going to go through each sampling step, call our sampler, which initially 01:12:21.520 |
we're going to do sample Euler. Right. With that information, add it to our list of results 01:12:30.920 |
and do it again. So that's it. That's all the sampling is. And of course, we need to grab 01:12:37.160 |
our list of sigmas to start with. So I think that's pretty cool. And at the very start, 01:12:43.960 |
we need to create our pure noise image. And so the amount of noise we start with is got 01:12:49.000 |
a sigma of 80. OK, so if we call sample using sample Euler and we get back some very nice 01:13:02.400 |
looking images and believe it or not, our fed is 1.98. So this extremely simple sampler, 01:13:13.240 |
three lines of code plus a loop has given us a bit of 1.98, which is clearly substantially 01:13:30.520 |
better than our coastline. Now we can improve it from there. So one potential improvement 01:13:38.640 |
is to you might have noticed we added no new noise at all. This is a deterministic scheduler. 01:13:46.520 |
There's no rand anywhere here. So we can do something called an ancestral Euler sampler, 01:13:54.280 |
which does add rand. So we basically do the denoising in the usual way, but then we also 01:14:02.600 |
add some rand. And so what we do need to make sure is given that we're adding a certain 01:14:07.080 |
amount of randomness, we need to remove that amount of randomness from the step that we 01:14:13.240 |
take. So I won't go into the details, but basically there's a way of calculating how 01:14:22.160 |
much new randomness and how much just going back in the existing direction do we do. And 01:14:29.080 |
so there's the amount in the existing direction and there's the amount in the new random direction. 01:14:35.240 |
And you can just pass in eta, which is just going to, when we pass it into here, is going 01:14:44.040 |
to scale that. So if we scale it by half, so basically half of it is new noise and half 01:14:53.720 |
of it is going in the direction that we thought we should go, that makes it better still. 01:14:59.800 |
Again with 100 steps. And just make sure I'm comparing to the same, yep, 100 steps. Okay, 01:15:06.880 |
so that's fair, like with like. Okay, so that's adding a bit of extra noise. Now then, something 01:15:17.120 |
that I think we might have mentioned back in the first lesson of this part is something 01:15:24.840 |
called Heun's method. And Heun's method does something which we can pictorially see here 01:15:34.640 |
to decide where to go, which is basically we say, okay, where are we right now? What's 01:15:40.160 |
the, you know, at our current point, what's the direction? So we take the tangent line, 01:15:46.080 |
the slope, right? That's basically all it does is it takes a slope. It says, oh, here's 01:15:49.960 |
a slope, you know. Okay, and so if we take that slope, and that would take us to a new 01:16:05.880 |
spot, and then at that new spot, we can then calculate a slope at the new spot as well. 01:16:16.840 |
And at the new spot, the slope is something else. So that's it here, right? And then you 01:16:25.960 |
say, like, okay, well, let's go halfway between the two. And let's actually follow that line. 01:16:32.520 |
And so basically, it's saying, like, okay, each of these slopes is going to be inaccurate. 01:16:38.080 |
But what we could do is calculate the slope of where we are, the slope of where we're 01:16:41.400 |
going, and then go halfway between the two. It's, I actually found it easier to look at 01:16:47.200 |
in code personally. I'm just going to delete a whole bunch of stuff that's totally irrelevant 01:16:53.280 |
to this conversation. So take a look at this compared to Euler. So here's our Euler, right? 01:17:05.400 |
So we're going to do the same first line exactly the same, right? Then the denoising is exactly 01:17:10.960 |
the same. And then this step here is exactly the same. I've actually just done it in multiple 01:17:17.120 |
steps for no particular reason. And then you say, okay, well, if this is the last step, 01:17:25.720 |
then we're done. So actually, the last step is Euler. But then what we do is we then say, 01:17:31.640 |
well, that's okay, for an Euler step, this is where we'd go. Well, what does that look 01:17:39.080 |
like if we denoise it? So this calls the model the second time, right? And where would that 01:17:44.080 |
take us if we took an Euler step there? And so here, if we took an Euler step there, what's 01:17:50.420 |
the slope? And so what we then do is we say, oh, okay, well, it's just, just like in the 01:17:55.640 |
picture, let's take the average. Okay, so let's take the average and then use that, 01:18:06.880 |
the step. So that's all the HUIN sampler does is just takes the average of the slope where 01:18:13.280 |
we're at and the slope where the Euler method would have taken us. And so if we now so notice 01:18:20.240 |
that it called the model twice for a single step. So to be fair, since we've been taking 01:18:25.080 |
100 steps with Euler, we should take 50 steps with HUIN, right? Because it's going to call 01:18:29.400 |
the model twice. And still that is now whoa, we beat one, which is pretty amazing. And 01:18:38.360 |
so we could keep going, check this out, we could even go down to 20. This is actually 01:18:41.520 |
doing 40 model evaluations and this is better than our best Euler, which is pretty crazy. 01:18:47.600 |
Now, something which you might have noticed is kind of weird about this or kind of silly 01:18:52.600 |
about this is we're calling the model twice just in order to average them. But we already 01:19:00.120 |
have two model results, like without calling it twice, because we could have just looked 01:19:04.760 |
at the previous time step. And so something called the LMS sampler does that instead. 01:19:16.920 |
And so the LMS sampler, if I call it with 20, it actually literally does 20 evaluations 01:19:22.640 |
and actually it beats Euler with 100 evaluations. And so LMS, I won't go into the details too 01:19:29.560 |
much. It didn't actually fit into my little sampling very well. So basically largely copied 01:19:33.480 |
and pasted the cat's code. But the key thing it does is look, it gets the current sigma, 01:19:39.280 |
it does the denoising, it calculates the slope, and it stores the slope in a list, right? 01:19:48.640 |
And then it grabs the first one from the list. So it's kind of keeping a list of up to, in 01:19:58.960 |
this case, four at a time. And so it then uses up to the last four to basically, yes, kind 01:20:07.920 |
of the curvature of this and take the next step. So that's pretty smart. And yeah, I think 01:20:20.160 |
if you wanted to do super fast sampling, it seems like a pretty good way to do it. And 01:20:26.360 |
I think, Johnno, you're telling me that, or maybe it's Pedro was saying that currently 01:20:32.120 |
people have started to move away. This was very popular, but people started to move towards 01:20:36.280 |
a new sampler, which is a bit similar called the DPM++ sampler, something like that. Yeah. 01:20:43.320 |
Yeah. Yeah. Yeah. But I think it's the same idea. I think it kind of keeps a, I said, 01:20:51.200 |
keep a list of recent results and use that. I'll have to check it more closely. 01:20:57.480 |
I'll have to look at the code. Yeah. That's a similar idea. It's like, if it's done more 01:21:01.360 |
than one step, then it's using some history to the next thing. Yeah. This history and 01:21:07.760 |
thing doesn't make a huge amount of sense, I guess, from that perspective. I mean, still 01:21:14.080 |
works very well. This makes more sense. So then we can compare if we use an actual mini 01:21:18.920 |
match of data, we get about 0.5. So yeah, I feel like this is quite a stunning result 01:21:33.240 |
to get close to, very close to real data, this in terms of fit, really with 40 model 01:21:44.160 |
evaluations and the entire, nearly the entire thing here is by making sure we've got unit 01:21:54.440 |
variance, inputs, unit variance, outputs, and kind of equally difficult problems to 01:21:59.760 |
solve in our loss function. Yeah. Thus having that different schedule for sampling. That's 01:22:06.240 |
completely unrelated to the training schedule. I think that was one of the big things with 01:22:09.760 |
Karas et al's paper was they also could apply this to like, oh, existing diffusion models 01:22:15.360 |
that have been trained by other papers. We can use our sampler and in fewer steps get 01:22:19.160 |
better results without any of the other changes. And yeah, I mean, they do a little bit of 01:22:25.680 |
rearranging equations to get the other papers versions into their C skip C and C out framework. 01:22:34.240 |
But then yeah, it's really nice that these ideas can be applied to, so for example, I 01:22:38.640 |
think stable diffusion, especially version one was trained DDPM style training, Epsilon 01:22:44.720 |
objective, whatever. But you can now get these different samplers and different sometimes 01:22:50.400 |
schedules and things like that and use that to sample it and do it in 15, 20 steps and 01:22:56.680 |
get pretty nice samples. Yeah. And another nice thing about this paper is they, in fact, 01:23:08.400 |
the name of the paper elucidating the design space of diffusion based models. They looked 01:23:14.960 |
at various different papers and approaches and trying to set like, oh, you know what? 01:23:19.080 |
These are all doing the same thing when we kind of parameterize things in this way. And 01:23:25.320 |
if you fill in these parameters, you get this paper and these parameters, you get that paper. 01:23:29.760 |
And then so we found a better set of parameters, which it was very nice to code because, you 01:23:42.360 |
know, it really actually ended up simplifying things a whole lot. And so if you look through 01:23:48.520 |
the notebook carefully, which I hope everybody will, you'll see, you know, that the code 01:23:56.240 |
is really there and simple compared to all the previous ones, in my opinion. Like, I 01:24:04.680 |
feel like every notebook we've done from DDPM onwards, the code's got easier to understand. 01:24:13.320 |
And just to again clarify like how this connects with some of the previous papers that we've 01:24:19.200 |
looked at. So like, for example, with the DDIM, the deterministic, that's again, the sort 01:24:24.520 |
of deterministic approach that's similar to the Euler method sampler that we were just 01:24:30.800 |
looking at, which was completely deterministic. And then some of something like the Euler 01:24:35.520 |
ancestral that we were looking at is similar to the standard DDPM approach with that was 01:24:43.680 |
kind of a more stochastic approach. So again, there's just all those sorts of connections 01:24:48.760 |
that then are kind of nice to see, again, the sorts of connections between the different 01:24:52.880 |
papers and how they change it, how they can be expressed in this common framework. 01:24:57.840 |
Yeah. Thanks, Tanish. So we definitely now are at the point where we can show you the 01:25:09.040 |
unit next time. And so I think we're, unless any of us come up with interesting new insights 01:25:18.060 |
on the unconditional diffusion sampling, training and sampling process, we might be putting 01:25:28.160 |
that aside for a while, and instead we're going to be looking at creating a good quality 01:25:34.880 |
unit from scratch. And we're going to look at a different data set to do that as we start 01:25:46.600 |
into scale things up a bit, as Jono mentioned in the last lesson. So we're going to be using 01:25:50.720 |
a 64 by 64 pixel image net subset called tiny image net. So we'll start looking at some 01:25:59.120 |
three channel images. So I'm sure we're all sick of looking at black and white shoes. 01:26:05.800 |
So now we get to look at shift dwellings and trolley buses and koala bears and yeah, 200 01:26:17.980 |
different things. So that'll be nice. Yeah. All right. Well, thank you, Jono. Thank you, 01:26:24.800 |
Tanish. That was fun as always. And yeah, next time we'll be lesson 22. Bye. 01:26:32.160 |
Oh, listen to me. Hey, this was lesson 22. Oh, no way. Okay. You're right. See ya. Bye.