back to indexLesson 24: Deep Learning Foundations to Stable Diffusion
00:00:00.000 |
Hi, we are here for lesson 24. And once again, it's becoming a bit of a tradition now. We're 00:00:06.320 |
joined by Jono and Tanishk, which is always a pleasure. Hi, Jono. Hi, Tanishk. 00:00:12.600 |
Yeah, are you guys looking forward to finally actually completing stable diffusion, at least 00:00:19.960 |
the unconditional stable diffusion? Well, I should say no, even conditional. So conditional 00:00:24.040 |
stable diffusion, except for the clip bit from scratch. We should be able to finish 00:00:32.000 |
That is exciting. All right. Let's do it. Jump in any time. We've got things to talk 00:00:41.520 |
about. So we're going to start with a very hopefully named 26 diffusion unit. And what 00:00:52.720 |
we're going to do in 26 diffusion unit is to do unconditional diffusion from scratch. 00:01:04.480 |
And there's not really too many new pieces, if I remember correctly. So all the stuff 00:01:09.880 |
at the start we've already seen. And so when I wrote this, it was before I had noticed 00:01:17.920 |
that the Keras approach was doing less well than the regular cosine schedule approach. 00:01:22.480 |
So I'm still using Keras Noisify, but this is all the same as from the Keras notebook, 00:01:31.440 |
Okay, so we can now create a unit that is based on what diffusers has, which is in turn 00:01:46.240 |
based on lots of other prior art. I mean, the code's not at all based on it, but the 00:01:52.280 |
basic structure is going to be the same as what you'll get in diffusers. The convolution 00:02:00.880 |
we're going to use is the same as the final kind of convolution we used for tiny image 00:02:06.080 |
net, which is what's called the preactivation convolution. So the convolution itself happens 00:02:12.460 |
at the end and the normalization and activation happen first. So this is a preact convolution. 00:02:23.720 |
So then I've got a unit res net block. So I kind of wrote this before I actually did 00:02:31.000 |
the preact version of tiny image net. So I suspect this is actually the same, quite possibly 00:02:36.880 |
exactly the same as the tiny image net one. So maybe this is nothing specific about this 00:02:41.200 |
for unit, this is just really a preact conv and a preact res net block. So we've got the 00:02:49.320 |
two comms as per usual and the identity conv. Now there is one difference though to what 00:02:56.240 |
we've seen before for res net blocks, which is that this res net block has no option to 00:03:01.360 |
do downsampling, no option to do a strayed. This is always strayed one, which is our default. 00:03:12.760 |
So the reason for that is that when we get to the thing that strings a bunch of them 00:03:18.860 |
together, which will be called down block, this is where you have the option to add downsampling. 00:03:26.040 |
But if you do add downsampling, we're going to add a strayed to convolution after the res 00:03:31.040 |
block. And that's because this is how diffusers and stable diffusion does it. I haven't studied 00:03:40.960 |
this closely to Nishkif or Dono if either of you have know like where this idea came 00:03:46.440 |
from or why. I'd be curious, you know, the difference is that normally we would have 00:03:54.480 |
average pooling here in this connection. But yeah, this different approach is what we're 00:04:04.060 |
A lot of the history of the diffusers unconditional unit is to be compatible with the DDPM weights 00:04:12.520 |
that were released and some follow and work from that. And I know like then improved DDPM 00:04:17.560 |
and these others like they all kind of built on that same sort of unit structure, even 00:04:21.120 |
though it's slightly unconventional if you're coming from like a normal computer vision 00:04:26.720 |
And do you recall where the DDPM architecture came from? Because like some of the ideas 00:04:31.600 |
came from some of the N units, but I don't know if DDPM. 00:04:38.080 |
Yeah, they had something called efficient unit that was inspired by some prior work 00:04:42.760 |
that I can't remember the lineage. Anyway, yeah, I just think the diffusers one has since 00:04:50.320 |
become you know, like you can add in parameters to control some of this stuff. But yeah, it's 00:04:56.000 |
we shouldn't assume that this is the optimal approach, I suppose. But yeah, I will dig 00:05:01.280 |
into the history and try and find out how much like, what ablation studies have been 00:05:05.480 |
done. So for those of you who haven't heard of ablation studies, that's where you'd like 00:05:08.480 |
try, you know, a bunch of different ways of doing things and score which one works better 00:05:13.160 |
and which one works less well and kind of create a table of all of those options. And 00:05:19.000 |
so where you can't find ablation studies for something you're interested in, often that 00:05:24.160 |
means that, you know, maybe not many other options were tried because researchers don't 00:05:31.480 |
Okay, now the unit, if we go back to the unit that we used for super resolution, we just 00:05:41.640 |
go back to our most basic version. What we did as we went down through the layers in 00:05:51.560 |
the down sampling section, we stored the activations at each point into a list called layers. 00:06:06.600 |
And then as we went through the up sampling, we added those down sampling layers back into 00:06:13.000 |
the up sampling activations. So that's kind of basic structure of a unit. You don't have 00:06:20.440 |
to add, you can also concatenate and actually concatenating is what is, I think it's more 00:06:27.160 |
common nowadays and I think your original unit might have been concatenating. Although 00:06:31.840 |
for super resolution, just adding seems pretty sensible. So we're going to concatenate. But 00:06:37.920 |
what we're going to do is we're going to try to, we're going to kind of exercise our Python 00:06:42.400 |
muscles a little bit to try to see interesting ways to make some of this a little easier 00:06:47.520 |
to turn different down sampling backbones into units. And you also use that as an opportunity 00:06:55.880 |
to learn a bit more Python. So what we're going to do is we're going to create something called 00:07:03.460 |
a saved res block and a saved convolution. And so our down blocks, so these are our res 00:07:11.880 |
blocks containing a certain number of res block layers, followed by this optional strive 00:07:18.800 |
to conv. We're going to use saved res blocks and saved cons. And what these are going to 00:07:23.080 |
do, it's going to be the same as a normal convolution and the same as a normal res block, 00:07:27.600 |
the same as normal unit res block. But they're going to remember the activations. And the 00:07:35.240 |
reason for that is that later on in the unit, we're going to go through and grab those saved 00:07:43.040 |
activations all at once into a big list. So then yeah, we basically don't have to kind 00:07:51.240 |
of think about it. And so to do that, we create a class called a save module. And all saved 00:07:58.400 |
module does is it calls forward to grab the res block or conv results and stores that 00:08:11.400 |
before returning it. Now that's weird because hopefully you know by now that super calls 00:08:19.120 |
the thing in the parent class, that save module doesn't have a parent class. So this is what's 00:08:25.480 |
called a mixin. And it's using something called multiple inheritance. And mixins are as it 00:08:42.800 |
describes here. It's a design pattern, which is to say it's not particularly a part of 00:08:50.800 |
Python per se. It's a design pattern that uses multiple inheritance. Now what multiple 00:08:55.000 |
inheritance is where you can say, oh, this class called saved res block inherits from 00:09:03.000 |
two things, save module and unit res block. And what that does is it means that all of 00:09:11.080 |
the methods in both of these will end up in here. Now that would be simple enough, except 00:09:18.040 |
we've got a bit of a confusion here, which is that unit res block contains forward and 00:09:23.440 |
saved module contains forward. So it's all very well just combining the methods from 00:09:27.960 |
both of them. But what if they have the same method? And the answer is that the one you 00:09:35.440 |
list first can call, when it calls forward, it's actually calling forward in the later 00:09:44.040 |
one. And that's why it's a mixin. It's mixing this functionality into this functionality. 00:09:51.560 |
So it's a unit res block where we've customized forward. So it calls the existing forward 00:09:57.520 |
and also saves it. So you see mixins quite a lot in the Python standard library. For 00:10:04.560 |
example, the basic HTTP stuff, some of the basic thread stuff with networking uses multiple 00:10:16.280 |
inheritance using this mixin pattern. So with this approach, then the actual implementation 00:10:22.120 |
of saved res block is nothing at all. So pass means don't do anything. So this is just literally 00:10:29.080 |
just a class which has no implementation of its own other than just to be a mixin of these 00:10:36.520 |
two classes. So a saved convolution is an nn.conf2d with the saved module mixed in. 00:10:47.800 |
So what's going to happen now is that we can call a saved res block just like a unit res 00:10:54.280 |
block and a saved conv just like an nn.conf2d. But that object is going to end up with the 00:11:01.400 |
activations inside the dot saved attribute. So now a downsampling block is just a sequential 00:11:10.920 |
of saved res blocks. As per usual, the very first one is going to have the number of n 00:11:22.640 |
channels to start with and it will always have the number of nf, the number of filters 00:11:27.640 |
output, and then after that the inputs will be else equal to nf because the first one's 00:11:32.560 |
changed the number of channels. And we'll do that for however many layers we have. And 00:11:38.680 |
then at the end of that process, as we discussed, we will add to that sequential a saved conv 00:11:45.600 |
with str2 to do the downsampling if requested. So we're going to end up with a single nn.sequential 00:11:51.360 |
for a down block. And then an up block is going to look very similar, but instead of 00:12:01.800 |
using an nn.conf2d with str2, upsampling will be done with a sequence of an upsampling layer. 00:12:13.400 |
And so literally all that does is it just duplicates every pixel four times into little 00:12:17.360 |
two by two grid. That's what an upsampling layer does, nothing clever. And then follow 00:12:22.360 |
that by a str1 convolution. So that allows it to, you know, adjust some of those pixels 00:12:31.000 |
as if necessary with a simple three by three conv. So that's pretty similar to a str2 downsampling. 00:12:40.040 |
This is kind of the rough equivalent for upsampling. There are other ways of doing upsampling. 00:12:47.120 |
This is just the one that stable diffusion does. So an up block looks a lot like a down 00:12:54.540 |
block, except that now, so as before, we're going to create a bunch of unit res blocks. 00:13:02.360 |
These are not saved res blocks, of course. We want to use the saved results in the upsampling 00:13:06.600 |
path of the unit. So we just use normal res blocks. But what we're going to do now is 00:13:13.960 |
as we go through each res net, we're going to call it not just on our activations, but 00:13:20.800 |
we're going to concatenate that with whatever was stored during the downsampling path. So 00:13:29.640 |
this is going to be a list of all of the things stored in the downsampling path. It'll be 00:13:34.880 |
passed to the up block. And so .pop will grab the last one off that list and concatenate 00:13:41.120 |
it with the activations and pass that to the res net. 00:13:45.480 |
So we need to know how many filters there were, how many activations there were in the 00:13:52.760 |
downsampling path. So that's stored here. This is the previous number of filters in the downsampling 00:13:57.920 |
path. And so the res block wanted to add those in in addition to the normal number. So that's 00:14:12.120 |
what's going to happen there. And so yeah, do that for each layer as before. And then 00:14:19.880 |
at the end, add an upsampling layer if it's been requested. So it's a boolean. OK, so 00:14:30.760 |
that's the upsampling block. Does that all make sense so far? 00:14:38.600 |
OK. OK, so the unit now is going to look a lot like our previous unit. We're going to 00:14:49.320 |
start out as we tend to with a convolution to now allow us to create a few more channels. 00:14:57.080 |
And so we're passing to our unit. That's just how many channels are in your image and how 00:15:02.320 |
many channels are in your output image. So for normal full color images, that'll be 3/3. 00:15:08.780 |
How many filters are there for each of those res net blocks, up blocks and down blocks 00:15:13.800 |
you've got. And in the downsampling, how many layers are there in each block? So we go from 00:15:21.240 |
the conv will go from in channel. So it'd be 3 to an F0, which this is the number of filters 00:15:26.800 |
in the stable diffusion model. They're pretty big, as you see by default. And so that's 00:15:37.160 |
the number of channels we would create, which is like very redundant in that this is a 3 00:15:45.480 |
by 3 conv. So it only contains 3 by 3 by 3 channels equals 27 inputs and 224 outputs. 00:15:53.360 |
So it's not doing computation, useful computation in a sense. It's just giving it more space 00:16:00.220 |
to work with down the line, which I don't think that makes sense, but I haven't played 00:16:06.360 |
with it enough to be sure. Normally we would do like, you know, like a few res blocks or 00:16:15.200 |
something at this level to more gradually increase it because this feels like a lot 00:16:18.640 |
of wasted effort. But yeah, I haven't studied that closely enough to be sure. 00:16:24.360 |
So Jamie, just to tweet, this is the default, I think the default settings for the unconditional 00:16:30.000 |
unit in diffusers. But the stable diffusion unit actually has even more channels. It has 00:16:34.400 |
320, 640, and then 1,280, 1,280. Cool. Thanks for clarifying. And it's, yeah, 00:16:41.800 |
the unconditional one, which is what we're doing right now. That's a great point. Okay. 00:16:47.720 |
So then we, yeah, we go through all of our number of filters and actually the first res 00:16:54.520 |
block contains 224 to 224. So that's why it's kind of keeping track of this stuff. And then 00:17:01.480 |
the second res block is 224 to 448 and then 448 to 672 and then 672 to 896. So that's 00:17:07.760 |
why we're just going to have to keep track of these things. So yeah, we add, so we have 00:17:12.800 |
a sequential for our down blocks and we just add a down block. The very last one doesn't 00:17:18.440 |
have downsampling, which makes sense, right? Because the very last one, there's nothing 00:17:23.560 |
after it, so no point downsampling. Other than that, they all have downsampling. 00:17:29.240 |
And then we have one more res block in the middle, which, is that the same as what we 00:17:37.000 |
did? Okay. So we didn't have a middle res block in our original unit here. What about 00:17:47.280 |
this one? Do we have any mid blocks? No, so we haven't done. Okay. But I mean, so it's 00:17:51.240 |
just another res block that you do after the downsampling. And then we go through the reversed 00:17:57.920 |
list of filters and go through those and adding up blocks. And then one convolution at the 00:18:02.960 |
end to turn it from 224 channels to three channels. Okay. And so the forward then is 00:18:14.560 |
going to store in saved all the layers, just like we did back with this unit. But we don't 00:18:27.240 |
really have to do it explicitly now. We just call the sequential model. And thanks to our 00:18:32.840 |
automatic saving, each of those now will, we can just go through each of those and grab 00:18:38.400 |
their dot saved. So that's handy. We then call that mid block, which is just another 00:18:44.080 |
res block. And then same thing. Okay. Now for the ARPS. And what we do is we just passed 00:18:48.360 |
in those saved, right? And just remember, it's going to pop them out each time. And 00:18:56.600 |
then the conv at the end. So that's, yeah, that's it. That's our unconditional model. 00:19:06.200 |
It's not quite the same as the diffuses unconditional model, because it doesn't have a tension, which 00:19:09.800 |
is something we're going to add next. But other than that, this is the same. So let's 00:19:17.120 |
for, because we're doing a simpler problem, which is fashion MNIST, we'll use less channels 00:19:21.560 |
than the default. Using two layers per block is standard. One thing to note though, is 00:19:29.280 |
that in the up sampling blocks, it actually is going to be three layers, num layers plus 00:19:34.560 |
one. And the reason for that is that the way stable diffusion and diffuses do it is that 00:19:41.360 |
even the output of the down sampling is also saved. So if you have num layers equals two, 00:19:48.600 |
then there'll be two res blocks saving things here and one conv saving things here. So you'll 00:19:54.280 |
have three saved cross connections. So that's why there's an extra plus one here. 00:20:02.760 |
Okay. And then we can just train it using mini AI as per usual. Nope, I didn't save 00:20:13.240 |
it after I last trained it. Sorry about that. So trust me, it trained. Okay. Now that, oh, 00:20:22.800 |
okay. No, that is actually missing something else important as well as attention. The other 00:20:26.840 |
thing it's missing is that thing that we discovered is pretty important, which is the time embedding. 00:20:34.600 |
So we already know that sampling doesn't work particularly well with that time embedding. 00:20:38.280 |
So I didn't even bother sampling this. I didn't want to add all this stuff necessary to make 00:20:42.240 |
that work a bit better. So let's just go ahead and do time embedding. 00:20:48.520 |
So time embedding, there's a few ways to do it. And the way it's done in stable diffusion 00:20:55.340 |
is what's called sinusoidal embeddings. The basic idea, maybe we'll skip ahead a bit. 00:21:04.680 |
The basic idea is that we're going to create a res block with embeddings where forward 00:21:11.080 |
is not just going to get the activations, but it's also going to get T, which is a vector 00:21:19.480 |
that represents the embeddings of each time step. So actually it'll be a matrix because 00:21:25.320 |
it's really in the batch. But for one element of the batch, it's a vector. And it's an embedding 00:21:29.880 |
in exactly the same way as when we did NLP. Each token had an embedding. And so the word 00:21:36.760 |
"the" would have an embedding and the word "Johnno" would have an embedding and the word 00:21:41.040 |
"Tanishk" would have an embedding, although Tanishk would probably actually be multiple 00:21:44.800 |
tokens until he's famous enough that he's mentioned in nearly every piece of literature, 00:21:50.320 |
at which point Tanishk will get his own token, I expect. That's how you know when you've 00:21:56.480 |
made it. So the time embedding will be the same. T of time step zero will have a particular 00:22:05.600 |
vector, time step one will have a particular vector, and so forth. Well, we're doing Keras. 00:22:12.760 |
So actually they're not time step one, two, three. They're actually sigmas, you know. 00:22:18.560 |
So they're continuous. But same idea. A specific value of sigma, which is actually what T is 00:22:26.960 |
going to be, slightly confusingly, will have a specific embedding. Now, we want two values 00:22:37.240 |
of sigma or T, which are very close to each other, should have similar embeddings. And 00:22:43.840 |
if they're different to each other, they should have different embeddings. So how do we make 00:22:51.600 |
that happen? You know, and also make sure there's a lot of variety of the embeddings across 00:22:56.440 |
all the possibilities. So the way we do that is with these sinusoidal time steps. So let's 00:23:04.680 |
have a look at how they work. So you first have to decide how big do you want your embeddings 00:23:09.440 |
to be? Just like we do at NLP. Does the word "the," is it represented by eight floats 00:23:16.440 |
or 16 floats or 400 floats or whatever? Let's just assume it's 16 now. So let's say we're 00:23:24.120 |
just looking at a bunch of time steps, which is between negative 10 and 10. And we'll just 00:23:30.400 |
do 100 of them. I mean, we don't actually have negative sigmas or T. So it doesn't exactly 00:23:36.120 |
make sense. But it doesn't matter. It gives you the idea. And so then we say, OK, what's 00:23:43.960 |
the largest time step you could have or the largest sigma that you could have? Interestingly, 00:23:50.440 |
every single model I've found, every single model I've found uses 10,000 for this. Even 00:23:57.280 |
though that number actually comes from the NLP transformers literature, and it's based 00:24:01.920 |
on the idea of, like, OK, what's the maximum sequence length we support? You could have 00:24:05.960 |
up to 10,000 things in a document or whatever in a sequence. But we don't actually have 00:24:14.800 |
a sigmas that go up to 10,000. So I'm using the number that's used in real life in stable 00:24:19.200 |
diffusion and all the other models. But interestingly, here purely, as far as I can tell, as a hysterical 00:24:25.280 |
accident, because this is like the maximum sequence length that NLP transformers people 00:24:31.200 |
thought they would need to support. OK, now what we're then going to do is we're going 00:24:38.640 |
to be then doing e to the power of a bunch of things. And so that's going to be our exponent. 00:24:47.720 |
And so our exponent is going to be equal to log of the period, which is about nine, times 00:24:57.820 |
the numbers between 0 and 1, eight of them, because we said we want 16. So you'll see 00:25:04.720 |
why we want eight of them and not 16 in a moment. But basically here are the eight exponents 00:25:11.600 |
we're going to use. So then not surprisingly, we do e to the power of that. OK, so we do 00:25:18.880 |
e to the power of that, each of these eight things. And we've also got the actual time 00:25:28.440 |
steps. So imagine these are the actual time steps we have in our batch. So there's a batch 00:25:33.320 |
of 100, and they contain this range of sigmas or time steps. So to create our embeddings, 00:25:43.200 |
what we do is we do an outer product of the exponent.x and the time steps. This is step 00:25:52.520 |
one. And so this is using a broadcasting trick we've seen before. We add a unit axis and 00:26:01.560 |
an axis 0 here, and add a unit axis 1 here, and add a unit axis and axis 0 here. So if 00:26:13.840 |
we multiply those together, then it's going to broadcast this one across this axis and 00:26:22.520 |
this one across this axis. So we end up with a 100 by 8. So it's basically a Cartesian 00:26:28.760 |
product or the possible combinations of time step and exponent multiplied together. And 00:26:36.080 |
so here's a few of those different exponents for a few different values. OK, so that's 00:26:50.800 |
not very interesting yet. We haven't yet reached something where each time step is similar 00:26:57.960 |
to each next door time step. You know, over here, you know, these embeddings look very 00:27:04.000 |
different to each other. And over here, they're very similar. So what we then do is we take 00:27:10.720 |
the sine and the cosine of those. So that is 100 by 8. And that is 100 by 8. And that 00:27:23.780 |
gives us 100 by 16. So we concatenate those together. And so that's a little bit hard 00:27:34.920 |
to wrap your head around. So let's take a look. So across the 100 time steps, 100 sigma, 00:27:44.120 |
this one here is the first sine wave. And then this one here is the second sine wave. 00:27:57.240 |
And this one here is the third. And this one here is the fourth and the fifth. So you can 00:28:05.200 |
see as you go up to higher numbers, you're basically stretching the sine wave out. And 00:28:17.160 |
then once you get up to index 8, you're back up to the same frequency as this blue one, 00:28:25.240 |
because now we're starting the cosine rather than sine. And cosine is identical to sine. 00:28:30.200 |
It's just shifted across a tiny bit. You can see these two light blue lines are the same. 00:28:35.080 |
And these two orange lines are the same. They're just shifted across, I shouldn't say, lines 00:28:39.600 |
or curves. So when we concatenate those all together, we can actually draw a picture of 00:28:45.720 |
it. And so this picture is 100 pixels across and 16 pixels top to bottom. And so if you 00:28:56.880 |
picked out a particular point, so for example, in the middle here for t equals 0, well sigma 00:29:03.000 |
equals 0, one column is an embedding. So the bright represents higher numbers and the dark 00:29:11.260 |
represents lower numbers. And so you can see every column looks different, even though 00:29:17.280 |
the columns next to each other look similar. So that's called a time step embedding. And 00:29:26.180 |
this is definitely something you want to experiment with. I've tried to do the plots I thought 00:29:34.060 |
are useful to understand this. And Johno and Tanishka also had ideas about plots for these, 00:29:42.320 |
which we've shown. But the only way to really understand them is to experiment. So then 00:29:49.320 |
we can put that all into a function where you just say, OK, well, how many times-- sorry, 00:29:53.680 |
what are the time steps? How many embedding dimensions do you want? What's the maximum 00:29:57.780 |
period? And then all I did was I just copied and pasted the previous cells and merged them 00:30:02.200 |
together. So you can see there's our outer product. And there's our cat of sine and cos. 00:30:15.020 |
If you end up with a-- if you have an odd numbered embedding dimension, you have to 00:30:18.480 |
pat it to make it even. Don't worry about that. So here's something that now you can pass 00:30:22.800 |
in the number of-- sorry, the actual time steps or sigma's and the number of embedding 00:30:27.440 |
dimensions. And you will get back something like this. It won't be a nice curve because 00:30:34.480 |
your time steps in a batch won't all be next to each other. It's the same idea. 00:30:39.600 |
Can I call it something on that little visualization there, which goes back to your comment about 00:30:45.400 |
the max period being super high? So you said, OK, adjacent ones are somewhat similar because 00:30:51.200 |
that's what we want. But there is some change. But if you look at all of this first 100, 00:30:57.000 |
some-- just like the half of the embeddings look like they don't really change at all. 00:31:01.200 |
And that's because 50 to 100 on a scale of like 0 to 10,000, you want those to be quite 00:31:07.240 |
similar because those are still very early in this super long sequence that these are 00:31:15.560 |
Yeah. So here we've got a max period of 1,000 instead. And I've changed the figure size so 00:31:20.040 |
you can see it better. And it's using up a bit more of the space. Yeah. Or go to max 00:31:27.240 |
period of 10. And it's actually now-- this is, yeah, using it much better. Yeah. So based 00:31:36.440 |
on what you're saying, Johnno, I agree. It seems like it would be a lot richer to use 00:31:44.280 |
these time step embeddings with a suitable max period. Or maybe you just wouldn't need 00:31:48.640 |
as many embedding dimensions. I guess if you did use something very wasteful like this 00:31:53.600 |
but you used lots of embedding dimensions, then it's going to still capture some useful 00:32:01.520 |
Yeah. Thanks, Johnno. So yeah. Yeah. So this is one of these interesting little insights 00:32:10.840 |
about things that are buried deep in code, which I'm not sure anybody probably much looks 00:32:20.800 |
OK. So let's do a unit with time step embedding in it. So what do you do once you've got like 00:32:31.440 |
this column of embeddings for each item of the batch? What do you do with it? Well, there's 00:32:37.640 |
a few things you can do with it. What stable diffusion does, I think, is correct. I'm not 00:32:44.960 |
promising. I remember all these details right, is that they make their embedding dimension 00:32:51.560 |
length twice as big as the number of activations. And what they then do is we can use chunk 00:33:04.800 |
to take that and split it into two separate variables. So that's literally just the opposite 00:33:10.720 |
of concatenate. It's just two separate variables. And one of them is added to the activations 00:33:17.700 |
and one of them is multiplied by the activations. So this is a scale and a shift. We don't just 00:33:27.680 |
grab the embeddings as is, though, because each layer might want to do-- each res block 00:33:35.000 |
might want to do different things with them. So we have a embedding projection, which is 00:33:41.400 |
just a linear layer which allows them to be projected. So it's projected from the number 00:33:45.920 |
of embeddings to 2 times the number of filters so that that torch.chunk works. We also have 00:33:55.200 |
an activation function called silu. This is the activation function that's used in stable 00:34:01.800 |
diffusion. I don't think the details are particularly important. But it looks basically like a rectified 00:34:15.120 |
linear with a slight curvy bit. Also known as SWISH. Also known as SWISH. And it's just 00:34:25.960 |
equal to x times sigmoid x. And yeah, I think it's like activation functions don't make 00:34:37.120 |
a huge difference. But they can make things train a little better or a little faster. 00:34:44.120 |
And SWISH has been something that's worked pretty well. So a lot of people using SWISH 00:34:48.880 |
or silu. I always call it SWISH. But I think silu was originally the galley paper which 00:34:57.440 |
had silu was where it originally was kind of invented. And maybe people didn't quite 00:35:01.320 |
notice. And then another paper called it SWISH. And everybody called it SWISH. And then people 00:35:05.640 |
were like, wait, that wasn't the original paper. So I guess I should try to call it 00:35:11.080 |
silu. Other than that, it's just a normal res block. So we do our first conv. Then we do 00:35:19.200 |
our embedding projection of the activation function of time steps. And so that's going 00:35:24.800 |
to be applied to every pixel height and width. So that's why we have to add unit axes on 00:35:34.840 |
the height and width that it's going to cause it to broadcast across those two axes. Do 00:35:39.400 |
our chunk. Do the scale and shift. Then we're ready for the second conv. And then we add 00:35:44.800 |
it to the input with an additional conv, one stride one conv if necessary as we've done 00:35:52.680 |
before if we have to change the number of channels. OK. Yeah, because I like exercising 00:36:00.920 |
our Python muscles, I decided to use a second approach now for the down block and the up 00:36:07.680 |
block. I'm not saying which one's better or worse. We're not going to use multiple inheritance 00:36:15.960 |
anymore. But instead, we're going to use-- well, it's not even a decorator. It's a function 00:36:23.200 |
which takes a function. What we're going to do now is we're going to use funf2dd and mbrezblock 00:36:29.680 |
directly. But we're going to pass them to a function called saved. The function called 00:36:34.120 |
saved is something which is going to take as input a callable, which could be a function 00:36:41.920 |
or a module or whatever. So in this case, it's a module. Takes an mbrezblock or a conv2d. 00:36:49.480 |
And it returns a callable. The callable it returns is identical to the callable that's 00:36:56.000 |
passed into it, except that it saves the result, saves the activations, saves the result of 00:37:01.640 |
a function. Where does it save it? It's going to save it into a list in the second argument 00:37:10.080 |
you pass to it, which is the block. So the save function, you're going to pass it the 00:37:18.280 |
module. We're going to grab the forward from it and store that away to remember what it 00:37:23.080 |
was. And then the function that we want to replace it with, call it underscore f, going 00:37:29.320 |
to take some arguments and some keyword arguments. Well, basically, it's just going to call the 00:37:32.840 |
original modules.forward, passing in the arguments and keyword arguments. And we're then going 00:37:39.880 |
to store the result in something called the saved attribute inside here. And then we have 00:37:51.320 |
to return the result. So then we're going to replace the modules forward method with 00:37:57.720 |
this function and return the module. So that module's now been-- yeah, I said callable, 00:38:04.760 |
actually. It can't be called. It has to specifically be a module, because with the forward that 00:38:08.600 |
we're changing. This at wraps is just something which automatically-- it's from the Python 00:38:14.120 |
standard library. So it's going to copy in the documentation and everything from the 00:38:18.200 |
original forward so that it all looks like nothing's changed. 00:38:26.280 |
Now, where does this dot saved come from? I realized now, actually, we could make this 00:38:30.280 |
easier and automate it. But I forgot, didn't think of this at the time. So we have to create 00:38:36.200 |
the saved here in the down block. It actually would have made more sense, I think, here 00:38:42.120 |
for it to have said if the saved attribute doesn't exist, then create it, which would 00:38:47.160 |
look like this. If not has atcher block comma saved block dot saved, because if you do this, 00:38:59.800 |
then you wouldn't need this anymore. Anyway, I didn't think of that at the time. So let's 00:39:06.680 |
pretend that's not what we do. OK, so now the downsampling conv and the resnets both 00:39:19.880 |
contain saved versions of modules. We don't have to do anything to make that work. We 00:39:26.120 |
just have to call them. We can't use sequential anymore, because we have to pass in the time 00:39:31.880 |
step to the resnets as well. It would be easy enough to create your own sequential for things 00:39:39.480 |
with time steps, which passes them along. But that's not what we're doing here. Yeah, maybe 00:39:48.600 |
it makes sense for sequential to always pass along all the extra arguments. But I don't 00:39:53.240 |
think that's how they work. Yeah, so our up block is basically exactly the same as before, 00:40:01.320 |
except we're now using ember as blocks instead. Just like before, we're going to concatenate. 00:40:06.360 |
So that's all the same. OK, so a unit model with time embeddings 00:40:15.720 |
is going to look, if we look at the forward, the thing we're passing into it now is a tuple 00:40:25.160 |
containing the activations and the time steps, or the segments in our case. So split them out. 00:40:31.560 |
And what we're going to do is we're going to call that time step embedding function we wrote, 00:40:37.800 |
saying, OK, these are the time steps. And the number of time step embeddings we want 00:40:44.200 |
is equal to however many we asked for. And we're just going to set it equal to the first number 00:40:52.200 |
of filters. That's all that happens there. And then we want to give the model the ability 00:41:01.400 |
then to do whatever it wants with those, to make those work the way it wants to. 00:41:04.680 |
And the easiest, smallest way to do that is to create a tiny little MLP. So we create a tiny 00:41:09.960 |
little MLP, which is going to take the time step embeddings and return the actual embeddings to 00:41:15.480 |
pass into the ResNet box. So tiny little MLP is just a linear layer with-- it's thinking here. 00:41:27.160 |
That's interesting. My linear layer by default has an activation function. I'm pretty sure we 00:41:44.600 |
should have act equals none here. It should be a linear layer and then an activation and then a 00:41:50.040 |
linear layer. So I think I've got a bug, which we will need to try rerunning. 00:41:55.160 |
OK. It won't be the end of the world. It just means all the negatives will be lost here. 00:42:14.280 |
Makes it half-- only half as useful. That's not great. OK. And these are the kind of things like, 00:42:23.240 |
you know, as you can see, you've got to be super careful of, like, where do you have activation 00:42:28.040 |
functions? Where do you have batch norms? Is it pre-activation? Is it post-activation? 00:42:33.480 |
It trains even if you make that mistake. And in this case, it's probably not too much performance, 00:42:40.760 |
but often it's like, oh, you've done something where you accidentally zeroed out, you know, 00:42:46.040 |
all except the last few channels of your output block or something like that. 00:42:53.000 |
When it work tries anyway, it does its best. It uses what it can. 00:43:00.760 |
To make sure you're not giving it those handicaps. 00:43:04.040 |
Yeah. It's not like you're making a CRUD app or something and you know that it's not working 00:43:08.120 |
because it crashes or because, like, it doesn't show the username or whatever. 00:43:13.800 |
Instead, you just get, like, slightly less good results. But since you haven't 00:43:20.200 |
done it correctly in the first place, you don't know it's the less good results. 00:43:23.640 |
Yeah, there's not really great ways to do this. It's really nice if you can have an 00:43:28.360 |
existing model to compare to or something like that, which is where Kaggle competitions work 00:43:35.080 |
really well. Actually, if somebody's got a Kaggle result, then you know that's a really 00:43:38.840 |
good baseline and you can check whether yours is as good as theirs. 00:43:43.320 |
All right. So, yeah, that's what this MLP is for. So, the down and up blocks are the 00:43:52.920 |
same as before. The convout is the same as before. So, yeah, so we grab our time step embedding. So, 00:43:57.960 |
that's just that outer product passed through this sinusoidal, the sine and cosine. We then 00:44:03.400 |
pass that through the MLP. And then we call our downsampling, passing in those embeddings each 00:44:19.160 |
time. You know, it's kind of interesting that we pass in the embeddings every time 00:44:23.240 |
in the sense I don't exactly know why we don't just pass them in at the start. 00:44:28.280 |
And in fact, in MLP, these kinds of embeddings, I think, are generally just passed into the start. 00:44:34.440 |
So, this is kind of a curious difference. I don't know why. It's, you know, if there's 00:44:41.800 |
been ablation studies or whatever. Do you guys know, are there like any popular diffusion-y or 00:44:48.760 |
generative models with time embeddings that don't pass them in or is this pretty universal? 00:45:02.200 |
recurrent interface networks and stuff just pass in the conditioning. 00:45:06.680 |
I'm actually not sure. Yeah, maybe they do still do it at every stage. I think some of them just 00:45:15.240 |
take in everything all at once up front and then do a stack of transformer blocks or something 00:45:19.800 |
like that. So, I don't know if it's universal, but it definitely seems like all the unit-style ones 00:45:24.840 |
have this the time step embedding going in. >> Maybe we should try some ablations to see, 00:45:31.320 |
yeah, if it matters. I mean, I guess it doesn't matter too much either way. But, yeah, 00:45:37.800 |
if you didn't need it at every step, then it would maybe save you a bit of compute, potentially. 00:45:43.480 |
Yeah, so now the upsampling, you're passing in the activations, the time step embeddings, 00:45:50.280 |
and that list of saved activations. So, yeah, now we have a non-attention stable diffusion unit. 00:46:01.080 |
So, we can train that. And we can sample from it 00:46:11.640 |
using the same -- I just copied and pasted all the stuff from the Keras notebook that we had. 00:46:19.880 |
And there we have it. This is our first diffusion from scratch. 00:46:26.440 |
>> So, we wrote every piece of code for this diffusion model. 00:46:32.280 |
>> Yeah, I believe so. I mean, obviously, in terms of the optimized kudor implementations of 00:46:38.600 |
stuff, no. But, yeah, we've written our version of everything here, I believe. 00:46:46.280 |
>> I think so, yeah. And these FIDS are about the same as the FIDS that we get 00:46:50.360 |
from the stable diffusion one. They're not particularly higher or lower. They bounce 00:46:56.200 |
around a bit, so it's a little hard to compare. Yeah, they're basically the same. 00:46:59.880 |
Yeah, so that's -- that is an exciting step. And 00:47:11.240 |
okay, yeah, that's probably a good time to have a five-minute break. 00:47:20.440 |
Okay. Normally, I would say we're back, but only some of us are back. 00:47:29.480 |
Johnno -- Johnno's internet and electricity in same-type way is not the most reliable thing. 00:47:36.280 |
And he seems to have disappeared, but we expect him to reappear at some point. 00:47:39.960 |
So we will kick on Johnno-less and hope that Zimbabwe's infrastructure sorts itself out. 00:47:50.360 |
All right. So we're going to talk about attention. We're going to talk about attention for a few 00:47:58.760 |
reasons. Reason number one, very pragmatic. We said that we would replicate stable diffusion, 00:48:04.120 |
and the stable diffusion unit has tension in it. So we would be lying if we didn't do attention. 00:48:12.840 |
attention is one of the two basic building blocks of transformers. A transformer layer is attention 00:48:22.280 |
attached to a one-layer MLP. We already know how to create a one-layer or one-hidden layer MLP. 00:48:28.280 |
So once we learn how to do attention, we'll know how to -- we'll know how to create transformer 00:48:34.760 |
blocks. So those are two good reasons. I'm not including a reason which is our model is going 00:48:46.760 |
to look a lot better with attention, because I actually haven't had any success seeing any 00:48:51.480 |
diffusion models I've trained work better with attention. So just to set your expectations, 00:48:59.960 |
we are going to get it all working. But regardless of whether I use our implementation of attention 00:49:05.960 |
or the diffuser's one, it's not actually making it better. That might be because we need to use 00:49:15.560 |
better types of attention than what diffuser's has, or it might be because it's just a very 00:49:21.160 |
subtle difference that you only see on bigger images. I'm not sure. That's something we're still 00:49:28.360 |
trying to figure out. This is all pretty new. And not many people have done kind of the diffusion, 00:49:34.840 |
the kind of ablation studies necessary to figure these things out. So yeah, so that's just life. 00:49:42.040 |
Anyway, so there's lots of good reasons to know about attention. We'll certainly be using it a 00:49:46.840 |
lot once we do an LP, which we'll be coming to pretty shortly, pretty soon. And it looks like 00:49:54.360 |
Jono is reappearing as well. So that's good. Okay, so let's talk about attention. 00:50:17.320 |
an image, and we're going to be sliding a convolution kernel across that image. 00:50:25.880 |
And obviously, we've got channels as well, or filters. And so this also has that. Okay. 00:50:36.760 |
And as we bring it across, we might be, you know, we're trying to figure out like what 00:50:45.640 |
activations do we need to create to eventually, you know, correctly create our outputs. 00:50:51.240 |
But the correct answer as to what's here may depend on something that's way over here, 00:51:00.760 |
and/or something that's way over here. So for example, if it's a cute little bunny rabbit, 00:51:08.600 |
and this is where its ear is, you know, and there might be two different types of bunny rabbit that 00:51:15.640 |
have different shaped ears, well, it'd be really nice to be able to see over here what its other 00:51:22.040 |
ear looks like, for instance. With just convolutions, that's challenging. It's not 00:51:29.240 |
impossible. We talked in part one about the receptive field. And as you get deeper and 00:51:35.640 |
deeper in a convnet, the receptive field gets bigger and bigger. But it's, you know, 00:51:41.960 |
at higher up, it probably can't see the other ear at all. So it can't put it into those kind of more 00:51:46.920 |
texture level layers. And later on, you know, even though this might be in the receptive field here, 00:51:54.200 |
most of the weight, you know, the vast majority of the activations it's using is the stuff 00:51:59.480 |
immediately around it. So what attention does is it lets you take a weighted average of other pixels 00:52:10.440 |
around the image, regardless of how far away they are. And so in this case, for example, 00:52:18.840 |
we might be interested in bringing in at least a few of the channels of these pixels over here. 00:52:28.280 |
The way that attention is done in stable diffusion is pretty hacky and known to be suboptimal. 00:52:39.960 |
But it's what we're going to implement because we're implementing stable diffusion and time 00:52:45.400 |
permitting. Maybe we'll look at some other options later. But the kind of attention we're going to 00:52:50.600 |
be doing is 1D attention. And it was a tension that was developed for NLP. And NLP is sequences, 00:52:59.400 |
one-dimensional sequences of tokens. So to do attention stable diffusion style, 00:53:04.840 |
we're going to take this image and we're going to flatten out the pixels. 00:53:11.400 |
So we've got all these pixels. We're going to take this row and put it here. And then we're 00:53:18.680 |
going to take this row, we're going to put it here. So we're just going to flatten the whole thing out 00:53:23.400 |
into one big vector of all the pixels of row one and then all the pixels of row two and then all 00:53:29.880 |
the pixels of the row three. Or maybe it's column one, column two, column three. I can't remember 00:53:32.520 |
this row-wise or column-wise, but it's flattened out anywho. And then it's actually, for each image, 00:53:44.120 |
it's actually a matrix, which I'm going to draw it a little bit 3D because we've got the channel 00:53:52.280 |
dimension as well. So this is going to be the number across this way is going to be equal to 00:54:00.520 |
the height times the width. And then the number this way is going to be the number of channels. 00:54:13.800 |
Okay, so how do we decide, yeah, which, you know, bring in these other pixels? Well, what we do 00:54:25.560 |
is we basically create a weighted average of all of these pixels. So maybe these ones get a bit 00:54:34.680 |
of a negative weight and these ones get a bit of a positive weight and, you know, these get a weight 00:54:45.960 |
kind of somewhere in between. And so we're going to have a weighted average. And so basically each 00:54:52.520 |
pixel, so let's say we're doing this pixel here right now, is going to equal its original pixel 00:54:59.400 |
plus, so let's call it x, plus the weighted average. So the sum across, so maybe this is like x, i, 00:55:13.240 |
over all the other pixels. So from zero to the height times the width. 00:55:35.800 |
The weights, they're going to sum to one. And so that way the, you know, the pixel value 00:55:52.920 |
scale isn't going to change. Well, that's not actually quite true. It's going to end up 00:55:57.640 |
potentially twice as big, I guess, because it's being added to the original pixel. 00:56:01.000 |
So attention itself is not with the x plus, but the way it's done in stable diffusion, 00:56:10.120 |
at least, is that the attention is added to the original pixel. So, yeah, now I think about it. 00:56:17.640 |
I'm not going to need to think about how this is being scaled, anyhow. So the big question is what 00:56:27.640 |
values to use for the weights. And the way that we calculate those is we do a matrix product. 00:56:41.320 |
And so our, for a particular pixel, we've got, 00:56:52.680 |
you know, the number of channels for that one pixel. 00:57:02.120 |
And what we do is we can compare that to all of the number of channels for all the other pixels. 00:57:13.080 |
So we've got kind of, this is pixel, let's say x1. And then we've got pixel number x2. 00:57:21.560 |
Right, all those channels. We can take the dot product between those two things. 00:57:31.960 |
And that will tell us how similar they are. And so one way of doing this would be to say, like, 00:57:40.200 |
okay, well, let's take that dot product for every pair of pixels. And that's very easy dot product 00:57:46.600 |
do, because that's just what the matrix product is equal to. So if we've got h by w by c and then 00:57:57.960 |
multiply it by its transpose, h by w base, sorry, it said transpose and then totally failed to do 00:58:09.400 |
transpose, multiply by its transpose, that will give us an h by w by h by w matrix. 00:58:30.440 |
So each pixel, all the pixels are down here. And for each pixel, as long as these add up to one, 00:58:38.920 |
then we've got to wait for each pixel. And it's easy to make these add up to one. We could just 00:58:43.880 |
take this matrix multiplication and take the sigmoid over the last dimension. And that makes, 00:58:54.760 |
sorry, not sigmoid. Man, what's wrong with me? Softmax, right? Yep. And take the softmax over 00:59:04.360 |
the last dimension. And that will give me something that adds the sum equals one. 00:59:12.200 |
Okay. Now, the thing is, it's not just that we want to find the places where they look the same, 00:59:21.960 |
where the channels are basically the same, but we want to find the places where they're, like, 00:59:27.160 |
similar in some particular way, you know? And so some particular set of channels are similar in one 00:59:34.200 |
to some different set of channels in another. And so, you know, in this case, we may be looking for 00:59:39.240 |
the pointy-earedness activations, you know, which actually represented by, you know, this, this, 00:59:47.640 |
and this, you know, and we want to just find those. So the way we do that is before we do this 00:59:54.360 |
matrix product, we first put our matrix through a projection. 01:00:03.000 |
So we just basically put our matrix through a matrix multiplication, this one. So it's the 01:00:12.680 |
same matrix, right? But we put it through two different projections. And so that lets it pick 01:00:17.960 |
two different kind of sets of channels to focus on or not focus on before it decides, you know, 01:00:24.280 |
of this pixel, similar to this pixel in the way we care about. And then actually, we don't even 01:00:29.640 |
just multiply it then by the original pixels. We also put that through a different projection as 01:00:36.120 |
well. So there's these different projections. Well, then projection one, projection two, 01:00:40.760 |
and projection three. And that gives it the ability to say, like, oh, I want to compare 01:00:45.720 |
these channels and, you know, these channels to these channels to find similarity. And based on 01:00:51.160 |
similarity, yeah, they want to pick out these channels, right? Both positive and negative 01:00:56.280 |
weight. So that's why there's these three different projections. And so the projections are called 01:01:02.760 |
A, Q, and V. Those are the projections. And so they're all being passed the same 01:01:14.440 |
matrix. And because they're all being passed the same matrix, we call this self-attention. 01:01:21.400 |
Okay, Jono, Tindish, I know this is, I know you guys know this very well, but you also 01:01:27.960 |
know it's really confusing. Did you have anything to add? Change? Anything else? 01:01:33.960 |
Yeah, I like that you introduced this without resorting to the, 01:01:39.000 |
let's think of this as queries at all, which I think is, yeah. 01:01:44.120 |
Yeah, these are actually short for key, query, and value, even though I personally don't find those 01:01:54.520 |
useful concepts. Yeah. You'll note on the scaling, you said, oh, so we said it so that the weight's 01:02:01.880 |
sum to one. And so then we'd need to worry about like, are we doubling the scale of X? Yeah. But 01:02:08.360 |
because of that P3, aka V, that projection that can learn to scale this thing that's added to X 01:02:19.960 |
appropriately. And so it's not like just doubling the size of X, it's increasing it a little bit, 01:02:24.200 |
which is why we scatter normalization in between all of these attention layers. 01:02:28.680 |
But it's not as bad as it might be because we have that V projection. 01:02:36.280 |
Yeah, that's a good point. And if this is, if P3, or it's actually the V make projection, is 01:02:46.040 |
initialized such that it would have a mean of zero, then on average it should start out by not 01:02:54.760 |
messing with our scale. OK, so yeah, I guess I find it easier to think in terms of code. 01:03:04.440 |
So let's look at the code. You know, there's actually not much code. 01:03:07.320 |
I think you've got a bit of background noise too, Jono, maybe. Yes, that's much better. Thank you. 01:03:15.480 |
So in terms of code, there's, you know, this is one of these things getting everything exactly 01:03:26.840 |
right. And it's not just right. I wanted to get it identical to the stable diffusion. So we can say 01:03:31.640 |
we've made it identical to stable diffusion. I've actually imported the attention block from 01:03:36.760 |
diffusers so we can compare. And it is so nice when you've got an existing version of something 01:03:41.960 |
to compare to to make sure you're getting the same results. So we're going to start off by saying, 01:03:50.120 |
let's say we've got a 16 by 16 pixel image. And this is some deeper level of activation. So it's 01:03:56.360 |
got 32 channels with a batch size of 64. So NCHW. I'm just going to use random numbers for now, 01:04:03.640 |
but this has the, you know, reasonable dimensions for an activation inside a batch size 64 01:04:09.880 |
CNN or diffusion model or unit, whatever. OK, so the first thing we have to do 01:04:15.640 |
is to flatten these out because, as I said, in 1D attention, this is just ignored. 01:04:25.400 |
So it's easy to flatten things out. You just say dot view and you pass in the dimensions of the, 01:04:31.560 |
in this case, the three dimensions we want, which is 6432 and everything else. Minus one means 01:04:38.040 |
everything else. So x dot shape colon two. In this case, you know, obviously it'd be easy just to 01:04:43.400 |
type 6432, but I'm trying to create something that I can paste into a function later. So it's general. 01:04:48.840 |
So that's the first two elements, 6432. And then the star just inserts them directly in here. So 01:04:54.760 |
6432 minus one. So 16 by 16. Now then, again, because this is all stolen from the NLP world, 01:05:04.360 |
in the NLP world, things are, have, they call this sequence. So I'm going to call this sequence by 01:05:12.760 |
which we're in height by width. Sequence comes before channel, which is often called D or dimension. 01:05:19.160 |
So we then transpose those last two dimensions. So we've now got batch by sequence, 16 by 16, 01:05:27.800 |
by channel or dimension. So N, they didn't really call this NSD sequence dimension. 01:05:38.120 |
Okay, so we've got 32 channels. So we now need three different projections that go from 32 01:05:48.680 |
channels in to 32 channels out. So that's just a linear layer. Okay, and just remember a linear 01:05:54.120 |
layer is just a matrix multiply plus a bias. So there's three of them. And so they're all going 01:06:01.560 |
to be randomly initialized at different random numbers. We're going to call them SK, SQ, SV. 01:06:08.040 |
And so we can then, they're just callable. So we can then pass the exact same thing into three, 01:06:13.320 |
all three, because we're doing self-attention to get back our keys, queries, and values, 01:06:18.200 |
or K, Q, and V. I just think of them as K, Q, and V, because they're not really 01:06:22.600 |
keys, queries, and values to me. So then we have to do the matrix multiply by the transpose. 01:06:30.600 |
And so then for every one of the 64 items in the batch, for every one of the 256 pixels, 01:06:40.200 |
there are now 256 weights. So at least there would be if we had done softmax, which we haven't yet. 01:06:45.720 |
So we can now put that into a self-attention. As Johnno mentioned, we want to make sure that 01:06:52.120 |
we normalize things. So we can proper normalization here. We talked about group norm back when we 01:06:57.960 |
talked about batch norm. So group norm is just batch norm, which has been split into a bunch of 01:07:04.200 |
sets of channels. Okay, so then we are going to create our K, Q, V. Yep, Johnno? 01:07:18.280 |
I was just going to ask, should those be just bias equals false so that they're only a matrix 01:07:23.480 |
multiplied to strictly match the traditional implementation? 01:07:38.520 |
Yeah, they have bias in their attention blocks. 01:07:52.200 |
Okay, so we've got our QK and V, self.q, self.k, self.v being our projections. 01:07:58.200 |
And so to do 2D self-attention, we need to find the NCHW from our shape. We can do a normalization. 01:08:09.800 |
We then do our flattening as discussed. We then transpose the last two dimensions. We then create 01:08:17.480 |
our QKV by doing the projections. And we then do the matrix multiply. Now, we've got to be a bit 01:08:24.360 |
careful now because as a result of that matrix multiply, we've changed the scale by multiplying 01:08:30.520 |
and adding all those things together. So if we then simply divide by the square root of the number 01:08:38.360 |
of filters, it turns out that you can convince yourself of this if you wish to, but that's going 01:08:43.000 |
to return it to the original scale. We can now do the softmax across the last dimension, and then 01:08:52.600 |
multiply each of them by V. So using matrix multiply to do them all in one go. 01:08:58.280 |
We didn't mention, but we then do one final projection. Again, just to give it the opportunity 01:09:05.880 |
to map things to some different scale. Shift it also if necessary. Transpose the last two back 01:09:15.880 |
to where they started from, and then reshape it back to where it started from, and then add it. 01:09:20.760 |
Remember, I said it's going to be X plus. Add it back to the original. So this is actually kind of 01:09:25.320 |
self-attention ResNet style, if you like. Diffuses, if I remember correctly, does include the X plus 01:09:34.200 |
in theirs, but some implementations, like, for example, PyTorch implementation doesn't. 01:09:39.560 |
Okay, so that's a self-attention module, and all you need to do is tell it how many channels to do 01:09:45.800 |
attention on. And you need to tell it that because that's what we need for our four different 01:09:52.920 |
projections and our group and our scale. I guess, strictly speaking, it doesn't have to be stored 01:10:00.120 |
here. You could calculate it here, but anyway, either way is fine. Okay, so if we create a 01:10:06.440 |
self-attention layer, we can then call it on our little randomly generated numbers. 01:10:14.120 |
And it doesn't change the shape because we transpose it back and reshape it back, 01:10:21.160 |
but we can see that's basically worked. We can see it creates some numbers. How do we know if 01:10:25.160 |
they're right? Well, we could create a diffuser's attention block. That will randomly generate 01:10:31.800 |
a QKV projection. Sorry, actually they call something else. They call it a query, key, 01:10:40.040 |
value, projection, attention, and group norm. We call it QKVprogen norm. They're the same things. 01:10:46.840 |
And so then we can just zip those tuples together. So that's going to take each pair, 01:10:53.480 |
first pair, second pair, third pair, and copy the weight and the bias from their attention block. 01:11:02.680 |
Sorry, from our attention block to the diffuser's attention block. And then we can check that they 01:11:10.520 |
give the same value, which you can see they do. So this shows us that our attention block is the same 01:11:15.480 |
as the diffuser's attention block, which is nice. Here's a trick which neither diffusers nor PyTorch 01:11:27.880 |
use for reasons I don't understand, which is that we don't actually need three separate projections 01:11:34.200 |
here. We could create one projection from Ni to Ni times three. That's basically doing three 01:11:40.920 |
projections. So we could call this QKV. And so that gives us 64 by 256 by 96 instead of 01:11:49.720 |
64 by 256 by 32, because it's the three sets. And then we can use chunk, which we saw earlier, 01:11:59.560 |
to split that into three separate variables along the last dimension to get us our QKV. 01:12:07.240 |
And we can then do the same thing, Q at Q dot transpose, et cetera. So here's another version 01:12:12.040 |
of attention where we just have one projection for QKV, and we chunkify it into separate QK and V. 01:12:22.360 |
And this does the same thing. It's just a bit more concise. And it should be faster as well, 01:12:30.040 |
at least if you're not using some kind of XLA compiler or ONX or Triton or whatever, 01:12:36.680 |
for normal PyTorch. This should be faster because it's doing less back and forth 01:12:40.840 |
between the CPU and the GPU. All right. So that's basic self-attention. This is not what's done 01:12:56.920 |
basically ever, however, because, in fact, the question of which pixels do I care about 01:13:07.720 |
depends on which channels you're referring to. Because the ones which are about, oh, 01:13:16.840 |
what color is its ear, as opposed to how pointy is its ear, might depend more on is this bunny 01:13:26.280 |
in the shade or in the sun. And so maybe you may want to look at its body over here to decide what 01:13:33.640 |
color to make them rather than how pointy to make it. And so, yeah, different channels need to bring 01:13:44.040 |
in information from different parts of the picture depending on which channel we're talking about. 01:13:51.080 |
And so the way we do that is with multi-headed attention. And multi-headed attention actually 01:13:57.800 |
turns out to be really simple. And conceptually, it's also really simple. What we do is we say, 01:14:05.800 |
let's come back to when we look at C here and let's split them into four separate 01:14:20.680 |
vectors. One, two, three, four. Let's split them, right? And let's do the whole dot product thing 01:14:36.920 |
on just the first part with the first part. And then do the whole dot product part with the second 01:14:46.840 |
part with the second part and so forth, right? So we're just going to do it separately, 01:14:53.800 |
separate matrix multiplies for different groups of channels. 01:14:58.200 |
And the reason we do that is it then allows, yeah, different parts, different sets of channels to pull 01:15:11.000 |
in different parts of the image. And so these different groups are called heads. And I don't 01:15:25.080 |
know why, but they are. Does that seem reasonable? Anything to add to that? 01:15:34.040 |
It's maybe worth thinking about why, with just a single head, specifically the softmax starts to 01:15:41.160 |
come into play. Because, you know, we said it's like a weighted sum, just able to bring in information 01:15:46.760 |
from different parts and whatever else. But with softmax, what tends to happen is whatever 01:15:52.040 |
weight is highest gets scaled up quite dramatically. And so it's like almost like focused on just that 01:15:58.200 |
one thing. And then, yeah, like, as you said, Jeremy, like different channels might want to refer to 01:16:03.160 |
different things. And, you know, just having this one like single weight that's across all the 01:16:08.840 |
channels means that that signal is going to be like focused on maybe only one or two things as 01:16:14.040 |
opposed to being able to bring in lots of different kinds of information based on the different 01:16:18.040 |
channels. Right. I was going to measure the same thing, actually. That's a good point. So you're 01:16:26.600 |
mentioning the second interesting important point about softmax, you know, point one is that it 01:16:32.520 |
creates something that adds to one. But point two is that because of its e to the z, it tends to 01:16:38.040 |
highlight one thing very strongly. And yes, so if we had single-headed attention, your point, 01:16:45.240 |
guys, I guess, is that you're saying it would end up basically picking nearly all one pixel, 01:16:50.840 |
which would not be very interesting. OK, awesome. Oh, I see where everything's got thick. I've 01:16:59.720 |
accidentally turned it into a marker. Right. OK, so multi-headed attention. I'll come back to 01:17:12.440 |
the details of how it's implemented in terms of, but I'm just going to mention the basic idea. 01:17:18.200 |
This is multi-headed attention. And this is identical to before, except I've just stored 01:17:24.680 |
one more thing, which is how many heads do you want. And then the forward is actually nearly 01:17:33.240 |
all the same. So this is identical, identical, identical. This is new. Identical, identical, 01:17:42.840 |
identical, new, identical, identical. So there's just two new lines of code, which might be 01:17:50.280 |
surprising, but that's all we needed to make this work. And they're also pretty wacky, interesting 01:17:55.560 |
new lines of code to look at. Conceptually, what these two lines of code do is they first, 01:18:02.760 |
they do the projection, right? And then they basically take the number of heads. 01:18:21.320 |
So we're going to do four heads. We've got 32 channels, four heads. So each head is going to 01:18:25.800 |
contain eight channels. And they basically grab, they're going to, we're going to keep it as being 01:18:34.520 |
eight channels, not as 32 channels. And we're going to make each batch four times bigger, right? 01:18:41.800 |
Because the images in a batch don't combine with each other at all. They're totally separate. 01:18:50.120 |
So instead of having one image containing 32 channels, we're going to turn that into four 01:18:59.800 |
images containing eight channels. And that's actually all we need, right? Because remember, 01:19:06.040 |
I told you that each group of channels, each head, we want to have nothing to do with each other. 01:19:13.080 |
So if we literally turn them into different images, then they can't have anything to do 01:19:18.680 |
with each other because batches don't react to each other at all. So these rearrange, 01:19:25.400 |
this rearrange, and I'll explain how this works in a moment, but it's basically saying, 01:19:30.360 |
think of the channel dimension as being of H groups of D and rearrange it. So instead, 01:19:39.640 |
the batch channel is n groups of H and the channels is now just D. So that would be eight 01:19:47.720 |
instead of four by eight. And then we do everything else exactly the same way as usual, 01:19:53.240 |
but now that group, that the channels are split into groups of H, groups of four. And then after 01:20:01.160 |
that, okay, well, we were thinking of the batches as being of size n by H. Let's now think of the 01:20:07.080 |
channels as being of size H by D. That's what these rearranges do. So let me explain how these 01:20:14.440 |
work. In the diffusers code, I've, can't remember if I duplicated it or just inspired by it. 01:20:22.520 |
They've got things called heads to batch and batch to heads, which do exactly these things. 01:20:26.520 |
And so for heads to batch, they say, okay, you've got 64 per batch by 256 pixels by 32 channels. 01:20:39.880 |
Okay, let's reshape it. So you've got 64 images by 256 pixels by four heads by the rest. 01:20:54.920 |
So that would be 32 over eight channels. So it's split it out into a separate dimension. 01:21:09.880 |
And then if we transpose these two dimensions, it'll then be n by four. So n by heads by SL by 01:21:18.120 |
minus one. And so then we can reshape. So those first two dimensions get combined into one. 01:21:24.280 |
So that's what heads to batch does. And batch to heads does the exact opposite, right? Reshapes 01:21:30.680 |
to bring the batch back to here and then heads by SL by D and then transpose it back again and 01:21:36.680 |
reshape it back again so that the heads gets it. So this is kind of how to do it using just 01:21:44.360 |
traditional PyTorch methods that we've seen before. But I wanted to show you guys this new-ish 01:21:52.440 |
library called Inops, inspired as it suggests by Einstein summation notation. But it's absolutely 01:21:58.760 |
not Einstein summation notation. It's something different. And the main thing it has is this 01:22:02.920 |
thing called rearrange. And rearrange is kind of like a nifty rethinking of Einstein summation 01:22:10.600 |
notation as a tensor rearrangement notation. And so we've got a tensor called t we created earlier, 01:22:21.240 |
64 by 256 by 32. And what Inops rearrange does is you pass it this specification string that says, 01:22:31.880 |
turn this into this. Okay, this says that I have a rank three tensor, three dimensions, three axes, 01:22:47.640 |
containing the first dimension is of length n, the second dimension is of length s, 01:22:57.560 |
the third dimension is in parentheses is of length h times d, where h is eight. 01:23:04.040 |
Okay, and then I want you to just move things around so that nothing is like broken, you know, 01:23:14.120 |
so everything's shifted correctly into the right spots so that we now have each batch 01:23:23.000 |
is now instead n times eight, n times h. The sequence length is the same, 01:23:30.040 |
and d is now the number of channels. Previously the number of channels was h by d. Now it's d, 01:23:36.360 |
so the number of channels has been reduced by a factor of eight. And you can see it here, 01:23:40.600 |
it's turned t from something of 64 by 256 by 32 into something of size 64 times eight by 256 01:23:51.480 |
by 32 divided by eight. And so this is like really nice because, you know, a, this one line of code 01:24:01.720 |
to me is clearer and easier and I liked writing it better than these lines of code. But whereas 01:24:08.040 |
particularly nice is when I had to go the opposite direction, I literally took this, cut it, 01:24:15.800 |
put it here and put the arrow in the middle. Like it's literally backwards, which is really nice, 01:24:22.040 |
right? Because we're just rearranging it in the other order. And so if we rearrange in the other 01:24:27.160 |
order, we take our 512 by 256 by 4 thing that we just created and end up with a 64 by 256 by 32 01:24:34.040 |
thing, which we started with, and we can confirm that the end thing equals, or every element equals 01:24:41.800 |
the first thing. So that shows me that my rearrangement has returned its original correctly. 01:24:48.760 |
Yeah, so multi-headed attention, I've already shown you. It's the same thing as before, 01:24:54.200 |
but pulling everything out into the batch for each head and then pulling the heads back into 01:25:01.000 |
the channels. So we can do multi-headed attention with 32 channels and four heads 01:25:09.640 |
and check that all looks okay. So PyTorch has that all built in. It's called nn.multi_headed_attention. 01:25:17.720 |
Be very careful. Be more careful than me, in fact, because I keep forgetting that it actually expects 01:25:25.880 |
the batch to be the second dimension. So make sure you write batch first equals true 01:25:32.760 |
to make batch the first dimension and that way it'll be the same as 01:25:38.040 |
diffusers. I mean, it might not be identical, but the same. It should be almost the same 01:25:43.880 |
idea. And to make it self-attention, you've got to pass in three things, right? So the three things 01:25:51.160 |
will all be the same for self-attention. This is the thing that's going to be passed through the 01:25:55.720 |
Q projection, the K projection and the V projection. And you can pass different things to those. 01:26:03.640 |
If you pass different things to those, you'll get something called cross-attention rather than 01:26:10.120 |
self-attention, which I'm not sure we're going to talk about until we do it in NLP. 01:26:15.080 |
Just on the rearrange thing, I know that if you've been doing PyTorch and you used to, 01:26:24.040 |
like, you really know what transpose and, you know, reshape and whatever do, then it can be a 01:26:29.480 |
little bit weird to see this new notation. But once you get into it, it's really, really nice. 01:26:33.400 |
And if you look at the self-attention multi-headed implementation there, you've got 01:26:37.000 |
dot view and dot transpose and dot reshape. It's quite fun practice. Like, if you're just saying, 01:26:43.000 |
oh, this INOPS thing looks really useful, like, take an existing implementation like this and say, 01:26:47.880 |
oh, maybe instead of, like, can I do it instead of dot reshape or whatever, can I start replacing 01:26:53.000 |
these individual operations with the equivalent, like, rearrange call? And then checking at the 01:26:58.440 |
output to the same, like, that's what helped it, like, click for me was, oh, okay. Like, 01:27:03.240 |
I can start to express, if it's just transpose, then that's a rearrange with the last two channels. 01:27:08.840 |
Yeah. I only just started using this. And I've obviously had many years of using 01:27:16.200 |
reshape transpose, et cetera, in Theano, TensorFlow, Keras, PyTorch, APL. And I would say within 01:27:27.640 |
10 minutes, I was like, oh, I like this much better. You know, like, it's 01:27:32.040 |
fine for me at least. It didn't take too long to be convinced. It's not part of PyTorch or anything. 01:27:40.040 |
You've got to pip install it, by the way. And it seems to be becoming super popular now, 01:27:47.240 |
at least in the kind of diffusion research crowd. Everybody seems to be using INOPS suddenly, 01:27:53.800 |
even though it's been around for a few years. And I actually put in an issue there and asked them 01:28:00.040 |
to add in Einstein summation notation as well, which they've now done. So it's kind of like your 01:28:05.560 |
one place for everything, which is great. And it also works across 01:28:08.600 |
TensorFlow and other libraries as well, which is nice. Okay. So we can now add that to our 01:28:21.800 |
unit. So this is basically a copy of the previous notebook, except what I've now done is I did this 01:28:28.920 |
at the point where it's like, oh, yeah, it turns out that cosine scheduling is better. So I'm back 01:28:34.520 |
to cosine schedule now. This is copied from the cosine schedule book. And we're still doing the 01:28:39.400 |
minus 0.5 thing because we love it. And so this time, I actually decided to export stuff into a 01:28:47.960 |
mini-AI.diffusion. So this point, I still think things are working pretty well. And so I renamed 01:28:54.200 |
unit.com to pre-con, since it's a better name. Time step embedding has been exported. Up sample's 01:29:00.440 |
been exported. This is like a pre-act linear version exported. I tried using an n.multihead 01:29:11.160 |
attention, and it didn't work very well for some reason. So I haven't figured out why that is yet. 01:29:17.720 |
So I'm using, yeah, this self-attention, which we just talked about. Multiheaded self-attention. 01:29:25.800 |
You know, just the scale, we have to divide the number of channels by the number of heads 01:29:33.640 |
because the effective number of heads is, you know, divided across n heads. 01:29:44.120 |
And instead of specifying n heads, yeah, you specify attention channels. 01:29:47.480 |
So if you have like 32, n_I is 32, attention channels is 8, then you calculate. 01:29:51.640 |
Yeah, that's what diffusers does, I think. It's not what an n.multihead attention does. 01:29:56.920 |
And actually, I think n_I divided by n_I divided by attention chance is actually just equal to 01:30:03.720 |
attention chance. So I could have just put that probably. Anyway, never mind. Yeah. 01:30:14.040 |
So okay, so that's all copied in from the previous one. The only thing that's different here is I 01:30:19.880 |
haven't got the dot view minus one thing here. So this is a 1D self-attention, and then 2D 01:30:28.600 |
self-attention just adds the dot view before we call forward and then dot reshape it back again. 01:30:41.960 |
So yeah, so we've got 1D and 2D self-attention. Okay, so now our MRes block has one extra thing 01:30:48.360 |
you can pass in, which is attention channels. And so if you pass in attention channels, we're going 01:30:54.760 |
to create something called self.attention, which is a self-attention 2D layer with the right number 01:31:00.360 |
of filters and the requested number of channels. And so this is all identical to what we've seen 01:31:08.040 |
before, except if we've got attention, then we add it. Oh yeah, and the attention that I did here is 01:31:15.080 |
the non-res-netty version. So we have to do x plus because that's more flexible. You can then choose 01:31:21.880 |
to have it or not have it this way. Okay, so that's an MRes block with attention. 01:31:28.520 |
And so now our down block, you have to tell it how many attention channels you want, 01:31:35.560 |
because the res blocks need that. The up block, you have to know how many attention channels you 01:31:41.160 |
want, because again the res blocks need that. And so now the unit model, where does the attention 01:31:47.960 |
go? Okay, we have to say how many attention channels you want. And then you say which index 01:31:54.920 |
block do you start adding attention? So why don't we, so then what happens is the attention is done 01:32:05.400 |
here. Each res-net has attention. And so as we discussed, you just do the normal res and then 01:32:19.720 |
put that in at the very start, right, let's say you've got a 256 by 256 image. 01:32:34.840 |
Then you're going to end up with this matrix here. It's going to be 256 by 256 on one side 01:32:48.360 |
and 256 by 256 on the other side and contain however many, you know, NF channels. That's huge. 01:33:03.640 |
And you have to back prop through it. So you have to store all that to allow back prop to happen. 01:33:09.240 |
It's going to explode your memory. So what happens is basically nobody puts attention 01:33:16.360 |
in the first layers. So that's why I've added a attention start, which is like at which block 01:33:25.640 |
do we start adding attention and it's not zero for the reason we just discussed. Another way you 01:33:34.760 |
could do this is to say like at what grid size should you start adding attention? And so generally 01:33:40.440 |
speaking, people say when you get to 16 by 16, that's a good time to start adding attention. 01:33:47.400 |
Although stable diffusion adds it at 32 by 32 because remember they're using latents, 01:33:54.680 |
which we'll see very shortly I guess in the next lesson. So it starts at 64 by 64 and then they 01:34:00.760 |
add attention at 32 by 32. So we're again, we're replicating stable diffusion here. Stable diffusion 01:34:05.800 |
uses attention start at index one. So we, you know, when we go self.down, dot append, the down block 01:34:13.160 |
has zero attention channels if we're not up to that block yet. And ditto on the up block, 01:34:26.040 |
Now I think about it, that should have attention as well, the mid block. So that's missing. 01:34:39.480 |
Yeah, so the forward actually doesn't change at all for attention. It's only the in it. 01:34:50.040 |
Yeah, so we can train that. And so previously, yeah, we got 01:34:55.080 |
without attention, we got to 137. And with attention, 01:35:04.680 |
oh, we can't compare directly because we've changed from Keras to cosine. 01:35:13.560 |
We can compare the sampling though. So we're getting, what are we getting? 4, 5, 5, 5. 01:35:21.320 |
It's very hard to tell if it's any better or not because, 01:35:27.160 |
well, again, you know, our cosine schedule is better. But yeah, when I've done kind of direct 01:35:34.520 |
like with like, I haven't managed to find any obvious improvements from adding attention. 01:35:41.560 |
But I mean, it's doing fine, you know, 4 is great. Yeah. All right. So then finally, 01:35:50.440 |
did you guys want to add anything before we go into a conditional model? 01:35:53.160 |
I was just going to make a note that, like, I guess, just to clarify, 01:35:58.120 |
with the attention, part of the motivation was certainly to do the sort of spatial mixing and 01:36:04.120 |
kind of like, yeah, to get from different parts of the image and mix it. But then the problem is, 01:36:09.800 |
if it's too early, where you do have one of, you know, the more individual pixels, 01:36:14.360 |
then the memory is very high. So it seems like you have to get that balance of where you don't, 01:36:21.400 |
you kind of want it to be early. So you can do some of that mixing, but you don't want to be too 01:36:25.080 |
early, where then the memory usage is, is too high. So it seems like there is certainly kind 01:36:30.680 |
of the balance of trying to find maybe that right place where to add attention into your network. 01:36:36.200 |
So I just thought I was just thinking about that. And maybe that's a point worth noting. 01:36:41.560 |
There is a trick, which is like what they do in, for example, vision transformers, 01:36:46.280 |
or the DIT, the diffusion with transformers, which is that if you take like an eight by 01:36:54.920 |
eight patch of the image, and you flatten that all out, or you run that through some 01:36:59.960 |
like convolutional thing to turn it into a one by one by some larger number of channels, 01:37:05.000 |
but you can reduce the spatial dimension by increasing the number of channels. And that 01:37:10.760 |
gets you down to like a manageable size where you can then start doing attention as well. 01:37:14.360 |
So that's another trick is like patching, where you take a patch of the image and you focus 01:37:18.680 |
on that as some number, like some embedding dimension or whatever you like to think of it, 01:37:23.560 |
that as a one by one rather than an eight by eight or a 16 by 16. And so that's how, 01:37:29.000 |
like you'll see, you know, 32 by 32 patch models, like some of the smaller clip models or 14 by 14 01:37:35.800 |
patch for some of the larger like DIT classification models and things like that. 01:37:40.360 |
So that's another Yeah, I guess that's the yeah, that's mainly used when you have like a full 01:37:45.800 |
transformer network, I guess. And then this is one where we have that sort of incorporating the 01:37:51.240 |
attention into convolutional network. So there's certainly, I guess, yeah, for different sorts of 01:37:56.840 |
networks, different tricks. But yeah. Yeah. And I haven't decided yet if we're going to look at 01:38:06.200 |
the IT or not. Maybe we should based on what you're describing. I was just going to mention, 01:38:11.080 |
though, that since you mentioned transformers, we've actually now got everything we need to 01:38:18.600 |
create a transformer. Here's a transformer block with embeddings. A transformer block with embeddings 01:38:25.160 |
is exactly the same embeddings that we've seen before. And then we add attention, as we've seen 01:38:31.000 |
before, there's a scale and shift. And then we pass it through an MLP, which is just a linear layer, 01:38:40.680 |
an activation, a normalize, and a linear layer. For whatever reason, this is, you know, 01:38:46.760 |
GALU, which is just another activation function, is what people always use in transformers. 01:38:54.120 |
For reasons I suspect don't quite make sense in vision, everybody uses layer norm. And again, 01:38:58.440 |
I was just trying to replicate an existing paper. But this is just a standard MLP. So if you do, 01:39:03.160 |
so in fact, if we get rid of the embeddings, just to show you a true pure transformer. 01:39:17.960 |
Okay, here's a pure transformer block. So it's just normalize, attention, add, normalize, 01:39:25.240 |
multi-layer perceptron, add. That's all a transformer block is. And then what's a 01:39:30.200 |
transformer network? A transformer network is a sequential of transformers. And so in this 01:39:35.960 |
diffusion model, I replaced my mid block with a list of sequential transformer blocks. 01:39:46.280 |
So that is a transformer network. And to prove it, this is another version in which I 01:39:56.280 |
replaced that entire thing with the PyTorch transformers encoder. This is called encoder. 01:40:03.960 |
This is just taken from PyTorch. And so that's the encoder. And I just replaced it with that. 01:40:13.000 |
So yeah, we've now built transformers. Now, okay, why aren't we using them right now? And why did 01:40:20.600 |
I just say, I'm not even sure if we're going to do VIT, which is vision transformers. The reason is 01:40:25.240 |
that transformers, you know, they're doing something very interesting, right? Which is, 01:40:34.680 |
remember, we're just doing 1D versions here, right? So transformers are taking something 01:40:44.760 |
where we've got a sequence, right? Which in our case is pixels height by width, but we'll just 01:40:50.120 |
call it the sequence. And everything in that sequence has a bunch of channels, right, for 01:40:56.520 |
dimensions, right? I'm not going to draw them all, but you get the idea. And so for each element of 01:41:05.720 |
that sequence, which in our case, it's just some particular pixel, right? And these are just the 01:41:11.480 |
filters, channels, activations, whatever, activations, I guess. 01:41:19.080 |
What we're doing is the, we first do attention, which, you know, remember there's a projection for 01:41:26.200 |
each. So like it's mixing the channels a little bit, but just putting that aside, the main thing 01:41:30.680 |
it's doing is each row is getting mixed together, you know, into a weighted average. 01:41:46.760 |
And then after we do that, we put the whole thing through a multi-layer perceptron. And what the 01:41:53.080 |
multi-layer perceptron does is it entirely looks at each pixel on its own. So let's say this one, 01:42:05.480 |
right, and puts that through linear, activation, norm, linear, which we call an MLP. 01:42:16.840 |
And so a transformer network is a bunch of transformer layers. So it's basically going 01:42:27.800 |
attention, MLP, attention, MLP, attention, et cetera, et cetera, MLP. That's all it's doing. 01:42:37.160 |
And so in other words, it's mixing together the pixels or sequences, and then it's mixing 01:42:48.920 |
together the channels. Then it's mixing together the sequences and then mixing together the channels 01:42:53.240 |
and it's repeating this over and over. Because of the projections being done in the 01:43:00.520 |
attention, it's not just mixing the pixels, but it's kind of, it's largely mixing the pixels. 01:43:07.400 |
And so this combination is very, very, very flexible. And it's flexible enough 01:43:19.640 |
that it provably can actually approximate any convolution that you can think of 01:43:26.040 |
given enough layers and enough time and learning the right parameters. 01:43:34.040 |
The problem is that for this to approximate a combination requires a lot of data and a lot of 01:43:45.560 |
layers and a lot of parameters and a lot of compute. So if you try to use this, so this is a transformer 01:43:54.760 |
network, transformer architecture. If you pass images into this, so pass an image in 01:44:04.280 |
and try to predict, say from ImageNet, the class of the image. So use SGD to try and 01:44:13.640 |
find weights for these attention projections and MLPs. If you do that on ImageNet, you will end up 01:44:21.160 |
with something that does indeed predict the class of each image, but it does it poorly. 01:44:25.560 |
Now it doesn't do it poorly because it's not capable of approximating a convolution. It 01:44:31.240 |
does it poorly because ImageNet, the entire ImageNet as in ImageNet 1k is not big enough 01:44:38.200 |
to for a transformer to learn how to do this. However, if you pass it a much bigger data set, 01:44:46.360 |
many times larger than ImageNet 1k, then it will learn to approximate this very well. And in fact, 01:44:54.280 |
it'll figure out a way of doing something like convolutions that are actually better than 01:44:58.920 |
convolutions. And so if you then take that, so that's going to be called a vision transformer 01:45:06.120 |
or VIT that's been pre-trained on a data set much bigger than ImageNet, and then you fine tune it 01:45:12.600 |
on ImageNet, you will end up with something that is actually better than ResNet. And the reason 01:45:21.880 |
it's better than ResNet is because these combinations, right, which together when combined 01:45:32.920 |
can approximate a convolution, these transformers, you know, convolutions are our best guess as to 01:45:41.080 |
like a good way to kind of represent the calculations we should do on images. But there's 01:45:47.480 |
actually much more sophisticated things you could do, you know, if you're a computer and you could 01:45:51.960 |
figure these things out better than a human can. And so a VIT actually figures out things that are 01:45:56.920 |
even better than convolutions. And so when you fine tune ImageNet using a very, you know, a VIT 01:46:05.080 |
that's been pre-trained on lots of data, then that's why it ends up being better than a ResNet. So 01:46:11.560 |
that's why, you know, the things I'm showing you are not the things that contain transformers and 01:46:22.200 |
diffusion because to make that work would require pre-training on a really, really large data set 01:46:29.880 |
for a really, really long amount of time. So anyway, so we might only come to transformers, 01:46:38.040 |
well not in a very long time, but when we do them in NLP in vision, maybe we'll cover them briefly, 01:46:47.240 |
you know, they're very interesting to use as pre-trained models. The main thing to know about 01:46:52.440 |
them is, yeah, a VIT, you know, which is a really successful and when pre-trained on lots of data, 01:46:59.080 |
which they all are nowadays, is a very successful architecture. But like literally the VIT paper 01:47:03.960 |
says, oh, we wondered what would happen if we take a totally plain 1D transformer, you know, 01:47:11.000 |
and convert it and make it work on images with as few changes as possible. So everything we've 01:47:17.800 |
learned about attention today and MLPs applies directly because they haven't changed anything. 01:47:25.480 |
And so one of the things you might realize that means is that you can't use a VIT that was trained 01:47:34.040 |
on 224 by 224 pixel images on 128 by 128 pixel images because, you know, all of these 01:47:41.960 |
self-attention things are the wrong size, you know, and specifically the problem is actually the, 01:47:53.560 |
actually it's not really the attention, let me take that back. All of the 01:48:02.120 |
position embeddings are the wrong size. And so actually that's something I, sorry, 01:48:06.600 |
I forgot to mention, is that in transformers the first thing you do is you always take 01:48:15.640 |
your, you know, these pixels and you add to them a positional embedding. 01:48:24.840 |
And that's done, I mean, it can be done lots of different ways, but the most popular way is 01:48:30.440 |
identical to what we did for the time step embedding is the sinusoidal embedding. And so 01:48:36.040 |
that's specific, you know, to how many pixels there are in your image. So yeah, that's an example, 01:48:47.320 |
it's one of the things that makes VITs a little tricky. Anyway, hopefully, yeah, you get the idea 01:48:52.520 |
that we've got all the pieces that we need. Okay, so with that discussion, 01:49:08.280 |
I think that's officially taken us over time. So maybe we should do the conditional next time. 01:49:20.040 |
Do you know what actually it's tiny? Let's just quickly do it now. You guys got time? 01:49:24.840 |
Yeah. Okay. So let's just, yeah, let's finish by doing a conditional model. So for a conditional 01:49:31.640 |
model, we're going to basically say I want something where I can say draw me the number, 01:49:38.120 |
sorry, draw me a shirt, or draw me some pants, or draw me some sandals. So we're going to pick one 01:49:44.040 |
of the 10 fashion MNIST classes and create an image of a particular class. To do that, 01:49:53.960 |
we need to know what class each thing is. Now, we already know what class each thing is 01:50:05.080 |
because it's the Y label, which way back in the beginning of time, we set, okay, 01:50:14.360 |
it's just called the label. So that tells you what category it is. 01:50:19.240 |
So we're going to change our collation function. So we call noisify as per usual. That gives us 01:50:29.480 |
our noised image, our time step, and our noise. But we're also going to then add to that tuple 01:50:40.600 |
what kind of fashion item is this. And so the first tuple will be noised image, 01:50:48.440 |
noise, and label, and then the dependent variable as per usual is the noise. 01:50:56.280 |
And so what's going to happen now when we call our unit, which is now a conditioned unit model, 01:51:00.920 |
is the input is now going to contain not just the activations and the time step, 01:51:08.520 |
but it's also going to contain the label. Okay, that label will be a number between zero and nine. 01:51:14.120 |
So how do we convert the number between zero and nine into a vector which represents that number? 01:51:21.080 |
Well, we know exactly how to do that in n.embedding. Okay, so we did that lots in part one. 01:51:27.320 |
So let's make it exactly, you know, the same size as our time embedding. So n number of 01:51:42.920 |
activations in the embedding. It's going to be the same as our time step embedding. 01:51:50.920 |
And so that's convenient. So now in the forward we do our time step embedding as usual. 01:51:55.240 |
We'll pass the labels into our conditioned embedding. 01:52:00.520 |
The time embedding we will put through the embedding let LPN before, and then we're just 01:52:06.520 |
going to add them together. And that's it, right? So this now represents a combination of the time 01:52:12.680 |
and the fashion item plus. And then everything else is identical in both parts. So all we've 01:52:21.400 |
added is this one thing. And then we just literally sum it up. So we've now got a joint 01:52:28.360 |
embedding representing two things. And then, yeah, and then we train it. 01:52:33.320 |
And, you know, interestingly, it looks like the loss, well, it ends up a bit the same, 01:52:41.640 |
but you don't often see 0.031. It is a bit easier for it to do a conditional embedding 01:52:48.360 |
model because you're telling it what it is. It just makes it a bit easier. So then to do 01:52:53.160 |
conditional sampling, you have to pass in what type of thing do you want from these labels. 01:53:06.600 |
And so then we create a vector just containing that number repeated over many times there on 01:53:19.080 |
the batch. And we pass it to our model. So our model has now learned how to denoise something 01:53:27.160 |
of type C. And so now if we say like, oh, trust me, this noise contains is a noised image of type C, 01:53:34.600 |
it should hopefully denoise it into something of type C. That's all there is to it. There's no 01:53:41.960 |
magic there. So, yeah, that's all we have to do to change the sampling. So we didn't have to change 01:53:48.760 |
DDIM step at all, right? Literally all we did was we added this one line of code and we added it 01:53:55.960 |
there. So now we can say, okay, let's say class ID zero, which is t-shirt top. So we'll pass that to 01:54:04.280 |
sample. And there we go. Well, everything looks like t-shirts and tops. Yeah, okay. I'm glad we 01:54:16.760 |
didn't leave that till next time because we can now say we have successfully replicated 01:54:25.320 |
everything in stable diffusion, except for being able to create whole sentences, 01:54:32.520 |
which is what we do with clip. Getting really close. Yes. Well, except the clip requires 01:54:39.720 |
all of NLP. So I guess we'll, we might do that next or depending on how research goes. 01:54:53.160 |
All right. We still need a latent diffusion part. Oh, good point. Latents. Okay. We'll definitely 01:55:00.600 |
do that next time. So let's see. Yeah. So we'll do a VAE and latent diffusion, 01:55:08.040 |
which isn't enough for one lesson. So maybe some of the research I'm doing will end up in the next 01:55:14.600 |
lesson as well. Yes. Okay. Thanks for the reminder. Although we've already kind of done auto-encoders, 01:55:22.120 |
so VAEs are going to be pretty, pretty easy. Well, thank you, Tanishka and Johnno. Fantastic 01:55:30.280 |
comments, as always. Glad your internet/power reappeared, Johnno. Back up. 01:55:42.600 |
All right. Thanks, gang. Cool. Thanks, everybody. That was great.