Hi, we are here for lesson 24. And once again, it's becoming a bit of a tradition now. We're joined by Jono and Tanishk, which is always a pleasure. Hi, Jono. Hi, Tanishk. Hello. Another great lesson. Yeah, are you guys looking forward to finally actually completing stable diffusion, at least the unconditional stable diffusion?
Well, I should say no, even conditional. So conditional stable diffusion, except for the clip bit from scratch. We should be able to finish today. Time permitting. Oh, that's exciting. That is exciting. All right. Let's do it. Jump in any time. We've got things to talk about. So we're going to start with a very hopefully named 26 diffusion unit.
And what we're going to do in 26 diffusion unit is to do unconditional diffusion from scratch. And there's not really too many new pieces, if I remember correctly. So all the stuff at the start we've already seen. And so when I wrote this, it was before I had noticed that the Keras approach was doing less well than the regular cosine schedule approach.
So I'm still using Keras Noisify, but this is all the same as from the Keras notebook, which was 23. Okay, so we can now create a unit that is based on what diffusers has, which is in turn based on lots of other prior art. I mean, the code's not at all based on it, but the basic structure is going to be the same as what you'll get in diffusers.
The convolution we're going to use is the same as the final kind of convolution we used for tiny image net, which is what's called the preactivation convolution. So the convolution itself happens at the end and the normalization and activation happen first. So this is a preact convolution. So then I've got a unit res net block.
So I kind of wrote this before I actually did the preact version of tiny image net. So I suspect this is actually the same, quite possibly exactly the same as the tiny image net one. So maybe this is nothing specific about this for unit, this is just really a preact conv and a preact res net block.
So we've got the two comms as per usual and the identity conv. Now there is one difference though to what we've seen before for res net blocks, which is that this res net block has no option to do downsampling, no option to do a strayed. This is always strayed one, which is our default.
So the reason for that is that when we get to the thing that strings a bunch of them together, which will be called down block, this is where you have the option to add downsampling. But if you do add downsampling, we're going to add a strayed to convolution after the res block.
And that's because this is how diffusers and stable diffusion does it. I haven't studied this closely to Nishkif or Dono if either of you have know like where this idea came from or why. I'd be curious, you know, the difference is that normally we would have average pooling here in this connection.
But yeah, this different approach is what we're using. A lot of the history of the diffusers unconditional unit is to be compatible with the DDPM weights that were released and some follow and work from that. And I know like then improved DDPM and these others like they all kind of built on that same sort of unit structure, even though it's slightly unconventional if you're coming from like a normal computer vision background.
And do you recall where the DDPM architecture came from? Because like some of the ideas came from some of the N units, but I don't know if DDPM. Yeah, they had something called efficient unit that was inspired by some prior work that I can't remember the lineage. Anyway, yeah, I just think the diffusers one has since become you know, like you can add in parameters to control some of this stuff.
But yeah, it's we shouldn't assume that this is the optimal approach, I suppose. But yeah, I will dig into the history and try and find out how much like, what ablation studies have been done. So for those of you who haven't heard of ablation studies, that's where you'd like try, you know, a bunch of different ways of doing things and score which one works better and which one works less well and kind of create a table of all of those options.
And so where you can't find ablation studies for something you're interested in, often that means that, you know, maybe not many other options were tried because researchers don't have time to try everything. Okay, now the unit, if we go back to the unit that we used for super resolution, we just go back to our most basic version.
What we did as we went down through the layers in the down sampling section, we stored the activations at each point into a list called layers. And then as we went through the up sampling, we added those down sampling layers back into the up sampling activations. So that's kind of basic structure of a unit.
You don't have to add, you can also concatenate and actually concatenating is what is, I think it's more common nowadays and I think your original unit might have been concatenating. Although for super resolution, just adding seems pretty sensible. So we're going to concatenate. But what we're going to do is we're going to try to, we're going to kind of exercise our Python muscles a little bit to try to see interesting ways to make some of this a little easier to turn different down sampling backbones into units.
And you also use that as an opportunity to learn a bit more Python. So what we're going to do is we're going to create something called a saved res block and a saved convolution. And so our down blocks, so these are our res blocks containing a certain number of res block layers, followed by this optional strive to conv.
We're going to use saved res blocks and saved cons. And what these are going to do, it's going to be the same as a normal convolution and the same as a normal res block, the same as normal unit res block. But they're going to remember the activations. And the reason for that is that later on in the unit, we're going to go through and grab those saved activations all at once into a big list.
So then yeah, we basically don't have to kind of think about it. And so to do that, we create a class called a save module. And all saved module does is it calls forward to grab the res block or conv results and stores that before returning it. Now that's weird because hopefully you know by now that super calls the thing in the parent class, that save module doesn't have a parent class.
So this is what's called a mixin. And it's using something called multiple inheritance. And mixins are as it describes here. It's a design pattern, which is to say it's not particularly a part of Python per se. It's a design pattern that uses multiple inheritance. Now what multiple inheritance is where you can say, oh, this class called saved res block inherits from two things, save module and unit res block.
And what that does is it means that all of the methods in both of these will end up in here. Now that would be simple enough, except we've got a bit of a confusion here, which is that unit res block contains forward and saved module contains forward. So it's all very well just combining the methods from both of them.
But what if they have the same method? And the answer is that the one you list first can call, when it calls forward, it's actually calling forward in the later one. And that's why it's a mixin. It's mixing this functionality into this functionality. So it's a unit res block where we've customized forward.
So it calls the existing forward and also saves it. So you see mixins quite a lot in the Python standard library. For example, the basic HTTP stuff, some of the basic thread stuff with networking uses multiple inheritance using this mixin pattern. So with this approach, then the actual implementation of saved res block is nothing at all.
So pass means don't do anything. So this is just literally just a class which has no implementation of its own other than just to be a mixin of these two classes. So a saved convolution is an nn.conf2d with the saved module mixed in. So what's going to happen now is that we can call a saved res block just like a unit res block and a saved conv just like an nn.conf2d.
But that object is going to end up with the activations inside the dot saved attribute. So now a downsampling block is just a sequential of saved res blocks. As per usual, the very first one is going to have the number of n channels to start with and it will always have the number of nf, the number of filters output, and then after that the inputs will be else equal to nf because the first one's changed the number of channels.
And we'll do that for however many layers we have. And then at the end of that process, as we discussed, we will add to that sequential a saved conv with str2 to do the downsampling if requested. So we're going to end up with a single nn.sequential for a down block.
And then an up block is going to look very similar, but instead of using an nn.conf2d with str2, upsampling will be done with a sequence of an upsampling layer. And so literally all that does is it just duplicates every pixel four times into little two by two grid. That's what an upsampling layer does, nothing clever.
And then follow that by a str1 convolution. So that allows it to, you know, adjust some of those pixels as if necessary with a simple three by three conv. So that's pretty similar to a str2 downsampling. This is kind of the rough equivalent for upsampling. There are other ways of doing upsampling.
This is just the one that stable diffusion does. So an up block looks a lot like a down block, except that now, so as before, we're going to create a bunch of unit res blocks. These are not saved res blocks, of course. We want to use the saved results in the upsampling path of the unit.
So we just use normal res blocks. But what we're going to do now is as we go through each res net, we're going to call it not just on our activations, but we're going to concatenate that with whatever was stored during the downsampling path. So this is going to be a list of all of the things stored in the downsampling path.
It'll be passed to the up block. And so .pop will grab the last one off that list and concatenate it with the activations and pass that to the res net. So we need to know how many filters there were, how many activations there were in the downsampling path. So that's stored here.
This is the previous number of filters in the downsampling path. And so the res block wanted to add those in in addition to the normal number. So that's what's going to happen there. And so yeah, do that for each layer as before. And then at the end, add an upsampling layer if it's been requested.
So it's a boolean. OK, so that's the upsampling block. Does that all make sense so far? Yeah, it looks good. OK. OK, so the unit now is going to look a lot like our previous unit. We're going to start out as we tend to with a convolution to now allow us to create a few more channels.
And so we're passing to our unit. That's just how many channels are in your image and how many channels are in your output image. So for normal full color images, that'll be 3/3. How many filters are there for each of those res net blocks, up blocks and down blocks you've got.
And in the downsampling, how many layers are there in each block? So we go from the conv will go from in channel. So it'd be 3 to an F0, which this is the number of filters in the stable diffusion model. They're pretty big, as you see by default. And so that's the number of channels we would create, which is like very redundant in that this is a 3 by 3 conv.
So it only contains 3 by 3 by 3 channels equals 27 inputs and 224 outputs. So it's not doing computation, useful computation in a sense. It's just giving it more space to work with down the line, which I don't think that makes sense, but I haven't played with it enough to be sure.
Normally we would do like, you know, like a few res blocks or something at this level to more gradually increase it because this feels like a lot of wasted effort. But yeah, I haven't studied that closely enough to be sure. So Jamie, just to tweet, this is the default, I think the default settings for the unconditional unit in diffusers.
But the stable diffusion unit actually has even more channels. It has 320, 640, and then 1,280, 1,280. Cool. Thanks for clarifying. And it's, yeah, the unconditional one, which is what we're doing right now. That's a great point. Okay. So then we, yeah, we go through all of our number of filters and actually the first res block contains 224 to 224.
So that's why it's kind of keeping track of this stuff. And then the second res block is 224 to 448 and then 448 to 672 and then 672 to 896. So that's why we're just going to have to keep track of these things. So yeah, we add, so we have a sequential for our down blocks and we just add a down block.
The very last one doesn't have downsampling, which makes sense, right? Because the very last one, there's nothing after it, so no point downsampling. Other than that, they all have downsampling. And then we have one more res block in the middle, which, is that the same as what we did?
Okay. So we didn't have a middle res block in our original unit here. What about this one? Do we have any mid blocks? No, so we haven't done. Okay. But I mean, so it's just another res block that you do after the downsampling. And then we go through the reversed list of filters and go through those and adding up blocks.
And then one convolution at the end to turn it from 224 channels to three channels. Okay. And so the forward then is going to store in saved all the layers, just like we did back with this unit. But we don't really have to do it explicitly now. We just call the sequential model.
And thanks to our automatic saving, each of those now will, we can just go through each of those and grab their dot saved. So that's handy. We then call that mid block, which is just another res block. And then same thing. Okay. Now for the ARPS. And what we do is we just passed in those saved, right?
And just remember, it's going to pop them out each time. And then the conv at the end. So that's, yeah, that's it. That's our unconditional model. It's not quite the same as the diffuses unconditional model, because it doesn't have a tension, which is something we're going to add next.
But other than that, this is the same. So let's for, because we're doing a simpler problem, which is fashion MNIST, we'll use less channels than the default. Using two layers per block is standard. One thing to note though, is that in the up sampling blocks, it actually is going to be three layers, num layers plus one.
And the reason for that is that the way stable diffusion and diffuses do it is that even the output of the down sampling is also saved. So if you have num layers equals two, then there'll be two res blocks saving things here and one conv saving things here. So you'll have three saved cross connections.
So that's why there's an extra plus one here. Okay. And then we can just train it using mini AI as per usual. Nope, I didn't save it after I last trained it. Sorry about that. So trust me, it trained. Okay. Now that, oh, okay. No, that is actually missing something else important as well as attention.
The other thing it's missing is that thing that we discovered is pretty important, which is the time embedding. So we already know that sampling doesn't work particularly well with that time embedding. So I didn't even bother sampling this. I didn't want to add all this stuff necessary to make that work a bit better.
So let's just go ahead and do time embedding. So time embedding, there's a few ways to do it. And the way it's done in stable diffusion is what's called sinusoidal embeddings. The basic idea, maybe we'll skip ahead a bit. The basic idea is that we're going to create a res block with embeddings where forward is not just going to get the activations, but it's also going to get T, which is a vector that represents the embeddings of each time step.
So actually it'll be a matrix because it's really in the batch. But for one element of the batch, it's a vector. And it's an embedding in exactly the same way as when we did NLP. Each token had an embedding. And so the word "the" would have an embedding and the word "Johnno" would have an embedding and the word "Tanishk" would have an embedding, although Tanishk would probably actually be multiple tokens until he's famous enough that he's mentioned in nearly every piece of literature, at which point Tanishk will get his own token, I expect.
That's how you know when you've made it. So the time embedding will be the same. T of time step zero will have a particular vector, time step one will have a particular vector, and so forth. Well, we're doing Keras. So actually they're not time step one, two, three. They're actually sigmas, you know.
So they're continuous. But same idea. A specific value of sigma, which is actually what T is going to be, slightly confusingly, will have a specific embedding. Now, we want two values of sigma or T, which are very close to each other, should have similar embeddings. And if they're different to each other, they should have different embeddings.
So how do we make that happen? You know, and also make sure there's a lot of variety of the embeddings across all the possibilities. So the way we do that is with these sinusoidal time steps. So let's have a look at how they work. So you first have to decide how big do you want your embeddings to be?
Just like we do at NLP. Does the word "the," is it represented by eight floats or 16 floats or 400 floats or whatever? Let's just assume it's 16 now. So let's say we're just looking at a bunch of time steps, which is between negative 10 and 10. And we'll just do 100 of them.
I mean, we don't actually have negative sigmas or T. So it doesn't exactly make sense. But it doesn't matter. It gives you the idea. And so then we say, OK, what's the largest time step you could have or the largest sigma that you could have? Interestingly, every single model I've found, every single model I've found uses 10,000 for this.
Even though that number actually comes from the NLP transformers literature, and it's based on the idea of, like, OK, what's the maximum sequence length we support? You could have up to 10,000 things in a document or whatever in a sequence. But we don't actually have a sigmas that go up to 10,000.
So I'm using the number that's used in real life in stable diffusion and all the other models. But interestingly, here purely, as far as I can tell, as a hysterical accident, because this is like the maximum sequence length that NLP transformers people thought they would need to support. OK, now what we're then going to do is we're going to be then doing e to the power of a bunch of things.
And so that's going to be our exponent. And so our exponent is going to be equal to log of the period, which is about nine, times the numbers between 0 and 1, eight of them, because we said we want 16. So you'll see why we want eight of them and not 16 in a moment.
But basically here are the eight exponents we're going to use. So then not surprisingly, we do e to the power of that. OK, so we do e to the power of that, each of these eight things. And we've also got the actual time steps. So imagine these are the actual time steps we have in our batch.
So there's a batch of 100, and they contain this range of sigmas or time steps. So to create our embeddings, what we do is we do an outer product of the exponent.x and the time steps. This is step one. And so this is using a broadcasting trick we've seen before.
We add a unit axis and an axis 0 here, and add a unit axis 1 here, and add a unit axis and axis 0 here. So if we multiply those together, then it's going to broadcast this one across this axis and this one across this axis. So we end up with a 100 by 8.
So it's basically a Cartesian product or the possible combinations of time step and exponent multiplied together. And so here's a few of those different exponents for a few different values. OK, so that's not very interesting yet. We haven't yet reached something where each time step is similar to each next door time step.
You know, over here, you know, these embeddings look very different to each other. And over here, they're very similar. So what we then do is we take the sine and the cosine of those. So that is 100 by 8. And that is 100 by 8. And that gives us 100 by 16.
So we concatenate those together. And so that's a little bit hard to wrap your head around. So let's take a look. So across the 100 time steps, 100 sigma, this one here is the first sine wave. And then this one here is the second sine wave. And this one here is the third.
And this one here is the fourth and the fifth. So you can see as you go up to higher numbers, you're basically stretching the sine wave out. And then once you get up to index 8, you're back up to the same frequency as this blue one, because now we're starting the cosine rather than sine.
And cosine is identical to sine. It's just shifted across a tiny bit. You can see these two light blue lines are the same. And these two orange lines are the same. They're just shifted across, I shouldn't say, lines or curves. So when we concatenate those all together, we can actually draw a picture of it.
And so this picture is 100 pixels across and 16 pixels top to bottom. And so if you picked out a particular point, so for example, in the middle here for t equals 0, well sigma equals 0, one column is an embedding. So the bright represents higher numbers and the dark represents lower numbers.
And so you can see every column looks different, even though the columns next to each other look similar. So that's called a time step embedding. And this is definitely something you want to experiment with. I've tried to do the plots I thought are useful to understand this. And Johno and Tanishka also had ideas about plots for these, which we've shown.
But the only way to really understand them is to experiment. So then we can put that all into a function where you just say, OK, well, how many times-- sorry, what are the time steps? How many embedding dimensions do you want? What's the maximum period? And then all I did was I just copied and pasted the previous cells and merged them together.
So you can see there's our outer product. And there's our cat of sine and cos. If you end up with a-- if you have an odd numbered embedding dimension, you have to pat it to make it even. Don't worry about that. So here's something that now you can pass in the number of-- sorry, the actual time steps or sigma's and the number of embedding dimensions.
And you will get back something like this. It won't be a nice curve because your time steps in a batch won't all be next to each other. It's the same idea. Can I call it something on that little visualization there, which goes back to your comment about the max period being super high?
So you said, OK, adjacent ones are somewhat similar because that's what we want. But there is some change. But if you look at all of this first 100, some-- just like the half of the embeddings look like they don't really change at all. And that's because 50 to 100 on a scale of like 0 to 10,000, you want those to be quite similar because those are still very early in this super long sequence that these are designed for.
Yeah. So here, actually, we've got-- --wasted space. Yeah. So here we've got a max period of 1,000 instead. And I've changed the figure size so you can see it better. And it's using up a bit more of the space. Yeah. Or go to max period of 10. And it's actually now-- this is, yeah, using it much better.
Yeah. So based on what you're saying, Johnno, I agree. It seems like it would be a lot richer to use these time step embeddings with a suitable max period. Or maybe you just wouldn't need as many embedding dimensions. I guess if you did use something very wasteful like this but you used lots of embedding dimensions, then it's going to still capture some useful ones.
Yeah. Thanks, Johnno. So yeah. Yeah. So this is one of these interesting little insights about things that are buried deep in code, which I'm not sure anybody probably much looks at. OK. So let's do a unit with time step embedding in it. So what do you do once you've got like this column of embeddings for each item of the batch?
What do you do with it? Well, there's a few things you can do with it. What stable diffusion does, I think, is correct. I'm not promising. I remember all these details right, is that they make their embedding dimension length twice as big as the number of activations. And what they then do is we can use chunk to take that and split it into two separate variables.
So that's literally just the opposite of concatenate. It's just two separate variables. And one of them is added to the activations and one of them is multiplied by the activations. So this is a scale and a shift. We don't just grab the embeddings as is, though, because each layer might want to do-- each res block might want to do different things with them.
So we have a embedding projection, which is just a linear layer which allows them to be projected. So it's projected from the number of embeddings to 2 times the number of filters so that that torch.chunk works. We also have an activation function called silu. This is the activation function that's used in stable diffusion.
I don't think the details are particularly important. But it looks basically like a rectified linear with a slight curvy bit. Also known as SWISH. Also known as SWISH. And it's just equal to x times sigmoid x. And yeah, I think it's like activation functions don't make a huge difference.
But they can make things train a little better or a little faster. And SWISH has been something that's worked pretty well. So a lot of people using SWISH or silu. I always call it SWISH. But I think silu was originally the galley paper which had silu was where it originally was kind of invented.
And maybe people didn't quite notice. And then another paper called it SWISH. And everybody called it SWISH. And then people were like, wait, that wasn't the original paper. So I guess I should try to call it silu. Other than that, it's just a normal res block. So we do our first conv.
Then we do our embedding projection of the activation function of time steps. And so that's going to be applied to every pixel height and width. So that's why we have to add unit axes on the height and width that it's going to cause it to broadcast across those two axes.
Do our chunk. Do the scale and shift. Then we're ready for the second conv. And then we add it to the input with an additional conv, one stride one conv if necessary as we've done before if we have to change the number of channels. OK. Yeah, because I like exercising our Python muscles, I decided to use a second approach now for the down block and the up block.
I'm not saying which one's better or worse. We're not going to use multiple inheritance anymore. But instead, we're going to use-- well, it's not even a decorator. It's a function which takes a function. What we're going to do now is we're going to use funf2dd and mbrezblock directly. But we're going to pass them to a function called saved.
The function called saved is something which is going to take as input a callable, which could be a function or a module or whatever. So in this case, it's a module. Takes an mbrezblock or a conv2d. And it returns a callable. The callable it returns is identical to the callable that's passed into it, except that it saves the result, saves the activations, saves the result of a function.
Where does it save it? It's going to save it into a list in the second argument you pass to it, which is the block. So the save function, you're going to pass it the module. We're going to grab the forward from it and store that away to remember what it was.
And then the function that we want to replace it with, call it underscore f, going to take some arguments and some keyword arguments. Well, basically, it's just going to call the original modules.forward, passing in the arguments and keyword arguments. And we're then going to store the result in something called the saved attribute inside here.
And then we have to return the result. So then we're going to replace the modules forward method with this function and return the module. So that module's now been-- yeah, I said callable, actually. It can't be called. It has to specifically be a module, because with the forward that we're changing.
This at wraps is just something which automatically-- it's from the Python standard library. So it's going to copy in the documentation and everything from the original forward so that it all looks like nothing's changed. Now, where does this dot saved come from? I realized now, actually, we could make this easier and automate it.
But I forgot, didn't think of this at the time. So we have to create the saved here in the down block. It actually would have made more sense, I think, here for it to have said if the saved attribute doesn't exist, then create it, which would look like this.
If not has atcher block comma saved block dot saved, because if you do this, then you wouldn't need this anymore. Anyway, I didn't think of that at the time. So let's pretend that's not what we do. OK, so now the downsampling conv and the resnets both contain saved versions of modules.
We don't have to do anything to make that work. We just have to call them. We can't use sequential anymore, because we have to pass in the time step to the resnets as well. It would be easy enough to create your own sequential for things with time steps, which passes them along.
But that's not what we're doing here. Yeah, maybe it makes sense for sequential to always pass along all the extra arguments. But I don't think that's how they work. Yeah, so our up block is basically exactly the same as before, except we're now using ember as blocks instead. Just like before, we're going to concatenate.
So that's all the same. OK, so a unit model with time embeddings is going to look, if we look at the forward, the thing we're passing into it now is a tuple containing the activations and the time steps, or the segments in our case. So split them out. And what we're going to do is we're going to call that time step embedding function we wrote, saying, OK, these are the time steps.
And the number of time step embeddings we want is equal to however many we asked for. And we're just going to set it equal to the first number of filters. That's all that happens there. And then we want to give the model the ability then to do whatever it wants with those, to make those work the way it wants to.
And the easiest, smallest way to do that is to create a tiny little MLP. So we create a tiny little MLP, which is going to take the time step embeddings and return the actual embeddings to pass into the ResNet box. So tiny little MLP is just a linear layer with-- it's thinking here.
That's interesting. My linear layer by default has an activation function. I'm pretty sure we should have act equals none here. It should be a linear layer and then an activation and then a linear layer. So I think I've got a bug, which we will need to try rerunning. OK.
It won't be the end of the world. It just means all the negatives will be lost here. Makes it half-- only half as useful. That's not great. OK. And these are the kind of things like, you know, as you can see, you've got to be super careful of, like, where do you have activation functions?
Where do you have batch norms? Is it pre-activation? Is it post-activation? It trains even if you make that mistake. And in this case, it's probably not too much performance, but often it's like, oh, you've done something where you accidentally zeroed out, you know, all except the last few channels of your output block or something like that.
When it work tries anyway, it does its best. It uses what it can. Yeah, it makes it very difficult. To make sure you're not giving it those handicaps. Yeah. It's not like you're making a CRUD app or something and you know that it's not working because it crashes or because, like, it doesn't show the username or whatever.
Instead, you just get, like, slightly less good results. But since you haven't done it correctly in the first place, you don't know it's the less good results. Yeah, there's not really great ways to do this. It's really nice if you can have an existing model to compare to or something like that, which is where Kaggle competitions work really well.
Actually, if somebody's got a Kaggle result, then you know that's a really good baseline and you can check whether yours is as good as theirs. All right. So, yeah, that's what this MLP is for. So, the down and up blocks are the same as before. The convout is the same as before.
So, yeah, so we grab our time step embedding. So, that's just that outer product passed through this sinusoidal, the sine and cosine. We then pass that through the MLP. And then we call our downsampling, passing in those embeddings each time. You know, it's kind of interesting that we pass in the embeddings every time in the sense I don't exactly know why we don't just pass them in at the start.
And in fact, in MLP, these kinds of embeddings, I think, are generally just passed into the start. So, this is kind of a curious difference. I don't know why. It's, you know, if there's been ablation studies or whatever. Do you guys know, are there like any popular diffusion-y or generative models with time embeddings that don't pass them in or is this pretty universal?
>> Some of the fancier architectures like recurrent interface networks and stuff just pass in the conditioning. I'm actually not sure. Yeah, maybe they do still do it at every stage. I think some of them just take in everything all at once up front and then do a stack of transformer blocks or something like that.
So, I don't know if it's universal, but it definitely seems like all the unit-style ones have this the time step embedding going in. >> Maybe we should try some ablations to see, yeah, if it matters. I mean, I guess it doesn't matter too much either way. But, yeah, if you didn't need it at every step, then it would maybe save you a bit of compute, potentially.
Yeah, so now the upsampling, you're passing in the activations, the time step embeddings, and that list of saved activations. So, yeah, now we have a non-attention stable diffusion unit. So, we can train that. And we can sample from it using the same -- I just copied and pasted all the stuff from the Keras notebook that we had.
And there we have it. This is our first diffusion from scratch. >> So, we wrote every piece of code for this diffusion model. >> Yeah, I believe so. I mean, obviously, in terms of the optimized kudor implementations of stuff, no. But, yeah, we've written our version of everything here, I believe.
>> A big milestone. >> I think so, yeah. And these FIDS are about the same as the FIDS that we get from the stable diffusion one. They're not particularly higher or lower. They bounce around a bit, so it's a little hard to compare. Yeah, they're basically the same. Yeah, so that's -- that is an exciting step.
And okay, yeah, that's probably a good time to have a five-minute break. Yeah, okay. Let's have a five-minute break. Okay. Normally, I would say we're back, but only some of us are back. Johnno -- Johnno's internet and electricity in same-type way is not the most reliable thing. And he seems to have disappeared, but we expect him to reappear at some point.
So we will kick on Johnno-less and hope that Zimbabwe's infrastructure sorts itself out. All right. So we're going to talk about attention. We're going to talk about attention for a few reasons. Reason number one, very pragmatic. We said that we would replicate stable diffusion, and the stable diffusion unit has tension in it.
So we would be lying if we didn't do attention. Okay. Number two, attention is one of the two basic building blocks of transformers. A transformer layer is attention attached to a one-layer MLP. We already know how to create a one-layer or one-hidden layer MLP. So once we learn how to do attention, we'll know how to -- we'll know how to create transformer blocks.
So those are two good reasons. I'm not including a reason which is our model is going to look a lot better with attention, because I actually haven't had any success seeing any diffusion models I've trained work better with attention. So just to set your expectations, we are going to get it all working.
But regardless of whether I use our implementation of attention or the diffuser's one, it's not actually making it better. That might be because we need to use better types of attention than what diffuser's has, or it might be because it's just a very subtle difference that you only see on bigger images.
I'm not sure. That's something we're still trying to figure out. This is all pretty new. And not many people have done kind of the diffusion, the kind of ablation studies necessary to figure these things out. So yeah, so that's just life. Anyway, so there's lots of good reasons to know about attention.
We'll certainly be using it a lot once we do an LP, which we'll be coming to pretty shortly, pretty soon. And it looks like Jono is reappearing as well. So that's good. Okay, so let's talk about attention. The basic idea of attention is that we have an image, and we're going to be sliding a convolution kernel across that image.
And obviously, we've got channels as well, or filters. And so this also has that. Okay. And as we bring it across, we might be, you know, we're trying to figure out like what activations do we need to create to eventually, you know, correctly create our outputs. But the correct answer as to what's here may depend on something that's way over here, and/or something that's way over here.
So for example, if it's a cute little bunny rabbit, and this is where its ear is, you know, and there might be two different types of bunny rabbit that have different shaped ears, well, it'd be really nice to be able to see over here what its other ear looks like, for instance.
With just convolutions, that's challenging. It's not impossible. We talked in part one about the receptive field. And as you get deeper and deeper in a convnet, the receptive field gets bigger and bigger. But it's, you know, at higher up, it probably can't see the other ear at all. So it can't put it into those kind of more texture level layers.
And later on, you know, even though this might be in the receptive field here, most of the weight, you know, the vast majority of the activations it's using is the stuff immediately around it. So what attention does is it lets you take a weighted average of other pixels around the image, regardless of how far away they are.
And so in this case, for example, we might be interested in bringing in at least a few of the channels of these pixels over here. The way that attention is done in stable diffusion is pretty hacky and known to be suboptimal. But it's what we're going to implement because we're implementing stable diffusion and time permitting.
Maybe we'll look at some other options later. But the kind of attention we're going to be doing is 1D attention. And it was a tension that was developed for NLP. And NLP is sequences, one-dimensional sequences of tokens. So to do attention stable diffusion style, we're going to take this image and we're going to flatten out the pixels.
So we've got all these pixels. We're going to take this row and put it here. And then we're going to take this row, we're going to put it here. So we're just going to flatten the whole thing out into one big vector of all the pixels of row one and then all the pixels of row two and then all the pixels of the row three.
Or maybe it's column one, column two, column three. I can't remember this row-wise or column-wise, but it's flattened out anywho. And then it's actually, for each image, it's actually a matrix, which I'm going to draw it a little bit 3D because we've got the channel dimension as well. So this is going to be the number across this way is going to be equal to the height times the width.
And then the number this way is going to be the number of channels. Okay, so how do we decide, yeah, which, you know, bring in these other pixels? Well, what we do is we basically create a weighted average of all of these pixels. So maybe these ones get a bit of a negative weight and these ones get a bit of a positive weight and, you know, these get a weight kind of somewhere in between.
And so we're going to have a weighted average. And so basically each pixel, so let's say we're doing this pixel here right now, is going to equal its original pixel plus, so let's call it x, plus the weighted average. So the sum across, so maybe this is like x, i, plus the sum of over all the other pixels.
So from zero to the height times the width. Sum weight times each pixel. The weights, they're going to sum to one. And so that way the, you know, the pixel value scale isn't going to change. Well, that's not actually quite true. It's going to end up potentially twice as big, I guess, because it's being added to the original pixel.
So attention itself is not with the x plus, but the way it's done in stable diffusion, at least, is that the attention is added to the original pixel. So, yeah, now I think about it. I'm not going to need to think about how this is being scaled, anyhow. So the big question is what values to use for the weights.
And the way that we calculate those is we do a matrix product. And so our, for a particular pixel, we've got, you know, the number of channels for that one pixel. And what we do is we can compare that to all of the number of channels for all the other pixels.
So we've got kind of, this is pixel, let's say x1. And then we've got pixel number x2. Right, all those channels. We can take the dot product between those two things. And that will tell us how similar they are. And so one way of doing this would be to say, like, okay, well, let's take that dot product for every pair of pixels.
And that's very easy dot product do, because that's just what the matrix product is equal to. So if we've got h by w by c and then multiply it by its transpose, h by w base, sorry, it said transpose and then totally failed to do transpose, multiply by its transpose, that will give us an h by w by h by w matrix.
So each pixel, all the pixels are down here. And for each pixel, as long as these add up to one, then we've got to wait for each pixel. And it's easy to make these add up to one. We could just take this matrix multiplication and take the sigmoid over the last dimension.
And that makes, sorry, not sigmoid. Man, what's wrong with me? Softmax, right? Yep. And take the softmax over the last dimension. And that will give me something that adds the sum equals one. Okay. Now, the thing is, it's not just that we want to find the places where they look the same, where the channels are basically the same, but we want to find the places where they're, like, similar in some particular way, you know?
And so some particular set of channels are similar in one to some different set of channels in another. And so, you know, in this case, we may be looking for the pointy-earedness activations, you know, which actually represented by, you know, this, this, and this, you know, and we want to just find those.
So the way we do that is before we do this matrix product, we first put our matrix through a projection. So we just basically put our matrix through a matrix multiplication, this one. So it's the same matrix, right? But we put it through two different projections. And so that lets it pick two different kind of sets of channels to focus on or not focus on before it decides, you know, of this pixel, similar to this pixel in the way we care about.
And then actually, we don't even just multiply it then by the original pixels. We also put that through a different projection as well. So there's these different projections. Well, then projection one, projection two, and projection three. And that gives it the ability to say, like, oh, I want to compare these channels and, you know, these channels to these channels to find similarity.
And based on similarity, yeah, they want to pick out these channels, right? Both positive and negative weight. So that's why there's these three different projections. And so the projections are called A, Q, and V. Those are the projections. And so they're all being passed the same matrix. And because they're all being passed the same matrix, we call this self-attention.
Okay, Jono, Tindish, I know this is, I know you guys know this very well, but you also know it's really confusing. Did you have anything to add? Change? Anything else? Yeah, I like that you introduced this without resorting to the, let's think of this as queries at all, which I think is, yeah.
Yeah, these are actually short for key, query, and value, even though I personally don't find those useful concepts. Yeah. You'll note on the scaling, you said, oh, so we said it so that the weight's sum to one. And so then we'd need to worry about like, are we doubling the scale of X?
Yeah. But because of that P3, aka V, that projection that can learn to scale this thing that's added to X appropriately. And so it's not like just doubling the size of X, it's increasing it a little bit, which is why we scatter normalization in between all of these attention layers.
But it's not as bad as it might be because we have that V projection. Yeah, that's a good point. And if this is, if P3, or it's actually the V make projection, is initialized such that it would have a mean of zero, then on average it should start out by not messing with our scale.
OK, so yeah, I guess I find it easier to think in terms of code. So let's look at the code. You know, there's actually not much code. I think you've got a bit of background noise too, Jono, maybe. Yes, that's much better. Thank you. So in terms of code, there's, you know, this is one of these things getting everything exactly right.
And it's not just right. I wanted to get it identical to the stable diffusion. So we can say we've made it identical to stable diffusion. I've actually imported the attention block from diffusers so we can compare. And it is so nice when you've got an existing version of something to compare to to make sure you're getting the same results.
So we're going to start off by saying, let's say we've got a 16 by 16 pixel image. And this is some deeper level of activation. So it's got 32 channels with a batch size of 64. So NCHW. I'm just going to use random numbers for now, but this has the, you know, reasonable dimensions for an activation inside a batch size 64 CNN or diffusion model or unit, whatever.
OK, so the first thing we have to do is to flatten these out because, as I said, in 1D attention, this is just ignored. So it's easy to flatten things out. You just say dot view and you pass in the dimensions of the, in this case, the three dimensions we want, which is 6432 and everything else.
Minus one means everything else. So x dot shape colon two. In this case, you know, obviously it'd be easy just to type 6432, but I'm trying to create something that I can paste into a function later. So it's general. So that's the first two elements, 6432. And then the star just inserts them directly in here.
So 6432 minus one. So 16 by 16. Now then, again, because this is all stolen from the NLP world, in the NLP world, things are, have, they call this sequence. So I'm going to call this sequence by which we're in height by width. Sequence comes before channel, which is often called D or dimension.
So we then transpose those last two dimensions. So we've now got batch by sequence, 16 by 16, by channel or dimension. So N, they didn't really call this NSD sequence dimension. Okay, so we've got 32 channels. So we now need three different projections that go from 32 channels in to 32 channels out.
So that's just a linear layer. Okay, and just remember a linear layer is just a matrix multiply plus a bias. So there's three of them. And so they're all going to be randomly initialized at different random numbers. We're going to call them SK, SQ, SV. And so we can then, they're just callable.
So we can then pass the exact same thing into three, all three, because we're doing self-attention to get back our keys, queries, and values, or K, Q, and V. I just think of them as K, Q, and V, because they're not really keys, queries, and values to me. So then we have to do the matrix multiply by the transpose.
And so then for every one of the 64 items in the batch, for every one of the 256 pixels, there are now 256 weights. So at least there would be if we had done softmax, which we haven't yet. So we can now put that into a self-attention. As Johnno mentioned, we want to make sure that we normalize things.
So we can proper normalization here. We talked about group norm back when we talked about batch norm. So group norm is just batch norm, which has been split into a bunch of sets of channels. Okay, so then we are going to create our K, Q, V. Yep, Johnno? I was just going to ask, should those be just bias equals false so that they're only a matrix multiplied to strictly match the traditional implementation?
No, because... Okay, they also do it that way. Yeah, they have bias in their attention blocks. Cool. Okay, so we've got our QK and V, self.q, self.k, self.v being our projections. And so to do 2D self-attention, we need to find the NCHW from our shape. We can do a normalization.
We then do our flattening as discussed. We then transpose the last two dimensions. We then create our QKV by doing the projections. And we then do the matrix multiply. Now, we've got to be a bit careful now because as a result of that matrix multiply, we've changed the scale by multiplying and adding all those things together.
So if we then simply divide by the square root of the number of filters, it turns out that you can convince yourself of this if you wish to, but that's going to return it to the original scale. We can now do the softmax across the last dimension, and then multiply each of them by V.
So using matrix multiply to do them all in one go. We didn't mention, but we then do one final projection. Again, just to give it the opportunity to map things to some different scale. Shift it also if necessary. Transpose the last two back to where they started from, and then reshape it back to where it started from, and then add it.
Remember, I said it's going to be X plus. Add it back to the original. So this is actually kind of self-attention ResNet style, if you like. Diffuses, if I remember correctly, does include the X plus in theirs, but some implementations, like, for example, PyTorch implementation doesn't. Okay, so that's a self-attention module, and all you need to do is tell it how many channels to do attention on.
And you need to tell it that because that's what we need for our four different projections and our group and our scale. I guess, strictly speaking, it doesn't have to be stored here. You could calculate it here, but anyway, either way is fine. Okay, so if we create a self-attention layer, we can then call it on our little randomly generated numbers.
And it doesn't change the shape because we transpose it back and reshape it back, but we can see that's basically worked. We can see it creates some numbers. How do we know if they're right? Well, we could create a diffuser's attention block. That will randomly generate a QKV projection.
Sorry, actually they call something else. They call it a query, key, value, projection, attention, and group norm. We call it QKVprogen norm. They're the same things. And so then we can just zip those tuples together. So that's going to take each pair, first pair, second pair, third pair, and copy the weight and the bias from their attention block.
Sorry, from our attention block to the diffuser's attention block. And then we can check that they give the same value, which you can see they do. So this shows us that our attention block is the same as the diffuser's attention block, which is nice. Here's a trick which neither diffusers nor PyTorch use for reasons I don't understand, which is that we don't actually need three separate projections here.
We could create one projection from Ni to Ni times three. That's basically doing three projections. So we could call this QKV. And so that gives us 64 by 256 by 96 instead of 64 by 256 by 32, because it's the three sets. And then we can use chunk, which we saw earlier, to split that into three separate variables along the last dimension to get us our QKV.
And we can then do the same thing, Q at Q dot transpose, et cetera. So here's another version of attention where we just have one projection for QKV, and we chunkify it into separate QK and V. And this does the same thing. It's just a bit more concise. And it should be faster as well, at least if you're not using some kind of XLA compiler or ONX or Triton or whatever, for normal PyTorch.
This should be faster because it's doing less back and forth between the CPU and the GPU. All right. So that's basic self-attention. This is not what's done basically ever, however, because, in fact, the question of which pixels do I care about depends on which channels you're referring to. Because the ones which are about, oh, what color is its ear, as opposed to how pointy is its ear, might depend more on is this bunny in the shade or in the sun.
And so maybe you may want to look at its body over here to decide what color to make them rather than how pointy to make it. And so, yeah, different channels need to bring in information from different parts of the picture depending on which channel we're talking about. And so the way we do that is with multi-headed attention.
And multi-headed attention actually turns out to be really simple. And conceptually, it's also really simple. What we do is we say, let's come back to when we look at C here and let's split them into four separate vectors. One, two, three, four. Let's split them, right? And let's do the whole dot product thing on just the first part with the first part.
And then do the whole dot product part with the second part with the second part and so forth, right? So we're just going to do it separately, separate matrix multiplies for different groups of channels. And the reason we do that is it then allows, yeah, different parts, different sets of channels to pull in different parts of the image.
And so these different groups are called heads. And I don't know why, but they are. Does that seem reasonable? Anything to add to that? It's maybe worth thinking about why, with just a single head, specifically the softmax starts to come into play. Because, you know, we said it's like a weighted sum, just able to bring in information from different parts and whatever else.
But with softmax, what tends to happen is whatever weight is highest gets scaled up quite dramatically. And so it's like almost like focused on just that one thing. And then, yeah, like, as you said, Jeremy, like different channels might want to refer to different things. And, you know, just having this one like single weight that's across all the channels means that that signal is going to be like focused on maybe only one or two things as opposed to being able to bring in lots of different kinds of information based on the different channels.
Right. I was going to measure the same thing, actually. That's a good point. So you're mentioning the second interesting important point about softmax, you know, point one is that it creates something that adds to one. But point two is that because of its e to the z, it tends to highlight one thing very strongly.
And yes, so if we had single-headed attention, your point, guys, I guess, is that you're saying it would end up basically picking nearly all one pixel, which would not be very interesting. OK, awesome. Oh, I see where everything's got thick. I've accidentally turned it into a marker. Right. OK, so multi-headed attention.
I'll come back to the details of how it's implemented in terms of, but I'm just going to mention the basic idea. This is multi-headed attention. And this is identical to before, except I've just stored one more thing, which is how many heads do you want. And then the forward is actually nearly all the same.
So this is identical, identical, identical. This is new. Identical, identical, identical, new, identical, identical. So there's just two new lines of code, which might be surprising, but that's all we needed to make this work. And they're also pretty wacky, interesting new lines of code to look at. Conceptually, what these two lines of code do is they first, they do the projection, right?
And then they basically take the number of heads. So we're going to do four heads. We've got 32 channels, four heads. So each head is going to contain eight channels. And they basically grab, they're going to, we're going to keep it as being eight channels, not as 32 channels.
And we're going to make each batch four times bigger, right? Because the images in a batch don't combine with each other at all. They're totally separate. So instead of having one image containing 32 channels, we're going to turn that into four images containing eight channels. And that's actually all we need, right?
Because remember, I told you that each group of channels, each head, we want to have nothing to do with each other. So if we literally turn them into different images, then they can't have anything to do with each other because batches don't react to each other at all. So these rearrange, this rearrange, and I'll explain how this works in a moment, but it's basically saying, think of the channel dimension as being of H groups of D and rearrange it.
So instead, the batch channel is n groups of H and the channels is now just D. So that would be eight instead of four by eight. And then we do everything else exactly the same way as usual, but now that group, that the channels are split into groups of H, groups of four.
And then after that, okay, well, we were thinking of the batches as being of size n by H. Let's now think of the channels as being of size H by D. That's what these rearranges do. So let me explain how these work. In the diffusers code, I've, can't remember if I duplicated it or just inspired by it.
They've got things called heads to batch and batch to heads, which do exactly these things. And so for heads to batch, they say, okay, you've got 64 per batch by 256 pixels by 32 channels. Okay, let's reshape it. So you've got 64 images by 256 pixels by four heads by the rest.
So that would be 32 over eight channels. So it's split it out into a separate dimension. And then if we transpose these two dimensions, it'll then be n by four. So n by heads by SL by minus one. And so then we can reshape. So those first two dimensions get combined into one.
So that's what heads to batch does. And batch to heads does the exact opposite, right? Reshapes to bring the batch back to here and then heads by SL by D and then transpose it back again and reshape it back again so that the heads gets it. So this is kind of how to do it using just traditional PyTorch methods that we've seen before.
But I wanted to show you guys this new-ish library called Inops, inspired as it suggests by Einstein summation notation. But it's absolutely not Einstein summation notation. It's something different. And the main thing it has is this thing called rearrange. And rearrange is kind of like a nifty rethinking of Einstein summation notation as a tensor rearrangement notation.
And so we've got a tensor called t we created earlier, 64 by 256 by 32. And what Inops rearrange does is you pass it this specification string that says, turn this into this. Okay, this says that I have a rank three tensor, three dimensions, three axes, containing the first dimension is of length n, the second dimension is of length s, the third dimension is in parentheses is of length h times d, where h is eight.
Okay, and then I want you to just move things around so that nothing is like broken, you know, so everything's shifted correctly into the right spots so that we now have each batch is now instead n times eight, n times h. The sequence length is the same, and d is now the number of channels.
Previously the number of channels was h by d. Now it's d, so the number of channels has been reduced by a factor of eight. And you can see it here, it's turned t from something of 64 by 256 by 32 into something of size 64 times eight by 256 by 32 divided by eight.
And so this is like really nice because, you know, a, this one line of code to me is clearer and easier and I liked writing it better than these lines of code. But whereas particularly nice is when I had to go the opposite direction, I literally took this, cut it, put it here and put the arrow in the middle.
Like it's literally backwards, which is really nice, right? Because we're just rearranging it in the other order. And so if we rearrange in the other order, we take our 512 by 256 by 4 thing that we just created and end up with a 64 by 256 by 32 thing, which we started with, and we can confirm that the end thing equals, or every element equals the first thing.
So that shows me that my rearrangement has returned its original correctly. Yeah, so multi-headed attention, I've already shown you. It's the same thing as before, but pulling everything out into the batch for each head and then pulling the heads back into the channels. So we can do multi-headed attention with 32 channels and four heads and check that all looks okay.
So PyTorch has that all built in. It's called nn.multi_headed_attention. Be very careful. Be more careful than me, in fact, because I keep forgetting that it actually expects the batch to be the second dimension. So make sure you write batch first equals true to make batch the first dimension and that way it'll be the same as diffusers.
I mean, it might not be identical, but the same. It should be almost the same idea. And to make it self-attention, you've got to pass in three things, right? So the three things will all be the same for self-attention. This is the thing that's going to be passed through the Q projection, the K projection and the V projection.
And you can pass different things to those. If you pass different things to those, you'll get something called cross-attention rather than self-attention, which I'm not sure we're going to talk about until we do it in NLP. Just on the rearrange thing, I know that if you've been doing PyTorch and you used to, like, you really know what transpose and, you know, reshape and whatever do, then it can be a little bit weird to see this new notation.
But once you get into it, it's really, really nice. And if you look at the self-attention multi-headed implementation there, you've got dot view and dot transpose and dot reshape. It's quite fun practice. Like, if you're just saying, oh, this INOPS thing looks really useful, like, take an existing implementation like this and say, oh, maybe instead of, like, can I do it instead of dot reshape or whatever, can I start replacing these individual operations with the equivalent, like, rearrange call?
And then checking at the output to the same, like, that's what helped it, like, click for me was, oh, okay. Like, I can start to express, if it's just transpose, then that's a rearrange with the last two channels. Yeah. I only just started using this. And I've obviously had many years of using reshape transpose, et cetera, in Theano, TensorFlow, Keras, PyTorch, APL.
And I would say within 10 minutes, I was like, oh, I like this much better. You know, like, it's fine for me at least. It didn't take too long to be convinced. It's not part of PyTorch or anything. You've got to pip install it, by the way. And it seems to be becoming super popular now, at least in the kind of diffusion research crowd.
Everybody seems to be using INOPS suddenly, even though it's been around for a few years. And I actually put in an issue there and asked them to add in Einstein summation notation as well, which they've now done. So it's kind of like your one place for everything, which is great.
And it also works across TensorFlow and other libraries as well, which is nice. Okay. So we can now add that to our unit. So this is basically a copy of the previous notebook, except what I've now done is I did this at the point where it's like, oh, yeah, it turns out that cosine scheduling is better.
So I'm back to cosine schedule now. This is copied from the cosine schedule book. And we're still doing the minus 0.5 thing because we love it. And so this time, I actually decided to export stuff into a mini-AI.diffusion. So this point, I still think things are working pretty well.
And so I renamed unit.com to pre-con, since it's a better name. Time step embedding has been exported. Up sample's been exported. This is like a pre-act linear version exported. I tried using an n.multihead attention, and it didn't work very well for some reason. So I haven't figured out why that is yet.
So I'm using, yeah, this self-attention, which we just talked about. Multiheaded self-attention. You know, just the scale, we have to divide the number of channels by the number of heads because the effective number of heads is, you know, divided across n heads. And instead of specifying n heads, yeah, you specify attention channels.
So if you have like 32, n_I is 32, attention channels is 8, then you calculate. Yeah, that's what diffusers does, I think. It's not what an n.multihead attention does. And actually, I think n_I divided by n_I divided by attention chance is actually just equal to attention chance. So I could have just put that probably.
Anyway, never mind. Yeah. So okay, so that's all copied in from the previous one. The only thing that's different here is I haven't got the dot view minus one thing here. So this is a 1D self-attention, and then 2D self-attention just adds the dot view before we call forward and then dot reshape it back again.
So yeah, so we've got 1D and 2D self-attention. Okay, so now our MRes block has one extra thing you can pass in, which is attention channels. And so if you pass in attention channels, we're going to create something called self.attention, which is a self-attention 2D layer with the right number of filters and the requested number of channels.
And so this is all identical to what we've seen before, except if we've got attention, then we add it. Oh yeah, and the attention that I did here is the non-res-netty version. So we have to do x plus because that's more flexible. You can then choose to have it or not have it this way.
Okay, so that's an MRes block with attention. And so now our down block, you have to tell it how many attention channels you want, because the res blocks need that. The up block, you have to know how many attention channels you want, because again the res blocks need that.
And so now the unit model, where does the attention go? Okay, we have to say how many attention channels you want. And then you say which index block do you start adding attention? So why don't we, so then what happens is the attention is done here. Each res-net has attention.
And so as we discussed, you just do the normal res and then the attention, right? And if you put that in at the very start, right, let's say you've got a 256 by 256 image. Then you're going to end up with this matrix here. It's going to be 256 by 256 on one side and 256 by 256 on the other side and contain however many, you know, NF channels.
That's huge. And you have to back prop through it. So you have to store all that to allow back prop to happen. It's going to explode your memory. So what happens is basically nobody puts attention in the first layers. So that's why I've added a attention start, which is like at which block do we start adding attention and it's not zero for the reason we just discussed.
Another way you could do this is to say like at what grid size should you start adding attention? And so generally speaking, people say when you get to 16 by 16, that's a good time to start adding attention. Although stable diffusion adds it at 32 by 32 because remember they're using latents, which we'll see very shortly I guess in the next lesson.
So it starts at 64 by 64 and then they add attention at 32 by 32. So we're again, we're replicating stable diffusion here. Stable diffusion uses attention start at index one. So we, you know, when we go self.down, dot append, the down block has zero attention channels if we're not up to that block yet.
And ditto on the up block, except we have to count from the end blocks. Now I think about it, that should have attention as well, the mid block. So that's missing. Yeah, so the forward actually doesn't change at all for attention. It's only the in it. Yeah, so we can train that.
And so previously, yeah, we got without attention, we got to 137. And with attention, oh, we can't compare directly because we've changed from Keras to cosine. We can compare the sampling though. So we're getting, what are we getting? 4, 5, 5, 5. It's very hard to tell if it's any better or not because, well, again, you know, our cosine schedule is better.
But yeah, when I've done kind of direct like with like, I haven't managed to find any obvious improvements from adding attention. But I mean, it's doing fine, you know, 4 is great. Yeah. All right. So then finally, did you guys want to add anything before we go into a conditional model?
I was just going to make a note that, like, I guess, just to clarify, with the attention, part of the motivation was certainly to do the sort of spatial mixing and kind of like, yeah, to get from different parts of the image and mix it. But then the problem is, if it's too early, where you do have one of, you know, the more individual pixels, then the memory is very high.
So it seems like you have to get that balance of where you don't, you kind of want it to be early. So you can do some of that mixing, but you don't want to be too early, where then the memory usage is, is too high. So it seems like there is certainly kind of the balance of trying to find maybe that right place where to add attention into your network.
So I just thought I was just thinking about that. And maybe that's a point worth noting. Yeah, for sure. There is a trick, which is like what they do in, for example, vision transformers, or the DIT, the diffusion with transformers, which is that if you take like an eight by eight patch of the image, and you flatten that all out, or you run that through some like convolutional thing to turn it into a one by one by some larger number of channels, but you can reduce the spatial dimension by increasing the number of channels.
And that gets you down to like a manageable size where you can then start doing attention as well. So that's another trick is like patching, where you take a patch of the image and you focus on that as some number, like some embedding dimension or whatever you like to think of it, that as a one by one rather than an eight by eight or a 16 by 16.
And so that's how, like you'll see, you know, 32 by 32 patch models, like some of the smaller clip models or 14 by 14 patch for some of the larger like DIT classification models and things like that. So that's another Yeah, I guess that's the yeah, that's mainly used when you have like a full transformer network, I guess.
And then this is one where we have that sort of incorporating the attention into convolutional network. So there's certainly, I guess, yeah, for different sorts of networks, different tricks. But yeah. Yeah. And I haven't decided yet if we're going to look at the IT or not. Maybe we should based on what you're describing.
I was just going to mention, though, that since you mentioned transformers, we've actually now got everything we need to create a transformer. Here's a transformer block with embeddings. A transformer block with embeddings is exactly the same embeddings that we've seen before. And then we add attention, as we've seen before, there's a scale and shift.
And then we pass it through an MLP, which is just a linear layer, an activation, a normalize, and a linear layer. For whatever reason, this is, you know, GALU, which is just another activation function, is what people always use in transformers. For reasons I suspect don't quite make sense in vision, everybody uses layer norm.
And again, I was just trying to replicate an existing paper. But this is just a standard MLP. So if you do, so in fact, if we get rid of the embeddings, just to show you a true pure transformer. Okay, here's a pure transformer block. So it's just normalize, attention, add, normalize, multi-layer perceptron, add.
That's all a transformer block is. And then what's a transformer network? A transformer network is a sequential of transformers. And so in this diffusion model, I replaced my mid block with a list of sequential transformer blocks. So that is a transformer network. And to prove it, this is another version in which I replaced that entire thing with the PyTorch transformers encoder.
This is called encoder. This is just taken from PyTorch. And so that's the encoder. And I just replaced it with that. So yeah, we've now built transformers. Now, okay, why aren't we using them right now? And why did I just say, I'm not even sure if we're going to do VIT, which is vision transformers.
The reason is that transformers, you know, they're doing something very interesting, right? Which is, remember, we're just doing 1D versions here, right? So transformers are taking something where we've got a sequence, right? Which in our case is pixels height by width, but we'll just call it the sequence. And everything in that sequence has a bunch of channels, right, for dimensions, right?
I'm not going to draw them all, but you get the idea. And so for each element of that sequence, which in our case, it's just some particular pixel, right? And these are just the filters, channels, activations, whatever, activations, I guess. What we're doing is the, we first do attention, which, you know, remember there's a projection for each.
So like it's mixing the channels a little bit, but just putting that aside, the main thing it's doing is each row is getting mixed together, you know, into a weighted average. And then after we do that, we put the whole thing through a multi-layer perceptron. And what the multi-layer perceptron does is it entirely looks at each pixel on its own.
So let's say this one, right, and puts that through linear, activation, norm, linear, which we call an MLP. And so a transformer network is a bunch of transformer layers. So it's basically going attention, MLP, attention, MLP, attention, et cetera, et cetera, MLP. That's all it's doing. And so in other words, it's mixing together the pixels or sequences, and then it's mixing together the channels.
Then it's mixing together the sequences and then mixing together the channels and it's repeating this over and over. Because of the projections being done in the attention, it's not just mixing the pixels, but it's kind of, it's largely mixing the pixels. And so this combination is very, very, very flexible.
And it's flexible enough that it provably can actually approximate any convolution that you can think of given enough layers and enough time and learning the right parameters. The problem is that for this to approximate a combination requires a lot of data and a lot of layers and a lot of parameters and a lot of compute.
So if you try to use this, so this is a transformer network, transformer architecture. If you pass images into this, so pass an image in and try to predict, say from ImageNet, the class of the image. So use SGD to try and find weights for these attention projections and MLPs.
If you do that on ImageNet, you will end up with something that does indeed predict the class of each image, but it does it poorly. Now it doesn't do it poorly because it's not capable of approximating a convolution. It does it poorly because ImageNet, the entire ImageNet as in ImageNet 1k is not big enough to for a transformer to learn how to do this.
However, if you pass it a much bigger data set, many times larger than ImageNet 1k, then it will learn to approximate this very well. And in fact, it'll figure out a way of doing something like convolutions that are actually better than convolutions. And so if you then take that, so that's going to be called a vision transformer or VIT that's been pre-trained on a data set much bigger than ImageNet, and then you fine tune it on ImageNet, you will end up with something that is actually better than ResNet.
And the reason it's better than ResNet is because these combinations, right, which together when combined can approximate a convolution, these transformers, you know, convolutions are our best guess as to like a good way to kind of represent the calculations we should do on images. But there's actually much more sophisticated things you could do, you know, if you're a computer and you could figure these things out better than a human can.
And so a VIT actually figures out things that are even better than convolutions. And so when you fine tune ImageNet using a very, you know, a VIT that's been pre-trained on lots of data, then that's why it ends up being better than a ResNet. So that's why, you know, the things I'm showing you are not the things that contain transformers and diffusion because to make that work would require pre-training on a really, really large data set for a really, really long amount of time.
So anyway, so we might only come to transformers, well not in a very long time, but when we do them in NLP in vision, maybe we'll cover them briefly, you know, they're very interesting to use as pre-trained models. The main thing to know about them is, yeah, a VIT, you know, which is a really successful and when pre-trained on lots of data, which they all are nowadays, is a very successful architecture.
But like literally the VIT paper says, oh, we wondered what would happen if we take a totally plain 1D transformer, you know, and convert it and make it work on images with as few changes as possible. So everything we've learned about attention today and MLPs applies directly because they haven't changed anything.
And so one of the things you might realize that means is that you can't use a VIT that was trained on 224 by 224 pixel images on 128 by 128 pixel images because, you know, all of these self-attention things are the wrong size, you know, and specifically the problem is actually the, actually it's not really the attention, let me take that back.
All of the position embeddings are the wrong size. And so actually that's something I, sorry, I forgot to mention, is that in transformers the first thing you do is you always take your, you know, these pixels and you add to them a positional embedding. And that's done, I mean, it can be done lots of different ways, but the most popular way is identical to what we did for the time step embedding is the sinusoidal embedding.
And so that's specific, you know, to how many pixels there are in your image. So yeah, that's an example, it's one of the things that makes VITs a little tricky. Anyway, hopefully, yeah, you get the idea that we've got all the pieces that we need. Okay, so with that discussion, I think that's officially taken us over time.
So maybe we should do the conditional next time. Do you know what actually it's tiny? Let's just quickly do it now. You guys got time? Yeah. Okay. So let's just, yeah, let's finish by doing a conditional model. So for a conditional model, we're going to basically say I want something where I can say draw me the number, sorry, draw me a shirt, or draw me some pants, or draw me some sandals.
So we're going to pick one of the 10 fashion MNIST classes and create an image of a particular class. To do that, we need to know what class each thing is. Now, we already know what class each thing is because it's the Y label, which way back in the beginning of time, we set, okay, it's just called the label.
So that tells you what category it is. So we're going to change our collation function. So we call noisify as per usual. That gives us our noised image, our time step, and our noise. But we're also going to then add to that tuple what kind of fashion item is this.
And so the first tuple will be noised image, noise, and label, and then the dependent variable as per usual is the noise. And so what's going to happen now when we call our unit, which is now a conditioned unit model, is the input is now going to contain not just the activations and the time step, but it's also going to contain the label.
Okay, that label will be a number between zero and nine. So how do we convert the number between zero and nine into a vector which represents that number? Well, we know exactly how to do that in n.embedding. Okay, so we did that lots in part one. So let's make it exactly, you know, the same size as our time embedding.
So n number of activations in the embedding. It's going to be the same as our time step embedding. And so that's convenient. So now in the forward we do our time step embedding as usual. We'll pass the labels into our conditioned embedding. The time embedding we will put through the embedding let LPN before, and then we're just going to add them together.
And that's it, right? So this now represents a combination of the time and the fashion item plus. And then everything else is identical in both parts. So all we've added is this one thing. And then we just literally sum it up. So we've now got a joint embedding representing two things.
And then, yeah, and then we train it. And, you know, interestingly, it looks like the loss, well, it ends up a bit the same, but you don't often see 0.031. It is a bit easier for it to do a conditional embedding model because you're telling it what it is.
It just makes it a bit easier. So then to do conditional sampling, you have to pass in what type of thing do you want from these labels. And so then we create a vector just containing that number repeated over many times there on the batch. And we pass it to our model.
So our model has now learned how to denoise something of type C. And so now if we say like, oh, trust me, this noise contains is a noised image of type C, it should hopefully denoise it into something of type C. That's all there is to it. There's no magic there.
So, yeah, that's all we have to do to change the sampling. So we didn't have to change DDIM step at all, right? Literally all we did was we added this one line of code and we added it there. So now we can say, okay, let's say class ID zero, which is t-shirt top.
So we'll pass that to sample. And there we go. Well, everything looks like t-shirts and tops. Yeah, okay. I'm glad we didn't leave that till next time because we can now say we have successfully replicated everything in stable diffusion, except for being able to create whole sentences, which is what we do with clip.
Getting really close. Yes. Well, except the clip requires all of NLP. So I guess we'll, we might do that next or depending on how research goes. All right. We still need a latent diffusion part. Oh, good point. Latents. Okay. We'll definitely do that next time. So let's see. Yeah.
So we'll do a VAE and latent diffusion, which isn't enough for one lesson. So maybe some of the research I'm doing will end up in the next lesson as well. Yes. Okay. Thanks for the reminder. Although we've already kind of done auto-encoders, so VAEs are going to be pretty, pretty easy.
Well, thank you, Tanishka and Johnno. Fantastic comments, as always. Glad your internet/power reappeared, Johnno. Back up. All right. Thanks, gang. Cool. Thanks, everybody. That was great.