Hello, everyone. My name is Jonathan, and today I'm going to be taking you through this stable diffusion deep dive notebook, looking into the code behind the kind of popular high level APIs and libraries and tools and so on to see what exactly does the generation process look like and how can we modify that?
How do each of the individual components work? So feel free to run along with me. If you haven't before, this might take a little while to run just because it's downloading these large models. If they aren't already downloaded and loading them up. So we're going to start by just kind of recreating what it looks like to generate an image using, say, one of the existing pipelines and hugging face.
So we're going to basically have copied the code from the core method of the default stable diffusion pipeline. So if you go and view that here, you'll see that we're going to be basically replicating this code. But now we'll be doing it on our own sort of notebook, and then we'll slowly understand what each of these different parts is doing.
So we've got some setup, we've got some sort of loop running through a number of sampling timestamps, and we're generating an image. So this is supposed to be a watercolor picture of an otter. And it's very, very cool that this model can just do this. But now we want to know how does that actually work?
What's going on? So the first component is the autoencoder. Now this stable diffusion is a latent diffusion model. And what that means is that it doesn't operate on pixels, it operates in the latent space of some other autoencoder model. In this case, a variational autoencoder that's been trained on a large number of images to compress them down into this latent representation and bring them back up again.
So I have some functions to do that. We're going to look at what it's like in action, just downloading a picture from the internet, opening it up with PIL. So we have this 512 by 512 pixel image, and we're going to load it in. And then we're going to use our function defined above to encode that into some latent representation.
And what this is doing is calling the VAE.encode method on a tensor version of the image. And that gives us a distribution. And so we sampling from that distribution and we scaling by this because that's what the authors of the model did. They scaled the latents down before they fed them to the model.
And so we have to do that scaling and then the reverse when we decoding just to be consistent with that. But the key idea is that we go from this big image down to this 4 by 64 by 64 latent representation. So we've gone from this much larger image down.
And if we visualize what the four channels here, this four different 64 by 64 channels, what that looks like, we'll see that it's capturing something of the image. You can sort of see the same shapes and things there. But it's not quite a direct mapping or anything. For example, there's this weirdness going on the beak.
Some of the channels look slightly stranger than the others. So there's some sort of rich information captured there. And if we decode this back, what we'll see is that the decoded image looks really good. You really have to look closely to tell the difference between our input image here and the decoded version.
So very, very impressive compression, right? This is a factor of 8 in each dimension. So 512 by 512 down to 64 by 64. It's like a factor of 64 reduction in data, but it's still somehow capturing most of that information. It's a very information-rich representation. And this is going to be great because now we can work with that with our diffusion model and get nice high resolution results even though we're only working with these 64 by 64 latents.
Now, it doesn't have to be 64 by 64. You can go and modify this to say what if this is 640 and encode that down and you'll see that it's just that same factor of 8 reduction. And there we go. Now we have 80 by 64. This just has to be a multiple of 8.
Otherwise, you'll get, I think, an error. Okay. So we have our encoded version of this image, and that's pretty great. The next component we're going to look at is the scheduler, and I'll look more closely at this later. But for now, we're going to focus on this idea of adding noise, right?
So during training, we add some noise to an image, and then the model tries to predict what that noise is, and we're going to do that to different amounts. So here we're going to recreate the same type of schedule, and you can try different schedulers from the library. Oops.
And these parameters here, beta start, beta end, beta schedule, that's how much noise was added at different time steps and how many time steps are used during training. For sampling, we don't want to have to do a thousand steps, so we can set a new number of time steps, and then we'll see how these correspond with the scheduler.timesteps attribute to the original training time steps.
So here we're going to have 15 sampling steps, and that's going to be equivalent to starting at time step 999 and just moving linearly down to time steps zero. We can also look at the actual amount of noise present with the sigma's attribute. So again, starting high, moving down, and if you want to see what that schedule looks like, we can plot that here.
And if you want to see the time steps, you'll see that it's just a linear relationship. So there we go. We're going to start at a very high noise value, and we're going to slowly, slowly try and reduce this down until ideally we get an image out. Okay, so the sigma is the amount of noise added.
Let's see what that looks like. So I'm going to start with some random noise that's the same shape as my latent representation, my encoded image, and then I'd like to be equivalent to sampling step 10 out of 15 here. So I'm going to go and look up what time step that equates to, and that's going to be one of the arguments that I passed the scheduler.addnoise function.
So I'm calling scheduler.addnoise, giving it my encoded image, the noise, and what time step I'd like to be noising equivalent to. And this is going to give me this noisy but still recognisable version of the image. And you can go and say, okay, what if I look at somewhere earlier in the process, right?
Does it look more noisy? What about right at the beginning, right at the end? Feel free to play around there. Okay, so this adding noise, what are we actually doing? What does the code look like? Let's inspect the function, and you'll see that there's some set up for different types of argument and shapes.
But the key line is just this noisy samples is equal to original samples plus the noise scaled by the sigma parameter. All right, so that's all it is. It's not always the same. Different papers and implementations will add the noise slightly differently. But in this case, that's all it's doing.
So scheduler.addnoise, just adding noise that's the same shape as the latency scaled by the sigma parameter. Okay, so that's what we're doing. So if we want to start from random noise instead of a noisy image, we're going to scale it by that same sigma value so that it looks the same as an image that's been scaled by that amount.
But then before we feed that to the actual model, we then have to handle that scaling again. You could do it like this, but now we have this scale model input function associated with the scheduler just to hide that complexity away. Okay, so now we're going to look at the same kind of sampling loop as before.
But we're going to start now with our image, we're going to take our encoded image, we're going to noise it to some time set, and then we're only going to denoise from there. So in code, we are now preparing our text and everything the same as before, which we'll look at, we setting our number of inference steps to 50, right, number inference steps is equal to 50 here.
And we're saying I'd like to start at the equivalent of step 10 out of 50. So I'll look up what time step that equates to, I'll add noise to my image, equivalent to that step. And then we're going to run through sampling, but this time we're only going to start doing things once we get above that start step.
So I'm going to ignore the first 10 out of 50 steps. And then beyond that, I'm now going to start with this noisy version of my input image. And I'm going to denoise it according to this prompt. And the hope here is that by starting from something that has some of the sort of rough structure and color of that input image, I can kind of fix that into my generation.
But I've got a new prompt, a National Geographic photo of a colorful dancer. And here we go, we see this is the same sort of thing as the parrot. But now we have this completely different actual content thanks to a different prompt. And so that's a fun kind of use of this image to image process.
You might have seen this for taking drawings, adding a bunch of noise and then denoising them into fancy paintings and so on. So again, this is something that there's existing tools for this, right, the strength parameter and the image to image pipeline. That's just something like this, what step are we starting at?
How many steps are we skipping? But you can see that this is a pretty powerful technique for getting a bit of extra control over like composition and color and a bit of the structure. Okay, so that's that trick with adding noise and then using that as image to image.
The next big section I'd like to look at is how do we go from a piece of text that describes what we want into a numerical representation that we can feed to the model. So we're going to trace out that pipeline. And along the way, we'll see how we can modify that for a bit of fun.
So step number one, we're taking our prompt and returning it into a sequence of discrete tokens. So here we have, in this case, 77, because that's the maximum length, discrete tokens, it's always going to be that if your prompt is longer, it'll truncate it. And if we decode these tokens back, we'll see that we have a special token for the start of the text, then a picture of a puppy.
And then the rest is all the same token, which is this kind of end of text padding token. Right, so we have this special token for puppy. This special token has its own meaning end of text. And the prompts are always going to be padded to be the same length.
So before, in the code that we were using there, we always jump straight to the circle output embeddings, which is what we fed to the model as conditioning. And so somehow this captures some information about this prompt. And but now we want to say, well, how do we get there?
How do we get from this sequence of tokens to these output embeddings? What is this text encoder forward pass doing? Right, so we can look at this, and there's going to be multiple steps. And the first is going to be some embeddings. So if we look at the text encoder dot text model dot embeddings, we'll see there's a couple of different ones, we have token embeddings, right?
And so this is to take those individual tokens, token 408, or whatever, and map it into a representation that's a numerical representation. So here it's a learned embedding. There are about 50,000 rows, one for each token. And for each token, we have 768 values. So that's the embedding of that token.
And if we want to feed one in and see what the embedding looks like, here's the token for puppy. And here's the token embedding, right? 768 numbers that somehow capture that meaning of that token on its own. And we can do the same for all of the tokens in our prompt.
So we feed them through this token embedding layer. And now we get 77 768 dimensional representations of this of each token. Now, these are all on their own. And no matter where in the sentence is, it is the token embedding will be the same. So the next step is to add some positional information.
Some models will do this with some kind of like learned pattern of positioning. But in this case, the positional embedding is just another learned embedding. But now instead of having one embedding for every token, we have one embedding for every position out of all 77 possible positions. And so just like we did for the tokens, we can feed them the position IDs, one for every possible position, and we'll get back out an embedding for every position in the prompt.
And combining them together, there's again, multiple ways people do this in the literature. But in this case, it's as simple as adding them. That's why they made them the same shape, so that you can just add the two together. And now, these input embeddings have some information related to the token and some related to the position.
And so so far, we haven't seen any big model just to learn embeddings, but this is getting everything ready to feed through that model. And so we can check that this is the same as if we just called the embeddings layer of that model, which is going to do both of those steps at once.
And but we'll see just now why we want to separate that out into individual ones. Okay, so we have these individual tokens, and they have some positional information, we have these final embeddings. Now we'd like to turn them into something that has a richer representation, thanks to some big transformer model.
And so we're going to feed these through. And I made this little diagram here, each token is going to turn into a token embedding combined with the positional embedding. And then it's going to get fed through this transformer encoder, which is just a stack of these blocks. And so each block has some magic like attention has some feed forward components, there's additions and normalizations and skips and so on as well.
And but we're going to have some number of these blocks all stacked together, and the outputs of each one get fed into the next block and so on. And so we get our final set of hidden states, these encoder hidden states, aka the output embeddings. And this is what we feed to our unit to make its predictions.
So the way we get this, I just copied the text encoded up text model forward method, pulled out the relevant bits, we are going to take in those input embeddings, combined positional and token embeddings, and we're going to feed that through the text model dot encoder function with some additional parameters around attention masking and telling it that we'd like to output the hidden states rather than the final outputs.
So if we run this, we can just double check, these embeddings are going to look just like the output embeddings we saw right at the beginning. So we've taken that one step, tokens to output embeddings, and we've broken it down into this number of smaller steps where we have tokenization, getting our token embeddings, combining with position embeddings, feeding it through the model, and then that gives us those final outputs.
So why have we gone through this problem trouble? Well, there's a couple of things we can do. One demo here, I'm getting the token embeddings, but then I'm looking up where is the token for puppy, and I'm going to replace it with a new set of embeddings. And this is going to be another just learned embedding of this particular token here, 2368.
So I'm kind of cutting out the token embedding for puppy, slipping in this new set of token embeddings, and I'm going to get some output embeddings which at the start look very similar to the previous ones, in fact identical. But as soon as you get past the position of puppy in that prompt, you're going to see that the rest have changed.
So we've somehow messed with these embeddings by slipping in this new token embedding right at the start. And if we generate with those embeddings, which is what this function is doing, we should see something other than a puppy. And sure enough, drum roll, we don't, we get a cat.
And so now you know what token 2368 means. We've managed to slip in a new token embedding and get a different image. Okay, what can we do with this? Why is this fun? Well, a couple of tricks. First off, we could look up the token embedding for skunk, right, which is this number here.
And then instead of now just replacing that in place of puppy, what if I make a new token embedding that's some combination of the embedding of puppy and the embedding of skunk, right? So I'm taking these two token embeddings, I'm just averaging them, and I'm inserting them into my set of token embeddings for my prompt in place of just the word puppy.
And so hopefully when we generate with this, we get something that looks a bit like a puppy, a bit like a skunk. And this doesn't work all the time, but it's pretty cute when it does. There we go, puppy skunk hybrid. Okay, so that's not the real reason we're looking at this.
The main application at the moment of being able to mess with these token embeddings is to be able to do something called textual inversion. So in textual inversion, we're going to have our prompt tokenize it and so on. But here we're going to have a special learned embedding for some new concept, right?
And so the way that's trained is going to be outside of the scope of this notebook. But there's a good blog post and community notebooks and things for doing that. But let's just see this in application here. So there's a whole library of these concepts, stable diffusion concept library, where you can browse through tons and tons and tons of look over 1,400 different community contributed token embeddings that people have trained.
And so I'm going to use this one here, this bird style. Here's some example outputs. And then these are the images it was trained on. So these pretty little bird paintings done by my mother. And I've trained a new token embedding that tries to capture the essence of the style.
And that's represented here in this learned embed stop in. So if you download this, and then upload it to wherever your notebooks running, I have it here, learned embed stop in, we can load that in. And you'll see that it's just a dictionary, where we have one key, that's the name of my new style.
And then we have this token embedding 768 numbers. And so now instead of slipping in the token embedding for cat, we're going to slip in this new embedding, which we've loaded from the file into this prompt. So a mouse in the style of puppy, tokenize, get my token embeddings, and then I'm going to slip in this replacement embedding in place of the embedding for puppy.
And when we generate with that, we should hopefully get a mouse in the style of this kind of cutesy watercolor on rough paper image. And sure enough, that's what we get, very cute little drawing of a mouse in an apron, apparently. Okay, so very, very cool application. Again, there's a nice inference notebook that makes this really easy.
You can say a cat toy in the style of burp style, you don't have to worry about manually replacing the token embeddings yourself. But it's good to know what the code looks like under the hood, right? How are we doing that? What stage of the text embedding process we're modifying?
Very fun to get a bit of extra control, and a very useful technique, because now we can kind of augment our model's vocabulary without having to actually retrain the model itself, we're just learning a new token embedding. It's a very, very powerful idea, and really fun to play with.
And like I said, there's thousands of community contributed tokens, but you can also train your own, I think I linked the notebook from here, but it's also in all the docs and so on. Here's the training notebook. Okay, final little trick with embeddings, rather than messing with them at the token embedding level, we can push the whole prompt through that entire process to get our final output embeddings, and we can mess with those at that stage as well.
So here I have two prompts, a mouse and a leopard, tokenizing them, encoding them with a text encoder, so that's that whole process. And these final output embeddings, I'm just going to mix them together according to some factor, and generate with that. And so you can try this with, you know, a cat and a snake.
And you should be able to get some really fun, different chimeras and oops, a snail apparently. Okay, well, I can't spell. But yeah, have fun with that, doesn't have to be animals. I'd love to see what you create with these weird mixed up generations. Okay, we should look at the actual model itself, the key unit model, the diffusion model.
What is it doing? What is it predicting? What is it accepting as arguments? So this is the kind of call signature, we call our units forward pass, and we feed in our noisy latency, the timestamp, and it's like the training timestamp, and the encoder hidden states, right? So those text embeddings that we've just been having fun with.
So doing that without any loops or anything, I'm sitting in my scheduler, getting my time step, getting my noisy latent, and my text embeddings. And then we're going to get our model prediction. And you'll look at the shape of that. And you'll see that this prediction has the same shape as the latency.
And given these noisy latency, what the model is predicting is the noise component of that. And actually, it's predicting the noise component scaled by sigma. So if we wanted to see what the original image looks like, we could say, well, the de-noise latency is going to be the current noisy latency minus sigma times the model prediction, right?
And so when we de-noising, we're not going to go straight to that upward prediction, we're going to just remove a little bit of the noise at a time. But it might be useful to visualize what that final prediction looks like. So that's what we're doing here, making a folder to store some images, preparing our text scheduler and input.
And then we're going to do this loop. But now we're going to get the model prediction. And instead of just updating our latency by one step, we're also going to store an image, right, and decoding these two images, an image of the predicted completely de-noised, like original sample. So that's this predicted original sample here.
You could also calculate this yourself. Latency zero is equal to the current latency minus sigma times the noise prediction. All right, so those two should work equivalently. But this loop is going to run, and it's going to save those images to the steps folder, which we can then visualize.
And so once this finishes, in a second or two, on the left, we're going to see the kind of noisy input to the model at each stage. And on the right, we're going to see the noisy input minus the noise prediction, right, so the de-noised version. And so we'll just give it a second or two to run.
It's taking it a little bit longer because it's decoding those images each time, saving them. But once this finishes, we should have a nice little preview video. Okay, here we go. So this is the noisy latent. And if we take the model's noise prediction and subtract it from that, we get this very blurry output.
And so you'll see as we play this -- oh, I've left some modifications in from last time, sorry. When you see this guidance scale, we'll be back at I think it was eight. In the next section, we'll talk about classifier-free guidance. And so I've been modifying that example. My bad.
I might cut this out of the video. We'll see. So I've got to wait a few seconds again for that to generate. And I'll do so as patiently as I can. Okay, so here we go again, the noisy input, the predicted de-noised version. And you can see at the start, it's very blurry.
But over time, it gradually converges on our final output. And you'll notice that on the left, these are the latents as they are each step. They don't change particularly drastically a little bit at a time. But at the start, when the model doesn't have much to go on, its predictions do change quite a bit at each step, right?
It's much less well-defined. And then as we go forward in time, it gets more and more refined, better and better predictions. And so it's got a more accurate estimation of the noise to remove. And we remove that noise gradually until we finally get our output. Quite fun to visualize the process.
And hopefully that helps you understand why we don't just make one prediction and do it in one step, right? Because we get this very blurry mess. But instead, we do this kind of iterative sampling there, which we'll talk about very shortly. Before then, though, the final thing I should mention, classifier-free guidance.
What is that? Well, like you saw when I accidentally generated the version with a much lower guidance scale, the way classifier-free guidance works is that in all of these loops, we haven't actually been passing one set of noisy latents through the model. We've been passing two identical versions. And as our text embeddings, we've not just been passing the embeddings of our prompts, right?
These ones here, we've been concatenating them with some unconditional embeddings as well. And what the unconditional embeddings are is just a blank prompt, right? No text whatsoever. So just all padding passing that through. So when we get our predictions here, we've given in two sets of latents and two sets of text embeddings, we're going to get out two predictions for the noise.
So we splitting that apart, one prediction for the unconditional, like no prompt version, and one for the prediction based on the prompt. And so what we can do now is we can say, well, my final prediction is going to be the unconditional version plus the guidance scale times the difference, right?
So if you think about it, if I predict without the noise, I'm predicting here. If I predict with the noise, sorry, with the text encoding, with the prompt, I get this prediction instead. And I'd like to move more in that direction. I'd like to push it even further towards the prompt version and beyond.
So this guidance scale can be larger than one, to push it even more in that direction. And this, it turns out, is kind of key for getting it to follow the prompt nicely. And I think it was first brought up in the glide paper. AI Coffee Break on YouTube has a great video on that.
But yeah, really useful trick or really neat hack, depending on who you talk to. But it does seem to work. And the higher the guidance scale, the more the model will try and look like the prompts kind of in the extreme versus that lower guidance scale, it might just try and look like a generic good picture.
Okay, we've been hiding away some complexity in terms of this scheduler dot step function. So I think we're going to step away from the notebook now and scribble a bit on some paper to try and explain exactly what's going on with sampling and so on. And then we'll come back to the notebook for one final trick.
All right, so here's my take on sampling. And to start with, I'd like you to imagine the space of all possible images. So this is a very large high dimensional space for 256 by 256 by three image, that is 200,000 dimensional. And my paper, unfortunately, is only two dimensional.
So we're going to have to squish this down a fair bit and use our imagination. Now, if you just look at a random point in this space, this is most likely not going to look like anything recognizable, it'll probably just look like garbled noise. But if we map an image into the space, we'll see that it has some sort of fixed point.
And a very similar image almost pixel equivalent, it's going to be very close by. Now, there's this theory that you'll hear talked about called manifold theory, which says that for most real images, like a data set of images, these are going to lie on some lower dimensional manifold within this higher dimensional space, right?
In other words, if we map a whole bunch of images into the space, they're not going to fill the whole space, they're going to be kind of clustered onto some surface. Now, I've drawn it as a line here because we stuck with 2D, but this is a much higher dimensional plane equivalent.
Okay, so each of these ones here is some image. And the reason that I'm starting with this is because we'd like to generate images, we'd like to generate plausible looking images, not just random nonsense. And so we'd like to do that with diffusion models. So where did they come in?
Well, we can start with some image here, some real image from our training data. And we can push it away from the manifold of like plausible existing images by corrupting it somehow. So for example, just adding random noise, that's equivalent to like moving in some random direction in this space of all possible images.
And so that's going to push the image away. And then we can try and predict using some model, what this noise looks like, right? How do I go from here back to a plausible image? What is this noise that's been added? And so that's going to be our big unit that does that prediction, that's going to be our diffusion model, right?
And so that's, in this language, going to be called something like a score function, right? How do I get from wherever I am? What's the noise that I need to remove to get back to a plausible image? Okay, so that's all well and good. We can train this model with a number of examples, because we can just take our training data, add some random noise, predict, predict, try and predict the noise, update our model parameters.
So we can hopefully learn that function fairly well. Now we'd like to generate with this model, right? So how do we do that? Well, we can start at some random point, right? Like, let's start over here. And you might think, well, surely I can just now predict the noise, remove that, and then I get my output image.
And that's great, except that you've got to remember now we're starting from a random point in the space of all possible images. It just looks like garbled nonsense. And the model's trying to say, well, what does the noise look like? And so you can imagine here, for training, the first thing we're training, the further away we get from our examples, the sparser our training will have been.
But also, it's not like it's very obvious how we got to this noisy version, right? We could have come from this image over here, added a bunch of noise. We could have come from one over here, one over here. And so this model's not going to be able to make a perfect prediction.
At best, it might say, well, somewhere in that direction, right? It could point towards something like the dataset mean, or at least the edge that's closer. But it's not going to be able to perfectly give you one nice solution. And sure enough, that's what we see. If we sample the fusion model system one step, we get the predictions, look at what that corresponds to as an image, it's just going to look like a blurry mess, maybe like the mean of the data or, you know, some sort of garbled output, definitely not going to look like a nice image.
So how do we do better? And the idea of sampling is to say, well, there's a couple of framings. So I'll start with the existing framing that you'll see talked about a lot of score-based models and so on. And then we'll talk about some other ways to think about it as well.
So this process of gradually corrupting our images away, adding a little bit of noise at a time, people like to talk of this as a stochastic differential equation. Stochastic because there's some randomness, right, we're picking random amounts of noise, random directions to add, and a differential equation because it's not talking about anything absolute, just how we should change this from moment to moment to get more and more corrupted, right?
So that's why it's a differential equation. And with that framing, the question of, well, how do I go now back to the image? That's framed as solving an ordinary differential equation that corresponds to like the reverse of this process. You can't solve ODEs in a single step, but you can find an approximate solution.
And the more sort of sub-steps you take, the better your approximation. And so that's what these samples are doing, given like, okay, we set this image over here, here's my prediction, rather than moving the whole way there in one go, we'll remove some of that noise, right, do a little update, and then we'll get a new prediction, right?
And so maybe now the prediction is slightly better. It says up here. So we move a little bit in that direction. And now it makes an even better prediction, because as we get closer to the manifold, right, as we have less and less noise, and more and more of like some image emerging, the model is able to get more and more accurate predictions.
And so in some sort of number of steps, we divide up this this process, and we get closer and closer and closer until we ideally find some image that looks very plausible as our output. And so that's what we're doing here with a lot of these samplers, they're effectively trying to solve this ODE in some number of steps by, yeah, breaking the process up and only moving a small amount at a time.
Now, you get sort of first order solvers, right, where all we're doing is just linearly moving within each one. And this is equivalent to something called Euler's method or Euler's method, if you're like me, and you've only ever read it. And this is what some of the most basic samplers are doing, just linear approximations for each of these little steps.
But you also get additional approaches. So for example, maybe if we were to make a prediction from here, it might look like something like this. And if we were to make a prediction from here, it might look like something like that. So we have our error here. But as you move in that direction, it's also changing, right?
So there's like a derivative of a derivative, a gradient of a gradient. And that's where this second order solver comes in and says, well, if I know how this prediction changes as I move in this direction, like what is the derivative of it, then I can kind of account for that curvature when I make my update step, and maybe know that it's going to curve a bit in that direction.
And so that's where we get things like these so called second order solvers and higher order solvers. The upside of this is that we can get, you know, do a larger step at a time, because we have a more accurate prediction, we're not just doing a first order linear approximation, we have this kind of curvature taken into account.
The downside is that to estimate that curvature for a given point, we might need to call our model multiple times to get multiple estimates. And so that takes time. So we can take a larger step, but we need more model evaluations per step. A kind of hybrid approach is to say, well, rather than trying to estimate the curvature here, I might just take a linear step, look at the next prediction, but I'll keep a history of my previous steps.
And so then over here, it predicts like this. So I have now this history. And I'm going to use that to better guess what this trajectory is. So I might keep a history of the past, you know, three or four or five predictions, and know that since they're quite close to each other, maybe that tells me some information about the curvature here.
And I can use that again, take larger steps. And so that's where we see the so-called linear multi-step sampling coming in, just keeping this buffer of past predictions to try and do a better job estimating than the simple one-step linear type first order solvers. Okay, so that's the score-based sampling version.
And all of the variance and innovation comes down to things like, how can we do this in as few steps as possible? Maybe we have a schedule that says we take larger steps at first and then gradually smaller steps as we get closer. There's, I think, now some dynamic methods and can we estimate how many steps we need to take, and so on.
So that's all trying to attack it from this kind of score-based ODE solving framework. But there's another way to think of this as well. And that's to say, okay, well, I don't really care about solving this exact reverse ODE, right? All I care about is that I end up with an image that's on this manifold, like a plausible looking image.
And so I have a model that estimates how much noise there is, right? And if that noise is very small, then that means I've got a good image. And if that noise is really large, then that means I've got some work to do. And so this kind of starts bringing up some analogies to training neural networks, because in neural networks, we have the space of all possible parameters.
And we're trying to adjust those parameters not to solve the gradient flow equation, right? Although that's, you know, possible in theory that you might try and do that. We don't care about that, we just want to find a minima, we want to find a point where our loss is really good.
And so when we're training a neural network, that's exactly what we do. We set up an optimizer, and we take some number of steps trying to reduce some loss. And once that loss gets sort of, you know, levels off, right, reduced over time levels off, okay, cool, I guess we found a good neural network.
And so we can apply that same kind of thinking here to say, all right, I'll start at some point. And I'll have an estimate of the gradient, right, like maybe pointing over here. But remember, that estimate is not very good, just like the first gradients estimated when training a neural network are pretty bad, because it's all just these randomly initialized weights, but hopefully it at least points in a useful direction.
So then I'll take some step, and the length of the step, I won't try and do some fancy schedule, I'll just offload this to an sort of off the shelf optimizer, right? So I have some learning rate, maybe something like momentum, that determines how big of a step I take.
And then I update my prediction, right, take another step in that direction, and so on. So now, instead of following a fixed schedule, we can use tricks that have been developed for training neural networks, right, adaptive learning rates, momentum, weight decay, and so on. And we can apply them back to this kind of sampling case.
And so it turns out this works okay, I've tried this for stable diffusion, needs some tricks to get it working. But it's a slightly different way of thinking about sampling, rather than relying on sort of a hard coded ODE solver that you figured out yourself, just saying, why don't we treat this like an optimization problem, where if the model predicts almost no noise, that's good, we're doing a good job.
And if the model predicts lots of noise, then we can use that as a gradient, and take a gradient update step according to our optimizer, and try and sort of converge on a good image as our output. And this is, you know, you can stop early once your model prediction is sufficiently low for the amount of noise, okay, cool, I'm done.
And so I found, you know, in 10, 15 steps, you can get some pretty good images out. Yeah, so that's a different way of viewing it. Not so popular at the moment, but maybe, hopefully something we'll see. Yeah, just a different framing. And for me, at least that helps me think about what we're actually doing with the samplers, we try to find a point where the model predicts very little noise.
And so starting from a bad prediction, moving towards it getting better, by looking at this estimated amount of noise as our sort of gradient and solving that, just kind of iteratively removing bits at a time. So I hope that helps elucidate the different kinds of samplers, and the goal of that whole thing, and also illustrate at least why we don't just do this in a single step, right?
Why we need some sort of iterative approach, otherwise, we'd end up with just very bad blurry predictions. All right, I hope that helps. Now we're going to head back to the notebook to talk about our final trick of guidance. Okay, the final part of this notebook, guidance, how do we add some extra control to this generation process, right?
So we already have control via the text, and we've seen how we can modify those embeddings. We have some control via starting at a noisy version of an input image, rather than pure noise to kind of control the structure. But what if there's something else, what if we'd like a particular style, or to enforce that the model looks like some input image, or maybe sticks to some color palette, it would be nice to have some way to add this additional control.
And so the way we do this is to look at some loss function on the decoded denoised predicted image, right? The predicted denoise needs final output, and use that loss to then update the noisy latents as we generate in a direction that tries to reduce that loss. So for demo, we're going to make a very simple loss function.
I would like the image to be quite blue. And to enforce that my error is going to be the difference between the blue channel, right? Red, green, blue, blue is the third channel of the color channels, and the difference between the blue channel and 0.9. So the closer all the blue values are to 0.9, the lower my error will be.
So that's going to be my kind of guidance loss. And then during sampling here, what I'm going to do, everything's going to be the same as before. But every few iterations, you could do it every iteration, but that's a little slow. So here, every five iterations, I'm going to set requires grad equals true on the latents.
I'm then going to compute my predicted denoised version. I'm going to decode that into image space, and then I'm going to calculate my loss using my special blue loss and scale it with some scaling factor. Then I'm going to use torch to find the gradient of this loss with respect to those latents, those noisy latents.
And I'm going to modify them, right? And I want to reduce the loss. I'm going to subtract here this gradient multiplied by sigma squared because we're going to be working at different noise levels. And so if we run this, we should see, hopefully, it's going to do that same sort of sampling process as before, but we also are occasionally modifying our latents by looking at the gradient of the loss with respect to those latents and updating them in a direction that reduces that loss.
And sure enough, we get a very nice blue picture out. And if I change the scale here down to something lower and run it, we'll see that scale is lower. So the loss is lower. So our modifications to the latents are smaller. We'll see that we get out a much less blue image.
There we go. So that's the default image, very red and dark, because the prompt is just a picture of a campfire. But as soon as we add our additional loss, our guidance, we're going to get out something that better matches that additional constraint that we've imposed, right? So this is very useful, not just for making your images blue, but like I said, color palettes or using some classifier model to make it look like a specific class of image or using a model like clip to, again, associate it with some text.
So lots and lots of different things you can do. Now, a few things I should note. One, we decoding the image back to image space, calculating our loss and then tracing back. That's very computationally intensive compared to just working in latent space. And so we can do that only every fifth operation to reduce the time, but it still is much slower than just your generic sampling.
And then also, we're actually still cheating a little bit here because what we should do is set requires grad equals true on the latents and then use those to make our noise prediction, use that to calculate the denoised version and decode that, calculate our loss and trace back all the way through the decoder and the process and the unit back to the latents, right?
The reason I'm not doing that is because that takes a lot of memory. So you'll see, for example, like the clip guided diffusion notebook from the hugging face examples, they do it that way, but they have to use tricks like gradient checkpointing and so on to kind of keep the RAM usage under control.
And for simple losses, it works fine to do it this way, because now we just tracing back through denoised latents is equal to latents minus sigma times this noise prediction, right? So we don't have to trace any gradients back through the unit. But if you wanted to get more accurate gradients, maybe it's not working as well as you'd hoped, you can do it that other way that I described.
But however you do it, very, very powerful technique, fun to be able to again inject some additional control into this generation process by crafting a loss that expresses exactly what you'd like to see. All right, that's the end of the notebook for now. If you have any questions, feel free to reach out to me, I'll be on the forums and you can find me on Twitter and so on.
But for now, enjoy and I can't wait to see what you make.