back to indexLesson 9A 2022 - Stable Diffusion deep dive
Chapters
0:0 Introduction
0:40 Replicating the sampling loop
1:17 The Auto-Encoder
3:55 Adding Noise and image-to-image
8:43 The Text Encoding Process
15:15 Textual Inversion
18:36 The UNET and classifier free guidance
24:41 Sampling explanation
36:30 Additional guidance
00:00:00.000 |
Hello, everyone. My name is Jonathan, and today I'm going to be taking you through this 00:00:06.540 |
stable diffusion deep dive notebook, looking into the code behind the kind of popular high 00:00:13.240 |
level APIs and libraries and tools and so on to see what exactly does the generation 00:00:18.000 |
process look like and how can we modify that? How do each of the individual components work? 00:00:23.640 |
So feel free to run along with me. If you haven't before, this might take a little while 00:00:28.880 |
to run just because it's downloading these large models. If they aren't already downloaded 00:00:32.720 |
and loading them up. So we're going to start by just kind of recreating what it looks like 00:00:37.640 |
to generate an image using, say, one of the existing pipelines and hugging face. So we're 00:00:42.500 |
going to basically have copied the code from the core method of the default stable diffusion 00:00:47.920 |
pipeline. So if you go and view that here, you'll see that we're going to be basically 00:00:52.000 |
replicating this code. But now we'll be doing it on our own sort of notebook, and then we'll 00:00:58.760 |
slowly understand what each of these different parts is doing. So we've got some setup, we've 00:01:03.080 |
got some sort of loop running through a number of sampling timestamps, and we're generating 00:01:06.920 |
an image. So this is supposed to be a watercolor picture of an otter. And it's very, very cool 00:01:11.760 |
that this model can just do this. But now we want to know how does that actually work? 00:01:15.960 |
What's going on? So the first component is the autoencoder. Now this stable diffusion 00:01:21.200 |
is a latent diffusion model. And what that means is that it doesn't operate on pixels, 00:01:25.480 |
it operates in the latent space of some other autoencoder model. In this case, a variational 00:01:31.240 |
autoencoder that's been trained on a large number of images to compress them down into 00:01:35.440 |
this latent representation and bring them back up again. So I have some functions to 00:01:39.400 |
do that. We're going to look at what it's like in action, just downloading a picture 00:01:43.480 |
from the internet, opening it up with PIL. So we have this 512 by 512 pixel image, and 00:01:50.240 |
we're going to load it in. And then we're going to use our function defined above to 00:01:54.520 |
encode that into some latent representation. And what this is doing is calling the VAE.encode 00:01:59.840 |
method on a tensor version of the image. And that gives us a distribution. And so we sampling 00:02:04.440 |
from that distribution and we scaling by this because that's what the authors of the model 00:02:09.320 |
did. They scaled the latents down before they fed them to the model. And so we have to do 00:02:14.400 |
that scaling and then the reverse when we decoding just to be consistent with that. 00:02:19.440 |
But the key idea is that we go from this big image down to this 4 by 64 by 64 latent representation. 00:02:26.360 |
So we've gone from this much larger image down. And if we visualize what the four channels 00:02:30.440 |
here, this four different 64 by 64 channels, what that looks like, we'll see that it's 00:02:35.560 |
capturing something of the image. You can sort of see the same shapes and things there. 00:02:39.880 |
But it's not quite a direct mapping or anything. For example, there's this weirdness going 00:02:43.440 |
on the beak. Some of the channels look slightly stranger than the others. So there's some 00:02:47.940 |
sort of rich information captured there. And if we decode this back, what we'll see is 00:02:53.160 |
that the decoded image looks really good. You really have to look closely to tell the 00:02:57.480 |
difference between our input image here and the decoded version. So very, very impressive 00:03:02.800 |
compression, right? This is a factor of 8 in each dimension. So 512 by 512 down to 64 00:03:09.320 |
by 64. It's like a factor of 64 reduction in data, but it's still somehow capturing 00:03:16.220 |
most of that information. It's a very information-rich representation. And this is going to be great 00:03:20.960 |
because now we can work with that with our diffusion model and get nice high resolution 00:03:25.520 |
results even though we're only working with these 64 by 64 latents. Now, it doesn't have 00:03:30.640 |
to be 64 by 64. You can go and modify this to say what if this is 640 and encode that 00:03:37.920 |
down and you'll see that it's just that same factor of 8 reduction. And there we go. Now 00:03:44.320 |
we have 80 by 64. This just has to be a multiple of 8. Otherwise, you'll get, I think, an error. 00:03:51.000 |
Okay. So we have our encoded version of this image, and that's pretty great. The next component 00:03:56.600 |
we're going to look at is the scheduler, and I'll look more closely at this later. But 00:04:01.240 |
for now, we're going to focus on this idea of adding noise, right? So during training, 00:04:05.040 |
we add some noise to an image, and then the model tries to predict what that noise is, 00:04:09.280 |
and we're going to do that to different amounts. So here we're going to recreate the same type 00:04:13.960 |
of schedule, and you can try different schedulers from the library. Oops. And these parameters 00:04:18.960 |
here, beta start, beta end, beta schedule, that's how much noise was added at different 00:04:23.000 |
time steps and how many time steps are used during training. For sampling, we don't want 00:04:28.080 |
to have to do a thousand steps, so we can set a new number of time steps, and then we'll 00:04:32.760 |
see how these correspond with the scheduler.timesteps attribute to the original training time steps. 00:04:41.360 |
So here we're going to have 15 sampling steps, and that's going to be equivalent to starting 00:04:45.360 |
at time step 999 and just moving linearly down to time steps zero. We can also look 00:04:51.480 |
at the actual amount of noise present with the sigma's attribute. So again, starting 00:04:55.680 |
high, moving down, and if you want to see what that schedule looks like, we can plot 00:05:00.240 |
that here. And if you want to see the time steps, you'll see that it's just a linear 00:05:04.000 |
relationship. So there we go. We're going to start at a very high noise value, and we're 00:05:10.000 |
going to slowly, slowly try and reduce this down until ideally we get an image out. 00:05:14.440 |
Okay, so the sigma is the amount of noise added. Let's see what that looks like. So 00:05:19.440 |
I'm going to start with some random noise that's the same shape as my latent representation, 00:05:23.640 |
my encoded image, and then I'd like to be equivalent to sampling step 10 out of 15 00:05:29.200 |
here. So I'm going to go and look up what time step that equates to, and that's going 00:05:33.920 |
to be one of the arguments that I passed the scheduler.addnoise function. So I'm calling 00:05:38.760 |
scheduler.addnoise, giving it my encoded image, the noise, and what time step I'd like to 00:05:44.280 |
be noising equivalent to. And this is going to give me this noisy but still recognisable 00:05:48.680 |
version of the image. And you can go and say, okay, what if I look at somewhere earlier 00:05:52.520 |
in the process, right? Does it look more noisy? What about right at the beginning, right at 00:05:56.680 |
the end? Feel free to play around there. Okay, so this adding noise, what are we actually 00:06:02.760 |
doing? What does the code look like? Let's inspect the function, and you'll see that 00:06:06.600 |
there's some set up for different types of argument and shapes. But the key line is just 00:06:11.240 |
this noisy samples is equal to original samples plus the noise scaled by the sigma parameter. 00:06:17.480 |
All right, so that's all it is. It's not always the same. Different papers and implementations 00:06:21.960 |
will add the noise slightly differently. But in this case, that's all it's doing. So scheduler.addnoise, 00:06:27.320 |
just adding noise that's the same shape as the latency scaled by the sigma parameter. 00:06:32.360 |
Okay, so that's what we're doing. So if we want to start from random noise instead of 00:06:36.840 |
a noisy image, we're going to scale it by that same sigma value so that it looks the 00:06:40.760 |
same as an image that's been scaled by that amount. But then before we feed that to the 00:06:44.960 |
actual model, we then have to handle that scaling again. You could do it like this, 00:06:49.280 |
but now we have this scale model input function associated with the scheduler just to hide 00:06:54.760 |
that complexity away. Okay, so now we're going to look at the same kind of sampling loop 00:06:59.080 |
as before. But we're going to start now with our image, we're going to take our encoded 00:07:03.080 |
image, we're going to noise it to some time set, and then we're only going to denoise 00:07:06.840 |
from there. So in code, we are now preparing our text and everything the same as before, 00:07:11.320 |
which we'll look at, we setting our number of inference steps to 50, right, number inference 00:07:15.560 |
steps is equal to 50 here. And we're saying I'd like to start at the equivalent of step 00:07:19.240 |
10 out of 50. So I'll look up what time step that equates to, I'll add noise to my image, 00:07:26.040 |
equivalent to that step. And then we're going to run through sampling, but this time we're 00:07:30.440 |
only going to start doing things once we get above that start step. So I'm going to ignore 00:07:34.360 |
the first 10 out of 50 steps. And then beyond that, I'm now going to start with this noisy 00:07:38.920 |
version of my input image. And I'm going to denoise it according to this prompt. And the 00:07:43.160 |
hope here is that by starting from something that has some of the sort of rough structure and color 00:07:48.040 |
of that input image, I can kind of fix that into my generation. But I've got a new prompt, 00:07:53.320 |
a National Geographic photo of a colorful dancer. And here we go, we see this is the same sort of 00:07:58.040 |
thing as the parrot. But now we have this completely different actual content thanks to 00:08:02.200 |
a different prompt. And so that's a fun kind of use of this image to image process. You might have 00:08:07.880 |
seen this for taking drawings, adding a bunch of noise and then denoising them into fancy paintings 00:08:11.880 |
and so on. So again, this is something that there's existing tools for this, right, the 00:08:16.920 |
strength parameter and the image to image pipeline. That's just something like this, what step are we 00:08:23.560 |
starting at? How many steps are we skipping? But you can see that this is a pretty powerful 00:08:29.000 |
technique for getting a bit of extra control over like composition and color and a bit of the 00:08:33.720 |
structure. Okay, so that's that trick with adding noise and then using that as image to image. 00:08:39.560 |
The next big section I'd like to look at is how do we go from a piece of text that describes what 00:08:46.120 |
we want into a numerical representation that we can feed to the model. So we're going to trace 00:08:52.360 |
out that pipeline. And along the way, we'll see how we can modify that for a bit of fun. 00:08:56.920 |
So step number one, we're taking our prompt and returning it into a sequence of discrete tokens. 00:09:03.640 |
So here we have, in this case, 77, because that's the maximum length, discrete tokens, 00:09:09.640 |
it's always going to be that if your prompt is longer, it'll truncate it. And if we decode these 00:09:13.800 |
tokens back, we'll see that we have a special token for the start of the text, then a picture of a 00:09:19.880 |
puppy. And then the rest is all the same token, which is this kind of end of text padding token. 00:09:24.680 |
Right, so we have this special token for puppy. This special token has its own meaning end of 00:09:30.600 |
text. And the prompts are always going to be padded to be the same length. So before, in the 00:09:37.000 |
code that we were using there, we always jump straight to the circle output embeddings, 00:09:41.560 |
which is what we fed to the model as conditioning. And so somehow this captures some information 00:09:46.120 |
about this prompt. And but now we want to say, well, how do we get there? How do we get from this 00:09:50.520 |
sequence of tokens to these output embeddings? What is this text encoder forward pass doing? 00:09:57.080 |
Right, so we can look at this, and there's going to be multiple steps. And the first is going to 00:10:01.960 |
be some embeddings. So if we look at the text encoder dot text model dot embeddings, we'll see 00:10:06.200 |
there's a couple of different ones, we have token embeddings, right? And so this is to take those 00:10:10.600 |
individual tokens, token 408, or whatever, and map it into a representation that's a numerical 00:10:18.760 |
representation. So here it's a learned embedding. There are about 50,000 rows, one for each token. 00:10:25.800 |
And for each token, we have 768 values. So that's the embedding of that token. And if we want to 00:10:31.160 |
feed one in and see what the embedding looks like, here's the token for puppy. And here's the token 00:10:35.640 |
embedding, right? 768 numbers that somehow capture that meaning of that token on its own. And we can 00:10:42.280 |
do the same for all of the tokens in our prompt. So we feed them through this token embedding layer. 00:10:46.360 |
And now we get 77 768 dimensional representations of this of each token. Now, these are all on their 00:10:55.640 |
own. And no matter where in the sentence is, it is the token embedding will be the same. So the next 00:11:01.080 |
step is to add some positional information. Some models will do this with some kind of like learned 00:11:05.560 |
pattern of positioning. But in this case, the positional embedding is just another learned 00:11:09.800 |
embedding. But now instead of having one embedding for every token, we have one embedding for every 00:11:14.840 |
position out of all 77 possible positions. And so just like we did for the tokens, we can feed them 00:11:20.280 |
the position IDs, one for every possible position, and we'll get back out an embedding for every 00:11:26.520 |
position in the prompt. And combining them together, there's again, multiple ways people do this in the 00:11:32.360 |
literature. But in this case, it's as simple as adding them. That's why they made them the same 00:11:36.360 |
shape, so that you can just add the two together. And now, these input embeddings have some 00:11:41.560 |
information related to the token and some related to the position. And so so far, we haven't seen 00:11:47.000 |
any big model just to learn embeddings, but this is getting everything ready to feed through that 00:11:51.240 |
model. And so we can check that this is the same as if we just called the embeddings layer of that 00:11:56.680 |
model, which is going to do both of those steps at once. And but we'll see just now why we want to 00:12:01.240 |
separate that out into individual ones. Okay, so we have these individual tokens, and they have some 00:12:07.080 |
positional information, we have these final embeddings. Now we'd like to turn them into something 00:12:11.000 |
that has a richer representation, thanks to some big transformer model. And so we're going to feed 00:12:15.720 |
these through. And I made this little diagram here, each token is going to turn into a token 00:12:20.120 |
embedding combined with the positional embedding. And then it's going to get fed through this 00:12:24.760 |
transformer encoder, which is just a stack of these blocks. And so each block has some magic like 00:12:29.960 |
attention has some feed forward components, there's additions and normalizations and skips and so on 00:12:35.320 |
as well. And but we're going to have some number of these blocks all stacked together, and the 00:12:39.480 |
outputs of each one get fed into the next block and so on. And so we get our final set of hidden 00:12:44.120 |
states, these encoder hidden states, aka the output embeddings. And this is what we feed to our unit 00:12:50.200 |
to make its predictions. So the way we get this, I just copied the text encoded up text model forward 00:12:56.200 |
method, pulled out the relevant bits, we are going to take in those input embeddings, combined 00:13:00.360 |
positional and token embeddings, and we're going to feed that through the text model dot encoder 00:13:05.000 |
function with some additional parameters around attention masking and telling it that we'd like 00:13:10.520 |
to output the hidden states rather than the final outputs. So if we run this, we can just double 00:13:16.040 |
check, these embeddings are going to look just like the output embeddings we saw right at the 00:13:21.240 |
beginning. So we've taken that one step, tokens to output embeddings, and we've broken it down into 00:13:26.040 |
this number of smaller steps where we have tokenization, getting our token embeddings, 00:13:30.120 |
combining with position embeddings, feeding it through the model, and then that gives us those 00:13:33.720 |
final outputs. So why have we gone through this problem trouble? Well, there's a couple of things 00:13:38.520 |
we can do. One demo here, I'm getting the token embeddings, but then I'm looking up where is the 00:13:46.120 |
token for puppy, and I'm going to replace it with a new set of embeddings. And this is going to be 00:13:51.080 |
another just learned embedding of this particular token here, 2368. So I'm kind of cutting out the 00:13:57.160 |
token embedding for puppy, slipping in this new set of token embeddings, and I'm going to get some 00:14:01.640 |
output embeddings which at the start look very similar to the previous ones, in fact identical. 00:14:05.640 |
But as soon as you get past the position of puppy in that prompt, you're going to see that the rest 00:14:10.440 |
have changed. So we've somehow messed with these embeddings by slipping in this new token embedding 00:14:15.400 |
right at the start. And if we generate with those embeddings, which is what this function is doing, 00:14:20.600 |
we should see something other than a puppy. And sure enough, drum roll, we don't, we get a cat. 00:14:27.320 |
And so now you know what token 2368 means. We've managed to slip in a new token embedding and get 00:14:33.880 |
a different image. Okay, what can we do with this? Why is this fun? Well, a couple of tricks. First 00:14:38.840 |
off, we could look up the token embedding for skunk, right, which is this number here. And then 00:14:44.360 |
instead of now just replacing that in place of puppy, what if I make a new token embedding 00:14:50.120 |
that's some combination of the embedding of puppy and the embedding of skunk, right? 00:14:54.680 |
So I'm taking these two token embeddings, I'm just averaging them, and I'm inserting them into my 00:14:59.720 |
set of token embeddings for my prompt in place of just the word puppy. And so hopefully when we 00:15:05.720 |
generate with this, we get something that looks a bit like a puppy, a bit like a skunk. And this 00:15:10.440 |
doesn't work all the time, but it's pretty cute when it does. There we go, puppy skunk hybrid. 00:15:15.400 |
Okay, so that's not the real reason we're looking at this. The main application at the moment of 00:15:20.280 |
being able to mess with these token embeddings is to be able to do something called textual 00:15:24.120 |
inversion. So in textual inversion, we're going to have our prompt tokenize it and so on. But here 00:15:29.560 |
we're going to have a special learned embedding for some new concept, right? And so the way that's 00:15:35.000 |
trained is going to be outside of the scope of this notebook. But there's a good blog post and 00:15:40.520 |
community notebooks and things for doing that. But let's just see this in application here. So 00:15:45.080 |
there's a whole library of these concepts, stable diffusion concept library, where you can browse 00:15:51.160 |
through tons and tons and tons of look over 1,400 different community contributed token embeddings 00:15:59.240 |
that people have trained. And so I'm going to use this one here, this bird style. Here's some example 00:16:04.120 |
outputs. And then these are the images it was trained on. So these pretty little bird paintings 00:16:08.440 |
done by my mother. And I've trained a new token embedding that tries to capture the essence of 00:16:14.360 |
the style. And that's represented here in this learned embed stop in. So if you download this, 00:16:19.480 |
and then upload it to wherever your notebooks running, I have it here, learned embed stop in, 00:16:23.640 |
we can load that in. And you'll see that it's just a dictionary, where we have one key, that's the 00:16:28.840 |
name of my new style. And then we have this token embedding 768 numbers. And so now instead of 00:16:35.240 |
slipping in the token embedding for cat, we're going to slip in this new embedding, which we've 00:16:39.160 |
loaded from the file into this prompt. So a mouse in the style of puppy, tokenize, get my token 00:16:44.360 |
embeddings, and then I'm going to slip in this replacement embedding in place of the embedding 00:16:49.640 |
for puppy. And when we generate with that, we should hopefully get a mouse in the style of 00:16:55.640 |
this kind of cutesy watercolor on rough paper image. And sure enough, that's what we get, 00:17:01.400 |
very cute little drawing of a mouse in an apron, apparently. Okay, so very, very cool 00:17:06.120 |
application. Again, there's a nice inference notebook that makes this really easy. You can 00:17:11.320 |
say a cat toy in the style of burp style, you don't have to worry about manually replacing 00:17:15.400 |
the token embeddings yourself. But it's good to know what the code looks like under the hood, 00:17:19.240 |
right? How are we doing that? What stage of the text embedding process we're modifying? 00:17:23.320 |
Very fun to get a bit of extra control, and a very useful technique, because now we can 00:17:28.200 |
kind of augment our model's vocabulary without having to actually retrain the model itself, 00:17:32.840 |
we're just learning a new token embedding. It's a very, very powerful idea, and really fun to 00:17:37.000 |
play with. And like I said, there's thousands of community contributed tokens, but you can also 00:17:42.120 |
train your own, I think I linked the notebook from here, but it's also in all the docs and so on. 00:17:47.000 |
Here's the training notebook. Okay, final little trick with embeddings, rather than messing with 00:17:52.840 |
them at the token embedding level, we can push the whole prompt through that entire process to 00:17:57.000 |
get our final output embeddings, and we can mess with those at that stage as well. So here I have 00:18:01.160 |
two prompts, a mouse and a leopard, tokenizing them, encoding them with a text encoder, so that's that 00:18:06.200 |
whole process. And these final output embeddings, I'm just going to mix them together according to 00:18:10.680 |
some factor, and generate with that. And so you can try this with, you know, a cat and a snake. 00:18:16.280 |
And you should be able to get some really fun, different chimeras and oops, a snail apparently. 00:18:25.880 |
Okay, well, I can't spell. But yeah, have fun with that, doesn't have to be animals. 00:18:31.240 |
I'd love to see what you create with these weird mixed up generations. Okay, we should look at the 00:18:37.960 |
actual model itself, the key unit model, the diffusion model. What is it doing? What is it 00:18:43.320 |
predicting? What is it accepting as arguments? So this is the kind of call signature, we call 00:18:49.560 |
our units forward pass, and we feed in our noisy latency, the timestamp, and it's like the training 00:18:55.480 |
timestamp, and the encoder hidden states, right? So those text embeddings that we've just been 00:19:00.200 |
having fun with. So doing that without any loops or anything, I'm sitting in my scheduler, getting my 00:19:05.640 |
time step, getting my noisy latent, and my text embeddings. And then we're going to get our model 00:19:11.960 |
prediction. And you'll look at the shape of that. And you'll see that this prediction has the same 00:19:15.320 |
shape as the latency. And given these noisy latency, what the model is predicting is the noise component 00:19:21.960 |
of that. And actually, it's predicting the noise component scaled by sigma. So if we wanted to see 00:19:27.560 |
what the original image looks like, we could say, well, the de-noise latency is going to be the 00:19:32.280 |
current noisy latency minus sigma times the model prediction, right? And so when we de-noising, 00:19:39.480 |
we're not going to go straight to that upward prediction, we're going to just remove a little 00:19:43.160 |
bit of the noise at a time. But it might be useful to visualize what that final prediction looks like. 00:19:48.040 |
So that's what we're doing here, making a folder to store some images, preparing our text scheduler 00:19:53.000 |
and input. And then we're going to do this loop. But now we're going to get the model prediction. 00:19:58.200 |
And instead of just updating our latency by one step, we're also going to store an image, right, 00:20:03.320 |
and decoding these two images, an image of the predicted completely de-noised, like original 00:20:08.200 |
sample. So that's this predicted original sample here. You could also calculate this yourself. 00:20:12.120 |
Latency zero is equal to the current latency minus sigma times the noise prediction. 00:20:16.600 |
All right, so those two should work equivalently. But this loop is going to run, and it's going to 00:20:21.320 |
save those images to the steps folder, which we can then visualize. And so once this finishes, 00:20:26.600 |
in a second or two, on the left, we're going to see the kind of noisy input to the model at each 00:20:31.560 |
stage. And on the right, we're going to see the noisy input minus the noise prediction, right, 00:20:37.160 |
so the de-noised version. And so we'll just give it a second or two to run. It's taking it a little 00:20:41.640 |
bit longer because it's decoding those images each time, saving them. But once this finishes, 00:20:47.720 |
we should have a nice little preview video. Okay, here we go. So this is the noisy latent. 00:21:02.200 |
And if we take the model's noise prediction and subtract it from that, we get this very blurry 00:21:06.840 |
output. And so you'll see as we play this -- oh, I've left some modifications in from last time, 00:21:12.680 |
sorry. When you see this guidance scale, we'll be back at I think it was eight. In the next section, 00:21:20.280 |
we'll talk about classifier-free guidance. And so I've been modifying that example. My bad. I might 00:21:25.320 |
cut this out of the video. We'll see. So I've got to wait a few seconds again for that to generate. 00:21:32.760 |
Okay, so here we go again, the noisy input, the predicted de-noised version. And you can see at 00:21:59.640 |
the start, it's very blurry. But over time, it gradually converges on our final output. 00:22:03.960 |
And you'll notice that on the left, these are the latents as they are each step. They don't 00:22:10.680 |
change particularly drastically a little bit at a time. But at the start, when the model doesn't 00:22:15.400 |
have much to go on, its predictions do change quite a bit at each step, right? It's much less 00:22:20.840 |
well-defined. And then as we go forward in time, it gets more and more refined, better and better 00:22:25.560 |
predictions. And so it's got a more accurate estimation of the noise to remove. And we remove 00:22:31.160 |
that noise gradually until we finally get our output. Quite fun to visualize the process. And 00:22:36.440 |
hopefully that helps you understand why we don't just make one prediction and do it in one step, 00:22:40.760 |
right? Because we get this very blurry mess. But instead, we do this kind of iterative 00:22:44.520 |
sampling there, which we'll talk about very shortly. Before then, though, the final thing 00:22:49.720 |
I should mention, classifier-free guidance. What is that? Well, like you saw when I accidentally 00:22:54.840 |
generated the version with a much lower guidance scale, the way classifier-free guidance works 00:23:01.240 |
is that in all of these loops, we haven't actually been passing one set of noisy latents through the 00:23:06.280 |
model. We've been passing two identical versions. And as our text embeddings, we've not just been 00:23:13.080 |
passing the embeddings of our prompts, right? These ones here, we've been concatenating them 00:23:18.440 |
with some unconditional embeddings as well. And what the unconditional embeddings are is just a 00:23:22.600 |
blank prompt, right? No text whatsoever. So just all padding passing that through. So when we get 00:23:28.680 |
our predictions here, we've given in two sets of latents and two sets of text embeddings, we're 00:23:34.120 |
going to get out two predictions for the noise. So we splitting that apart, one prediction for the 00:23:39.800 |
unconditional, like no prompt version, and one for the prediction based on the prompt. And so what we 00:23:45.880 |
can do now is we can say, well, my final prediction is going to be the unconditional version plus the 00:23:51.160 |
guidance scale times the difference, right? So if you think about it, if I predict without the noise, 00:23:55.960 |
I'm predicting here. If I predict with the noise, sorry, with the text encoding, with the prompt, 00:24:02.120 |
I get this prediction instead. And I'd like to move more in that direction. I'd like to push it 00:24:06.040 |
even further towards the prompt version and beyond. So this guidance scale can be larger than one, 00:24:12.440 |
to push it even more in that direction. And this, it turns out, is kind of key for getting it to 00:24:16.920 |
follow the prompt nicely. And I think it was first brought up in the glide paper. AI Coffee Break on 00:24:23.400 |
YouTube has a great video on that. But yeah, really useful trick or really neat hack, depending on who 00:24:28.360 |
you talk to. But it does seem to work. And the higher the guidance scale, the more the model will 00:24:33.080 |
try and look like the prompts kind of in the extreme versus that lower guidance scale, it might 00:24:38.040 |
just try and look like a generic good picture. Okay, we've been hiding away some complexity in 00:24:44.360 |
terms of this scheduler dot step function. So I think we're going to step away from the notebook 00:24:48.520 |
now and scribble a bit on some paper to try and explain exactly what's going on with sampling and 00:24:53.080 |
so on. And then we'll come back to the notebook for one final trick. All right, so here's my take 00:25:00.040 |
on sampling. And to start with, I'd like you to imagine the space of all possible images. 00:25:05.480 |
So this is a very large high dimensional space for 256 by 256 by three image, that is 200,000 00:25:14.120 |
dimensional. And my paper, unfortunately, is only two dimensional. So we're going to have to squish 00:25:19.000 |
this down a fair bit and use our imagination. Now, if you just look at a random point in this space, 00:25:25.160 |
this is most likely not going to look like anything recognizable, it'll probably just look 00:25:30.280 |
like garbled noise. But if we map an image into the space, we'll see that it has some sort of fixed 00:25:36.920 |
point. And a very similar image almost pixel equivalent, it's going to be very close by. 00:25:42.760 |
Now, there's this theory that you'll hear talked about called manifold theory, 00:25:47.240 |
which says that for most real images, like a data set of images, these are going to lie 00:25:52.840 |
on some lower dimensional manifold within this higher dimensional space, right? In other words, 00:25:57.800 |
if we map a whole bunch of images into the space, they're not going to fill the whole space, 00:26:02.360 |
they're going to be kind of clustered onto some surface. Now, I've drawn it as a line here because 00:26:07.000 |
we stuck with 2D, but this is a much higher dimensional plane equivalent. Okay, so each of 00:26:12.520 |
these ones here is some image. And the reason that I'm starting with this is because we'd like to 00:26:18.840 |
generate images, we'd like to generate plausible looking images, not just random nonsense. And so 00:26:24.040 |
we'd like to do that with diffusion models. So where did they come in? Well, we can start with 00:26:29.160 |
some image here, some real image from our training data. And we can push it away from the manifold 00:26:35.800 |
of like plausible existing images by corrupting it somehow. So for example, just adding random noise, 00:26:41.080 |
that's equivalent to like moving in some random direction in this space of all possible images. 00:26:45.800 |
And so that's going to push the image away. And then we can try and predict using some model, 00:26:52.200 |
what this noise looks like, right? How do I go from here back to a plausible image? What is this 00:26:57.880 |
noise that's been added? And so that's going to be our big unit that does that prediction, 00:27:02.040 |
that's going to be our diffusion model, right? And so that's, in this language, going to be called 00:27:07.320 |
something like a score function, right? How do I get from wherever I am? What's the noise that I 00:27:11.480 |
need to remove to get back to a plausible image? Okay, so that's all well and good. We can train 00:27:19.240 |
this model with a number of examples, because we can just take our training data, add some random 00:27:23.080 |
noise, predict, predict, try and predict the noise, update our model parameters. So we can hopefully 00:27:27.400 |
learn that function fairly well. Now we'd like to generate with this model, right? So how do we do 00:27:32.360 |
that? Well, we can start at some random point, right? Like, let's start over here. And you might 00:27:38.840 |
think, well, surely I can just now predict the noise, remove that, and then I get my output image. 00:27:44.520 |
And that's great, except that you've got to remember now we're starting from a random point in the 00:27:47.880 |
space of all possible images. It just looks like garbled nonsense. And the model's trying to say, 00:27:52.280 |
well, what does the noise look like? And so you can imagine here, for training, the first thing 00:27:56.680 |
we're training, the further away we get from our examples, the sparser our training will have been. 00:28:00.840 |
But also, it's not like it's very obvious how we got to this noisy version, right? We could have 00:28:05.880 |
come from this image over here, added a bunch of noise. We could have come from one over here, 00:28:10.600 |
one over here. And so this model's not going to be able to make a perfect prediction. At best, 00:28:15.160 |
it might say, well, somewhere in that direction, right? It could point towards something like the 00:28:19.960 |
dataset mean, or at least the edge that's closer. But it's not going to be able to perfectly give 00:28:24.440 |
you one nice solution. And sure enough, that's what we see. If we sample the fusion model system one 00:28:29.960 |
step, we get the predictions, look at what that corresponds to as an image, it's just going to 00:28:33.800 |
look like a blurry mess, maybe like the mean of the data or, you know, some sort of garbled output, 00:28:38.360 |
definitely not going to look like a nice image. So how do we do better? And the idea of sampling 00:28:45.560 |
is to say, well, there's a couple of framings. So I'll start with the existing framing that you'll 00:28:50.840 |
see talked about a lot of score-based models and so on. And then we'll talk about some other ways 00:28:55.080 |
to think about it as well. So this process of gradually corrupting our images away, adding a 00:29:01.720 |
little bit of noise at a time, people like to talk of this as a stochastic differential equation. 00:29:06.920 |
Stochastic because there's some randomness, right, we're picking random amounts of noise, random 00:29:10.360 |
directions to add, and a differential equation because it's not talking about anything absolute, 00:29:15.320 |
just how we should change this from moment to moment to get more and more corrupted, right? 00:29:19.240 |
So that's why it's a differential equation. And with that framing, the question of, well, 00:29:25.000 |
how do I go now back to the image? That's framed as solving an ordinary differential equation that 00:29:30.440 |
corresponds to like the reverse of this process. You can't solve ODEs in a single step, but you 00:29:36.920 |
can find an approximate solution. And the more sort of sub-steps you take, the better your 00:29:42.200 |
approximation. And so that's what these samples are doing, given like, okay, we set this image 00:29:45.960 |
over here, here's my prediction, rather than moving the whole way there in one go, we'll 00:29:50.360 |
remove some of that noise, right, do a little update, and then we'll get a new prediction, 00:29:56.040 |
right? And so maybe now the prediction is slightly better. It says up here. So we move a little bit 00:29:59.480 |
in that direction. And now it makes an even better prediction, because as we get closer to the 00:30:03.080 |
manifold, right, as we have less and less noise, and more and more of like some image emerging, 00:30:07.480 |
the model is able to get more and more accurate predictions. And so in some sort of number of 00:30:11.960 |
steps, we divide up this this process, and we get closer and closer and closer until we 00:30:17.240 |
ideally find some image that looks very plausible as our output. And so that's what we're doing 00:30:23.240 |
here with a lot of these samplers, they're effectively trying to solve this ODE in some 00:30:28.360 |
number of steps by, yeah, breaking the process up and only moving a small amount at a time. 00:30:33.240 |
Now, you get sort of first order solvers, right, where all we're doing is just linearly moving 00:30:38.920 |
within each one. And this is equivalent to something called Euler's method or Euler's method, 00:30:43.480 |
if you're like me, and you've only ever read it. And this is what some of the most basic samplers 00:30:47.480 |
are doing, just linear approximations for each of these little steps. But you also get additional 00:30:53.560 |
approaches. So for example, maybe if we were to make a prediction from here, it might look like 00:30:59.480 |
something like this. And if we were to make a prediction from here, it might look like something 00:31:03.640 |
like that. So we have our error here. But as you move in that direction, it's also changing, 00:31:09.880 |
right? So there's like a derivative of a derivative, a gradient of a gradient. And that's where this 00:31:15.400 |
second order solver comes in and says, well, if I know how this prediction changes as I move in this 00:31:21.400 |
direction, like what is the derivative of it, then I can kind of account for that curvature when I 00:31:25.960 |
make my update step, and maybe know that it's going to curve a bit in that direction. And so that's 00:31:30.760 |
where we get things like these so called second order solvers and higher order solvers. The upside 00:31:35.560 |
of this is that we can get, you know, do a larger step at a time, because we have a more accurate 00:31:40.520 |
prediction, we're not just doing a first order linear approximation, we have this kind of curvature 00:31:45.000 |
taken into account. The downside is that to estimate that curvature for a given point, we might need to 00:31:50.680 |
call our model multiple times to get multiple estimates. And so that takes time. So we can 00:31:54.600 |
take a larger step, but we need more model evaluations per step. A kind of hybrid approach 00:32:00.200 |
is to say, well, rather than trying to estimate the curvature here, I might just take a linear step, 00:32:05.800 |
look at the next prediction, but I'll keep a history of my previous steps. And so then over here, 00:32:10.280 |
it predicts like this. So I have now this history. And I'm going to use that to better guess what 00:32:15.080 |
this trajectory is. So I might keep a history of the past, you know, three or four or five predictions, 00:32:20.920 |
and know that since they're quite close to each other, maybe that tells me some information about 00:32:24.040 |
the curvature here. And I can use that again, take larger steps. And so that's where we see 00:32:28.600 |
the so-called linear multi-step sampling coming in, just keeping this buffer of past predictions 00:32:33.800 |
to try and do a better job estimating than the simple one-step linear type first order solvers. 00:32:40.200 |
Okay, so that's the score-based sampling version. And all of the variance and innovation comes down 00:32:47.160 |
to things like, how can we do this in as few steps as possible? Maybe we have a schedule that says we 00:32:51.880 |
take larger steps at first and then gradually smaller steps as we get closer. There's, I think, 00:32:56.920 |
now some dynamic methods and can we estimate how many steps we need to take, and so on. So that's 00:33:01.720 |
all trying to attack it from this kind of score-based ODE solving framework. But there's another way to 00:33:08.360 |
think of this as well. And that's to say, okay, well, I don't really care about solving this 00:33:13.560 |
exact reverse ODE, right? All I care about is that I end up with an image that's on this manifold, 00:33:19.160 |
like a plausible looking image. And so I have a model that estimates how much noise there is, 00:33:25.160 |
right? And if that noise is very small, then that means I've got a good image. And if that noise is 00:33:31.000 |
really large, then that means I've got some work to do. And so this kind of starts bringing up some 00:33:37.640 |
analogies to training neural networks, because in neural networks, we have the space of all possible 00:33:42.520 |
parameters. And we're trying to adjust those parameters not to solve the gradient flow equation, 00:33:48.040 |
right? Although that's, you know, possible in theory that you might try and do that. 00:33:52.360 |
We don't care about that, we just want to find a minima, we want to find a point where our loss is 00:33:55.720 |
really good. And so when we're training a neural network, that's exactly what we do. We set up an 00:34:00.040 |
optimizer, and we take some number of steps trying to reduce some loss. And once that loss gets sort 00:34:05.880 |
of, you know, levels off, right, reduced over time levels off, okay, cool, I guess we found a good 00:34:10.200 |
neural network. And so we can apply that same kind of thinking here to say, all right, I'll start at 00:34:15.560 |
some point. And I'll have an estimate of the gradient, right, like maybe pointing over here. 00:34:21.400 |
But remember, that estimate is not very good, just like the first gradients estimated when 00:34:25.960 |
training a neural network are pretty bad, because it's all just these randomly initialized weights, 00:34:29.400 |
but hopefully it at least points in a useful direction. So then I'll take some step, and the 00:34:33.640 |
length of the step, I won't try and do some fancy schedule, I'll just offload this to an sort of 00:34:38.760 |
off the shelf optimizer, right? So I have some learning rate, maybe something like momentum, 00:34:43.000 |
that determines how big of a step I take. And then I update my prediction, right, take another step 00:34:49.080 |
in that direction, and so on. So now, instead of following a fixed schedule, we can use tricks that 00:34:53.880 |
have been developed for training neural networks, right, adaptive learning rates, momentum, 00:34:57.720 |
weight decay, and so on. And we can apply them back to this kind of sampling case. And so it turns out 00:35:03.160 |
this works okay, I've tried this for stable diffusion, needs some tricks to get it working. 00:35:06.840 |
But it's a slightly different way of thinking about sampling, rather than relying on sort of 00:35:11.560 |
a hard coded ODE solver that you figured out yourself, just saying, why don't we treat this 00:35:16.040 |
like an optimization problem, where if the model predicts almost no noise, that's good, we're doing 00:35:20.840 |
a good job. And if the model predicts lots of noise, then we can use that as a gradient, and take a 00:35:26.040 |
gradient update step according to our optimizer, and try and sort of converge on a good image as 00:35:30.760 |
our output. And this is, you know, you can stop early once your model prediction is sufficiently 00:35:35.800 |
low for the amount of noise, okay, cool, I'm done. And so I found, you know, in 10, 15 steps, 00:35:40.360 |
you can get some pretty good images out. Yeah, so that's a different way of viewing it. 00:35:45.320 |
Not so popular at the moment, but maybe, hopefully something we'll see. Yeah, just a different 00:35:50.040 |
framing. And for me, at least that helps me think about what we're actually doing with the samplers, 00:35:54.040 |
we try to find a point where the model predicts very little noise. And so starting from a bad 00:35:58.280 |
prediction, moving towards it getting better, by looking at this estimated amount of noise 00:36:03.240 |
as our sort of gradient and solving that, just kind of iteratively removing bits at a time. 00:36:08.360 |
So I hope that helps elucidate the different kinds of samplers, and the goal of that whole thing, 00:36:14.040 |
and also illustrate at least why we don't just do this in a single step, right? Why we need some sort 00:36:18.040 |
of iterative approach, otherwise, we'd end up with just very bad blurry predictions. All right, 00:36:22.520 |
I hope that helps. Now we're going to head back to the notebook to talk about our final trick of 00:36:26.920 |
guidance. Okay, the final part of this notebook, guidance, how do we add some extra control to this 00:36:36.680 |
generation process, right? So we already have control via the text, and we've seen how we can 00:36:41.080 |
modify those embeddings. We have some control via starting at a noisy version of an input image, 00:36:46.200 |
rather than pure noise to kind of control the structure. But what if there's something else, 00:36:49.960 |
what if we'd like a particular style, or to enforce that the model looks like some input image, 00:36:55.640 |
or maybe sticks to some color palette, it would be nice to have some way to add this additional 00:37:00.280 |
control. And so the way we do this is to look at some loss function on the decoded denoised 00:37:08.360 |
predicted image, right? The predicted denoise needs final output, and use that loss to then update 00:37:15.640 |
the noisy latents as we generate in a direction that tries to reduce that loss. So for demo, 00:37:20.840 |
we're going to make a very simple loss function. I would like the image to be quite blue. And to 00:37:25.240 |
enforce that my error is going to be the difference between the blue channel, right? Red, green, blue, 00:37:29.800 |
blue is the third channel of the color channels, and the difference between the blue channel 00:37:34.040 |
and 0.9. So the closer all the blue values are to 0.9, the lower my error will be. So that's going 00:37:39.800 |
to be my kind of guidance loss. And then during sampling here, what I'm going to do, 00:37:43.880 |
everything's going to be the same as before. But every few iterations, you could do it every 00:37:50.280 |
iteration, but that's a little slow. So here, every five iterations, I'm going to set requires 00:37:56.040 |
grad equals true on the latents. I'm then going to compute my predicted denoised version. I'm 00:38:01.800 |
going to decode that into image space, and then I'm going to calculate my loss using my special 00:38:06.120 |
blue loss and scale it with some scaling factor. Then I'm going to use torch to find the gradient 00:38:12.920 |
of this loss with respect to those latents, those noisy latents. And I'm going to modify them, 00:38:18.200 |
right? And I want to reduce the loss. I'm going to subtract here this gradient multiplied by sigma 00:38:23.640 |
squared because we're going to be working at different noise levels. And so if we run this, 00:38:28.120 |
we should see, hopefully, it's going to do that same sort of sampling process as before, 00:38:32.520 |
but we also are occasionally modifying our latents by looking at the gradient of the loss with respect 00:38:37.880 |
to those latents and updating them in a direction that reduces that loss. And sure enough, we get a 00:38:42.680 |
very nice blue picture out. And if I change the scale here down to something lower and run it, 00:38:48.760 |
we'll see that scale is lower. So the loss is lower. So our modifications to the latents 00:38:54.600 |
are smaller. We'll see that we get out a much less blue image. There we go. So that's the 00:39:01.880 |
default image, very red and dark, because the prompt is just a picture of a campfire. 00:39:06.040 |
But as soon as we add our additional loss, our guidance, we're going to get out something that 00:39:13.240 |
better matches that additional constraint that we've imposed, right? So this is very useful, 00:39:18.280 |
not just for making your images blue, but like I said, color palettes or using some classifier 00:39:22.920 |
model to make it look like a specific class of image or using a model like clip to, again, 00:39:28.360 |
associate it with some text. So lots and lots of different things you can do. Now, a few things I 00:39:33.000 |
should note. One, we decoding the image back to image space, calculating our loss and then tracing 00:39:38.280 |
back. That's very computationally intensive compared to just working in latent space. And so 00:39:44.040 |
we can do that only every fifth operation to reduce the time, but it still is much slower 00:39:48.840 |
than just your generic sampling. And then also, we're actually still cheating a little bit here 00:39:54.440 |
because what we should do is set requires grad equals true on the latents and then use those to 00:39:59.800 |
make our noise prediction, use that to calculate the denoised version and decode that, calculate 00:40:05.720 |
our loss and trace back all the way through the decoder and the process and the unit back to the 00:40:10.680 |
latents, right? The reason I'm not doing that is because that takes a lot of memory. So you'll see, 00:40:15.960 |
for example, like the clip guided diffusion notebook from the hugging face examples, 00:40:20.760 |
they do it that way, but they have to use tricks like gradient checkpointing and so on to kind of 00:40:24.360 |
keep the RAM usage under control. And for simple losses, it works fine to do it this way, because 00:40:29.400 |
now we just tracing back through denoised latents is equal to latents minus sigma times this noise 00:40:34.280 |
prediction, right? So we don't have to trace any gradients back through the unit. But if you wanted 00:40:38.680 |
to get more accurate gradients, maybe it's not working as well as you'd hoped, you can do it that 00:40:43.240 |
other way that I described. But however you do it, very, very powerful technique, 00:40:47.480 |
fun to be able to again inject some additional control into this generation process by crafting 00:40:52.440 |
a loss that expresses exactly what you'd like to see. All right, that's the end of the notebook 00:40:57.560 |
for now. If you have any questions, feel free to reach out to me, I'll be on the forums and 00:41:01.800 |
you can find me on Twitter and so on. But for now, enjoy and I can't wait to see what you make.