back to index

Lesson 9A 2022 - Stable Diffusion deep dive


Chapters

0:0 Introduction
0:40 Replicating the sampling loop
1:17 The Auto-Encoder
3:55 Adding Noise and image-to-image
8:43 The Text Encoding Process
15:15 Textual Inversion
18:36 The UNET and classifier free guidance
24:41 Sampling explanation
36:30 Additional guidance

Whisper Transcript | Transcript Only Page

00:00:00.000 | Hello, everyone. My name is Jonathan, and today I'm going to be taking you through this
00:00:06.540 | stable diffusion deep dive notebook, looking into the code behind the kind of popular high
00:00:13.240 | level APIs and libraries and tools and so on to see what exactly does the generation
00:00:18.000 | process look like and how can we modify that? How do each of the individual components work?
00:00:23.640 | So feel free to run along with me. If you haven't before, this might take a little while
00:00:28.880 | to run just because it's downloading these large models. If they aren't already downloaded
00:00:32.720 | and loading them up. So we're going to start by just kind of recreating what it looks like
00:00:37.640 | to generate an image using, say, one of the existing pipelines and hugging face. So we're
00:00:42.500 | going to basically have copied the code from the core method of the default stable diffusion
00:00:47.920 | pipeline. So if you go and view that here, you'll see that we're going to be basically
00:00:52.000 | replicating this code. But now we'll be doing it on our own sort of notebook, and then we'll
00:00:58.760 | slowly understand what each of these different parts is doing. So we've got some setup, we've
00:01:03.080 | got some sort of loop running through a number of sampling timestamps, and we're generating
00:01:06.920 | an image. So this is supposed to be a watercolor picture of an otter. And it's very, very cool
00:01:11.760 | that this model can just do this. But now we want to know how does that actually work?
00:01:15.960 | What's going on? So the first component is the autoencoder. Now this stable diffusion
00:01:21.200 | is a latent diffusion model. And what that means is that it doesn't operate on pixels,
00:01:25.480 | it operates in the latent space of some other autoencoder model. In this case, a variational
00:01:31.240 | autoencoder that's been trained on a large number of images to compress them down into
00:01:35.440 | this latent representation and bring them back up again. So I have some functions to
00:01:39.400 | do that. We're going to look at what it's like in action, just downloading a picture
00:01:43.480 | from the internet, opening it up with PIL. So we have this 512 by 512 pixel image, and
00:01:50.240 | we're going to load it in. And then we're going to use our function defined above to
00:01:54.520 | encode that into some latent representation. And what this is doing is calling the VAE.encode
00:01:59.840 | method on a tensor version of the image. And that gives us a distribution. And so we sampling
00:02:04.440 | from that distribution and we scaling by this because that's what the authors of the model
00:02:09.320 | did. They scaled the latents down before they fed them to the model. And so we have to do
00:02:14.400 | that scaling and then the reverse when we decoding just to be consistent with that.
00:02:19.440 | But the key idea is that we go from this big image down to this 4 by 64 by 64 latent representation.
00:02:26.360 | So we've gone from this much larger image down. And if we visualize what the four channels
00:02:30.440 | here, this four different 64 by 64 channels, what that looks like, we'll see that it's
00:02:35.560 | capturing something of the image. You can sort of see the same shapes and things there.
00:02:39.880 | But it's not quite a direct mapping or anything. For example, there's this weirdness going
00:02:43.440 | on the beak. Some of the channels look slightly stranger than the others. So there's some
00:02:47.940 | sort of rich information captured there. And if we decode this back, what we'll see is
00:02:53.160 | that the decoded image looks really good. You really have to look closely to tell the
00:02:57.480 | difference between our input image here and the decoded version. So very, very impressive
00:03:02.800 | compression, right? This is a factor of 8 in each dimension. So 512 by 512 down to 64
00:03:09.320 | by 64. It's like a factor of 64 reduction in data, but it's still somehow capturing
00:03:16.220 | most of that information. It's a very information-rich representation. And this is going to be great
00:03:20.960 | because now we can work with that with our diffusion model and get nice high resolution
00:03:25.520 | results even though we're only working with these 64 by 64 latents. Now, it doesn't have
00:03:30.640 | to be 64 by 64. You can go and modify this to say what if this is 640 and encode that
00:03:37.920 | down and you'll see that it's just that same factor of 8 reduction. And there we go. Now
00:03:44.320 | we have 80 by 64. This just has to be a multiple of 8. Otherwise, you'll get, I think, an error.
00:03:51.000 | Okay. So we have our encoded version of this image, and that's pretty great. The next component
00:03:56.600 | we're going to look at is the scheduler, and I'll look more closely at this later. But
00:04:01.240 | for now, we're going to focus on this idea of adding noise, right? So during training,
00:04:05.040 | we add some noise to an image, and then the model tries to predict what that noise is,
00:04:09.280 | and we're going to do that to different amounts. So here we're going to recreate the same type
00:04:13.960 | of schedule, and you can try different schedulers from the library. Oops. And these parameters
00:04:18.960 | here, beta start, beta end, beta schedule, that's how much noise was added at different
00:04:23.000 | time steps and how many time steps are used during training. For sampling, we don't want
00:04:28.080 | to have to do a thousand steps, so we can set a new number of time steps, and then we'll
00:04:32.760 | see how these correspond with the scheduler.timesteps attribute to the original training time steps.
00:04:41.360 | So here we're going to have 15 sampling steps, and that's going to be equivalent to starting
00:04:45.360 | at time step 999 and just moving linearly down to time steps zero. We can also look
00:04:51.480 | at the actual amount of noise present with the sigma's attribute. So again, starting
00:04:55.680 | high, moving down, and if you want to see what that schedule looks like, we can plot
00:05:00.240 | that here. And if you want to see the time steps, you'll see that it's just a linear
00:05:04.000 | relationship. So there we go. We're going to start at a very high noise value, and we're
00:05:10.000 | going to slowly, slowly try and reduce this down until ideally we get an image out.
00:05:14.440 | Okay, so the sigma is the amount of noise added. Let's see what that looks like. So
00:05:19.440 | I'm going to start with some random noise that's the same shape as my latent representation,
00:05:23.640 | my encoded image, and then I'd like to be equivalent to sampling step 10 out of 15
00:05:29.200 | here. So I'm going to go and look up what time step that equates to, and that's going
00:05:33.920 | to be one of the arguments that I passed the scheduler.addnoise function. So I'm calling
00:05:38.760 | scheduler.addnoise, giving it my encoded image, the noise, and what time step I'd like to
00:05:44.280 | be noising equivalent to. And this is going to give me this noisy but still recognisable
00:05:48.680 | version of the image. And you can go and say, okay, what if I look at somewhere earlier
00:05:52.520 | in the process, right? Does it look more noisy? What about right at the beginning, right at
00:05:56.680 | the end? Feel free to play around there. Okay, so this adding noise, what are we actually
00:06:02.760 | doing? What does the code look like? Let's inspect the function, and you'll see that
00:06:06.600 | there's some set up for different types of argument and shapes. But the key line is just
00:06:11.240 | this noisy samples is equal to original samples plus the noise scaled by the sigma parameter.
00:06:17.480 | All right, so that's all it is. It's not always the same. Different papers and implementations
00:06:21.960 | will add the noise slightly differently. But in this case, that's all it's doing. So scheduler.addnoise,
00:06:27.320 | just adding noise that's the same shape as the latency scaled by the sigma parameter.
00:06:32.360 | Okay, so that's what we're doing. So if we want to start from random noise instead of
00:06:36.840 | a noisy image, we're going to scale it by that same sigma value so that it looks the
00:06:40.760 | same as an image that's been scaled by that amount. But then before we feed that to the
00:06:44.960 | actual model, we then have to handle that scaling again. You could do it like this,
00:06:49.280 | but now we have this scale model input function associated with the scheduler just to hide
00:06:54.760 | that complexity away. Okay, so now we're going to look at the same kind of sampling loop
00:06:59.080 | as before. But we're going to start now with our image, we're going to take our encoded
00:07:03.080 | image, we're going to noise it to some time set, and then we're only going to denoise
00:07:06.840 | from there. So in code, we are now preparing our text and everything the same as before,
00:07:11.320 | which we'll look at, we setting our number of inference steps to 50, right, number inference
00:07:15.560 | steps is equal to 50 here. And we're saying I'd like to start at the equivalent of step
00:07:19.240 | 10 out of 50. So I'll look up what time step that equates to, I'll add noise to my image,
00:07:26.040 | equivalent to that step. And then we're going to run through sampling, but this time we're
00:07:30.440 | only going to start doing things once we get above that start step. So I'm going to ignore
00:07:34.360 | the first 10 out of 50 steps. And then beyond that, I'm now going to start with this noisy
00:07:38.920 | version of my input image. And I'm going to denoise it according to this prompt. And the
00:07:43.160 | hope here is that by starting from something that has some of the sort of rough structure and color
00:07:48.040 | of that input image, I can kind of fix that into my generation. But I've got a new prompt,
00:07:53.320 | a National Geographic photo of a colorful dancer. And here we go, we see this is the same sort of
00:07:58.040 | thing as the parrot. But now we have this completely different actual content thanks to
00:08:02.200 | a different prompt. And so that's a fun kind of use of this image to image process. You might have
00:08:07.880 | seen this for taking drawings, adding a bunch of noise and then denoising them into fancy paintings
00:08:11.880 | and so on. So again, this is something that there's existing tools for this, right, the
00:08:16.920 | strength parameter and the image to image pipeline. That's just something like this, what step are we
00:08:23.560 | starting at? How many steps are we skipping? But you can see that this is a pretty powerful
00:08:29.000 | technique for getting a bit of extra control over like composition and color and a bit of the
00:08:33.720 | structure. Okay, so that's that trick with adding noise and then using that as image to image.
00:08:39.560 | The next big section I'd like to look at is how do we go from a piece of text that describes what
00:08:46.120 | we want into a numerical representation that we can feed to the model. So we're going to trace
00:08:52.360 | out that pipeline. And along the way, we'll see how we can modify that for a bit of fun.
00:08:56.920 | So step number one, we're taking our prompt and returning it into a sequence of discrete tokens.
00:09:03.640 | So here we have, in this case, 77, because that's the maximum length, discrete tokens,
00:09:09.640 | it's always going to be that if your prompt is longer, it'll truncate it. And if we decode these
00:09:13.800 | tokens back, we'll see that we have a special token for the start of the text, then a picture of a
00:09:19.880 | puppy. And then the rest is all the same token, which is this kind of end of text padding token.
00:09:24.680 | Right, so we have this special token for puppy. This special token has its own meaning end of
00:09:30.600 | text. And the prompts are always going to be padded to be the same length. So before, in the
00:09:37.000 | code that we were using there, we always jump straight to the circle output embeddings,
00:09:41.560 | which is what we fed to the model as conditioning. And so somehow this captures some information
00:09:46.120 | about this prompt. And but now we want to say, well, how do we get there? How do we get from this
00:09:50.520 | sequence of tokens to these output embeddings? What is this text encoder forward pass doing?
00:09:57.080 | Right, so we can look at this, and there's going to be multiple steps. And the first is going to
00:10:01.960 | be some embeddings. So if we look at the text encoder dot text model dot embeddings, we'll see
00:10:06.200 | there's a couple of different ones, we have token embeddings, right? And so this is to take those
00:10:10.600 | individual tokens, token 408, or whatever, and map it into a representation that's a numerical
00:10:18.760 | representation. So here it's a learned embedding. There are about 50,000 rows, one for each token.
00:10:25.800 | And for each token, we have 768 values. So that's the embedding of that token. And if we want to
00:10:31.160 | feed one in and see what the embedding looks like, here's the token for puppy. And here's the token
00:10:35.640 | embedding, right? 768 numbers that somehow capture that meaning of that token on its own. And we can
00:10:42.280 | do the same for all of the tokens in our prompt. So we feed them through this token embedding layer.
00:10:46.360 | And now we get 77 768 dimensional representations of this of each token. Now, these are all on their
00:10:55.640 | own. And no matter where in the sentence is, it is the token embedding will be the same. So the next
00:11:01.080 | step is to add some positional information. Some models will do this with some kind of like learned
00:11:05.560 | pattern of positioning. But in this case, the positional embedding is just another learned
00:11:09.800 | embedding. But now instead of having one embedding for every token, we have one embedding for every
00:11:14.840 | position out of all 77 possible positions. And so just like we did for the tokens, we can feed them
00:11:20.280 | the position IDs, one for every possible position, and we'll get back out an embedding for every
00:11:26.520 | position in the prompt. And combining them together, there's again, multiple ways people do this in the
00:11:32.360 | literature. But in this case, it's as simple as adding them. That's why they made them the same
00:11:36.360 | shape, so that you can just add the two together. And now, these input embeddings have some
00:11:41.560 | information related to the token and some related to the position. And so so far, we haven't seen
00:11:47.000 | any big model just to learn embeddings, but this is getting everything ready to feed through that
00:11:51.240 | model. And so we can check that this is the same as if we just called the embeddings layer of that
00:11:56.680 | model, which is going to do both of those steps at once. And but we'll see just now why we want to
00:12:01.240 | separate that out into individual ones. Okay, so we have these individual tokens, and they have some
00:12:07.080 | positional information, we have these final embeddings. Now we'd like to turn them into something
00:12:11.000 | that has a richer representation, thanks to some big transformer model. And so we're going to feed
00:12:15.720 | these through. And I made this little diagram here, each token is going to turn into a token
00:12:20.120 | embedding combined with the positional embedding. And then it's going to get fed through this
00:12:24.760 | transformer encoder, which is just a stack of these blocks. And so each block has some magic like
00:12:29.960 | attention has some feed forward components, there's additions and normalizations and skips and so on
00:12:35.320 | as well. And but we're going to have some number of these blocks all stacked together, and the
00:12:39.480 | outputs of each one get fed into the next block and so on. And so we get our final set of hidden
00:12:44.120 | states, these encoder hidden states, aka the output embeddings. And this is what we feed to our unit
00:12:50.200 | to make its predictions. So the way we get this, I just copied the text encoded up text model forward
00:12:56.200 | method, pulled out the relevant bits, we are going to take in those input embeddings, combined
00:13:00.360 | positional and token embeddings, and we're going to feed that through the text model dot encoder
00:13:05.000 | function with some additional parameters around attention masking and telling it that we'd like
00:13:10.520 | to output the hidden states rather than the final outputs. So if we run this, we can just double
00:13:16.040 | check, these embeddings are going to look just like the output embeddings we saw right at the
00:13:21.240 | beginning. So we've taken that one step, tokens to output embeddings, and we've broken it down into
00:13:26.040 | this number of smaller steps where we have tokenization, getting our token embeddings,
00:13:30.120 | combining with position embeddings, feeding it through the model, and then that gives us those
00:13:33.720 | final outputs. So why have we gone through this problem trouble? Well, there's a couple of things
00:13:38.520 | we can do. One demo here, I'm getting the token embeddings, but then I'm looking up where is the
00:13:46.120 | token for puppy, and I'm going to replace it with a new set of embeddings. And this is going to be
00:13:51.080 | another just learned embedding of this particular token here, 2368. So I'm kind of cutting out the
00:13:57.160 | token embedding for puppy, slipping in this new set of token embeddings, and I'm going to get some
00:14:01.640 | output embeddings which at the start look very similar to the previous ones, in fact identical.
00:14:05.640 | But as soon as you get past the position of puppy in that prompt, you're going to see that the rest
00:14:10.440 | have changed. So we've somehow messed with these embeddings by slipping in this new token embedding
00:14:15.400 | right at the start. And if we generate with those embeddings, which is what this function is doing,
00:14:20.600 | we should see something other than a puppy. And sure enough, drum roll, we don't, we get a cat.
00:14:27.320 | And so now you know what token 2368 means. We've managed to slip in a new token embedding and get
00:14:33.880 | a different image. Okay, what can we do with this? Why is this fun? Well, a couple of tricks. First
00:14:38.840 | off, we could look up the token embedding for skunk, right, which is this number here. And then
00:14:44.360 | instead of now just replacing that in place of puppy, what if I make a new token embedding
00:14:50.120 | that's some combination of the embedding of puppy and the embedding of skunk, right?
00:14:54.680 | So I'm taking these two token embeddings, I'm just averaging them, and I'm inserting them into my
00:14:59.720 | set of token embeddings for my prompt in place of just the word puppy. And so hopefully when we
00:15:05.720 | generate with this, we get something that looks a bit like a puppy, a bit like a skunk. And this
00:15:10.440 | doesn't work all the time, but it's pretty cute when it does. There we go, puppy skunk hybrid.
00:15:15.400 | Okay, so that's not the real reason we're looking at this. The main application at the moment of
00:15:20.280 | being able to mess with these token embeddings is to be able to do something called textual
00:15:24.120 | inversion. So in textual inversion, we're going to have our prompt tokenize it and so on. But here
00:15:29.560 | we're going to have a special learned embedding for some new concept, right? And so the way that's
00:15:35.000 | trained is going to be outside of the scope of this notebook. But there's a good blog post and
00:15:40.520 | community notebooks and things for doing that. But let's just see this in application here. So
00:15:45.080 | there's a whole library of these concepts, stable diffusion concept library, where you can browse
00:15:51.160 | through tons and tons and tons of look over 1,400 different community contributed token embeddings
00:15:59.240 | that people have trained. And so I'm going to use this one here, this bird style. Here's some example
00:16:04.120 | outputs. And then these are the images it was trained on. So these pretty little bird paintings
00:16:08.440 | done by my mother. And I've trained a new token embedding that tries to capture the essence of
00:16:14.360 | the style. And that's represented here in this learned embed stop in. So if you download this,
00:16:19.480 | and then upload it to wherever your notebooks running, I have it here, learned embed stop in,
00:16:23.640 | we can load that in. And you'll see that it's just a dictionary, where we have one key, that's the
00:16:28.840 | name of my new style. And then we have this token embedding 768 numbers. And so now instead of
00:16:35.240 | slipping in the token embedding for cat, we're going to slip in this new embedding, which we've
00:16:39.160 | loaded from the file into this prompt. So a mouse in the style of puppy, tokenize, get my token
00:16:44.360 | embeddings, and then I'm going to slip in this replacement embedding in place of the embedding
00:16:49.640 | for puppy. And when we generate with that, we should hopefully get a mouse in the style of
00:16:55.640 | this kind of cutesy watercolor on rough paper image. And sure enough, that's what we get,
00:17:01.400 | very cute little drawing of a mouse in an apron, apparently. Okay, so very, very cool
00:17:06.120 | application. Again, there's a nice inference notebook that makes this really easy. You can
00:17:11.320 | say a cat toy in the style of burp style, you don't have to worry about manually replacing
00:17:15.400 | the token embeddings yourself. But it's good to know what the code looks like under the hood,
00:17:19.240 | right? How are we doing that? What stage of the text embedding process we're modifying?
00:17:23.320 | Very fun to get a bit of extra control, and a very useful technique, because now we can
00:17:28.200 | kind of augment our model's vocabulary without having to actually retrain the model itself,
00:17:32.840 | we're just learning a new token embedding. It's a very, very powerful idea, and really fun to
00:17:37.000 | play with. And like I said, there's thousands of community contributed tokens, but you can also
00:17:42.120 | train your own, I think I linked the notebook from here, but it's also in all the docs and so on.
00:17:47.000 | Here's the training notebook. Okay, final little trick with embeddings, rather than messing with
00:17:52.840 | them at the token embedding level, we can push the whole prompt through that entire process to
00:17:57.000 | get our final output embeddings, and we can mess with those at that stage as well. So here I have
00:18:01.160 | two prompts, a mouse and a leopard, tokenizing them, encoding them with a text encoder, so that's that
00:18:06.200 | whole process. And these final output embeddings, I'm just going to mix them together according to
00:18:10.680 | some factor, and generate with that. And so you can try this with, you know, a cat and a snake.
00:18:16.280 | And you should be able to get some really fun, different chimeras and oops, a snail apparently.
00:18:25.880 | Okay, well, I can't spell. But yeah, have fun with that, doesn't have to be animals.
00:18:31.240 | I'd love to see what you create with these weird mixed up generations. Okay, we should look at the
00:18:37.960 | actual model itself, the key unit model, the diffusion model. What is it doing? What is it
00:18:43.320 | predicting? What is it accepting as arguments? So this is the kind of call signature, we call
00:18:49.560 | our units forward pass, and we feed in our noisy latency, the timestamp, and it's like the training
00:18:55.480 | timestamp, and the encoder hidden states, right? So those text embeddings that we've just been
00:19:00.200 | having fun with. So doing that without any loops or anything, I'm sitting in my scheduler, getting my
00:19:05.640 | time step, getting my noisy latent, and my text embeddings. And then we're going to get our model
00:19:11.960 | prediction. And you'll look at the shape of that. And you'll see that this prediction has the same
00:19:15.320 | shape as the latency. And given these noisy latency, what the model is predicting is the noise component
00:19:21.960 | of that. And actually, it's predicting the noise component scaled by sigma. So if we wanted to see
00:19:27.560 | what the original image looks like, we could say, well, the de-noise latency is going to be the
00:19:32.280 | current noisy latency minus sigma times the model prediction, right? And so when we de-noising,
00:19:39.480 | we're not going to go straight to that upward prediction, we're going to just remove a little
00:19:43.160 | bit of the noise at a time. But it might be useful to visualize what that final prediction looks like.
00:19:48.040 | So that's what we're doing here, making a folder to store some images, preparing our text scheduler
00:19:53.000 | and input. And then we're going to do this loop. But now we're going to get the model prediction.
00:19:58.200 | And instead of just updating our latency by one step, we're also going to store an image, right,
00:20:03.320 | and decoding these two images, an image of the predicted completely de-noised, like original
00:20:08.200 | sample. So that's this predicted original sample here. You could also calculate this yourself.
00:20:12.120 | Latency zero is equal to the current latency minus sigma times the noise prediction.
00:20:16.600 | All right, so those two should work equivalently. But this loop is going to run, and it's going to
00:20:21.320 | save those images to the steps folder, which we can then visualize. And so once this finishes,
00:20:26.600 | in a second or two, on the left, we're going to see the kind of noisy input to the model at each
00:20:31.560 | stage. And on the right, we're going to see the noisy input minus the noise prediction, right,
00:20:37.160 | so the de-noised version. And so we'll just give it a second or two to run. It's taking it a little
00:20:41.640 | bit longer because it's decoding those images each time, saving them. But once this finishes,
00:20:47.720 | we should have a nice little preview video. Okay, here we go. So this is the noisy latent.
00:21:02.200 | And if we take the model's noise prediction and subtract it from that, we get this very blurry
00:21:06.840 | output. And so you'll see as we play this -- oh, I've left some modifications in from last time,
00:21:12.680 | sorry. When you see this guidance scale, we'll be back at I think it was eight. In the next section,
00:21:20.280 | we'll talk about classifier-free guidance. And so I've been modifying that example. My bad. I might
00:21:25.320 | cut this out of the video. We'll see. So I've got to wait a few seconds again for that to generate.
00:21:30.040 | And I'll do so as patiently as I can.
00:21:32.760 | Okay, so here we go again, the noisy input, the predicted de-noised version. And you can see at
00:21:59.640 | the start, it's very blurry. But over time, it gradually converges on our final output.
00:22:03.960 | And you'll notice that on the left, these are the latents as they are each step. They don't
00:22:10.680 | change particularly drastically a little bit at a time. But at the start, when the model doesn't
00:22:15.400 | have much to go on, its predictions do change quite a bit at each step, right? It's much less
00:22:20.840 | well-defined. And then as we go forward in time, it gets more and more refined, better and better
00:22:25.560 | predictions. And so it's got a more accurate estimation of the noise to remove. And we remove
00:22:31.160 | that noise gradually until we finally get our output. Quite fun to visualize the process. And
00:22:36.440 | hopefully that helps you understand why we don't just make one prediction and do it in one step,
00:22:40.760 | right? Because we get this very blurry mess. But instead, we do this kind of iterative
00:22:44.520 | sampling there, which we'll talk about very shortly. Before then, though, the final thing
00:22:49.720 | I should mention, classifier-free guidance. What is that? Well, like you saw when I accidentally
00:22:54.840 | generated the version with a much lower guidance scale, the way classifier-free guidance works
00:23:01.240 | is that in all of these loops, we haven't actually been passing one set of noisy latents through the
00:23:06.280 | model. We've been passing two identical versions. And as our text embeddings, we've not just been
00:23:13.080 | passing the embeddings of our prompts, right? These ones here, we've been concatenating them
00:23:18.440 | with some unconditional embeddings as well. And what the unconditional embeddings are is just a
00:23:22.600 | blank prompt, right? No text whatsoever. So just all padding passing that through. So when we get
00:23:28.680 | our predictions here, we've given in two sets of latents and two sets of text embeddings, we're
00:23:34.120 | going to get out two predictions for the noise. So we splitting that apart, one prediction for the
00:23:39.800 | unconditional, like no prompt version, and one for the prediction based on the prompt. And so what we
00:23:45.880 | can do now is we can say, well, my final prediction is going to be the unconditional version plus the
00:23:51.160 | guidance scale times the difference, right? So if you think about it, if I predict without the noise,
00:23:55.960 | I'm predicting here. If I predict with the noise, sorry, with the text encoding, with the prompt,
00:24:02.120 | I get this prediction instead. And I'd like to move more in that direction. I'd like to push it
00:24:06.040 | even further towards the prompt version and beyond. So this guidance scale can be larger than one,
00:24:12.440 | to push it even more in that direction. And this, it turns out, is kind of key for getting it to
00:24:16.920 | follow the prompt nicely. And I think it was first brought up in the glide paper. AI Coffee Break on
00:24:23.400 | YouTube has a great video on that. But yeah, really useful trick or really neat hack, depending on who
00:24:28.360 | you talk to. But it does seem to work. And the higher the guidance scale, the more the model will
00:24:33.080 | try and look like the prompts kind of in the extreme versus that lower guidance scale, it might
00:24:38.040 | just try and look like a generic good picture. Okay, we've been hiding away some complexity in
00:24:44.360 | terms of this scheduler dot step function. So I think we're going to step away from the notebook
00:24:48.520 | now and scribble a bit on some paper to try and explain exactly what's going on with sampling and
00:24:53.080 | so on. And then we'll come back to the notebook for one final trick. All right, so here's my take
00:25:00.040 | on sampling. And to start with, I'd like you to imagine the space of all possible images.
00:25:05.480 | So this is a very large high dimensional space for 256 by 256 by three image, that is 200,000
00:25:14.120 | dimensional. And my paper, unfortunately, is only two dimensional. So we're going to have to squish
00:25:19.000 | this down a fair bit and use our imagination. Now, if you just look at a random point in this space,
00:25:25.160 | this is most likely not going to look like anything recognizable, it'll probably just look
00:25:30.280 | like garbled noise. But if we map an image into the space, we'll see that it has some sort of fixed
00:25:36.920 | point. And a very similar image almost pixel equivalent, it's going to be very close by.
00:25:42.760 | Now, there's this theory that you'll hear talked about called manifold theory,
00:25:47.240 | which says that for most real images, like a data set of images, these are going to lie
00:25:52.840 | on some lower dimensional manifold within this higher dimensional space, right? In other words,
00:25:57.800 | if we map a whole bunch of images into the space, they're not going to fill the whole space,
00:26:02.360 | they're going to be kind of clustered onto some surface. Now, I've drawn it as a line here because
00:26:07.000 | we stuck with 2D, but this is a much higher dimensional plane equivalent. Okay, so each of
00:26:12.520 | these ones here is some image. And the reason that I'm starting with this is because we'd like to
00:26:18.840 | generate images, we'd like to generate plausible looking images, not just random nonsense. And so
00:26:24.040 | we'd like to do that with diffusion models. So where did they come in? Well, we can start with
00:26:29.160 | some image here, some real image from our training data. And we can push it away from the manifold
00:26:35.800 | of like plausible existing images by corrupting it somehow. So for example, just adding random noise,
00:26:41.080 | that's equivalent to like moving in some random direction in this space of all possible images.
00:26:45.800 | And so that's going to push the image away. And then we can try and predict using some model,
00:26:52.200 | what this noise looks like, right? How do I go from here back to a plausible image? What is this
00:26:57.880 | noise that's been added? And so that's going to be our big unit that does that prediction,
00:27:02.040 | that's going to be our diffusion model, right? And so that's, in this language, going to be called
00:27:07.320 | something like a score function, right? How do I get from wherever I am? What's the noise that I
00:27:11.480 | need to remove to get back to a plausible image? Okay, so that's all well and good. We can train
00:27:19.240 | this model with a number of examples, because we can just take our training data, add some random
00:27:23.080 | noise, predict, predict, try and predict the noise, update our model parameters. So we can hopefully
00:27:27.400 | learn that function fairly well. Now we'd like to generate with this model, right? So how do we do
00:27:32.360 | that? Well, we can start at some random point, right? Like, let's start over here. And you might
00:27:38.840 | think, well, surely I can just now predict the noise, remove that, and then I get my output image.
00:27:44.520 | And that's great, except that you've got to remember now we're starting from a random point in the
00:27:47.880 | space of all possible images. It just looks like garbled nonsense. And the model's trying to say,
00:27:52.280 | well, what does the noise look like? And so you can imagine here, for training, the first thing
00:27:56.680 | we're training, the further away we get from our examples, the sparser our training will have been.
00:28:00.840 | But also, it's not like it's very obvious how we got to this noisy version, right? We could have
00:28:05.880 | come from this image over here, added a bunch of noise. We could have come from one over here,
00:28:10.600 | one over here. And so this model's not going to be able to make a perfect prediction. At best,
00:28:15.160 | it might say, well, somewhere in that direction, right? It could point towards something like the
00:28:19.960 | dataset mean, or at least the edge that's closer. But it's not going to be able to perfectly give
00:28:24.440 | you one nice solution. And sure enough, that's what we see. If we sample the fusion model system one
00:28:29.960 | step, we get the predictions, look at what that corresponds to as an image, it's just going to
00:28:33.800 | look like a blurry mess, maybe like the mean of the data or, you know, some sort of garbled output,
00:28:38.360 | definitely not going to look like a nice image. So how do we do better? And the idea of sampling
00:28:45.560 | is to say, well, there's a couple of framings. So I'll start with the existing framing that you'll
00:28:50.840 | see talked about a lot of score-based models and so on. And then we'll talk about some other ways
00:28:55.080 | to think about it as well. So this process of gradually corrupting our images away, adding a
00:29:01.720 | little bit of noise at a time, people like to talk of this as a stochastic differential equation.
00:29:06.920 | Stochastic because there's some randomness, right, we're picking random amounts of noise, random
00:29:10.360 | directions to add, and a differential equation because it's not talking about anything absolute,
00:29:15.320 | just how we should change this from moment to moment to get more and more corrupted, right?
00:29:19.240 | So that's why it's a differential equation. And with that framing, the question of, well,
00:29:25.000 | how do I go now back to the image? That's framed as solving an ordinary differential equation that
00:29:30.440 | corresponds to like the reverse of this process. You can't solve ODEs in a single step, but you
00:29:36.920 | can find an approximate solution. And the more sort of sub-steps you take, the better your
00:29:42.200 | approximation. And so that's what these samples are doing, given like, okay, we set this image
00:29:45.960 | over here, here's my prediction, rather than moving the whole way there in one go, we'll
00:29:50.360 | remove some of that noise, right, do a little update, and then we'll get a new prediction,
00:29:56.040 | right? And so maybe now the prediction is slightly better. It says up here. So we move a little bit
00:29:59.480 | in that direction. And now it makes an even better prediction, because as we get closer to the
00:30:03.080 | manifold, right, as we have less and less noise, and more and more of like some image emerging,
00:30:07.480 | the model is able to get more and more accurate predictions. And so in some sort of number of
00:30:11.960 | steps, we divide up this this process, and we get closer and closer and closer until we
00:30:17.240 | ideally find some image that looks very plausible as our output. And so that's what we're doing
00:30:23.240 | here with a lot of these samplers, they're effectively trying to solve this ODE in some
00:30:28.360 | number of steps by, yeah, breaking the process up and only moving a small amount at a time.
00:30:33.240 | Now, you get sort of first order solvers, right, where all we're doing is just linearly moving
00:30:38.920 | within each one. And this is equivalent to something called Euler's method or Euler's method,
00:30:43.480 | if you're like me, and you've only ever read it. And this is what some of the most basic samplers
00:30:47.480 | are doing, just linear approximations for each of these little steps. But you also get additional
00:30:53.560 | approaches. So for example, maybe if we were to make a prediction from here, it might look like
00:30:59.480 | something like this. And if we were to make a prediction from here, it might look like something
00:31:03.640 | like that. So we have our error here. But as you move in that direction, it's also changing,
00:31:09.880 | right? So there's like a derivative of a derivative, a gradient of a gradient. And that's where this
00:31:15.400 | second order solver comes in and says, well, if I know how this prediction changes as I move in this
00:31:21.400 | direction, like what is the derivative of it, then I can kind of account for that curvature when I
00:31:25.960 | make my update step, and maybe know that it's going to curve a bit in that direction. And so that's
00:31:30.760 | where we get things like these so called second order solvers and higher order solvers. The upside
00:31:35.560 | of this is that we can get, you know, do a larger step at a time, because we have a more accurate
00:31:40.520 | prediction, we're not just doing a first order linear approximation, we have this kind of curvature
00:31:45.000 | taken into account. The downside is that to estimate that curvature for a given point, we might need to
00:31:50.680 | call our model multiple times to get multiple estimates. And so that takes time. So we can
00:31:54.600 | take a larger step, but we need more model evaluations per step. A kind of hybrid approach
00:32:00.200 | is to say, well, rather than trying to estimate the curvature here, I might just take a linear step,
00:32:05.800 | look at the next prediction, but I'll keep a history of my previous steps. And so then over here,
00:32:10.280 | it predicts like this. So I have now this history. And I'm going to use that to better guess what
00:32:15.080 | this trajectory is. So I might keep a history of the past, you know, three or four or five predictions,
00:32:20.920 | and know that since they're quite close to each other, maybe that tells me some information about
00:32:24.040 | the curvature here. And I can use that again, take larger steps. And so that's where we see
00:32:28.600 | the so-called linear multi-step sampling coming in, just keeping this buffer of past predictions
00:32:33.800 | to try and do a better job estimating than the simple one-step linear type first order solvers.
00:32:40.200 | Okay, so that's the score-based sampling version. And all of the variance and innovation comes down
00:32:47.160 | to things like, how can we do this in as few steps as possible? Maybe we have a schedule that says we
00:32:51.880 | take larger steps at first and then gradually smaller steps as we get closer. There's, I think,
00:32:56.920 | now some dynamic methods and can we estimate how many steps we need to take, and so on. So that's
00:33:01.720 | all trying to attack it from this kind of score-based ODE solving framework. But there's another way to
00:33:08.360 | think of this as well. And that's to say, okay, well, I don't really care about solving this
00:33:13.560 | exact reverse ODE, right? All I care about is that I end up with an image that's on this manifold,
00:33:19.160 | like a plausible looking image. And so I have a model that estimates how much noise there is,
00:33:25.160 | right? And if that noise is very small, then that means I've got a good image. And if that noise is
00:33:31.000 | really large, then that means I've got some work to do. And so this kind of starts bringing up some
00:33:37.640 | analogies to training neural networks, because in neural networks, we have the space of all possible
00:33:42.520 | parameters. And we're trying to adjust those parameters not to solve the gradient flow equation,
00:33:48.040 | right? Although that's, you know, possible in theory that you might try and do that.
00:33:52.360 | We don't care about that, we just want to find a minima, we want to find a point where our loss is
00:33:55.720 | really good. And so when we're training a neural network, that's exactly what we do. We set up an
00:34:00.040 | optimizer, and we take some number of steps trying to reduce some loss. And once that loss gets sort
00:34:05.880 | of, you know, levels off, right, reduced over time levels off, okay, cool, I guess we found a good
00:34:10.200 | neural network. And so we can apply that same kind of thinking here to say, all right, I'll start at
00:34:15.560 | some point. And I'll have an estimate of the gradient, right, like maybe pointing over here.
00:34:21.400 | But remember, that estimate is not very good, just like the first gradients estimated when
00:34:25.960 | training a neural network are pretty bad, because it's all just these randomly initialized weights,
00:34:29.400 | but hopefully it at least points in a useful direction. So then I'll take some step, and the
00:34:33.640 | length of the step, I won't try and do some fancy schedule, I'll just offload this to an sort of
00:34:38.760 | off the shelf optimizer, right? So I have some learning rate, maybe something like momentum,
00:34:43.000 | that determines how big of a step I take. And then I update my prediction, right, take another step
00:34:49.080 | in that direction, and so on. So now, instead of following a fixed schedule, we can use tricks that
00:34:53.880 | have been developed for training neural networks, right, adaptive learning rates, momentum,
00:34:57.720 | weight decay, and so on. And we can apply them back to this kind of sampling case. And so it turns out
00:35:03.160 | this works okay, I've tried this for stable diffusion, needs some tricks to get it working.
00:35:06.840 | But it's a slightly different way of thinking about sampling, rather than relying on sort of
00:35:11.560 | a hard coded ODE solver that you figured out yourself, just saying, why don't we treat this
00:35:16.040 | like an optimization problem, where if the model predicts almost no noise, that's good, we're doing
00:35:20.840 | a good job. And if the model predicts lots of noise, then we can use that as a gradient, and take a
00:35:26.040 | gradient update step according to our optimizer, and try and sort of converge on a good image as
00:35:30.760 | our output. And this is, you know, you can stop early once your model prediction is sufficiently
00:35:35.800 | low for the amount of noise, okay, cool, I'm done. And so I found, you know, in 10, 15 steps,
00:35:40.360 | you can get some pretty good images out. Yeah, so that's a different way of viewing it.
00:35:45.320 | Not so popular at the moment, but maybe, hopefully something we'll see. Yeah, just a different
00:35:50.040 | framing. And for me, at least that helps me think about what we're actually doing with the samplers,
00:35:54.040 | we try to find a point where the model predicts very little noise. And so starting from a bad
00:35:58.280 | prediction, moving towards it getting better, by looking at this estimated amount of noise
00:36:03.240 | as our sort of gradient and solving that, just kind of iteratively removing bits at a time.
00:36:08.360 | So I hope that helps elucidate the different kinds of samplers, and the goal of that whole thing,
00:36:14.040 | and also illustrate at least why we don't just do this in a single step, right? Why we need some sort
00:36:18.040 | of iterative approach, otherwise, we'd end up with just very bad blurry predictions. All right,
00:36:22.520 | I hope that helps. Now we're going to head back to the notebook to talk about our final trick of
00:36:26.920 | guidance. Okay, the final part of this notebook, guidance, how do we add some extra control to this
00:36:36.680 | generation process, right? So we already have control via the text, and we've seen how we can
00:36:41.080 | modify those embeddings. We have some control via starting at a noisy version of an input image,
00:36:46.200 | rather than pure noise to kind of control the structure. But what if there's something else,
00:36:49.960 | what if we'd like a particular style, or to enforce that the model looks like some input image,
00:36:55.640 | or maybe sticks to some color palette, it would be nice to have some way to add this additional
00:37:00.280 | control. And so the way we do this is to look at some loss function on the decoded denoised
00:37:08.360 | predicted image, right? The predicted denoise needs final output, and use that loss to then update
00:37:15.640 | the noisy latents as we generate in a direction that tries to reduce that loss. So for demo,
00:37:20.840 | we're going to make a very simple loss function. I would like the image to be quite blue. And to
00:37:25.240 | enforce that my error is going to be the difference between the blue channel, right? Red, green, blue,
00:37:29.800 | blue is the third channel of the color channels, and the difference between the blue channel
00:37:34.040 | and 0.9. So the closer all the blue values are to 0.9, the lower my error will be. So that's going
00:37:39.800 | to be my kind of guidance loss. And then during sampling here, what I'm going to do,
00:37:43.880 | everything's going to be the same as before. But every few iterations, you could do it every
00:37:50.280 | iteration, but that's a little slow. So here, every five iterations, I'm going to set requires
00:37:56.040 | grad equals true on the latents. I'm then going to compute my predicted denoised version. I'm
00:38:01.800 | going to decode that into image space, and then I'm going to calculate my loss using my special
00:38:06.120 | blue loss and scale it with some scaling factor. Then I'm going to use torch to find the gradient
00:38:12.920 | of this loss with respect to those latents, those noisy latents. And I'm going to modify them,
00:38:18.200 | right? And I want to reduce the loss. I'm going to subtract here this gradient multiplied by sigma
00:38:23.640 | squared because we're going to be working at different noise levels. And so if we run this,
00:38:28.120 | we should see, hopefully, it's going to do that same sort of sampling process as before,
00:38:32.520 | but we also are occasionally modifying our latents by looking at the gradient of the loss with respect
00:38:37.880 | to those latents and updating them in a direction that reduces that loss. And sure enough, we get a
00:38:42.680 | very nice blue picture out. And if I change the scale here down to something lower and run it,
00:38:48.760 | we'll see that scale is lower. So the loss is lower. So our modifications to the latents
00:38:54.600 | are smaller. We'll see that we get out a much less blue image. There we go. So that's the
00:39:01.880 | default image, very red and dark, because the prompt is just a picture of a campfire.
00:39:06.040 | But as soon as we add our additional loss, our guidance, we're going to get out something that
00:39:13.240 | better matches that additional constraint that we've imposed, right? So this is very useful,
00:39:18.280 | not just for making your images blue, but like I said, color palettes or using some classifier
00:39:22.920 | model to make it look like a specific class of image or using a model like clip to, again,
00:39:28.360 | associate it with some text. So lots and lots of different things you can do. Now, a few things I
00:39:33.000 | should note. One, we decoding the image back to image space, calculating our loss and then tracing
00:39:38.280 | back. That's very computationally intensive compared to just working in latent space. And so
00:39:44.040 | we can do that only every fifth operation to reduce the time, but it still is much slower
00:39:48.840 | than just your generic sampling. And then also, we're actually still cheating a little bit here
00:39:54.440 | because what we should do is set requires grad equals true on the latents and then use those to
00:39:59.800 | make our noise prediction, use that to calculate the denoised version and decode that, calculate
00:40:05.720 | our loss and trace back all the way through the decoder and the process and the unit back to the
00:40:10.680 | latents, right? The reason I'm not doing that is because that takes a lot of memory. So you'll see,
00:40:15.960 | for example, like the clip guided diffusion notebook from the hugging face examples,
00:40:20.760 | they do it that way, but they have to use tricks like gradient checkpointing and so on to kind of
00:40:24.360 | keep the RAM usage under control. And for simple losses, it works fine to do it this way, because
00:40:29.400 | now we just tracing back through denoised latents is equal to latents minus sigma times this noise
00:40:34.280 | prediction, right? So we don't have to trace any gradients back through the unit. But if you wanted
00:40:38.680 | to get more accurate gradients, maybe it's not working as well as you'd hoped, you can do it that
00:40:43.240 | other way that I described. But however you do it, very, very powerful technique,
00:40:47.480 | fun to be able to again inject some additional control into this generation process by crafting
00:40:52.440 | a loss that expresses exactly what you'd like to see. All right, that's the end of the notebook
00:40:57.560 | for now. If you have any questions, feel free to reach out to me, I'll be on the forums and
00:41:01.800 | you can find me on Twitter and so on. But for now, enjoy and I can't wait to see what you make.