back to index

Lesson 22: Deep Learning Foundations to Stable Diffusion


Chapters

0:0 Intro
0:30 Cosine Schedule (22_cosine)
6:5 Sampling
9:37 Summary / Notation
10:42 Predicting the noise level of noisy Fashion MNIST images
12:57 Why .logit() when predicting alpha bar t
14:50 Random baseline
16:40 mse_loss why .flatten()
17:30 Model & results
19:3 Why are we trying to predict the noise level?
20:10 Training diffusion without t - first attempt
22:58 Why it isn’t working?
27:2 Debugging (summary)
29:29 Bug in ddpm - paper that cast some light on the issue
38:40 Karras (Elucidating the Design Space of Diffusion - Based Generative Models)
49:47 Picture of target images
52:48 Scaling problem - (scalings)
59:42 Training and predictions of modified model
63:49 Sampling
66:5 Sampling: Problems of composition
67:40 Sampling: Rationale for rho selection
69:40 Sampling: Denosing
75:26 Sampling: Heun’s method fid: 0.972
79:0 Sampling: LMS sampler
80:0 Kerras Summary
83:0 Comparison of different approaches
85:0 Next lessons

Whisper Transcript | Transcript Only Page

00:00:00.000 | All right. Hi, gang. And here we are in Lesson 21, joined by the legends themselves, Johnno
00:00:09.200 | and Tanishk. Hello.
00:00:11.720 | Hello.
00:00:12.940 | And today, you'll be shocked to hear that we are going to look at a Jupiter notebook.
00:00:25.160 | We're going to look at notebook 22. This is pretty quick. Just, you know, improvement,
00:00:37.520 | pretty simple improvement to our DDPM/DDIM implementation for fashion MNIST. And this
00:00:53.000 | is all the same so far, but what I've done is I've made one quite significant change.
00:01:03.120 | And some of the changes we'll be making today are all about making life simpler. And they're
00:01:08.320 | kind of reflecting the way the papers have been taking things. And it's interesting to
00:01:14.000 | see how the papers have not only made things better, they made things simpler. And so one
00:01:21.880 | of the things that I've noticed in recent papers is that there's no longer a concept
00:01:29.200 | of n steps, which is something we've always had before, and it always bothered me a bit,
00:01:36.000 | this capital T thing. You know, this T over T, it's basically saying this is time step
00:01:44.480 | number, say 500 out of 1000, so it's time step 0.5. Why not just call it 0.5? And the
00:01:55.920 | answer is, well, we can. So we talked last time about the cosine scheduler. We didn't
00:02:01.400 | end up using it because I came up with an idea which was, you know, simpler and nearly
00:02:08.760 | the same, which is just to change our beta max. But in this next notebook, I decided let's
00:02:14.880 | use the cosine scheduler, but let's try to get rid of the n steps thing and the capital
00:02:21.560 | T thing. So here is A bar again. And now I've got rid of the capital T. So now I'm going
00:02:30.120 | to assume that your time step is between 0 and 1, and it basically represents what percentage
00:02:37.320 | of the way through the diffusion process are you? So 0 would be all noise, and 1 would be
00:02:42.840 | - well, no, sorry, the other way around - 0 would be all clean, and 1 would be all noise.
00:02:47.400 | So how far through the forward diffusion process? So other than that, this is exactly the same
00:02:53.540 | equation we've already seen. And I realized something else, which is kind of fun, which
00:02:56.800 | is you can take the inverse of that. So you can calculate T. So we would basically first
00:03:10.060 | take the square root, and we would then take the inverse cos, and we would then divide
00:03:21.680 | by 2 over pi, or times pi over 2. So we can both - so it's interesting now we don't - the
00:03:30.840 | alpha bar is not something we look up in a list, it's something we calculate with a function
00:03:37.160 | from a float. And so yeah, interestingly, that means we can also calculate T from an
00:03:42.280 | alpha bar. So Noisify has changed a little. So now when we get the alpha bar through our
00:03:52.120 | time step, we don't look it up, we just call the function. And now the time step is a random
00:04:00.920 | float between 0 and 1, actually between 0 and 1.999, which actually I'm sure there's a function
00:04:09.920 | I could have chosen to do a float in this range, but I just tapped it because I was lazy. Couldn't
00:04:14.960 | be bothered hooking it up. Other than that, Noisify is exactly the same. So we're still
00:04:22.440 | returning the xt, the time step, which is now a float, and the noise. That's the thing
00:04:30.960 | we're going to try and predict, dependent variable, this tuple there as our inputs to
00:04:35.520 | the model. All right, so here is what that looks like. So now when we look at our input
00:04:44.600 | to our unit training process, you can see we've got a T of 0.05, so 5% of the way through
00:04:51.840 | the forward diffusion process, it looks like this, and 65% through it looks like this.
00:04:58.000 | So now the time step and basically the process is more of a kind of a continuous time step
00:05:08.200 | and a continuous process. Rather before we were having these discrete time steps here,
00:05:13.280 | we get just any random value that could be between 0 and 1. I think, yeah, that's also
00:05:18.720 | something.
00:05:19.720 | Yeah, we should kind of get this more convenient, you know, to have...
00:05:22.240 | Yeah, it is convenient.
00:05:23.240 | To have a function to call. Yeah, I find this life a little bit easier. So the model's the
00:05:31.640 | same, the callbacks are the same, the fitting process is the same. And so something which
00:05:39.400 | is kind of fun is that we could now, we can, when we do now, create a little denoise function.
00:05:46.680 | So we can take, you know, this batch of data that we generated, the noisified data, so
00:05:54.160 | here it is again, and we can denoise it. So we know the T for each element, obviously.
00:06:03.880 | So remember T is different for each element now. And we can therefore calculate the alpha
00:06:11.720 | bar for each element. And then we can just undo the noisification to get the denoised
00:06:19.280 | version. And so if we do that, there's what we get. And so this is great, right? It shows
00:06:25.360 | you what actually happens when we run a single step of the model on variatingly, partially
00:06:33.960 | noised images. And this is something you don't see very often because I guess not many people
00:06:39.120 | are working in these kind of interactive notebook environments where it's really easy to do
00:06:42.640 | this kind of thing. But I think this is really helpful to get a sense of like, okay, if you're
00:06:46.600 | 25% of the way through the forward diffusion process, this is what it looks like when you
00:06:51.800 | undo that. If you're 95% of the way through it, this is what happens when you undo that.
00:06:59.800 | So you can see here, it's basically like, oh, I don't really know what the hell's going
00:07:03.160 | on, so at least a noisy mess. Yeah, I guess my feeling from looking at this is, I'm impressed,
00:07:15.720 | you know, like this 45% noise thing, it looks all noise to me. It's found the long-sleeved
00:07:24.400 | top. And yeah, it's actually pretty close to the real one. I looked it up, or you might
00:07:31.400 | see it later, it's a little bit more of a pattern here, but it even gives a sense of
00:07:35.320 | the pattern. So it shows you how impressive this is. So this is 35%. You can kind of see
00:07:41.040 | there's a shoe there, but it's really picked up the shoe nicely. So these are very impressive
00:07:47.040 | models in one step, in my opinion. So, okay, so sampling is basically the same, except
00:08:01.640 | now rather than starting with using the range function to create our timesteps, we use
00:08:09.360 | lin space to create our timesteps. So our timesteps start at, you know, if we did 1000, it would
00:08:17.560 | be 0.999, and they end at 0, and then they're just linearly spaced with this number of steps.
00:08:24.520 | So other than that, you know, a bar we now calculate, and the next a bar is going to
00:08:33.160 | be whatever the current step is, minus 1 over steps. So if you're doing 100 steps, then
00:08:40.520 | you'd be minus 0.01. So this is just stepping through linearly. And yeah, that's actually
00:08:51.680 | it for changes. So if we just do DDIM for 100 steps, you know, that works really well.
00:09:04.000 | We get a fit of 3, which is actually quite a bit better than we had on 100 steps for
00:09:13.160 | our previous DDIM. So this definitely seems like a good sampling, sampling, sampling approach.
00:09:21.280 | And I know Jono's going to talk a bit more shortly about, you know, some of the things
00:09:26.760 | that can make better sampling approaches, but yeah, definitely we can see it making
00:09:31.960 | a difference here. Did you guys have anything you wanted to say about this before we move
00:09:37.600 | No, but it is a nice transition towards some of the other things we'll be looking at to
00:09:42.880 | start thinking about how do we frame this. And it's also good, like the idea, so the
00:09:48.280 | original DDPM paper has this 1,000 time steps, and a lot of people followed that. But the
00:09:53.680 | idea that you don't have to be bound to that, and maybe it is worth breaking that convention.
00:09:57.520 | I know Tanish made that meme about, you know, this 15 competing different standards connotation.
00:10:03.160 | But yeah, sometimes it's helpful to reframe it, okay, time goes from 0 to 1. That can
00:10:07.400 | simplify some things. It complicates others, but yeah, it's nice to think how you can reframe
00:10:12.480 | stuff sometimes.
00:10:13.480 | In fact, where we will head today by the time we get to notebook 23, we will see, you know,
00:10:21.560 | even simpler notation. And yeah, simpler notation generally comes. I think what happens is over
00:10:27.600 | time people understand better what's the essence of the problem and the approach, and then
00:10:35.040 | that gets reflected in the notation.
00:10:40.680 | So okay, so the next one I wanted to share is something which is an idea we've been working
00:10:47.120 | on for a while, and it's some new research. So partly, I guess this is an interesting
00:10:52.880 | like insight into how we do research. So this is 22 noise pred. And the basic idea of this
00:11:02.440 | was, well, actually, I've got to take you through it to see what the basic idea is.
00:11:06.480 | So what I'm going to do is I'm going to create, okay, so fashion MNIST as before, but I'm
00:11:13.120 | going to create a different kind of model. I'm not going to create a model that predicts
00:11:17.480 | the noise given the noised image in t. Instead, I'm going to try to create a model which predicts
00:11:26.600 | t given the noised image. So why did I want to do that? Well, partly, well, entirely because
00:11:34.000 | I was curious. I felt like when I looked at something like this, I thought it was pretty
00:11:41.380 | obvious roughly how much noise each image had. And so I thought, why are we passing
00:11:50.280 | noise when we call the model? Why are we passing in the noised image and the amount of noise
00:11:55.640 | or the t? Given that I would have thought the model could figure out how much noise
00:11:59.560 | there is. So I wanted to check my contention, which is that the model could figure out how
00:12:04.520 | much noise there is. So I thought, okay, well, let's create a model that would try and figure
00:12:08.200 | out how much noise there is. So I created a different noisify now, and this noisify grabs
00:12:17.160 | an alpha bar t randomly. And it's just a random number between 0 and 1. You don't want 1 per
00:12:32.600 | item in the batch. And so then after just randomly grabbing an alpha bar t, we then
00:12:38.800 | noisify in the usual way. But now our independent variable is the noised image and the dependent
00:12:45.080 | variable is alpha bar t. And so we've got to try to create a model that can predict
00:12:48.920 | alpha bar t given a noised image. Okay, so everything else is the same as usual. And
00:12:57.280 | so we can see an example. You've got alpha bar t.squeeze.logit. Oh, yeah, that's true.
00:13:04.320 | So the alpha bar t goes between 0 and 1. So we've got a choice. Like, I mean, we don't
00:13:12.080 | have to do anything. But normally, if you've got something between 0 and 1, you might consider
00:13:16.400 | putting a sigmoid at the end of your model. But I felt like the difference between 0.999
00:13:24.080 | and 0.99 is very significant. So if we do log it, then we don't need the sigmoid at the
00:13:33.440 | end anymore. It will naturally cover the full range of kind of-- it ought to be centered
00:13:39.280 | at 0. It would have covered all the normal kind of range of numbers. And it also will
00:13:44.060 | treat equal ratios as equally important at both ends of the spectrum. So that was my
00:13:51.800 | hypothesis was that using logit would be better. I did test it and it was actually very dramatically
00:13:56.880 | better. So without this logit here, my model didn't work well at all. And so this is like
00:14:01.760 | an example of where thinking about these details is really important. Because if I hadn't
00:14:07.040 | have done this, then I would have come away from this bit of research thinking like, oh,
00:14:10.720 | I was wrong. We can't predict noise amount. Yeah. So thanks for pointing that out, China.
00:14:17.280 | Yeah. So that's why in this example of a mini batch, you can see that the numbers can be
00:14:23.280 | negative or positive. So 0 would represent alpha bar of 0.5. So here, 3.05 is not very
00:14:31.120 | noise at all, or else negative 1 is pretty noisy. So the idea is that, yeah, given this
00:14:40.400 | image, you would have to try to predict 3.05. So one thing I was kind of curious about is
00:14:49.680 | like, it's always useful to know is like, what's the baseline? Like, what counts as good? Because
00:14:54.120 | often people will say to me like, oh, I created a model and the MSE was 2.6. And I'll be like,
00:14:59.760 | well, is that good? Well, it's the best I can do, but is it good? Like, or is it better
00:15:05.440 | than random or is it better than predicting the average? So in this case, I was just like,
00:15:12.080 | okay, well, what if we just predicted? Actually, this is slightly out of date. I should have
00:15:16.440 | said 0 here rather than 0.5, but never mind close enough. So this is before I did the
00:15:24.040 | logit thing. So I basically was looking at like, what's the loss if you just always predicted
00:15:34.760 | a constant, which as I said, I should have put 0 here, haven't updated it. And so it's
00:15:41.760 | like, oh, that would give you a loss of 3.5. Or another way to do it is you could just
00:15:49.560 | put MSE here and then look at the MSE loss between 0.5 and your various, just a single
00:16:00.320 | mini batch, mini batch of alphabets, logits. Yeah, so we wanted to get some, if we're getting
00:16:10.880 | something that's about 3, then we basically haven't done any better than random. And so
00:16:16.760 | in this case, this model, it doesn't actually have anything to learn. It always returns
00:16:23.260 | the same thing. So we can just call fit with trade equals false just to find the loss.
00:16:28.480 | So this is just a couple of ways of getting quickly finding a loss for a baseline naive
00:16:34.400 | model. One thing that thankfully PyTorch will warn you about is if you try to use MSE and
00:16:48.800 | your inputs and targets have different shapes, it will broadcast and give you probably not
00:16:53.720 | the result you would expect, and it will give you a warning. So one way to avoid that is
00:16:58.480 | just to use dot flatten on each. So this kind of flattened MSE is useful to avoid the warning
00:17:09.200 | and also avoid getting weird errors or weird results. So we use that for our loss. So the
00:17:15.560 | model's the model that we always use. So it's kind of nice. We just use our same old model.
00:17:22.320 | Same changes. Even though we're doing something totally different. Oh, well, okay, that's
00:17:31.360 | not quite true. One difference is that our output, we just have one output now, because
00:17:35.600 | this is now a regression model. It's just trying to predict a single number. And so
00:17:39.880 | our learner now uses MSE as a loss. Everything else is the same as usual. So we could go
00:17:48.240 | ahead and trade it. And you can see, okay, the loss is already much better than 3, so
00:17:51.680 | we're definitely learning something. And we end up with a 0.075 mean squared error. That's
00:18:01.080 | pretty good considering, you know, there's a pretty wide range of numbers we're trying
00:18:05.800 | to predict here. So I've got to save that as noise prediction on sigma. So save that
00:18:14.920 | model. And so we can take a look at how it's doing by grabbing our one batch of noise images,
00:18:25.120 | putting it through our T model. Actually, it's really an alpha bar model, but never
00:18:29.960 | mind, call it a T model. And then we can take a look to see what it's predicted for each
00:18:35.120 | one. And we can compare it to the actual for each one. And so you can see here it said,
00:18:42.240 | oh, I think this is about 0.91. And actually, it is 0.91. I said, oh, here it looks like
00:18:47.760 | about 0.36. And yeah, it is actually 0.36. So, you know, you can see overall 0.72. It's
00:18:54.120 | actually 0.72. Well, those are exactly right. This one's 0.02 off. But yeah, my hypothesis
00:18:59.760 | was correct, which is that we, you know, we can predict the thing that we were putting
00:19:06.480 | in manually as input. So there's a couple of reasons I was interested in checking this
00:19:12.820 | out. You know, the first was just like, well, yeah, wouldn't it be simpler if we weren't
00:19:19.040 | passing in the T each time? You know, why not pass in the T each time? But it also felt
00:19:24.600 | like it would open up a wider range of kind of how we can do sampling. The idea of doing
00:19:31.700 | sampling by like precisely controlling the amount of noise that you try to remove each
00:19:37.720 | time and then assuming you can remove exactly that amount of noise each time feels limited
00:19:45.000 | to me. So I want to try to remove this constraint. So having, yeah, built this model, I thought,
00:19:56.560 | okay, well, you know, which is basically like, okay, I think we don't need to pass T in.
00:20:01.180 | Let's try it. So what I then did is I replicated the 22 cosine notebook. I just copied it,
00:20:06.720 | pasted it in here. But I made a couple of changes. The first is that Noisify doesn't
00:20:16.440 | return T anymore. So there's no way to cheat. We don't know what T is. And so that means
00:20:25.440 | that the unit now doesn't have T, so it's actually going to pass zero every time. So
00:20:33.400 | it has no ability to learn from T because it doesn't get T. So it doesn't really matter
00:20:38.840 | what we pass in. We could have changed the unit to like remove the conditioning on T.
00:20:46.260 | But for research, this is just as good, you know, for finding out. And it's good to be
00:20:51.600 | lazy when doing research. There's no point doing something a fancy way where you can
00:20:55.080 | do it a quick and easy way before you even know if it's going to work. So yeah, that's
00:21:00.700 | the only change. So we can then train the model and we can check the loss. So the loss
00:21:06.880 | here is 0.034. And previously it was 0.033. So interestingly, you know, maybe it's a tiny
00:21:21.080 | bit worse at that, you know, but it's very close. Okay, so we'll save that model. And
00:21:31.600 | then for sampling, I've got exactly the same DDIM step as usual. And my sampling is exactly
00:21:43.360 | the same as usual, except now, when I call the model, I have no T to pass in. So we just
00:21:53.720 | pass in this. I mean, I still know T because I'm still using the usual sampling approach,
00:22:02.560 | but I'm not passing it to the model. And yeah, we can sample. And what happens is actually
00:22:11.360 | pretty garbage. 22 is our fit. And as you can see here, you know, some of the images
00:22:24.720 | are still really noisy. So I totally failed. And so that's always a little discouraging
00:22:34.720 | when you think something's going to work and it doesn't. But my reaction to that is like,
00:22:39.280 | if I think something's going to work and it doesn't is to think, well, I'm just going
00:22:43.040 | to have to do a better job of it. You know, it ought to work. So I tried something different,
00:22:51.840 | which is I thought like, okay, since we're not passing in the T, then we're basically
00:23:02.700 | saying like, how much noise should you be removing? It doesn't know exactly. So it might
00:23:08.320 | remove a little bit more noise that we want or a little bit less noise than we want. And
00:23:13.000 | we know from the testing we did that sometimes it's out by like, in this case, 0.02. And
00:23:23.000 | I guess if you're out consistently, sometimes it's got to end up not removing all the noise.
00:23:29.860 | So the change I made was to the DDAM step, which is here. And let me just copy this and
00:23:39.880 | get rid of the commented out sections just to make it a bit easier to read. Okay. So
00:23:50.320 | the DDAM step, this is the normal DDAM step. Okay. And so step one is the same. So don't
00:23:56.320 | worry about that because it's the same as we've seen before. But what I did was I actually
00:24:02.240 | used my T model. So I passed the noised image into my T model, which is actually an alpha
00:24:08.900 | bar model, to get the predicted alpha bar. And this is remember the predicted alpha bar
00:24:15.640 | for each image, because we know from here that sometimes, so sometimes it did a pretty
00:24:20.860 | good job, right? But sometimes it didn't. So I felt like, okay, we need a predicted alpha
00:24:26.480 | bar for each image. What I then discovered is sometimes that could be really too low.
00:24:39.100 | So what I wanted to make sure is it wasn't too crazy. So I then found the median for
00:24:43.760 | a mini batch of all the predicted alpha bars, and I clamped it to not be too far away from
00:24:49.360 | the median. And so then what I did when I did my X naught hat is rather than using alpha
00:24:58.400 | bar T, I used the estimated alpha bar T for each image, clamped to be not too far away
00:25:06.720 | from the median. And so this way it was updating it based on the amount of noise that actually
00:25:12.520 | seems to be left behind, rather than the assumed amount of noise that should be left behind
00:25:19.240 | you know, if we assume it's removed the correct amount. And then everything else is the same.
00:25:26.800 | So when I did that, say, whoa, made all the difference. And here it is. They are beautiful
00:25:38.360 | pieces of clothing. So 3.88 versus 3.2. That's possibly close enough, like I'd have to run
00:25:53.760 | it a few times, you know, my guess is maybe it's a tiny bit worse, but it's pretty close.
00:26:00.360 | But like this definitely gives me some encouragement that, you know, even though this is like something
00:26:10.400 | I just did in a couple of days, where else the kind of the with T approaches have been
00:26:13.980 | developed since 2015, and we're now in 2023. You know, I would expect it's quite likely
00:26:21.280 | that these kind of like, no, no T approaches could eventually surpass the T based approaches.
00:26:35.080 | And like one thing that definitely makes me think there's room to improve is if I plot
00:26:39.880 | the fit or the kid, or each sample during the reverse diffusion process, it actually
00:26:45.080 | gets worse for a while. I'm like, okay, well, that's, that's a bad sign. I have no idea
00:26:50.320 | why that's happening. But it's a sign that, you know, if we could improve each step that
00:26:55.600 | one would assume we could get better than 3.8. So yeah, Tanishko, do you have any thoughts
00:27:03.740 | about that, or questions or comments or maybe to just like, to highlight that the research
00:27:12.440 | process a little bit, it wasn't like this linear thing of like, Oh, here's this issue.
00:27:17.280 | Not for me as well as we thought. Oh, here's the fix. We just scrapped this. You know,
00:27:22.300 | this was like multiple days of like, discussing and like Jeremy saying, like, you know, I'm
00:27:26.760 | taking my hair out. Do you guys have any ideas? And Oh, what about this? And Oh, and I just
00:27:30.360 | in the team paper, they do this clamping, maybe that'll help. You know, so there's
00:27:33.240 | a lot of back and forth. And also a lot of like, you saw that code that was commented
00:27:36.640 | out there, prints, xt.min, xt.max, alphabar, pred, you know, just like seeing, oh, okay,
00:27:43.760 | you know, my average prediction is about what I would expect. But sometimes the middle of
00:27:46.520 | the max goes, you know, 2, 3, 8, 16, 150, 212 million, infinity, you know, maybe like
00:27:54.640 | one or two little values that would just skyrocket out. Yeah, and so that kind of like, debugging
00:28:00.160 | and exploring and printing things out. And actually, our initial discussions about this
00:28:05.800 | idea, I kind of said to you guys, before lesson one of part two, I said, like, it feels to
00:28:12.600 | me like we shouldn't need the t thing. And so it's actually been like, mumbling away
00:28:17.520 | in the background for the months. Yeah, yeah. And I guess I mean, we should also mention
00:28:24.960 | we have tried this, like a friend of ours trained a no T version of stable diffusion
00:28:28.960 | for us. And we did the same sort of thing. I trained a pretty bad T predictor and it
00:28:34.160 | sort of generates samples. So we're not like focusing on that large scale stuff yet. But
00:28:39.840 | it is fun to like, every now and again, got this idea from fashion in this, we are trying
00:28:44.480 | these out on some bigger models and seeing, okay, this does seem like maybe it'll work.
00:28:48.920 | And to down the line that future plan is to say that's actually, you know, spend the time
00:28:52.400 | train a proper model, and see, yeah, see how well that does. If it's interesting, you say
00:28:57.120 | a friend of ours, we can be more specific. It's Robert, one of the two lead authors of
00:29:01.640 | the stable diffusion paper who actually has been fine tuning a real stale stable diffusion
00:29:07.480 | model, which is without T and it's looking super encouraging. So yeah, that'll be fun
00:29:16.600 | to play with with this new, you know, we'll have to train a T predictor for that. See
00:29:22.160 | how it looks. Yeah. All right. So I guess the other area we've been talking about kind
00:29:32.080 | of doing some research on is this weird thing that came up over the last two weeks where
00:29:39.680 | our bug in the DDPM implementation, where we accidentally weren't doing it from minus
00:29:45.600 | one to one for the input range, it turned out that actually being from minus one to
00:29:50.760 | one wasn't a very good idea anyway. And so we ended up centering it as being from minus
00:29:58.460 | point five to point five, and John O and Tanishk have managed to actually find a paper. Well,
00:30:06.800 | I say find a paper, a paper has come out in the last 24 hours, which has coincidentally
00:30:15.160 | cast some light on this and has also cited a paper that we weren't aware of, which was
00:30:21.080 | not released in the last 24 hours. So John O, are you going to tell us a bit about that?
00:30:25.120 | Yeah, sure. I can do that. So it's funny, this was such perfect timing because I actually
00:30:31.600 | got up early this morning planning to run with the different input scalings and the
00:30:36.720 | cosine schedule that Jeremy was showing and some of the other schedulers we look at. I
00:30:40.000 | thought it might be nice for the lesson to have a little plot of like, what is the fit
00:30:44.440 | with these different solvers and input scalings, but it was going to be a lot of work. I'm
00:30:48.080 | not looking forward to doing the groundwork. And then Tanishk sent me this paper, which
00:30:52.120 | AK had just tweeted out because he reviews anything that comes up on archive every day
00:30:56.640 | on the importance of noise scheduling for diffusion models. And this is by a researcher
00:31:01.000 | at the Google Brain team, who's also done a really cool recent paper on something called
00:31:05.560 | a recurrent interface network outside of the scope of this lesson, but also worth checking
00:31:09.680 | out. Yeah, so this paper they're hoping to study this noise scheduling and the strategies
00:31:16.340 | that you take for that. And they want to show that number one, those scheduling is crucial
00:31:20.320 | for performance and the optimal one depends on the tasks. When increasing the image size,
00:31:25.320 | the noise scheduling that you want changes and scaling the input data by some factor
00:31:31.960 | is a good strategy for working with this. And that's the big thing we've been talking
00:31:35.800 | about, right? Yeah, that's what we've been doing where we said, oh, do we scale from
00:31:39.320 | minus 0.5 to 0.5 or minus 1 to 1 or do we normalize? And so they demonstrate the effectiveness
00:31:45.160 | by training a really good high resolution model on image met, so class condition model.
00:31:51.320 | That's correct. Yeah, amazing samples. They'll show one later. So I really like this paper.
00:31:56.760 | It's very short and concise, and it just gets all the information across. And so they introduced
00:32:01.960 | us here. We have this noising process on noiseifier function where we have square root of something
00:32:07.600 | times x plus square root of 1 minus that something times the noise. And here they use gamma,
00:32:13.800 | gamma of t, which is often used for the continuous time case. So instead of the alpha bar and
00:32:18.400 | the beta bar scheduled for 1,000 times tapes, there'll be some function gamma of t that
00:32:22.720 | tells you what your alpha bar should be. Okay, so that's our function is actually called
00:32:27.400 | a bar, but it's the same thing. Yeah, same thing. It takes in a time set from 0 to 1,
00:32:32.720 | and then that's used to noise the image. Interestingly, what they're showing here actually
00:32:37.040 | is something that we had discovered, and I've been complaining about that my DTAMs with an
00:32:44.200 | eater of less than one weren't working, which is to say when I added extra noise to the
00:32:49.200 | image, it wasn't working. And what they're showing here is like, oh yeah, duh, if you
00:32:55.560 | use a smaller image, then adding extra noise is probably not a good idea.
00:33:01.240 | Yeah. And so they use a lot of reference in this paper to like information being destroyed
00:33:07.520 | and signal to noise ratios. And that's really helpful for thinking about because it's not
00:33:11.920 | something that's obvious, but at 64 by 64 pixels, adjacent pixels might have much less
00:33:17.600 | in common versus the same amount of noise added at a much higher resolution, the noise
00:33:22.800 | kind of averages out and you can still see a lot of the image. So yeah, that's one thing
00:33:26.800 | they highlight is that the same noise level for different image sizes, it might be a harder
00:33:32.200 | or easier task. And so they investigate some strategies for this. They look at the different
00:33:37.600 | noise schedule functions. So we've seen the original version from the DTDM paper. We've
00:33:44.080 | seen the cosine schedule and we've seen, I think we might look at, or the next thing
00:33:50.600 | that Jamie's going to show us, a sigmoid based schedule. So they show the continuous
00:33:55.680 | time versions of that and they plot how you can change various parameters to get these
00:33:59.660 | different gamma functions or in our case, the alpha bar, where we starting at all image,
00:34:08.560 | no noise at t equals zero, moving to all noise, no image at t equals one. But the path that
00:34:14.280 | you take is going to be different for these different classes of functions and parameters.
00:34:19.760 | And the signal to noise ratio, that's what this or the log signal to noise ratio is going
00:34:24.980 | to change over that time as well. And so that's one of the knobs we can tweak. We're saying
00:34:30.040 | our diffusion model isn't training that well, we think it might be related to the noise
00:34:33.780 | schedule and so on. One of the things you could do is try different noise schedules,
00:34:37.080 | either changing the parameters in one class of noise schedule or switching from a linear
00:34:42.240 | to a cosine to a sigmoid. And then the second strategy is kind of what we were doing in
00:34:47.640 | those experiments, which is just to add some scaling factor to exit error.
00:34:51.600 | Well, we were accidentally using b of 0.5.
00:34:59.040 | Exactly. And so that's the second dial that you can tweak is to say keeping your noise
00:35:03.960 | schedule fixed, maybe you just scale x zero, which is going to change the ratio of signal
00:35:08.680 | to noise.
00:35:09.680 | And that's what I think there's four in C there is what we were accidentally doing.
00:35:14.600 | Yes. Yeah, exactly. And so see if we can get to Oh, yeah. So that again, changes the signal
00:35:23.560 | to noise for different scalings you get. And so that's fine. So they have a compound, they
00:35:28.840 | have a strategy that combines some of those things. And this is the important part, they
00:35:32.320 | do their experiments. And so they have a nice table of investigating different schedules,
00:35:39.520 | cosine schedules and sigmoid schedules. And in bold are the best results. And you can
00:35:43.640 | see for 64 by 64 images versus 128 versus 256, the best schedule is not necessarily always
00:35:50.200 | the same. And so that's like important finding number one, depending on what your data looks
00:35:56.640 | like using a different noise schedule might be optimal. There's no one true best schedule.
00:36:01.840 | There's no one bad value of, you know, beta min and beta max, that's just magically the
00:36:06.280 | best. Likewise, for this input scaling at different sizes, with whatever schedules they tested,
00:36:16.520 | and different values were kind of optimal. And so, yeah, it's just a really great illustration,
00:36:25.000 | I guess that this is another design choice that's implicit or explicitly part of your
00:36:30.720 | diffusion model training and sampling is how are you dealing with this, this noise schedule,
00:36:34.960 | what schedule are you following, what scaling are you doing with your inputs. And by using
00:36:39.480 | this thinking and doing these experiments, and they come up with a kind of rule of thumb
00:36:43.800 | for how to scale the image based on image size, they show that they can, as they increase
00:36:49.400 | the resolution, they can still maintain really good performance. Where previously it was
00:36:53.600 | quite hard to train a really large resolution pixel space model, and they're able to do
00:37:00.280 | that, they get some advantage from their fancy recurrent interface network, but still, it's
00:37:06.080 | kind of cool that they can say, look, we get state of the art, high quality, 512 by 512
00:37:13.360 | or 1024 by 1024 samples on class-conditioned image net. And using this approach to really
00:37:20.760 | like consider how well do you train, how many steps do we need to take, one of the other
00:37:24.720 | things in this table is that they compare to previous approaches. Oh, we used, you know,
00:37:29.440 | a third of the training steps and for the same other settings, and we get better performance.
00:37:35.420 | Just because we've chosen that input scaling better. And yeah, so that's the paper, really,
00:37:41.080 | really nice, great work to the team. And that was very useful.
00:37:45.440 | I love that you got up in the morning and thought, oh, it's going to be a hassle training
00:37:53.720 | all these different models I need to train for different input scalings and different
00:37:58.600 | sampling approaches. I just look at Twitter first, and then you looked at Twitter and
00:38:04.120 | there was a paper saying like, hey, we just did a bunch of experiments for different noise
00:38:09.080 | schedules and input scaling. Yeah, does your life always work that way each other? It seems
00:38:14.760 | quite blessed.
00:38:15.760 | Yeah, it's very lucky like that. Yeah. You wait long enough, someone else will do it.
00:38:22.600 | That's why it shows that the time when the UK starts posting on Twitter, it's like my
00:38:26.460 | favourite hour of the day for all the papers to be posted.
00:38:32.040 | Oh, well, thank you for that. So let me switch to notebook 23. Because this notebook is actually
00:38:54.200 | an implementation of some ideas from this paper that everybody tends to just call it
00:39:01.120 | Keras because there's other people. But I will do it anyway, Keras paper. And the reason
00:39:11.560 | we're going to look at this is because in this paper, the authors actually take a much
00:39:22.840 | more explicit look at the question of input scaling. Their approach was not apparently
00:39:32.000 | to accidentally put a bug in their code, and then take it out, find it worked worse, and
00:39:36.800 | then just put it back in again. Their approach was actually to think, how should things be?
00:39:42.920 | So that's an interesting approach to doing things, and I guess it works for them. So
00:39:46.480 | that's fine.
00:39:47.480 | I think our approach is more infighting.
00:39:49.920 | Yeah, exactly. Our approach is much more fun because you never quite know what's going
00:39:54.000 | to happen. And so, yeah, in their approach, they actually tried to say, like, OK, given
00:40:01.960 | all the things that are coming into our model, how can we have them all nicely balanced?
00:40:09.520 | So we will skip back and forth between the notebook and the paper. So the start of this
00:40:16.760 | is all the same, except now we are actually going to do it minus one to one because we're
00:40:23.080 | not going to rely on accidental bugs anymore, but instead we're going to rely on the Keras
00:40:26.520 | papers carefully designed scaling.
00:40:33.080 | I say that, except that I put a bug in this notebook as well. One of the things that's
00:40:43.200 | in the Keras paper is what is the standard deviation of the actual data, which I calculated
00:40:50.760 | for a batch. However, this used to say minus 0.5. I used to do the minus 0.5 to 0.5 thing.
00:40:58.720 | And so this is actually the standard deviation of the data before I, when it was still minus
00:41:03.480 | 0.5. So this is actually half the real standard deviation. For reasons I don't yet understand,
00:41:10.720 | this is giving me better scaled results. So this actually should be 0.66. So there's still
00:41:18.040 | a bug here and the bug still seems to work better. So we still got some mysteries involved.
00:41:22.720 | So we're going to leave this. So it's actually, it's actually not 0.33, it's actually 0.66.
00:41:27.640 | Okay, so the basic idea of this, actually I'll come back. Well, let me have a little
00:41:40.400 | think. Yeah, okay. Now we'll start here. The basic idea of this paper is to say, you know
00:41:49.600 | what, sometimes maybe predicting the noise is a bad idea. And so like you can either
00:42:01.880 | try and predict the noise or you can try and predict the clean image and each of those
00:42:08.120 | can be a better idea in different situations. If you're given something which is nearly
00:42:12.520 | pure noise, you know, the model's given something which is nearly pure noise and is then asked
00:42:18.160 | to predict the noise. That's basically a waste of time, because the whole thing's noise.
00:42:26.140 | If you do the opposite, which is you try to get it predict the clean image. Well, then
00:42:30.640 | if you give it a clean image that's nearly clean and try to predict the clean image,
00:42:33.680 | that's nearly a waste of time as well. So you want something which is like, regardless
00:42:38.440 | of how noisy the image is, you want it to be kind of like an equally difficult problem
00:42:42.640 | to solve. And so what Keras do is they, they basically use this new thing called CSKIP,
00:42:57.440 | which is a number, which is basically saying like, you know what we should do for the training
00:43:04.040 | target is not just predict the noise all the time, not just predict the clean image all
00:43:10.000 | the time, but predict kind of a looped version of one or the other depending on how noisy
00:43:16.840 | it is. So here y is the clean image and n is the noise. So y plus n is the noised image.
00:43:35.000 | And so if CSKIP was 0, then we would be predicting the clean image. And if CSKIP was 1, we would
00:43:44.800 | be predicting y minus y, we would be predicting the noise. And so you can decide by picking
00:43:53.080 | a different CSKIP whether you're predicting the clean image or the noise. And so, as you
00:43:58.440 | can see from the way they've written it, they make this a function. They make it a function
00:44:02.080 | of sigma. Now, this is where we got to a point now where we've kind of got a fairly, a much
00:44:08.760 | simpler notation. There's no more alpha bars, no more alphas, no more betas, no more beta
00:44:13.600 | bars. There's just a single thing called sigma. Unfortunately, sigma is the same thing as
00:44:19.980 | alpha bar used to be, right? So we've simplified it, but we've also made things more confusing
00:44:24.920 | by using an existing symbol for something totally different. So this is alpha bar. Okay.
00:44:30.680 | So there's going to be a function that says, depending on how much noise there is, we'll
00:44:37.280 | either predict the noise or we'll predict the clean image or we'll predict something
00:44:42.720 | between the two. So in the paper, they showed this chart where they basically said like,
00:44:53.960 | okay, let's look at the loss to see how good are we with a trained model at predicting
00:45:02.840 | when sigma is really low. So when there's very small alpha bar, or when sigma is in
00:45:12.040 | the middle or when sigma is really high. And they basically said, you know what, when it's
00:45:17.520 | nearly all noise or nearly no noise, you know, we're basically not able to do anything at
00:45:24.240 | all. You know, we're basically good at doing things when there's a medium amount of noise.
00:45:32.640 | So when deciding, okay, what, what sigmas are we going to send to this thing? The first
00:45:38.040 | thing we need to do is to, is to figure out some sigmas. And they said, okay, well, let's
00:45:44.440 | pick a distribution of sigmas that matches this red curve here, as you can see. And so
00:45:51.840 | this is a normally distributed curve where this is on a log scale. So this is actually
00:46:01.000 | a log normal curve. So to get the sigmas that they're going to use, they picked a normally
00:46:06.980 | distributed random number and then they expect it. And this is called a log normal distribution.
00:46:14.400 | And so they used a mean of minus 1.2 and a standard deviation of 1.2. So that means that
00:46:24.120 | about one third of the time, they're going to be getting a number that's bigger than
00:46:31.080 | zero here. And e to the zero is one. So about one third of the time, they're going to be
00:46:38.620 | picking sigmas that are bigger than one. And so here's a histogram I drew of the sigmas
00:46:47.640 | that we're going to be using. And so it's nearly always less than five, but sometimes
00:46:55.960 | it's way out here. And so it's quite hard to read these histograms. So this really nice
00:47:01.440 | library called Seaborn, which is built on top of Matplotlib, has some more sophisticated
00:47:07.560 | and often nicer looking plots. And one of them they have is called a KDE plot, which
00:47:12.200 | is a kernel density plot. It's a histogram, but it's smooth. And so I clipped it at 10
00:47:20.040 | so that you could see it better. So you can basically see that the vast majority of the
00:47:23.680 | time it's going to be somewhere about 0.4 or 0.5, but sometimes it's going to be really
00:47:29.560 | big. So our Noisify is going to pick a sigma using that log-normal distribution. And then
00:47:42.840 | we're going to get the noise as usual, but now we're going to calculate C skip, right?
00:47:51.120 | Because we're going to do that thing we just saw. We're going to find something between
00:47:56.520 | the plain image and the noised input. So what do we use for C skip? We calculate it here.
00:48:10.560 | And so what we do is we say what's the total amount of variance at some level of sigma?
00:48:17.640 | Well it's going to be sigma squared, that's the definition of the variance of the noise,
00:48:23.000 | but we also have the sigma of the data itself, right? So if we add those two together we'll
00:48:28.880 | get the total variance. And so what the Keras paper said to do is to do the variance of
00:48:39.600 | the data divided by the total variance and use that for C skip. So that means that if
00:48:48.520 | your total variance is really big, so in other words it's got a lot of noise, then C skip's
00:48:54.300 | going to be really small. So if you've got a lot of noise then this bit here will be
00:49:00.000 | really small. So that means if there's a lot of noise try to predict the original image,
00:49:06.360 | right? That makes sense because predicting the noise would be too easy. If there's hardly
00:49:10.900 | any noise then this will be, total variance will be really small, right? So C skip will
00:49:17.520 | be really big and so if there's hardly any noise then try to predict the noise. And so
00:49:25.040 | that's basically what this C skip does. So it's a kind of slightly weird idea is that
00:49:31.760 | our target, the thing we're trying to do actually is not the input image, sorry the original
00:49:38.960 | image, it's not the noise but it's somewhere between the two. And I've found the easiest
00:49:43.080 | way to understand that is to draw a picture of it. So here is some examples of noised
00:49:49.560 | input, right? With various sigma's, remember sigma is alpha bar, right? So here's an example
00:49:58.480 | with very little noise, 0.06. And so in this case the target is predict the noise, right?
00:50:07.560 | So that's the hard thing to do is predict the noise. Whereas here's an example, 4.53
00:50:14.400 | which is nearly all noise. So for nearly all noise the target is predict the image, right?
00:50:22.560 | And then for something which is a little bit between the two like here, 0.64, the target
00:50:28.240 | is predict some of the noise and some of the image. So that's the idea of Paris. And so
00:50:38.080 | what this does is it's making the problem to be solved by the unit equally difficult
00:50:47.400 | regardless of what sigma is. It doesn't solve our input scaling problem, it solves our kind
00:50:55.440 | of difficulty scaling problem. To solve the input scaling problem they do it.
00:51:02.160 | I just want to make one quick note. And so like this sort of idea of like is also interpolating
00:51:08.720 | between the noise and the image is this similar to what's called the B-objectives as well.
00:51:15.440 | So there's also a similar kind of it's yeah, it's very quite similar to what Keras and
00:51:19.120 | Dell has, but that's also not been used in a lot of different models. Like for example
00:51:24.400 | Stable Diffusion 2.0 was trained with this sort of B-objective. So people are using this
00:51:29.440 | sort of methodology and getting good results. And yeah, so it's an actual practical thing
00:51:35.600 | that people are doing. So yeah, I just want to make a note of that.
00:51:38.840 | Yeah, as is the case of basically all papers created by Nvidia researchers, of which this
00:51:46.360 | is one. It flies under the radar and everybody ignores it. The V-objective paper came from
00:51:54.280 | the senior author was Jim Salomons, which is Google, right? Yeah. And so anything from
00:52:00.160 | Google and OpenAI everybody listens to. So yeah, although Keras I think has done the
00:52:04.880 | more complete version of this, and in fact the V-objective was almost like mentioned
00:52:12.680 | in passing in the distillation paper. But yeah, that's the one that everybody has ended
00:52:17.520 | up looking at. But I think this is the more...
00:52:19.960 | Yeah, I think what happened with the V-objective is not many people get attention to it. I
00:52:24.960 | think folks like Kat and Robin and these sorts of folks are actually paying attention to
00:52:29.280 | that V-objective in that Google brain paper. But then also this paper did a much more principled
00:52:35.040 | analysis of this sort of thing. So yeah, I think it's very interesting how sometimes
00:52:40.160 | even these sort of side notes in papers that maybe people don't pay much attention to,
00:52:45.080 | they can actually be quite important.
00:52:47.120 | Yeah. Yeah. So, okay. So the noised input as usual is the input image plus the noise
00:52:54.000 | times the sigma. But then, and then as we discussed, we decide how to kind of decide
00:53:01.760 | what our target is. But then we actually take that noised input and we scale it up or down
00:53:09.520 | by this number. And the target, we also scale up or down by this number. And those are both
00:53:18.440 | calculated in this thing as well. So here's C out and here's C in.
00:53:28.080 | Now I just wanted to show one example of where these numbers come from because for a while
00:53:32.200 | they all seem pretty mysterious to me and I felt like I'd never be smart enough to understand
00:53:36.320 | them, particularly because they were explained in the mathematical appendix of this paper,
00:53:41.760 | which are always the bits I don't understand, until I actually try to and then it tends
00:53:45.960 | to turn out they're not so bad after all, which was certainly the case here, which?
00:53:56.000 | I think it was B something, I think. So B6, I think? Is that the one?
00:54:04.880 | So in appendix B6, which does look pretty terrifying, but if you actually look at, for example,
00:54:14.600 | what we're just looking at, C in, it's like how do they calculate? So C in is this. Now
00:54:21.360 | this is the variance of the noise, this is the variance of the data, add them together
00:54:26.240 | to get the total variance, square roots, the total standard deviation. So it's just the
00:54:31.240 | inverse of the total standard deviation, which is what we have here. Where does that come
00:54:36.440 | from? Well, they just said, you know what? The inputs for a model should have unit variance.
00:54:45.000 | Now we know that. We've done that to dare in this course. So they said, right. So well,
00:54:53.760 | the inputs to the model is the, the clean data plus the noise times some number we're
00:55:03.520 | going to calculate and we want that to be one. Okay. So the variance of the plane images
00:55:14.660 | plus the noise is equal to the variance of the clean images plus the variance of the
00:55:20.600 | noise. Okay. So if we want that to be, if we want variance to be one, then divide both
00:55:31.040 | sides by this and take the square root. And that tells us that our multiplier has to be
00:55:36.840 | one over this. That's it. So it's like literally, you know, classical math. The only bit you
00:55:45.440 | have to know is that the variance of three things added together is the variance of the
00:55:51.960 | two things added together, which is not rocket science either.
00:55:57.320 | And in this context, like why we want to do this, when we looked at those sigma's that
00:56:02.120 | you're putting like the distribution, you've got some that are fairly low, but you've also
00:56:05.440 | got some where the standard deviation sigma is like 40, right? So the variance is super
00:56:09.760 | high. Yes. And so we don't want to feed something with standard deviation 40 into our model.
00:56:14.800 | You would like it to be closer to unit variance. So we're thinking, okay, well, if you divide
00:56:18.360 | by roughly 40, that would scare it down. But then we've also got some extra variance from
00:56:22.920 | our data. It's just like 40 plus space of the data of a little bit. We want to scale
00:56:30.040 | back down by that to get unit variance. Yeah. I mean, I love this paper because it's basically
00:56:34.440 | just doing what we spent weeks doing of like, I feel like everything that we've done that's
00:56:41.920 | improved every model has always been one thing, which is, can we get mean zero variance one
00:56:51.360 | inputs to our model and for all of our activations? And then the only other thing is include enough
00:57:00.200 | compute by adding enough layers and enough activations. Those two things seem to be all
00:57:05.960 | that matters. Basically, well, I guess ResNet's added an extra cool little thing to that,
00:57:11.960 | which is to make it even smoother by giving this kind of like identity path. So yeah,
00:57:21.600 | basically trying to make things as smooth as possible and as equal everywhere as possible.
00:57:30.880 | So yeah, this is what they've done. So they did that for the inputs, and then they've
00:57:34.360 | also done it for the outputs and for the outputs, it's basically the same idea. They have basically
00:57:43.840 | the same kind of analysis to show that. And so with this, so now, yeah, we've basically
00:57:51.920 | we've got our noised input, we've got the linear version somewhere between X0 and the
00:58:00.120 | noised input, we've got the scaling of the output and we've got the scaling of the input.
00:58:05.440 | So now for the inputs to our model, we're going to have the scaled noise, we're going
00:58:11.680 | to have the sigma and we're going to have the target, which is somewhere between the
00:58:17.920 | image and the noise. And so, yeah, so I've never seen anybody draw a picture of this
00:58:26.200 | before. So it was really cool when being in a notebook, being able to see like, oh, that's
00:58:30.800 | what they're doing. So yeah, have a good look at this notebook to see exactly what's going
00:58:36.680 | on because I think it gives you a really good intuition around what problem it's trying
00:58:41.400 | to solve. So then I actually checked the noised input has a standard deviation of 1, the main's
00:58:48.640 | not 0 and of course, why would it be? We didn't do anything. The only thing Keras cared about
00:58:53.440 | was having the variance 1. We could easily adjust the input and output to have a mean
00:58:59.780 | of 0 as well. And that's something I think we or somebody should try because I think
00:59:04.400 | it does seem to help a bit as we saw with that generalised value stuff we did, but it's
00:59:09.120 | less important than the variance. And so same with the target, it's got the 1. And yeah,
00:59:14.320 | this is where if I change this to the correct value, which is 0.66, then actually it's slightly
00:59:21.440 | further away from 1, both here and here, quite a lot further away. And maybe that's because
00:59:26.440 | actually the data is, well, we know the data is not Gaussian distributed. Pixel data definitely
00:59:32.240 | isn't Gaussian distributed. So this bug turned out better. Okay. So the unit's the same,
00:59:41.080 | the initialisation's the same. This is all the same. Train it for a while. We can't compare
00:59:47.200 | the losses because our target's different. But what we can do is we can create a denoise
00:59:55.440 | that just takes the thing that, as per usual, the thing we had in noisify, right? And so
01:00:04.600 | for x0, it's going to multiply by c out and then add c skip by noised input. Here it is,
01:00:10.440 | multiply by c out, add noised input, c skip. Okay, so we can denoise. So let's grab our
01:00:21.360 | sigmas from the actual batch we had. Let's calculate c skip c out and c in for the sigmas
01:00:27.840 | in our mini batch. Let's use the model to predict the target given the noised input
01:00:36.920 | and the sigmas, and then denoise it. And so here's our noised input, which we've already
01:00:44.120 | seen, and here's our predictions. And these are absolutely remarkable, in my opinion. Yeah.
01:00:56.480 | Like this one here, I can barely see it. You know, it's really found, look at the shirt.
01:01:01.720 | There's a shirt here. It's actually really finding the little thing on the front. And
01:01:04.960 | let me show you, here's what it should look like, right? And in cases where the sigma's
01:01:12.520 | pretty high, like here, you can see it's really like saying, like, I don't know, maybe it's
01:01:18.080 | shoes, but it could be something else. Is it shoes? Yeah, it wasn't shoes. But at least
01:01:23.600 | it's kind of got the, you know, the bulk of the pixels in the right spot. Yeah, something
01:01:30.400 | like this one is 4.5, has no idea what it is. It's like, oh, maybe it's shoes, maybe
01:01:35.240 | it's pants. You know, it turns out it is shoes. Yeah. So I think that's fascinating how well
01:01:43.920 | it can do. And then the other thing I did, which I thought was fun, was I just created,
01:01:53.080 | so I just, you did a sigma of 80, which is actually what they do when they're doing sampling
01:01:57.520 | from pure noise. That's what they consider the pure noise level. So I just created some
01:02:03.020 | pure noise and denoised it just for one step. And so here's what happens when you denoise
01:02:11.400 | it for one step. And you can see it's kind of overlaid all the possibilities. It's like,
01:02:16.560 | I can see a pair of shoes here, a pair of pants here at top here. And sometimes it's
01:02:22.320 | kind of like more confident that the noise is actually a pair of pants. And sometimes
01:02:27.440 | it's more confident that it's actually shoes. But you can really get a sense of how like
01:02:32.520 | from pure noise, it starts to make a call about like what this noise is actually covering
01:02:39.760 | up. And this is also the bit which I feel is like, I'm the least convinced about when
01:02:48.160 | it comes to diffusion models. This first step of going from like pure noise to something
01:02:56.960 | and like trying to have a good mix of all the possible somethings, I'm, I don't know,
01:03:02.160 | it feels a bit hand-waving to me. It clearly works quite well, but I'm not sure if it's
01:03:05.680 | like we're getting the full range of possibilities. And I feel like some of the papers we're starting
01:03:12.520 | to see are starting to say like, you know what, maybe this is not quite the right approach.
01:03:17.240 | And maybe later in the course, we'll look at some of the ones that look at what we call
01:03:21.880 | VQ models and tokenized stuff. Anyway, I thought this is pretty interesting to see these pictures,
01:03:30.560 | which I don't think, yeah, I've never seen any pictures like this before. So I think
01:03:35.080 | this is a fun result from doing all this stuff in notebooks step by step.
01:03:42.080 | Okay, so sampling. So one of the nice things with this is the sampling becomes much, much,
01:03:52.160 | much simpler. And so, and I should mention a lot of the code that I'm using, particularly
01:04:01.360 | in the sampling section is heavily inspired by, and some of it's actually copied and pasted
01:04:06.400 | from Kat's K-diffusion repo, which is, I think I mentioned before, some of the nicest generative
01:04:17.160 | modeling code or maybe the nicest generative modeling code I've ever seen. It's really
01:04:22.480 | great. So before we talk about the actual sampling,
01:04:27.600 | the first thing we need to talk about is what sigma do we use at each reverse time step?
01:04:34.400 | And in the past, we've always, well, nearly always done something, which I think has always
01:04:39.200 | felt is sketchy as all hell, which is we've just linearly gone down the sigmas or the alpha
01:04:46.480 | bars or the t's. So here, when we're sampling in the previous notebook, we used lin space.
01:04:52.340 | So I always felt like that was questionable. And I felt like at the start, you probably,
01:04:59.040 | like it was just noise anyway. So who cared? Who cares? So I, in DDPMv3, I experimented
01:05:07.880 | with something that I thought intuitively made more sense. I don't know if you remember this
01:05:13.360 | one, but I actually said, oh, let's, for the first 100 times steps, let's actually only
01:05:20.160 | run the model every 10 times. And then for the next 100, let's run it nine times. The
01:05:24.520 | next 100, let's run it every eight times. So basically at the start, be much less careful.
01:05:30.360 | And so Keras actually ran a whole bunch of experiments and they said, yeah, you know what?
01:05:41.040 | At the start of training, you know, you can start with a high sigma, but then like step
01:05:47.200 | to a much lower sigma in the next step and then a much lower sigma in the next step.
01:05:51.680 | And then the longer, the more you train step by smaller and smaller steps so that you spend
01:05:57.440 | a lot more time fine-tuning carefully at the end and not very much time at the start. Now,
01:06:07.860 | this has its own problems. And in fact, a paper just came out today, which we probably
01:06:12.060 | won't talk about today, but maybe another time, which talked about the problems is that in
01:06:16.320 | these very early steps, this is the bit where you're trying to create a composition that
01:06:21.600 | makes sense. Now for fashion MNIST, we don't have much composing to do. It's just a piece
01:06:26.400 | of clothing. But if you're trying to do an astronaut riding a horse, you know, you've
01:06:31.520 | got to think about how all those pieces fit together. And this is where that happens.
01:06:36.040 | And so I do worry that with the Keras approach is what's not giving that maybe enough time.
01:06:41.160 | But as I've said, that's really the same as this step. That whole piece feels a bit wrong
01:06:47.240 | to me. But aside from that, I think this makes a lot of sense, which is that, yeah, the sampling,
01:06:54.080 | you should jump, you know, by big steps early on and small steps later on and make sure
01:06:59.720 | that the fine details are just so. So that's what this function does, is it creates this
01:07:06.800 | plot. Now it's this schedule of reverse diffusion sigma steps. It's a bit of a weird function
01:07:17.680 | in that it's the rowth root of sigma, where row is seven. So the seventh root of sigma
01:07:26.580 | is basically what it's scaling on. But the answer to why it's that is because they tried
01:07:33.280 | it and it turned out to work pretty well. Do you guys remember where this was?
01:07:42.480 | This is a truncation error analysis, D1.
01:07:48.280 | That's very. So this image here, so thanks for telling me where this is, shows fed as
01:08:03.640 | a function of row. So it's basically what the whatth root are we taking? And they basically
01:08:10.680 | said, like, if you take the fifth root up, it seems to work well, basically. So, yeah,
01:08:19.600 | so that's a perfectly good way to do things is just to try things and see what works.
01:08:24.400 | And you'll notice they tried things just like we love on small data sets. Not as small as
01:08:28.520 | us because we're the king of small data sets, but small ish, so far 10, the image net 64.
01:08:35.200 | That's the way to do things. I saw, like, I might have even been the CEO of Hugging Face
01:08:40.980 | the other day, tweet something saying only people with huge amounts of GPUs can do research
01:08:45.360 | now. And I think it totally misunderstands how research is done, which is research is
01:08:49.240 | done on very small data sets. That's that's the actual research. And then when you're
01:08:55.400 | all done, you scale it up at the end. I think we're kind of pushing the envelope in terms
01:09:02.220 | of like, yeah, how how much can you do? And yeah, we've, like, we covered this kind of
01:09:12.360 | main substantive path of diffusion models history, step by step, showing every improvement
01:09:19.340 | and seeing clear improvements across all the papers using nothing but fashioned MNIST running
01:09:24.160 | on a single GPU in like 15 minutes of training or something per model. So, yeah, definitely
01:09:30.080 | you don't need lots of models. Anyway, OK, so this is the sigma we're going to jump to.
01:09:37.960 | So the denoising is going to involve calculating the C skip, C out and C in and calling our
01:09:44.400 | model with the C in scaled data and the sigma and then scaling it with C out and then doing
01:09:53.120 | the C skip. OK, so that's just undoing the Noisify. So check this out. This is all that's
01:10:00.360 | required to do one step of denoising for the simplest kind of scheduler, which sorry, the
01:10:07.240 | simplest kind of sampler, which is called Euler. So we basically say, OK, what's the
01:10:11.400 | sigma at time step I? What's the sigma 2 at time step I? And now when I'm talking about
01:10:19.120 | at time step, I'm really talking about like the step from this function. Right. So this
01:10:24.560 | is this is the sampling step. Yeah. OK, so then denoise using the function and then we
01:10:35.840 | say, OK, well, just send back whatever you were given plus move a little bit in the direction
01:10:43.680 | of the denoised image. So the direction is X minus denoised. So that's the noise. That's
01:10:49.880 | the gradient as we discussed right back in the first lesson of this part. So we'll take
01:10:54.840 | the noise. If we divide it by sigma, we get a slope. It's how much noise is there per
01:11:01.400 | sigma. And then the amount that we're stepping is sigma 2 minus sigma 1. So take that slope
01:11:08.960 | and multiply it by the change. Right. So that's the distance to travel towards the noise at
01:11:18.880 | this fraction. You know, or you could also think of it this way. And I know this is a
01:11:23.080 | very obvious algebraic change. But if we move this over here, you could also think of this
01:11:31.040 | as being, oh, of the total amount of noise, the change in sigma we're doing, what percentage
01:11:37.840 | is that? OK, well, that's the amount we should step. So there's two ways of thinking about
01:11:43.760 | the same thing. So again, this is just, you know, high school math. Well, I mean, actually,
01:11:56.160 | my seven-year-old daughter has done all these things. It's plus minus divided in times.
01:12:03.840 | So we're going to need to do this once per sampling step. So here's a thing called sample,
01:12:13.000 | which does that. It's going to go through each sampling step, call our sampler, which initially
01:12:21.520 | we're going to do sample Euler. Right. With that information, add it to our list of results
01:12:30.920 | and do it again. So that's it. That's all the sampling is. And of course, we need to grab
01:12:37.160 | our list of sigmas to start with. So I think that's pretty cool. And at the very start,
01:12:43.960 | we need to create our pure noise image. And so the amount of noise we start with is got
01:12:49.000 | a sigma of 80. OK, so if we call sample using sample Euler and we get back some very nice
01:13:02.400 | looking images and believe it or not, our fed is 1.98. So this extremely simple sampler,
01:13:13.240 | three lines of code plus a loop has given us a bit of 1.98, which is clearly substantially
01:13:30.520 | better than our coastline. Now we can improve it from there. So one potential improvement
01:13:38.640 | is to you might have noticed we added no new noise at all. This is a deterministic scheduler.
01:13:46.520 | There's no rand anywhere here. So we can do something called an ancestral Euler sampler,
01:13:54.280 | which does add rand. So we basically do the denoising in the usual way, but then we also
01:14:02.600 | add some rand. And so what we do need to make sure is given that we're adding a certain
01:14:07.080 | amount of randomness, we need to remove that amount of randomness from the step that we
01:14:13.240 | take. So I won't go into the details, but basically there's a way of calculating how
01:14:22.160 | much new randomness and how much just going back in the existing direction do we do. And
01:14:29.080 | so there's the amount in the existing direction and there's the amount in the new random direction.
01:14:35.240 | And you can just pass in eta, which is just going to, when we pass it into here, is going
01:14:44.040 | to scale that. So if we scale it by half, so basically half of it is new noise and half
01:14:53.720 | of it is going in the direction that we thought we should go, that makes it better still.
01:14:59.800 | Again with 100 steps. And just make sure I'm comparing to the same, yep, 100 steps. Okay,
01:15:06.880 | so that's fair, like with like. Okay, so that's adding a bit of extra noise. Now then, something
01:15:17.120 | that I think we might have mentioned back in the first lesson of this part is something
01:15:24.840 | called Heun's method. And Heun's method does something which we can pictorially see here
01:15:34.640 | to decide where to go, which is basically we say, okay, where are we right now? What's
01:15:40.160 | the, you know, at our current point, what's the direction? So we take the tangent line,
01:15:46.080 | the slope, right? That's basically all it does is it takes a slope. It says, oh, here's
01:15:49.960 | a slope, you know. Okay, and so if we take that slope, and that would take us to a new
01:16:05.880 | spot, and then at that new spot, we can then calculate a slope at the new spot as well.
01:16:16.840 | And at the new spot, the slope is something else. So that's it here, right? And then you
01:16:25.960 | say, like, okay, well, let's go halfway between the two. And let's actually follow that line.
01:16:32.520 | And so basically, it's saying, like, okay, each of these slopes is going to be inaccurate.
01:16:38.080 | But what we could do is calculate the slope of where we are, the slope of where we're
01:16:41.400 | going, and then go halfway between the two. It's, I actually found it easier to look at
01:16:47.200 | in code personally. I'm just going to delete a whole bunch of stuff that's totally irrelevant
01:16:53.280 | to this conversation. So take a look at this compared to Euler. So here's our Euler, right?
01:17:05.400 | So we're going to do the same first line exactly the same, right? Then the denoising is exactly
01:17:10.960 | the same. And then this step here is exactly the same. I've actually just done it in multiple
01:17:17.120 | steps for no particular reason. And then you say, okay, well, if this is the last step,
01:17:25.720 | then we're done. So actually, the last step is Euler. But then what we do is we then say,
01:17:31.640 | well, that's okay, for an Euler step, this is where we'd go. Well, what does that look
01:17:39.080 | like if we denoise it? So this calls the model the second time, right? And where would that
01:17:44.080 | take us if we took an Euler step there? And so here, if we took an Euler step there, what's
01:17:50.420 | the slope? And so what we then do is we say, oh, okay, well, it's just, just like in the
01:17:55.640 | picture, let's take the average. Okay, so let's take the average and then use that,
01:18:06.880 | the step. So that's all the HUIN sampler does is just takes the average of the slope where
01:18:13.280 | we're at and the slope where the Euler method would have taken us. And so if we now so notice
01:18:20.240 | that it called the model twice for a single step. So to be fair, since we've been taking
01:18:25.080 | 100 steps with Euler, we should take 50 steps with HUIN, right? Because it's going to call
01:18:29.400 | the model twice. And still that is now whoa, we beat one, which is pretty amazing. And
01:18:38.360 | so we could keep going, check this out, we could even go down to 20. This is actually
01:18:41.520 | doing 40 model evaluations and this is better than our best Euler, which is pretty crazy.
01:18:47.600 | Now, something which you might have noticed is kind of weird about this or kind of silly
01:18:52.600 | about this is we're calling the model twice just in order to average them. But we already
01:19:00.120 | have two model results, like without calling it twice, because we could have just looked
01:19:04.760 | at the previous time step. And so something called the LMS sampler does that instead.
01:19:16.920 | And so the LMS sampler, if I call it with 20, it actually literally does 20 evaluations
01:19:22.640 | and actually it beats Euler with 100 evaluations. And so LMS, I won't go into the details too
01:19:29.560 | much. It didn't actually fit into my little sampling very well. So basically largely copied
01:19:33.480 | and pasted the cat's code. But the key thing it does is look, it gets the current sigma,
01:19:39.280 | it does the denoising, it calculates the slope, and it stores the slope in a list, right?
01:19:48.640 | And then it grabs the first one from the list. So it's kind of keeping a list of up to, in
01:19:58.960 | this case, four at a time. And so it then uses up to the last four to basically, yes, kind
01:20:07.920 | of the curvature of this and take the next step. So that's pretty smart. And yeah, I think
01:20:20.160 | if you wanted to do super fast sampling, it seems like a pretty good way to do it. And
01:20:26.360 | I think, Johnno, you're telling me that, or maybe it's Pedro was saying that currently
01:20:32.120 | people have started to move away. This was very popular, but people started to move towards
01:20:36.280 | a new sampler, which is a bit similar called the DPM++ sampler, something like that. Yeah.
01:20:43.320 | Yeah. Yeah. Yeah. But I think it's the same idea. I think it kind of keeps a, I said,
01:20:51.200 | keep a list of recent results and use that. I'll have to check it more closely.
01:20:57.480 | I'll have to look at the code. Yeah. That's a similar idea. It's like, if it's done more
01:21:01.360 | than one step, then it's using some history to the next thing. Yeah. This history and
01:21:07.760 | thing doesn't make a huge amount of sense, I guess, from that perspective. I mean, still
01:21:14.080 | works very well. This makes more sense. So then we can compare if we use an actual mini
01:21:18.920 | match of data, we get about 0.5. So yeah, I feel like this is quite a stunning result
01:21:33.240 | to get close to, very close to real data, this in terms of fit, really with 40 model
01:21:44.160 | evaluations and the entire, nearly the entire thing here is by making sure we've got unit
01:21:54.440 | variance, inputs, unit variance, outputs, and kind of equally difficult problems to
01:21:59.760 | solve in our loss function. Yeah. Thus having that different schedule for sampling. That's
01:22:06.240 | completely unrelated to the training schedule. I think that was one of the big things with
01:22:09.760 | Karas et al's paper was they also could apply this to like, oh, existing diffusion models
01:22:15.360 | that have been trained by other papers. We can use our sampler and in fewer steps get
01:22:19.160 | better results without any of the other changes. And yeah, I mean, they do a little bit of
01:22:25.680 | rearranging equations to get the other papers versions into their C skip C and C out framework.
01:22:34.240 | But then yeah, it's really nice that these ideas can be applied to, so for example, I
01:22:38.640 | think stable diffusion, especially version one was trained DDPM style training, Epsilon
01:22:44.720 | objective, whatever. But you can now get these different samplers and different sometimes
01:22:50.400 | schedules and things like that and use that to sample it and do it in 15, 20 steps and
01:22:56.680 | get pretty nice samples. Yeah. And another nice thing about this paper is they, in fact,
01:23:08.400 | the name of the paper elucidating the design space of diffusion based models. They looked
01:23:14.960 | at various different papers and approaches and trying to set like, oh, you know what?
01:23:19.080 | These are all doing the same thing when we kind of parameterize things in this way. And
01:23:25.320 | if you fill in these parameters, you get this paper and these parameters, you get that paper.
01:23:29.760 | And then so we found a better set of parameters, which it was very nice to code because, you
01:23:42.360 | know, it really actually ended up simplifying things a whole lot. And so if you look through
01:23:48.520 | the notebook carefully, which I hope everybody will, you'll see, you know, that the code
01:23:56.240 | is really there and simple compared to all the previous ones, in my opinion. Like, I
01:24:04.680 | feel like every notebook we've done from DDPM onwards, the code's got easier to understand.
01:24:13.320 | And just to again clarify like how this connects with some of the previous papers that we've
01:24:19.200 | looked at. So like, for example, with the DDIM, the deterministic, that's again, the sort
01:24:24.520 | of deterministic approach that's similar to the Euler method sampler that we were just
01:24:30.800 | looking at, which was completely deterministic. And then some of something like the Euler
01:24:35.520 | ancestral that we were looking at is similar to the standard DDPM approach with that was
01:24:43.680 | kind of a more stochastic approach. So again, there's just all those sorts of connections
01:24:48.760 | that then are kind of nice to see, again, the sorts of connections between the different
01:24:52.880 | papers and how they change it, how they can be expressed in this common framework.
01:24:57.840 | Yeah. Thanks, Tanish. So we definitely now are at the point where we can show you the
01:25:09.040 | unit next time. And so I think we're, unless any of us come up with interesting new insights
01:25:18.060 | on the unconditional diffusion sampling, training and sampling process, we might be putting
01:25:28.160 | that aside for a while, and instead we're going to be looking at creating a good quality
01:25:34.880 | unit from scratch. And we're going to look at a different data set to do that as we start
01:25:46.600 | into scale things up a bit, as Jono mentioned in the last lesson. So we're going to be using
01:25:50.720 | a 64 by 64 pixel image net subset called tiny image net. So we'll start looking at some
01:25:59.120 | three channel images. So I'm sure we're all sick of looking at black and white shoes.
01:26:05.800 | So now we get to look at shift dwellings and trolley buses and koala bears and yeah, 200
01:26:17.980 | different things. So that'll be nice. Yeah. All right. Well, thank you, Jono. Thank you,
01:26:24.800 | Tanish. That was fun as always. And yeah, next time we'll be lesson 22. Bye.
01:26:32.160 | Oh, listen to me. Hey, this was lesson 22. Oh, no way. Okay. You're right. See ya. Bye.
01:26:40.080 | [BLANK_AUDIO]