back to index

Lesson 24: Deep Learning Foundations to Stable Diffusion


Whisper Transcript | Transcript Only Page

00:00:00.000 | Hi, we are here for lesson 24. And once again, it's becoming a bit of a tradition now. We're
00:00:06.320 | joined by Jono and Tanishk, which is always a pleasure. Hi, Jono. Hi, Tanishk.
00:00:10.600 | Hello.
00:00:11.600 | Another great lesson.
00:00:12.600 | Yeah, are you guys looking forward to finally actually completing stable diffusion, at least
00:00:19.960 | the unconditional stable diffusion? Well, I should say no, even conditional. So conditional
00:00:24.040 | stable diffusion, except for the clip bit from scratch. We should be able to finish
00:00:30.000 | today. Time permitting.
00:00:31.000 | Oh, that's exciting.
00:00:32.000 | That is exciting. All right. Let's do it. Jump in any time. We've got things to talk
00:00:41.520 | about. So we're going to start with a very hopefully named 26 diffusion unit. And what
00:00:52.720 | we're going to do in 26 diffusion unit is to do unconditional diffusion from scratch.
00:01:04.480 | And there's not really too many new pieces, if I remember correctly. So all the stuff
00:01:09.880 | at the start we've already seen. And so when I wrote this, it was before I had noticed
00:01:17.920 | that the Keras approach was doing less well than the regular cosine schedule approach.
00:01:22.480 | So I'm still using Keras Noisify, but this is all the same as from the Keras notebook,
00:01:28.320 | which was 23.
00:01:31.440 | Okay, so we can now create a unit that is based on what diffusers has, which is in turn
00:01:46.240 | based on lots of other prior art. I mean, the code's not at all based on it, but the
00:01:52.280 | basic structure is going to be the same as what you'll get in diffusers. The convolution
00:02:00.880 | we're going to use is the same as the final kind of convolution we used for tiny image
00:02:06.080 | net, which is what's called the preactivation convolution. So the convolution itself happens
00:02:12.460 | at the end and the normalization and activation happen first. So this is a preact convolution.
00:02:23.720 | So then I've got a unit res net block. So I kind of wrote this before I actually did
00:02:31.000 | the preact version of tiny image net. So I suspect this is actually the same, quite possibly
00:02:36.880 | exactly the same as the tiny image net one. So maybe this is nothing specific about this
00:02:41.200 | for unit, this is just really a preact conv and a preact res net block. So we've got the
00:02:49.320 | two comms as per usual and the identity conv. Now there is one difference though to what
00:02:56.240 | we've seen before for res net blocks, which is that this res net block has no option to
00:03:01.360 | do downsampling, no option to do a strayed. This is always strayed one, which is our default.
00:03:12.760 | So the reason for that is that when we get to the thing that strings a bunch of them
00:03:18.860 | together, which will be called down block, this is where you have the option to add downsampling.
00:03:26.040 | But if you do add downsampling, we're going to add a strayed to convolution after the res
00:03:31.040 | block. And that's because this is how diffusers and stable diffusion does it. I haven't studied
00:03:40.960 | this closely to Nishkif or Dono if either of you have know like where this idea came
00:03:46.440 | from or why. I'd be curious, you know, the difference is that normally we would have
00:03:54.480 | average pooling here in this connection. But yeah, this different approach is what we're
00:04:03.060 | using.
00:04:04.060 | A lot of the history of the diffusers unconditional unit is to be compatible with the DDPM weights
00:04:12.520 | that were released and some follow and work from that. And I know like then improved DDPM
00:04:17.560 | and these others like they all kind of built on that same sort of unit structure, even
00:04:21.120 | though it's slightly unconventional if you're coming from like a normal computer vision
00:04:25.720 | background.
00:04:26.720 | And do you recall where the DDPM architecture came from? Because like some of the ideas
00:04:31.600 | came from some of the N units, but I don't know if DDPM.
00:04:38.080 | Yeah, they had something called efficient unit that was inspired by some prior work
00:04:42.760 | that I can't remember the lineage. Anyway, yeah, I just think the diffusers one has since
00:04:50.320 | become you know, like you can add in parameters to control some of this stuff. But yeah, it's
00:04:56.000 | we shouldn't assume that this is the optimal approach, I suppose. But yeah, I will dig
00:05:01.280 | into the history and try and find out how much like, what ablation studies have been
00:05:05.480 | done. So for those of you who haven't heard of ablation studies, that's where you'd like
00:05:08.480 | try, you know, a bunch of different ways of doing things and score which one works better
00:05:13.160 | and which one works less well and kind of create a table of all of those options. And
00:05:19.000 | so where you can't find ablation studies for something you're interested in, often that
00:05:24.160 | means that, you know, maybe not many other options were tried because researchers don't
00:05:28.720 | have time to try everything.
00:05:31.480 | Okay, now the unit, if we go back to the unit that we used for super resolution, we just
00:05:41.640 | go back to our most basic version. What we did as we went down through the layers in
00:05:51.560 | the down sampling section, we stored the activations at each point into a list called layers.
00:06:06.600 | And then as we went through the up sampling, we added those down sampling layers back into
00:06:13.000 | the up sampling activations. So that's kind of basic structure of a unit. You don't have
00:06:20.440 | to add, you can also concatenate and actually concatenating is what is, I think it's more
00:06:27.160 | common nowadays and I think your original unit might have been concatenating. Although
00:06:31.840 | for super resolution, just adding seems pretty sensible. So we're going to concatenate. But
00:06:37.920 | what we're going to do is we're going to try to, we're going to kind of exercise our Python
00:06:42.400 | muscles a little bit to try to see interesting ways to make some of this a little easier
00:06:47.520 | to turn different down sampling backbones into units. And you also use that as an opportunity
00:06:55.880 | to learn a bit more Python. So what we're going to do is we're going to create something called
00:07:03.460 | a saved res block and a saved convolution. And so our down blocks, so these are our res
00:07:11.880 | blocks containing a certain number of res block layers, followed by this optional strive
00:07:18.800 | to conv. We're going to use saved res blocks and saved cons. And what these are going to
00:07:23.080 | do, it's going to be the same as a normal convolution and the same as a normal res block,
00:07:27.600 | the same as normal unit res block. But they're going to remember the activations. And the
00:07:35.240 | reason for that is that later on in the unit, we're going to go through and grab those saved
00:07:43.040 | activations all at once into a big list. So then yeah, we basically don't have to kind
00:07:51.240 | of think about it. And so to do that, we create a class called a save module. And all saved
00:07:58.400 | module does is it calls forward to grab the res block or conv results and stores that
00:08:11.400 | before returning it. Now that's weird because hopefully you know by now that super calls
00:08:19.120 | the thing in the parent class, that save module doesn't have a parent class. So this is what's
00:08:25.480 | called a mixin. And it's using something called multiple inheritance. And mixins are as it
00:08:42.800 | describes here. It's a design pattern, which is to say it's not particularly a part of
00:08:50.800 | Python per se. It's a design pattern that uses multiple inheritance. Now what multiple
00:08:55.000 | inheritance is where you can say, oh, this class called saved res block inherits from
00:09:03.000 | two things, save module and unit res block. And what that does is it means that all of
00:09:11.080 | the methods in both of these will end up in here. Now that would be simple enough, except
00:09:18.040 | we've got a bit of a confusion here, which is that unit res block contains forward and
00:09:23.440 | saved module contains forward. So it's all very well just combining the methods from
00:09:27.960 | both of them. But what if they have the same method? And the answer is that the one you
00:09:35.440 | list first can call, when it calls forward, it's actually calling forward in the later
00:09:44.040 | one. And that's why it's a mixin. It's mixing this functionality into this functionality.
00:09:51.560 | So it's a unit res block where we've customized forward. So it calls the existing forward
00:09:57.520 | and also saves it. So you see mixins quite a lot in the Python standard library. For
00:10:04.560 | example, the basic HTTP stuff, some of the basic thread stuff with networking uses multiple
00:10:16.280 | inheritance using this mixin pattern. So with this approach, then the actual implementation
00:10:22.120 | of saved res block is nothing at all. So pass means don't do anything. So this is just literally
00:10:29.080 | just a class which has no implementation of its own other than just to be a mixin of these
00:10:36.520 | two classes. So a saved convolution is an nn.conf2d with the saved module mixed in.
00:10:47.800 | So what's going to happen now is that we can call a saved res block just like a unit res
00:10:54.280 | block and a saved conv just like an nn.conf2d. But that object is going to end up with the
00:11:01.400 | activations inside the dot saved attribute. So now a downsampling block is just a sequential
00:11:10.920 | of saved res blocks. As per usual, the very first one is going to have the number of n
00:11:22.640 | channels to start with and it will always have the number of nf, the number of filters
00:11:27.640 | output, and then after that the inputs will be else equal to nf because the first one's
00:11:32.560 | changed the number of channels. And we'll do that for however many layers we have. And
00:11:38.680 | then at the end of that process, as we discussed, we will add to that sequential a saved conv
00:11:45.600 | with str2 to do the downsampling if requested. So we're going to end up with a single nn.sequential
00:11:51.360 | for a down block. And then an up block is going to look very similar, but instead of
00:12:01.800 | using an nn.conf2d with str2, upsampling will be done with a sequence of an upsampling layer.
00:12:13.400 | And so literally all that does is it just duplicates every pixel four times into little
00:12:17.360 | two by two grid. That's what an upsampling layer does, nothing clever. And then follow
00:12:22.360 | that by a str1 convolution. So that allows it to, you know, adjust some of those pixels
00:12:31.000 | as if necessary with a simple three by three conv. So that's pretty similar to a str2 downsampling.
00:12:40.040 | This is kind of the rough equivalent for upsampling. There are other ways of doing upsampling.
00:12:47.120 | This is just the one that stable diffusion does. So an up block looks a lot like a down
00:12:54.540 | block, except that now, so as before, we're going to create a bunch of unit res blocks.
00:13:02.360 | These are not saved res blocks, of course. We want to use the saved results in the upsampling
00:13:06.600 | path of the unit. So we just use normal res blocks. But what we're going to do now is
00:13:13.960 | as we go through each res net, we're going to call it not just on our activations, but
00:13:20.800 | we're going to concatenate that with whatever was stored during the downsampling path. So
00:13:29.640 | this is going to be a list of all of the things stored in the downsampling path. It'll be
00:13:34.880 | passed to the up block. And so .pop will grab the last one off that list and concatenate
00:13:41.120 | it with the activations and pass that to the res net.
00:13:45.480 | So we need to know how many filters there were, how many activations there were in the
00:13:52.760 | downsampling path. So that's stored here. This is the previous number of filters in the downsampling
00:13:57.920 | path. And so the res block wanted to add those in in addition to the normal number. So that's
00:14:12.120 | what's going to happen there. And so yeah, do that for each layer as before. And then
00:14:19.880 | at the end, add an upsampling layer if it's been requested. So it's a boolean. OK, so
00:14:30.760 | that's the upsampling block. Does that all make sense so far?
00:14:34.520 | Yeah, it looks good.
00:14:38.600 | OK. OK, so the unit now is going to look a lot like our previous unit. We're going to
00:14:49.320 | start out as we tend to with a convolution to now allow us to create a few more channels.
00:14:57.080 | And so we're passing to our unit. That's just how many channels are in your image and how
00:15:02.320 | many channels are in your output image. So for normal full color images, that'll be 3/3.
00:15:08.780 | How many filters are there for each of those res net blocks, up blocks and down blocks
00:15:13.800 | you've got. And in the downsampling, how many layers are there in each block? So we go from
00:15:21.240 | the conv will go from in channel. So it'd be 3 to an F0, which this is the number of filters
00:15:26.800 | in the stable diffusion model. They're pretty big, as you see by default. And so that's
00:15:37.160 | the number of channels we would create, which is like very redundant in that this is a 3
00:15:45.480 | by 3 conv. So it only contains 3 by 3 by 3 channels equals 27 inputs and 224 outputs.
00:15:53.360 | So it's not doing computation, useful computation in a sense. It's just giving it more space
00:16:00.220 | to work with down the line, which I don't think that makes sense, but I haven't played
00:16:06.360 | with it enough to be sure. Normally we would do like, you know, like a few res blocks or
00:16:15.200 | something at this level to more gradually increase it because this feels like a lot
00:16:18.640 | of wasted effort. But yeah, I haven't studied that closely enough to be sure.
00:16:24.360 | So Jamie, just to tweet, this is the default, I think the default settings for the unconditional
00:16:30.000 | unit in diffusers. But the stable diffusion unit actually has even more channels. It has
00:16:34.400 | 320, 640, and then 1,280, 1,280. Cool. Thanks for clarifying. And it's, yeah,
00:16:41.800 | the unconditional one, which is what we're doing right now. That's a great point. Okay.
00:16:47.720 | So then we, yeah, we go through all of our number of filters and actually the first res
00:16:54.520 | block contains 224 to 224. So that's why it's kind of keeping track of this stuff. And then
00:17:01.480 | the second res block is 224 to 448 and then 448 to 672 and then 672 to 896. So that's
00:17:07.760 | why we're just going to have to keep track of these things. So yeah, we add, so we have
00:17:12.800 | a sequential for our down blocks and we just add a down block. The very last one doesn't
00:17:18.440 | have downsampling, which makes sense, right? Because the very last one, there's nothing
00:17:23.560 | after it, so no point downsampling. Other than that, they all have downsampling.
00:17:29.240 | And then we have one more res block in the middle, which, is that the same as what we
00:17:37.000 | did? Okay. So we didn't have a middle res block in our original unit here. What about
00:17:47.280 | this one? Do we have any mid blocks? No, so we haven't done. Okay. But I mean, so it's
00:17:51.240 | just another res block that you do after the downsampling. And then we go through the reversed
00:17:57.920 | list of filters and go through those and adding up blocks. And then one convolution at the
00:18:02.960 | end to turn it from 224 channels to three channels. Okay. And so the forward then is
00:18:14.560 | going to store in saved all the layers, just like we did back with this unit. But we don't
00:18:27.240 | really have to do it explicitly now. We just call the sequential model. And thanks to our
00:18:32.840 | automatic saving, each of those now will, we can just go through each of those and grab
00:18:38.400 | their dot saved. So that's handy. We then call that mid block, which is just another
00:18:44.080 | res block. And then same thing. Okay. Now for the ARPS. And what we do is we just passed
00:18:48.360 | in those saved, right? And just remember, it's going to pop them out each time. And
00:18:56.600 | then the conv at the end. So that's, yeah, that's it. That's our unconditional model.
00:19:06.200 | It's not quite the same as the diffuses unconditional model, because it doesn't have a tension, which
00:19:09.800 | is something we're going to add next. But other than that, this is the same. So let's
00:19:17.120 | for, because we're doing a simpler problem, which is fashion MNIST, we'll use less channels
00:19:21.560 | than the default. Using two layers per block is standard. One thing to note though, is
00:19:29.280 | that in the up sampling blocks, it actually is going to be three layers, num layers plus
00:19:34.560 | one. And the reason for that is that the way stable diffusion and diffuses do it is that
00:19:41.360 | even the output of the down sampling is also saved. So if you have num layers equals two,
00:19:48.600 | then there'll be two res blocks saving things here and one conv saving things here. So you'll
00:19:54.280 | have three saved cross connections. So that's why there's an extra plus one here.
00:20:02.760 | Okay. And then we can just train it using mini AI as per usual. Nope, I didn't save
00:20:13.240 | it after I last trained it. Sorry about that. So trust me, it trained. Okay. Now that, oh,
00:20:22.800 | okay. No, that is actually missing something else important as well as attention. The other
00:20:26.840 | thing it's missing is that thing that we discovered is pretty important, which is the time embedding.
00:20:34.600 | So we already know that sampling doesn't work particularly well with that time embedding.
00:20:38.280 | So I didn't even bother sampling this. I didn't want to add all this stuff necessary to make
00:20:42.240 | that work a bit better. So let's just go ahead and do time embedding.
00:20:48.520 | So time embedding, there's a few ways to do it. And the way it's done in stable diffusion
00:20:55.340 | is what's called sinusoidal embeddings. The basic idea, maybe we'll skip ahead a bit.
00:21:04.680 | The basic idea is that we're going to create a res block with embeddings where forward
00:21:11.080 | is not just going to get the activations, but it's also going to get T, which is a vector
00:21:19.480 | that represents the embeddings of each time step. So actually it'll be a matrix because
00:21:25.320 | it's really in the batch. But for one element of the batch, it's a vector. And it's an embedding
00:21:29.880 | in exactly the same way as when we did NLP. Each token had an embedding. And so the word
00:21:36.760 | "the" would have an embedding and the word "Johnno" would have an embedding and the word
00:21:41.040 | "Tanishk" would have an embedding, although Tanishk would probably actually be multiple
00:21:44.800 | tokens until he's famous enough that he's mentioned in nearly every piece of literature,
00:21:50.320 | at which point Tanishk will get his own token, I expect. That's how you know when you've
00:21:56.480 | made it. So the time embedding will be the same. T of time step zero will have a particular
00:22:05.600 | vector, time step one will have a particular vector, and so forth. Well, we're doing Keras.
00:22:12.760 | So actually they're not time step one, two, three. They're actually sigmas, you know.
00:22:18.560 | So they're continuous. But same idea. A specific value of sigma, which is actually what T is
00:22:26.960 | going to be, slightly confusingly, will have a specific embedding. Now, we want two values
00:22:37.240 | of sigma or T, which are very close to each other, should have similar embeddings. And
00:22:43.840 | if they're different to each other, they should have different embeddings. So how do we make
00:22:51.600 | that happen? You know, and also make sure there's a lot of variety of the embeddings across
00:22:56.440 | all the possibilities. So the way we do that is with these sinusoidal time steps. So let's
00:23:04.680 | have a look at how they work. So you first have to decide how big do you want your embeddings
00:23:09.440 | to be? Just like we do at NLP. Does the word "the," is it represented by eight floats
00:23:16.440 | or 16 floats or 400 floats or whatever? Let's just assume it's 16 now. So let's say we're
00:23:24.120 | just looking at a bunch of time steps, which is between negative 10 and 10. And we'll just
00:23:30.400 | do 100 of them. I mean, we don't actually have negative sigmas or T. So it doesn't exactly
00:23:36.120 | make sense. But it doesn't matter. It gives you the idea. And so then we say, OK, what's
00:23:43.960 | the largest time step you could have or the largest sigma that you could have? Interestingly,
00:23:50.440 | every single model I've found, every single model I've found uses 10,000 for this. Even
00:23:57.280 | though that number actually comes from the NLP transformers literature, and it's based
00:24:01.920 | on the idea of, like, OK, what's the maximum sequence length we support? You could have
00:24:05.960 | up to 10,000 things in a document or whatever in a sequence. But we don't actually have
00:24:14.800 | a sigmas that go up to 10,000. So I'm using the number that's used in real life in stable
00:24:19.200 | diffusion and all the other models. But interestingly, here purely, as far as I can tell, as a hysterical
00:24:25.280 | accident, because this is like the maximum sequence length that NLP transformers people
00:24:31.200 | thought they would need to support. OK, now what we're then going to do is we're going
00:24:38.640 | to be then doing e to the power of a bunch of things. And so that's going to be our exponent.
00:24:47.720 | And so our exponent is going to be equal to log of the period, which is about nine, times
00:24:57.820 | the numbers between 0 and 1, eight of them, because we said we want 16. So you'll see
00:25:04.720 | why we want eight of them and not 16 in a moment. But basically here are the eight exponents
00:25:11.600 | we're going to use. So then not surprisingly, we do e to the power of that. OK, so we do
00:25:18.880 | e to the power of that, each of these eight things. And we've also got the actual time
00:25:28.440 | steps. So imagine these are the actual time steps we have in our batch. So there's a batch
00:25:33.320 | of 100, and they contain this range of sigmas or time steps. So to create our embeddings,
00:25:43.200 | what we do is we do an outer product of the exponent.x and the time steps. This is step
00:25:52.520 | one. And so this is using a broadcasting trick we've seen before. We add a unit axis and
00:26:01.560 | an axis 0 here, and add a unit axis 1 here, and add a unit axis and axis 0 here. So if
00:26:13.840 | we multiply those together, then it's going to broadcast this one across this axis and
00:26:22.520 | this one across this axis. So we end up with a 100 by 8. So it's basically a Cartesian
00:26:28.760 | product or the possible combinations of time step and exponent multiplied together. And
00:26:36.080 | so here's a few of those different exponents for a few different values. OK, so that's
00:26:50.800 | not very interesting yet. We haven't yet reached something where each time step is similar
00:26:57.960 | to each next door time step. You know, over here, you know, these embeddings look very
00:27:04.000 | different to each other. And over here, they're very similar. So what we then do is we take
00:27:10.720 | the sine and the cosine of those. So that is 100 by 8. And that is 100 by 8. And that
00:27:23.780 | gives us 100 by 16. So we concatenate those together. And so that's a little bit hard
00:27:34.920 | to wrap your head around. So let's take a look. So across the 100 time steps, 100 sigma,
00:27:44.120 | this one here is the first sine wave. And then this one here is the second sine wave.
00:27:57.240 | And this one here is the third. And this one here is the fourth and the fifth. So you can
00:28:05.200 | see as you go up to higher numbers, you're basically stretching the sine wave out. And
00:28:17.160 | then once you get up to index 8, you're back up to the same frequency as this blue one,
00:28:25.240 | because now we're starting the cosine rather than sine. And cosine is identical to sine.
00:28:30.200 | It's just shifted across a tiny bit. You can see these two light blue lines are the same.
00:28:35.080 | And these two orange lines are the same. They're just shifted across, I shouldn't say, lines
00:28:39.600 | or curves. So when we concatenate those all together, we can actually draw a picture of
00:28:45.720 | it. And so this picture is 100 pixels across and 16 pixels top to bottom. And so if you
00:28:56.880 | picked out a particular point, so for example, in the middle here for t equals 0, well sigma
00:29:03.000 | equals 0, one column is an embedding. So the bright represents higher numbers and the dark
00:29:11.260 | represents lower numbers. And so you can see every column looks different, even though
00:29:17.280 | the columns next to each other look similar. So that's called a time step embedding. And
00:29:26.180 | this is definitely something you want to experiment with. I've tried to do the plots I thought
00:29:34.060 | are useful to understand this. And Johno and Tanishka also had ideas about plots for these,
00:29:42.320 | which we've shown. But the only way to really understand them is to experiment. So then
00:29:49.320 | we can put that all into a function where you just say, OK, well, how many times-- sorry,
00:29:53.680 | what are the time steps? How many embedding dimensions do you want? What's the maximum
00:29:57.780 | period? And then all I did was I just copied and pasted the previous cells and merged them
00:30:02.200 | together. So you can see there's our outer product. And there's our cat of sine and cos.
00:30:15.020 | If you end up with a-- if you have an odd numbered embedding dimension, you have to
00:30:18.480 | pat it to make it even. Don't worry about that. So here's something that now you can pass
00:30:22.800 | in the number of-- sorry, the actual time steps or sigma's and the number of embedding
00:30:27.440 | dimensions. And you will get back something like this. It won't be a nice curve because
00:30:34.480 | your time steps in a batch won't all be next to each other. It's the same idea.
00:30:39.600 | Can I call it something on that little visualization there, which goes back to your comment about
00:30:45.400 | the max period being super high? So you said, OK, adjacent ones are somewhat similar because
00:30:51.200 | that's what we want. But there is some change. But if you look at all of this first 100,
00:30:57.000 | some-- just like the half of the embeddings look like they don't really change at all.
00:31:01.200 | And that's because 50 to 100 on a scale of like 0 to 10,000, you want those to be quite
00:31:07.240 | similar because those are still very early in this super long sequence that these are
00:31:10.760 | designed for.
00:31:11.760 | Yeah. So here, actually, we've got--
00:31:13.560 | [INTERPOSING VOICES]
00:31:14.560 | --wasted space.
00:31:15.560 | Yeah. So here we've got a max period of 1,000 instead. And I've changed the figure size so
00:31:20.040 | you can see it better. And it's using up a bit more of the space. Yeah. Or go to max
00:31:27.240 | period of 10. And it's actually now-- this is, yeah, using it much better. Yeah. So based
00:31:36.440 | on what you're saying, Johnno, I agree. It seems like it would be a lot richer to use
00:31:44.280 | these time step embeddings with a suitable max period. Or maybe you just wouldn't need
00:31:48.640 | as many embedding dimensions. I guess if you did use something very wasteful like this
00:31:53.600 | but you used lots of embedding dimensions, then it's going to still capture some useful
00:31:58.360 | ones.
00:32:01.520 | Yeah. Thanks, Johnno. So yeah. Yeah. So this is one of these interesting little insights
00:32:10.840 | about things that are buried deep in code, which I'm not sure anybody probably much looks
00:32:20.800 | OK. So let's do a unit with time step embedding in it. So what do you do once you've got like
00:32:31.440 | this column of embeddings for each item of the batch? What do you do with it? Well, there's
00:32:37.640 | a few things you can do with it. What stable diffusion does, I think, is correct. I'm not
00:32:44.960 | promising. I remember all these details right, is that they make their embedding dimension
00:32:51.560 | length twice as big as the number of activations. And what they then do is we can use chunk
00:33:04.800 | to take that and split it into two separate variables. So that's literally just the opposite
00:33:10.720 | of concatenate. It's just two separate variables. And one of them is added to the activations
00:33:17.700 | and one of them is multiplied by the activations. So this is a scale and a shift. We don't just
00:33:27.680 | grab the embeddings as is, though, because each layer might want to do-- each res block
00:33:35.000 | might want to do different things with them. So we have a embedding projection, which is
00:33:41.400 | just a linear layer which allows them to be projected. So it's projected from the number
00:33:45.920 | of embeddings to 2 times the number of filters so that that torch.chunk works. We also have
00:33:55.200 | an activation function called silu. This is the activation function that's used in stable
00:34:01.800 | diffusion. I don't think the details are particularly important. But it looks basically like a rectified
00:34:15.120 | linear with a slight curvy bit. Also known as SWISH. Also known as SWISH. And it's just
00:34:25.960 | equal to x times sigmoid x. And yeah, I think it's like activation functions don't make
00:34:37.120 | a huge difference. But they can make things train a little better or a little faster.
00:34:44.120 | And SWISH has been something that's worked pretty well. So a lot of people using SWISH
00:34:48.880 | or silu. I always call it SWISH. But I think silu was originally the galley paper which
00:34:57.440 | had silu was where it originally was kind of invented. And maybe people didn't quite
00:35:01.320 | notice. And then another paper called it SWISH. And everybody called it SWISH. And then people
00:35:05.640 | were like, wait, that wasn't the original paper. So I guess I should try to call it
00:35:11.080 | silu. Other than that, it's just a normal res block. So we do our first conv. Then we do
00:35:19.200 | our embedding projection of the activation function of time steps. And so that's going
00:35:24.800 | to be applied to every pixel height and width. So that's why we have to add unit axes on
00:35:34.840 | the height and width that it's going to cause it to broadcast across those two axes. Do
00:35:39.400 | our chunk. Do the scale and shift. Then we're ready for the second conv. And then we add
00:35:44.800 | it to the input with an additional conv, one stride one conv if necessary as we've done
00:35:52.680 | before if we have to change the number of channels. OK. Yeah, because I like exercising
00:36:00.920 | our Python muscles, I decided to use a second approach now for the down block and the up
00:36:07.680 | block. I'm not saying which one's better or worse. We're not going to use multiple inheritance
00:36:15.960 | anymore. But instead, we're going to use-- well, it's not even a decorator. It's a function
00:36:23.200 | which takes a function. What we're going to do now is we're going to use funf2dd and mbrezblock
00:36:29.680 | directly. But we're going to pass them to a function called saved. The function called
00:36:34.120 | saved is something which is going to take as input a callable, which could be a function
00:36:41.920 | or a module or whatever. So in this case, it's a module. Takes an mbrezblock or a conv2d.
00:36:49.480 | And it returns a callable. The callable it returns is identical to the callable that's
00:36:56.000 | passed into it, except that it saves the result, saves the activations, saves the result of
00:37:01.640 | a function. Where does it save it? It's going to save it into a list in the second argument
00:37:10.080 | you pass to it, which is the block. So the save function, you're going to pass it the
00:37:18.280 | module. We're going to grab the forward from it and store that away to remember what it
00:37:23.080 | was. And then the function that we want to replace it with, call it underscore f, going
00:37:29.320 | to take some arguments and some keyword arguments. Well, basically, it's just going to call the
00:37:32.840 | original modules.forward, passing in the arguments and keyword arguments. And we're then going
00:37:39.880 | to store the result in something called the saved attribute inside here. And then we have
00:37:51.320 | to return the result. So then we're going to replace the modules forward method with
00:37:57.720 | this function and return the module. So that module's now been-- yeah, I said callable,
00:38:04.760 | actually. It can't be called. It has to specifically be a module, because with the forward that
00:38:08.600 | we're changing. This at wraps is just something which automatically-- it's from the Python
00:38:14.120 | standard library. So it's going to copy in the documentation and everything from the
00:38:18.200 | original forward so that it all looks like nothing's changed.
00:38:26.280 | Now, where does this dot saved come from? I realized now, actually, we could make this
00:38:30.280 | easier and automate it. But I forgot, didn't think of this at the time. So we have to create
00:38:36.200 | the saved here in the down block. It actually would have made more sense, I think, here
00:38:42.120 | for it to have said if the saved attribute doesn't exist, then create it, which would
00:38:47.160 | look like this. If not has atcher block comma saved block dot saved, because if you do this,
00:38:59.800 | then you wouldn't need this anymore. Anyway, I didn't think of that at the time. So let's
00:39:06.680 | pretend that's not what we do. OK, so now the downsampling conv and the resnets both
00:39:19.880 | contain saved versions of modules. We don't have to do anything to make that work. We
00:39:26.120 | just have to call them. We can't use sequential anymore, because we have to pass in the time
00:39:31.880 | step to the resnets as well. It would be easy enough to create your own sequential for things
00:39:39.480 | with time steps, which passes them along. But that's not what we're doing here. Yeah, maybe
00:39:48.600 | it makes sense for sequential to always pass along all the extra arguments. But I don't
00:39:53.240 | think that's how they work. Yeah, so our up block is basically exactly the same as before,
00:40:01.320 | except we're now using ember as blocks instead. Just like before, we're going to concatenate.
00:40:06.360 | So that's all the same. OK, so a unit model with time embeddings
00:40:15.720 | is going to look, if we look at the forward, the thing we're passing into it now is a tuple
00:40:25.160 | containing the activations and the time steps, or the segments in our case. So split them out.
00:40:31.560 | And what we're going to do is we're going to call that time step embedding function we wrote,
00:40:37.800 | saying, OK, these are the time steps. And the number of time step embeddings we want
00:40:44.200 | is equal to however many we asked for. And we're just going to set it equal to the first number
00:40:52.200 | of filters. That's all that happens there. And then we want to give the model the ability
00:41:01.400 | then to do whatever it wants with those, to make those work the way it wants to.
00:41:04.680 | And the easiest, smallest way to do that is to create a tiny little MLP. So we create a tiny
00:41:09.960 | little MLP, which is going to take the time step embeddings and return the actual embeddings to
00:41:15.480 | pass into the ResNet box. So tiny little MLP is just a linear layer with-- it's thinking here.
00:41:27.160 | That's interesting. My linear layer by default has an activation function. I'm pretty sure we
00:41:44.600 | should have act equals none here. It should be a linear layer and then an activation and then a
00:41:50.040 | linear layer. So I think I've got a bug, which we will need to try rerunning.
00:41:55.160 | OK. It won't be the end of the world. It just means all the negatives will be lost here.
00:42:14.280 | Makes it half-- only half as useful. That's not great. OK. And these are the kind of things like,
00:42:23.240 | you know, as you can see, you've got to be super careful of, like, where do you have activation
00:42:28.040 | functions? Where do you have batch norms? Is it pre-activation? Is it post-activation?
00:42:33.480 | It trains even if you make that mistake. And in this case, it's probably not too much performance,
00:42:40.760 | but often it's like, oh, you've done something where you accidentally zeroed out, you know,
00:42:46.040 | all except the last few channels of your output block or something like that.
00:42:53.000 | When it work tries anyway, it does its best. It uses what it can.
00:42:57.080 | Yeah, it makes it very difficult.
00:43:00.760 | To make sure you're not giving it those handicaps.
00:43:04.040 | Yeah. It's not like you're making a CRUD app or something and you know that it's not working
00:43:08.120 | because it crashes or because, like, it doesn't show the username or whatever.
00:43:13.800 | Instead, you just get, like, slightly less good results. But since you haven't
00:43:20.200 | done it correctly in the first place, you don't know it's the less good results.
00:43:23.640 | Yeah, there's not really great ways to do this. It's really nice if you can have an
00:43:28.360 | existing model to compare to or something like that, which is where Kaggle competitions work
00:43:35.080 | really well. Actually, if somebody's got a Kaggle result, then you know that's a really
00:43:38.840 | good baseline and you can check whether yours is as good as theirs.
00:43:43.320 | All right. So, yeah, that's what this MLP is for. So, the down and up blocks are the
00:43:52.920 | same as before. The convout is the same as before. So, yeah, so we grab our time step embedding. So,
00:43:57.960 | that's just that outer product passed through this sinusoidal, the sine and cosine. We then
00:44:03.400 | pass that through the MLP. And then we call our downsampling, passing in those embeddings each
00:44:19.160 | time. You know, it's kind of interesting that we pass in the embeddings every time
00:44:23.240 | in the sense I don't exactly know why we don't just pass them in at the start.
00:44:28.280 | And in fact, in MLP, these kinds of embeddings, I think, are generally just passed into the start.
00:44:34.440 | So, this is kind of a curious difference. I don't know why. It's, you know, if there's
00:44:41.800 | been ablation studies or whatever. Do you guys know, are there like any popular diffusion-y or
00:44:48.760 | generative models with time embeddings that don't pass them in or is this pretty universal?
00:44:55.240 | >> Some of the fancier architectures like
00:45:02.200 | recurrent interface networks and stuff just pass in the conditioning.
00:45:06.680 | I'm actually not sure. Yeah, maybe they do still do it at every stage. I think some of them just
00:45:15.240 | take in everything all at once up front and then do a stack of transformer blocks or something
00:45:19.800 | like that. So, I don't know if it's universal, but it definitely seems like all the unit-style ones
00:45:24.840 | have this the time step embedding going in. >> Maybe we should try some ablations to see,
00:45:31.320 | yeah, if it matters. I mean, I guess it doesn't matter too much either way. But, yeah,
00:45:37.800 | if you didn't need it at every step, then it would maybe save you a bit of compute, potentially.
00:45:43.480 | Yeah, so now the upsampling, you're passing in the activations, the time step embeddings,
00:45:50.280 | and that list of saved activations. So, yeah, now we have a non-attention stable diffusion unit.
00:46:01.080 | So, we can train that. And we can sample from it
00:46:11.640 | using the same -- I just copied and pasted all the stuff from the Keras notebook that we had.
00:46:19.880 | And there we have it. This is our first diffusion from scratch.
00:46:26.440 | >> So, we wrote every piece of code for this diffusion model.
00:46:32.280 | >> Yeah, I believe so. I mean, obviously, in terms of the optimized kudor implementations of
00:46:38.600 | stuff, no. But, yeah, we've written our version of everything here, I believe.
00:46:45.240 | >> A big milestone.
00:46:46.280 | >> I think so, yeah. And these FIDS are about the same as the FIDS that we get
00:46:50.360 | from the stable diffusion one. They're not particularly higher or lower. They bounce
00:46:56.200 | around a bit, so it's a little hard to compare. Yeah, they're basically the same.
00:46:59.880 | Yeah, so that's -- that is an exciting step. And
00:47:11.240 | okay, yeah, that's probably a good time to have a five-minute break.
00:47:16.200 | Yeah, okay. Let's have a five-minute break.
00:47:20.440 | Okay. Normally, I would say we're back, but only some of us are back.
00:47:29.480 | Johnno -- Johnno's internet and electricity in same-type way is not the most reliable thing.
00:47:36.280 | And he seems to have disappeared, but we expect him to reappear at some point.
00:47:39.960 | So we will kick on Johnno-less and hope that Zimbabwe's infrastructure sorts itself out.
00:47:50.360 | All right. So we're going to talk about attention. We're going to talk about attention for a few
00:47:58.760 | reasons. Reason number one, very pragmatic. We said that we would replicate stable diffusion,
00:48:04.120 | and the stable diffusion unit has tension in it. So we would be lying if we didn't do attention.
00:48:09.160 | Okay. Number two,
00:48:12.840 | attention is one of the two basic building blocks of transformers. A transformer layer is attention
00:48:22.280 | attached to a one-layer MLP. We already know how to create a one-layer or one-hidden layer MLP.
00:48:28.280 | So once we learn how to do attention, we'll know how to -- we'll know how to create transformer
00:48:34.760 | blocks. So those are two good reasons. I'm not including a reason which is our model is going
00:48:46.760 | to look a lot better with attention, because I actually haven't had any success seeing any
00:48:51.480 | diffusion models I've trained work better with attention. So just to set your expectations,
00:48:59.960 | we are going to get it all working. But regardless of whether I use our implementation of attention
00:49:05.960 | or the diffuser's one, it's not actually making it better. That might be because we need to use
00:49:15.560 | better types of attention than what diffuser's has, or it might be because it's just a very
00:49:21.160 | subtle difference that you only see on bigger images. I'm not sure. That's something we're still
00:49:28.360 | trying to figure out. This is all pretty new. And not many people have done kind of the diffusion,
00:49:34.840 | the kind of ablation studies necessary to figure these things out. So yeah, so that's just life.
00:49:42.040 | Anyway, so there's lots of good reasons to know about attention. We'll certainly be using it a
00:49:46.840 | lot once we do an LP, which we'll be coming to pretty shortly, pretty soon. And it looks like
00:49:54.360 | Jono is reappearing as well. So that's good. Okay, so let's talk about attention.
00:50:03.480 | The basic idea of attention is that we have
00:50:17.320 | an image, and we're going to be sliding a convolution kernel across that image.
00:50:25.880 | And obviously, we've got channels as well, or filters. And so this also has that. Okay.
00:50:36.760 | And as we bring it across, we might be, you know, we're trying to figure out like what
00:50:45.640 | activations do we need to create to eventually, you know, correctly create our outputs.
00:50:51.240 | But the correct answer as to what's here may depend on something that's way over here,
00:51:00.760 | and/or something that's way over here. So for example, if it's a cute little bunny rabbit,
00:51:08.600 | and this is where its ear is, you know, and there might be two different types of bunny rabbit that
00:51:15.640 | have different shaped ears, well, it'd be really nice to be able to see over here what its other
00:51:22.040 | ear looks like, for instance. With just convolutions, that's challenging. It's not
00:51:29.240 | impossible. We talked in part one about the receptive field. And as you get deeper and
00:51:35.640 | deeper in a convnet, the receptive field gets bigger and bigger. But it's, you know,
00:51:41.960 | at higher up, it probably can't see the other ear at all. So it can't put it into those kind of more
00:51:46.920 | texture level layers. And later on, you know, even though this might be in the receptive field here,
00:51:54.200 | most of the weight, you know, the vast majority of the activations it's using is the stuff
00:51:59.480 | immediately around it. So what attention does is it lets you take a weighted average of other pixels
00:52:10.440 | around the image, regardless of how far away they are. And so in this case, for example,
00:52:18.840 | we might be interested in bringing in at least a few of the channels of these pixels over here.
00:52:28.280 | The way that attention is done in stable diffusion is pretty hacky and known to be suboptimal.
00:52:39.960 | But it's what we're going to implement because we're implementing stable diffusion and time
00:52:45.400 | permitting. Maybe we'll look at some other options later. But the kind of attention we're going to
00:52:50.600 | be doing is 1D attention. And it was a tension that was developed for NLP. And NLP is sequences,
00:52:59.400 | one-dimensional sequences of tokens. So to do attention stable diffusion style,
00:53:04.840 | we're going to take this image and we're going to flatten out the pixels.
00:53:11.400 | So we've got all these pixels. We're going to take this row and put it here. And then we're
00:53:18.680 | going to take this row, we're going to put it here. So we're just going to flatten the whole thing out
00:53:23.400 | into one big vector of all the pixels of row one and then all the pixels of row two and then all
00:53:29.880 | the pixels of the row three. Or maybe it's column one, column two, column three. I can't remember
00:53:32.520 | this row-wise or column-wise, but it's flattened out anywho. And then it's actually, for each image,
00:53:44.120 | it's actually a matrix, which I'm going to draw it a little bit 3D because we've got the channel
00:53:52.280 | dimension as well. So this is going to be the number across this way is going to be equal to
00:54:00.520 | the height times the width. And then the number this way is going to be the number of channels.
00:54:13.800 | Okay, so how do we decide, yeah, which, you know, bring in these other pixels? Well, what we do
00:54:25.560 | is we basically create a weighted average of all of these pixels. So maybe these ones get a bit
00:54:34.680 | of a negative weight and these ones get a bit of a positive weight and, you know, these get a weight
00:54:45.960 | kind of somewhere in between. And so we're going to have a weighted average. And so basically each
00:54:52.520 | pixel, so let's say we're doing this pixel here right now, is going to equal its original pixel
00:54:59.400 | plus, so let's call it x, plus the weighted average. So the sum across, so maybe this is like x, i,
00:55:09.160 | plus the sum of
00:55:13.240 | over all the other pixels. So from zero to the height times the width.
00:55:27.320 | Sum weight times each pixel.
00:55:35.800 | The weights, they're going to sum to one. And so that way the, you know, the pixel value
00:55:52.920 | scale isn't going to change. Well, that's not actually quite true. It's going to end up
00:55:57.640 | potentially twice as big, I guess, because it's being added to the original pixel.
00:56:01.000 | So attention itself is not with the x plus, but the way it's done in stable diffusion,
00:56:10.120 | at least, is that the attention is added to the original pixel. So, yeah, now I think about it.
00:56:17.640 | I'm not going to need to think about how this is being scaled, anyhow. So the big question is what
00:56:27.640 | values to use for the weights. And the way that we calculate those is we do a matrix product.
00:56:41.320 | And so our, for a particular pixel, we've got,
00:56:52.680 | you know, the number of channels for that one pixel.
00:57:02.120 | And what we do is we can compare that to all of the number of channels for all the other pixels.
00:57:13.080 | So we've got kind of, this is pixel, let's say x1. And then we've got pixel number x2.
00:57:21.560 | Right, all those channels. We can take the dot product between those two things.
00:57:31.960 | And that will tell us how similar they are. And so one way of doing this would be to say, like,
00:57:40.200 | okay, well, let's take that dot product for every pair of pixels. And that's very easy dot product
00:57:46.600 | do, because that's just what the matrix product is equal to. So if we've got h by w by c and then
00:57:57.960 | multiply it by its transpose, h by w base, sorry, it said transpose and then totally failed to do
00:58:09.400 | transpose, multiply by its transpose, that will give us an h by w by h by w matrix.
00:58:30.440 | So each pixel, all the pixels are down here. And for each pixel, as long as these add up to one,
00:58:38.920 | then we've got to wait for each pixel. And it's easy to make these add up to one. We could just
00:58:43.880 | take this matrix multiplication and take the sigmoid over the last dimension. And that makes,
00:58:54.760 | sorry, not sigmoid. Man, what's wrong with me? Softmax, right? Yep. And take the softmax over
00:59:04.360 | the last dimension. And that will give me something that adds the sum equals one.
00:59:12.200 | Okay. Now, the thing is, it's not just that we want to find the places where they look the same,
00:59:21.960 | where the channels are basically the same, but we want to find the places where they're, like,
00:59:27.160 | similar in some particular way, you know? And so some particular set of channels are similar in one
00:59:34.200 | to some different set of channels in another. And so, you know, in this case, we may be looking for
00:59:39.240 | the pointy-earedness activations, you know, which actually represented by, you know, this, this,
00:59:47.640 | and this, you know, and we want to just find those. So the way we do that is before we do this
00:59:54.360 | matrix product, we first put our matrix through a projection.
01:00:03.000 | So we just basically put our matrix through a matrix multiplication, this one. So it's the
01:00:12.680 | same matrix, right? But we put it through two different projections. And so that lets it pick
01:00:17.960 | two different kind of sets of channels to focus on or not focus on before it decides, you know,
01:00:24.280 | of this pixel, similar to this pixel in the way we care about. And then actually, we don't even
01:00:29.640 | just multiply it then by the original pixels. We also put that through a different projection as
01:00:36.120 | well. So there's these different projections. Well, then projection one, projection two,
01:00:40.760 | and projection three. And that gives it the ability to say, like, oh, I want to compare
01:00:45.720 | these channels and, you know, these channels to these channels to find similarity. And based on
01:00:51.160 | similarity, yeah, they want to pick out these channels, right? Both positive and negative
01:00:56.280 | weight. So that's why there's these three different projections. And so the projections are called
01:01:02.760 | A, Q, and V. Those are the projections. And so they're all being passed the same
01:01:14.440 | matrix. And because they're all being passed the same matrix, we call this self-attention.
01:01:21.400 | Okay, Jono, Tindish, I know this is, I know you guys know this very well, but you also
01:01:27.960 | know it's really confusing. Did you have anything to add? Change? Anything else?
01:01:33.960 | Yeah, I like that you introduced this without resorting to the,
01:01:39.000 | let's think of this as queries at all, which I think is, yeah.
01:01:44.120 | Yeah, these are actually short for key, query, and value, even though I personally don't find those
01:01:54.520 | useful concepts. Yeah. You'll note on the scaling, you said, oh, so we said it so that the weight's
01:02:01.880 | sum to one. And so then we'd need to worry about like, are we doubling the scale of X? Yeah. But
01:02:08.360 | because of that P3, aka V, that projection that can learn to scale this thing that's added to X
01:02:19.960 | appropriately. And so it's not like just doubling the size of X, it's increasing it a little bit,
01:02:24.200 | which is why we scatter normalization in between all of these attention layers.
01:02:28.680 | But it's not as bad as it might be because we have that V projection.
01:02:36.280 | Yeah, that's a good point. And if this is, if P3, or it's actually the V make projection, is
01:02:46.040 | initialized such that it would have a mean of zero, then on average it should start out by not
01:02:54.760 | messing with our scale. OK, so yeah, I guess I find it easier to think in terms of code.
01:03:04.440 | So let's look at the code. You know, there's actually not much code.
01:03:07.320 | I think you've got a bit of background noise too, Jono, maybe. Yes, that's much better. Thank you.
01:03:15.480 | So in terms of code, there's, you know, this is one of these things getting everything exactly
01:03:26.840 | right. And it's not just right. I wanted to get it identical to the stable diffusion. So we can say
01:03:31.640 | we've made it identical to stable diffusion. I've actually imported the attention block from
01:03:36.760 | diffusers so we can compare. And it is so nice when you've got an existing version of something
01:03:41.960 | to compare to to make sure you're getting the same results. So we're going to start off by saying,
01:03:50.120 | let's say we've got a 16 by 16 pixel image. And this is some deeper level of activation. So it's
01:03:56.360 | got 32 channels with a batch size of 64. So NCHW. I'm just going to use random numbers for now,
01:04:03.640 | but this has the, you know, reasonable dimensions for an activation inside a batch size 64
01:04:09.880 | CNN or diffusion model or unit, whatever. OK, so the first thing we have to do
01:04:15.640 | is to flatten these out because, as I said, in 1D attention, this is just ignored.
01:04:25.400 | So it's easy to flatten things out. You just say dot view and you pass in the dimensions of the,
01:04:31.560 | in this case, the three dimensions we want, which is 6432 and everything else. Minus one means
01:04:38.040 | everything else. So x dot shape colon two. In this case, you know, obviously it'd be easy just to
01:04:43.400 | type 6432, but I'm trying to create something that I can paste into a function later. So it's general.
01:04:48.840 | So that's the first two elements, 6432. And then the star just inserts them directly in here. So
01:04:54.760 | 6432 minus one. So 16 by 16. Now then, again, because this is all stolen from the NLP world,
01:05:04.360 | in the NLP world, things are, have, they call this sequence. So I'm going to call this sequence by
01:05:12.760 | which we're in height by width. Sequence comes before channel, which is often called D or dimension.
01:05:19.160 | So we then transpose those last two dimensions. So we've now got batch by sequence, 16 by 16,
01:05:27.800 | by channel or dimension. So N, they didn't really call this NSD sequence dimension.
01:05:38.120 | Okay, so we've got 32 channels. So we now need three different projections that go from 32
01:05:48.680 | channels in to 32 channels out. So that's just a linear layer. Okay, and just remember a linear
01:05:54.120 | layer is just a matrix multiply plus a bias. So there's three of them. And so they're all going
01:06:01.560 | to be randomly initialized at different random numbers. We're going to call them SK, SQ, SV.
01:06:08.040 | And so we can then, they're just callable. So we can then pass the exact same thing into three,
01:06:13.320 | all three, because we're doing self-attention to get back our keys, queries, and values,
01:06:18.200 | or K, Q, and V. I just think of them as K, Q, and V, because they're not really
01:06:22.600 | keys, queries, and values to me. So then we have to do the matrix multiply by the transpose.
01:06:30.600 | And so then for every one of the 64 items in the batch, for every one of the 256 pixels,
01:06:40.200 | there are now 256 weights. So at least there would be if we had done softmax, which we haven't yet.
01:06:45.720 | So we can now put that into a self-attention. As Johnno mentioned, we want to make sure that
01:06:52.120 | we normalize things. So we can proper normalization here. We talked about group norm back when we
01:06:57.960 | talked about batch norm. So group norm is just batch norm, which has been split into a bunch of
01:07:04.200 | sets of channels. Okay, so then we are going to create our K, Q, V. Yep, Johnno?
01:07:18.280 | I was just going to ask, should those be just bias equals false so that they're only a matrix
01:07:23.480 | multiplied to strictly match the traditional implementation?
01:07:28.840 | No, because...
01:07:35.160 | Okay, they also do it that way.
01:07:38.520 | Yeah, they have bias in their attention blocks.
01:07:46.600 | Cool.
01:07:52.200 | Okay, so we've got our QK and V, self.q, self.k, self.v being our projections.
01:07:58.200 | And so to do 2D self-attention, we need to find the NCHW from our shape. We can do a normalization.
01:08:09.800 | We then do our flattening as discussed. We then transpose the last two dimensions. We then create
01:08:17.480 | our QKV by doing the projections. And we then do the matrix multiply. Now, we've got to be a bit
01:08:24.360 | careful now because as a result of that matrix multiply, we've changed the scale by multiplying
01:08:30.520 | and adding all those things together. So if we then simply divide by the square root of the number
01:08:38.360 | of filters, it turns out that you can convince yourself of this if you wish to, but that's going
01:08:43.000 | to return it to the original scale. We can now do the softmax across the last dimension, and then
01:08:52.600 | multiply each of them by V. So using matrix multiply to do them all in one go.
01:08:58.280 | We didn't mention, but we then do one final projection. Again, just to give it the opportunity
01:09:05.880 | to map things to some different scale. Shift it also if necessary. Transpose the last two back
01:09:15.880 | to where they started from, and then reshape it back to where it started from, and then add it.
01:09:20.760 | Remember, I said it's going to be X plus. Add it back to the original. So this is actually kind of
01:09:25.320 | self-attention ResNet style, if you like. Diffuses, if I remember correctly, does include the X plus
01:09:34.200 | in theirs, but some implementations, like, for example, PyTorch implementation doesn't.
01:09:39.560 | Okay, so that's a self-attention module, and all you need to do is tell it how many channels to do
01:09:45.800 | attention on. And you need to tell it that because that's what we need for our four different
01:09:52.920 | projections and our group and our scale. I guess, strictly speaking, it doesn't have to be stored
01:10:00.120 | here. You could calculate it here, but anyway, either way is fine. Okay, so if we create a
01:10:06.440 | self-attention layer, we can then call it on our little randomly generated numbers.
01:10:14.120 | And it doesn't change the shape because we transpose it back and reshape it back,
01:10:21.160 | but we can see that's basically worked. We can see it creates some numbers. How do we know if
01:10:25.160 | they're right? Well, we could create a diffuser's attention block. That will randomly generate
01:10:31.800 | a QKV projection. Sorry, actually they call something else. They call it a query, key,
01:10:40.040 | value, projection, attention, and group norm. We call it QKVprogen norm. They're the same things.
01:10:46.840 | And so then we can just zip those tuples together. So that's going to take each pair,
01:10:53.480 | first pair, second pair, third pair, and copy the weight and the bias from their attention block.
01:11:02.680 | Sorry, from our attention block to the diffuser's attention block. And then we can check that they
01:11:10.520 | give the same value, which you can see they do. So this shows us that our attention block is the same
01:11:15.480 | as the diffuser's attention block, which is nice. Here's a trick which neither diffusers nor PyTorch
01:11:27.880 | use for reasons I don't understand, which is that we don't actually need three separate projections
01:11:34.200 | here. We could create one projection from Ni to Ni times three. That's basically doing three
01:11:40.920 | projections. So we could call this QKV. And so that gives us 64 by 256 by 96 instead of
01:11:49.720 | 64 by 256 by 32, because it's the three sets. And then we can use chunk, which we saw earlier,
01:11:59.560 | to split that into three separate variables along the last dimension to get us our QKV.
01:12:07.240 | And we can then do the same thing, Q at Q dot transpose, et cetera. So here's another version
01:12:12.040 | of attention where we just have one projection for QKV, and we chunkify it into separate QK and V.
01:12:22.360 | And this does the same thing. It's just a bit more concise. And it should be faster as well,
01:12:30.040 | at least if you're not using some kind of XLA compiler or ONX or Triton or whatever,
01:12:36.680 | for normal PyTorch. This should be faster because it's doing less back and forth
01:12:40.840 | between the CPU and the GPU. All right. So that's basic self-attention. This is not what's done
01:12:56.920 | basically ever, however, because, in fact, the question of which pixels do I care about
01:13:07.720 | depends on which channels you're referring to. Because the ones which are about, oh,
01:13:16.840 | what color is its ear, as opposed to how pointy is its ear, might depend more on is this bunny
01:13:26.280 | in the shade or in the sun. And so maybe you may want to look at its body over here to decide what
01:13:33.640 | color to make them rather than how pointy to make it. And so, yeah, different channels need to bring
01:13:44.040 | in information from different parts of the picture depending on which channel we're talking about.
01:13:51.080 | And so the way we do that is with multi-headed attention. And multi-headed attention actually
01:13:57.800 | turns out to be really simple. And conceptually, it's also really simple. What we do is we say,
01:14:05.800 | let's come back to when we look at C here and let's split them into four separate
01:14:20.680 | vectors. One, two, three, four. Let's split them, right? And let's do the whole dot product thing
01:14:36.920 | on just the first part with the first part. And then do the whole dot product part with the second
01:14:46.840 | part with the second part and so forth, right? So we're just going to do it separately,
01:14:53.800 | separate matrix multiplies for different groups of channels.
01:14:58.200 | And the reason we do that is it then allows, yeah, different parts, different sets of channels to pull
01:15:11.000 | in different parts of the image. And so these different groups are called heads. And I don't
01:15:25.080 | know why, but they are. Does that seem reasonable? Anything to add to that?
01:15:34.040 | It's maybe worth thinking about why, with just a single head, specifically the softmax starts to
01:15:41.160 | come into play. Because, you know, we said it's like a weighted sum, just able to bring in information
01:15:46.760 | from different parts and whatever else. But with softmax, what tends to happen is whatever
01:15:52.040 | weight is highest gets scaled up quite dramatically. And so it's like almost like focused on just that
01:15:58.200 | one thing. And then, yeah, like, as you said, Jeremy, like different channels might want to refer to
01:16:03.160 | different things. And, you know, just having this one like single weight that's across all the
01:16:08.840 | channels means that that signal is going to be like focused on maybe only one or two things as
01:16:14.040 | opposed to being able to bring in lots of different kinds of information based on the different
01:16:18.040 | channels. Right. I was going to measure the same thing, actually. That's a good point. So you're
01:16:26.600 | mentioning the second interesting important point about softmax, you know, point one is that it
01:16:32.520 | creates something that adds to one. But point two is that because of its e to the z, it tends to
01:16:38.040 | highlight one thing very strongly. And yes, so if we had single-headed attention, your point,
01:16:45.240 | guys, I guess, is that you're saying it would end up basically picking nearly all one pixel,
01:16:50.840 | which would not be very interesting. OK, awesome. Oh, I see where everything's got thick. I've
01:16:59.720 | accidentally turned it into a marker. Right. OK, so multi-headed attention. I'll come back to
01:17:12.440 | the details of how it's implemented in terms of, but I'm just going to mention the basic idea.
01:17:18.200 | This is multi-headed attention. And this is identical to before, except I've just stored
01:17:24.680 | one more thing, which is how many heads do you want. And then the forward is actually nearly
01:17:33.240 | all the same. So this is identical, identical, identical. This is new. Identical, identical,
01:17:42.840 | identical, new, identical, identical. So there's just two new lines of code, which might be
01:17:50.280 | surprising, but that's all we needed to make this work. And they're also pretty wacky, interesting
01:17:55.560 | new lines of code to look at. Conceptually, what these two lines of code do is they first,
01:18:02.760 | they do the projection, right? And then they basically take the number of heads.
01:18:21.320 | So we're going to do four heads. We've got 32 channels, four heads. So each head is going to
01:18:25.800 | contain eight channels. And they basically grab, they're going to, we're going to keep it as being
01:18:34.520 | eight channels, not as 32 channels. And we're going to make each batch four times bigger, right?
01:18:41.800 | Because the images in a batch don't combine with each other at all. They're totally separate.
01:18:50.120 | So instead of having one image containing 32 channels, we're going to turn that into four
01:18:59.800 | images containing eight channels. And that's actually all we need, right? Because remember,
01:19:06.040 | I told you that each group of channels, each head, we want to have nothing to do with each other.
01:19:13.080 | So if we literally turn them into different images, then they can't have anything to do
01:19:18.680 | with each other because batches don't react to each other at all. So these rearrange,
01:19:25.400 | this rearrange, and I'll explain how this works in a moment, but it's basically saying,
01:19:30.360 | think of the channel dimension as being of H groups of D and rearrange it. So instead,
01:19:39.640 | the batch channel is n groups of H and the channels is now just D. So that would be eight
01:19:47.720 | instead of four by eight. And then we do everything else exactly the same way as usual,
01:19:53.240 | but now that group, that the channels are split into groups of H, groups of four. And then after
01:20:01.160 | that, okay, well, we were thinking of the batches as being of size n by H. Let's now think of the
01:20:07.080 | channels as being of size H by D. That's what these rearranges do. So let me explain how these
01:20:14.440 | work. In the diffusers code, I've, can't remember if I duplicated it or just inspired by it.
01:20:22.520 | They've got things called heads to batch and batch to heads, which do exactly these things.
01:20:26.520 | And so for heads to batch, they say, okay, you've got 64 per batch by 256 pixels by 32 channels.
01:20:39.880 | Okay, let's reshape it. So you've got 64 images by 256 pixels by four heads by the rest.
01:20:54.920 | So that would be 32 over eight channels. So it's split it out into a separate dimension.
01:21:09.880 | And then if we transpose these two dimensions, it'll then be n by four. So n by heads by SL by
01:21:18.120 | minus one. And so then we can reshape. So those first two dimensions get combined into one.
01:21:24.280 | So that's what heads to batch does. And batch to heads does the exact opposite, right? Reshapes
01:21:30.680 | to bring the batch back to here and then heads by SL by D and then transpose it back again and
01:21:36.680 | reshape it back again so that the heads gets it. So this is kind of how to do it using just
01:21:44.360 | traditional PyTorch methods that we've seen before. But I wanted to show you guys this new-ish
01:21:52.440 | library called Inops, inspired as it suggests by Einstein summation notation. But it's absolutely
01:21:58.760 | not Einstein summation notation. It's something different. And the main thing it has is this
01:22:02.920 | thing called rearrange. And rearrange is kind of like a nifty rethinking of Einstein summation
01:22:10.600 | notation as a tensor rearrangement notation. And so we've got a tensor called t we created earlier,
01:22:21.240 | 64 by 256 by 32. And what Inops rearrange does is you pass it this specification string that says,
01:22:31.880 | turn this into this. Okay, this says that I have a rank three tensor, three dimensions, three axes,
01:22:47.640 | containing the first dimension is of length n, the second dimension is of length s,
01:22:57.560 | the third dimension is in parentheses is of length h times d, where h is eight.
01:23:04.040 | Okay, and then I want you to just move things around so that nothing is like broken, you know,
01:23:14.120 | so everything's shifted correctly into the right spots so that we now have each batch
01:23:23.000 | is now instead n times eight, n times h. The sequence length is the same,
01:23:30.040 | and d is now the number of channels. Previously the number of channels was h by d. Now it's d,
01:23:36.360 | so the number of channels has been reduced by a factor of eight. And you can see it here,
01:23:40.600 | it's turned t from something of 64 by 256 by 32 into something of size 64 times eight by 256
01:23:51.480 | by 32 divided by eight. And so this is like really nice because, you know, a, this one line of code
01:24:01.720 | to me is clearer and easier and I liked writing it better than these lines of code. But whereas
01:24:08.040 | particularly nice is when I had to go the opposite direction, I literally took this, cut it,
01:24:15.800 | put it here and put the arrow in the middle. Like it's literally backwards, which is really nice,
01:24:22.040 | right? Because we're just rearranging it in the other order. And so if we rearrange in the other
01:24:27.160 | order, we take our 512 by 256 by 4 thing that we just created and end up with a 64 by 256 by 32
01:24:34.040 | thing, which we started with, and we can confirm that the end thing equals, or every element equals
01:24:41.800 | the first thing. So that shows me that my rearrangement has returned its original correctly.
01:24:48.760 | Yeah, so multi-headed attention, I've already shown you. It's the same thing as before,
01:24:54.200 | but pulling everything out into the batch for each head and then pulling the heads back into
01:25:01.000 | the channels. So we can do multi-headed attention with 32 channels and four heads
01:25:09.640 | and check that all looks okay. So PyTorch has that all built in. It's called nn.multi_headed_attention.
01:25:17.720 | Be very careful. Be more careful than me, in fact, because I keep forgetting that it actually expects
01:25:25.880 | the batch to be the second dimension. So make sure you write batch first equals true
01:25:32.760 | to make batch the first dimension and that way it'll be the same as
01:25:38.040 | diffusers. I mean, it might not be identical, but the same. It should be almost the same
01:25:43.880 | idea. And to make it self-attention, you've got to pass in three things, right? So the three things
01:25:51.160 | will all be the same for self-attention. This is the thing that's going to be passed through the
01:25:55.720 | Q projection, the K projection and the V projection. And you can pass different things to those.
01:26:03.640 | If you pass different things to those, you'll get something called cross-attention rather than
01:26:10.120 | self-attention, which I'm not sure we're going to talk about until we do it in NLP.
01:26:15.080 | Just on the rearrange thing, I know that if you've been doing PyTorch and you used to,
01:26:24.040 | like, you really know what transpose and, you know, reshape and whatever do, then it can be a
01:26:29.480 | little bit weird to see this new notation. But once you get into it, it's really, really nice.
01:26:33.400 | And if you look at the self-attention multi-headed implementation there, you've got
01:26:37.000 | dot view and dot transpose and dot reshape. It's quite fun practice. Like, if you're just saying,
01:26:43.000 | oh, this INOPS thing looks really useful, like, take an existing implementation like this and say,
01:26:47.880 | oh, maybe instead of, like, can I do it instead of dot reshape or whatever, can I start replacing
01:26:53.000 | these individual operations with the equivalent, like, rearrange call? And then checking at the
01:26:58.440 | output to the same, like, that's what helped it, like, click for me was, oh, okay. Like,
01:27:03.240 | I can start to express, if it's just transpose, then that's a rearrange with the last two channels.
01:27:08.840 | Yeah. I only just started using this. And I've obviously had many years of using
01:27:16.200 | reshape transpose, et cetera, in Theano, TensorFlow, Keras, PyTorch, APL. And I would say within
01:27:27.640 | 10 minutes, I was like, oh, I like this much better. You know, like, it's
01:27:32.040 | fine for me at least. It didn't take too long to be convinced. It's not part of PyTorch or anything.
01:27:40.040 | You've got to pip install it, by the way. And it seems to be becoming super popular now,
01:27:47.240 | at least in the kind of diffusion research crowd. Everybody seems to be using INOPS suddenly,
01:27:53.800 | even though it's been around for a few years. And I actually put in an issue there and asked them
01:28:00.040 | to add in Einstein summation notation as well, which they've now done. So it's kind of like your
01:28:05.560 | one place for everything, which is great. And it also works across
01:28:08.600 | TensorFlow and other libraries as well, which is nice. Okay. So we can now add that to our
01:28:21.800 | unit. So this is basically a copy of the previous notebook, except what I've now done is I did this
01:28:28.920 | at the point where it's like, oh, yeah, it turns out that cosine scheduling is better. So I'm back
01:28:34.520 | to cosine schedule now. This is copied from the cosine schedule book. And we're still doing the
01:28:39.400 | minus 0.5 thing because we love it. And so this time, I actually decided to export stuff into a
01:28:47.960 | mini-AI.diffusion. So this point, I still think things are working pretty well. And so I renamed
01:28:54.200 | unit.com to pre-con, since it's a better name. Time step embedding has been exported. Up sample's
01:29:00.440 | been exported. This is like a pre-act linear version exported. I tried using an n.multihead
01:29:11.160 | attention, and it didn't work very well for some reason. So I haven't figured out why that is yet.
01:29:17.720 | So I'm using, yeah, this self-attention, which we just talked about. Multiheaded self-attention.
01:29:25.800 | You know, just the scale, we have to divide the number of channels by the number of heads
01:29:33.640 | because the effective number of heads is, you know, divided across n heads.
01:29:44.120 | And instead of specifying n heads, yeah, you specify attention channels.
01:29:47.480 | So if you have like 32, n_I is 32, attention channels is 8, then you calculate.
01:29:51.640 | Yeah, that's what diffusers does, I think. It's not what an n.multihead attention does.
01:29:56.920 | And actually, I think n_I divided by n_I divided by attention chance is actually just equal to
01:30:03.720 | attention chance. So I could have just put that probably. Anyway, never mind. Yeah.
01:30:14.040 | So okay, so that's all copied in from the previous one. The only thing that's different here is I
01:30:19.880 | haven't got the dot view minus one thing here. So this is a 1D self-attention, and then 2D
01:30:28.600 | self-attention just adds the dot view before we call forward and then dot reshape it back again.
01:30:41.960 | So yeah, so we've got 1D and 2D self-attention. Okay, so now our MRes block has one extra thing
01:30:48.360 | you can pass in, which is attention channels. And so if you pass in attention channels, we're going
01:30:54.760 | to create something called self.attention, which is a self-attention 2D layer with the right number
01:31:00.360 | of filters and the requested number of channels. And so this is all identical to what we've seen
01:31:08.040 | before, except if we've got attention, then we add it. Oh yeah, and the attention that I did here is
01:31:15.080 | the non-res-netty version. So we have to do x plus because that's more flexible. You can then choose
01:31:21.880 | to have it or not have it this way. Okay, so that's an MRes block with attention.
01:31:28.520 | And so now our down block, you have to tell it how many attention channels you want,
01:31:35.560 | because the res blocks need that. The up block, you have to know how many attention channels you
01:31:41.160 | want, because again the res blocks need that. And so now the unit model, where does the attention
01:31:47.960 | go? Okay, we have to say how many attention channels you want. And then you say which index
01:31:54.920 | block do you start adding attention? So why don't we, so then what happens is the attention is done
01:32:05.400 | here. Each res-net has attention. And so as we discussed, you just do the normal res and then
01:32:16.440 | the attention, right? And if you
01:32:19.720 | put that in at the very start, right, let's say you've got a 256 by 256 image.
01:32:34.840 | Then you're going to end up with this matrix here. It's going to be 256 by 256 on one side
01:32:48.360 | and 256 by 256 on the other side and contain however many, you know, NF channels. That's huge.
01:33:03.640 | And you have to back prop through it. So you have to store all that to allow back prop to happen.
01:33:09.240 | It's going to explode your memory. So what happens is basically nobody puts attention
01:33:16.360 | in the first layers. So that's why I've added a attention start, which is like at which block
01:33:25.640 | do we start adding attention and it's not zero for the reason we just discussed. Another way you
01:33:34.760 | could do this is to say like at what grid size should you start adding attention? And so generally
01:33:40.440 | speaking, people say when you get to 16 by 16, that's a good time to start adding attention.
01:33:47.400 | Although stable diffusion adds it at 32 by 32 because remember they're using latents,
01:33:54.680 | which we'll see very shortly I guess in the next lesson. So it starts at 64 by 64 and then they
01:34:00.760 | add attention at 32 by 32. So we're again, we're replicating stable diffusion here. Stable diffusion
01:34:05.800 | uses attention start at index one. So we, you know, when we go self.down, dot append, the down block
01:34:13.160 | has zero attention channels if we're not up to that block yet. And ditto on the up block,
01:34:23.480 | except we have to count from the end blocks.
01:34:26.040 | Now I think about it, that should have attention as well, the mid block. So that's missing.
01:34:39.480 | Yeah, so the forward actually doesn't change at all for attention. It's only the in it.
01:34:50.040 | Yeah, so we can train that. And so previously, yeah, we got
01:34:55.080 | without attention, we got to 137. And with attention,
01:35:04.680 | oh, we can't compare directly because we've changed from Keras to cosine.
01:35:13.560 | We can compare the sampling though. So we're getting, what are we getting? 4, 5, 5, 5.
01:35:21.320 | It's very hard to tell if it's any better or not because,
01:35:27.160 | well, again, you know, our cosine schedule is better. But yeah, when I've done kind of direct
01:35:34.520 | like with like, I haven't managed to find any obvious improvements from adding attention.
01:35:41.560 | But I mean, it's doing fine, you know, 4 is great. Yeah. All right. So then finally,
01:35:50.440 | did you guys want to add anything before we go into a conditional model?
01:35:53.160 | I was just going to make a note that, like, I guess, just to clarify,
01:35:58.120 | with the attention, part of the motivation was certainly to do the sort of spatial mixing and
01:36:04.120 | kind of like, yeah, to get from different parts of the image and mix it. But then the problem is,
01:36:09.800 | if it's too early, where you do have one of, you know, the more individual pixels,
01:36:14.360 | then the memory is very high. So it seems like you have to get that balance of where you don't,
01:36:21.400 | you kind of want it to be early. So you can do some of that mixing, but you don't want to be too
01:36:25.080 | early, where then the memory usage is, is too high. So it seems like there is certainly kind
01:36:30.680 | of the balance of trying to find maybe that right place where to add attention into your network.
01:36:36.200 | So I just thought I was just thinking about that. And maybe that's a point worth noting.
01:36:40.680 | Yeah, for sure.
01:36:41.560 | There is a trick, which is like what they do in, for example, vision transformers,
01:36:46.280 | or the DIT, the diffusion with transformers, which is that if you take like an eight by
01:36:54.920 | eight patch of the image, and you flatten that all out, or you run that through some
01:36:59.960 | like convolutional thing to turn it into a one by one by some larger number of channels,
01:37:05.000 | but you can reduce the spatial dimension by increasing the number of channels. And that
01:37:10.760 | gets you down to like a manageable size where you can then start doing attention as well.
01:37:14.360 | So that's another trick is like patching, where you take a patch of the image and you focus
01:37:18.680 | on that as some number, like some embedding dimension or whatever you like to think of it,
01:37:23.560 | that as a one by one rather than an eight by eight or a 16 by 16. And so that's how,
01:37:29.000 | like you'll see, you know, 32 by 32 patch models, like some of the smaller clip models or 14 by 14
01:37:35.800 | patch for some of the larger like DIT classification models and things like that.
01:37:40.360 | So that's another Yeah, I guess that's the yeah, that's mainly used when you have like a full
01:37:45.800 | transformer network, I guess. And then this is one where we have that sort of incorporating the
01:37:51.240 | attention into convolutional network. So there's certainly, I guess, yeah, for different sorts of
01:37:56.840 | networks, different tricks. But yeah. Yeah. And I haven't decided yet if we're going to look at
01:38:06.200 | the IT or not. Maybe we should based on what you're describing. I was just going to mention,
01:38:11.080 | though, that since you mentioned transformers, we've actually now got everything we need to
01:38:18.600 | create a transformer. Here's a transformer block with embeddings. A transformer block with embeddings
01:38:25.160 | is exactly the same embeddings that we've seen before. And then we add attention, as we've seen
01:38:31.000 | before, there's a scale and shift. And then we pass it through an MLP, which is just a linear layer,
01:38:40.680 | an activation, a normalize, and a linear layer. For whatever reason, this is, you know,
01:38:46.760 | GALU, which is just another activation function, is what people always use in transformers.
01:38:54.120 | For reasons I suspect don't quite make sense in vision, everybody uses layer norm. And again,
01:38:58.440 | I was just trying to replicate an existing paper. But this is just a standard MLP. So if you do,
01:39:03.160 | so in fact, if we get rid of the embeddings, just to show you a true pure transformer.
01:39:17.960 | Okay, here's a pure transformer block. So it's just normalize, attention, add, normalize,
01:39:25.240 | multi-layer perceptron, add. That's all a transformer block is. And then what's a
01:39:30.200 | transformer network? A transformer network is a sequential of transformers. And so in this
01:39:35.960 | diffusion model, I replaced my mid block with a list of sequential transformer blocks.
01:39:46.280 | So that is a transformer network. And to prove it, this is another version in which I
01:39:56.280 | replaced that entire thing with the PyTorch transformers encoder. This is called encoder.
01:40:03.960 | This is just taken from PyTorch. And so that's the encoder. And I just replaced it with that.
01:40:13.000 | So yeah, we've now built transformers. Now, okay, why aren't we using them right now? And why did
01:40:20.600 | I just say, I'm not even sure if we're going to do VIT, which is vision transformers. The reason is
01:40:25.240 | that transformers, you know, they're doing something very interesting, right? Which is,
01:40:34.680 | remember, we're just doing 1D versions here, right? So transformers are taking something
01:40:44.760 | where we've got a sequence, right? Which in our case is pixels height by width, but we'll just
01:40:50.120 | call it the sequence. And everything in that sequence has a bunch of channels, right, for
01:40:56.520 | dimensions, right? I'm not going to draw them all, but you get the idea. And so for each element of
01:41:05.720 | that sequence, which in our case, it's just some particular pixel, right? And these are just the
01:41:11.480 | filters, channels, activations, whatever, activations, I guess.
01:41:19.080 | What we're doing is the, we first do attention, which, you know, remember there's a projection for
01:41:26.200 | each. So like it's mixing the channels a little bit, but just putting that aside, the main thing
01:41:30.680 | it's doing is each row is getting mixed together, you know, into a weighted average.
01:41:46.760 | And then after we do that, we put the whole thing through a multi-layer perceptron. And what the
01:41:53.080 | multi-layer perceptron does is it entirely looks at each pixel on its own. So let's say this one,
01:42:05.480 | right, and puts that through linear, activation, norm, linear, which we call an MLP.
01:42:16.840 | And so a transformer network is a bunch of transformer layers. So it's basically going
01:42:27.800 | attention, MLP, attention, MLP, attention, et cetera, et cetera, MLP. That's all it's doing.
01:42:37.160 | And so in other words, it's mixing together the pixels or sequences, and then it's mixing
01:42:48.920 | together the channels. Then it's mixing together the sequences and then mixing together the channels
01:42:53.240 | and it's repeating this over and over. Because of the projections being done in the
01:43:00.520 | attention, it's not just mixing the pixels, but it's kind of, it's largely mixing the pixels.
01:43:07.400 | And so this combination is very, very, very flexible. And it's flexible enough
01:43:19.640 | that it provably can actually approximate any convolution that you can think of
01:43:26.040 | given enough layers and enough time and learning the right parameters.
01:43:34.040 | The problem is that for this to approximate a combination requires a lot of data and a lot of
01:43:45.560 | layers and a lot of parameters and a lot of compute. So if you try to use this, so this is a transformer
01:43:54.760 | network, transformer architecture. If you pass images into this, so pass an image in
01:44:04.280 | and try to predict, say from ImageNet, the class of the image. So use SGD to try and
01:44:13.640 | find weights for these attention projections and MLPs. If you do that on ImageNet, you will end up
01:44:21.160 | with something that does indeed predict the class of each image, but it does it poorly.
01:44:25.560 | Now it doesn't do it poorly because it's not capable of approximating a convolution. It
01:44:31.240 | does it poorly because ImageNet, the entire ImageNet as in ImageNet 1k is not big enough
01:44:38.200 | to for a transformer to learn how to do this. However, if you pass it a much bigger data set,
01:44:46.360 | many times larger than ImageNet 1k, then it will learn to approximate this very well. And in fact,
01:44:54.280 | it'll figure out a way of doing something like convolutions that are actually better than
01:44:58.920 | convolutions. And so if you then take that, so that's going to be called a vision transformer
01:45:06.120 | or VIT that's been pre-trained on a data set much bigger than ImageNet, and then you fine tune it
01:45:12.600 | on ImageNet, you will end up with something that is actually better than ResNet. And the reason
01:45:21.880 | it's better than ResNet is because these combinations, right, which together when combined
01:45:32.920 | can approximate a convolution, these transformers, you know, convolutions are our best guess as to
01:45:41.080 | like a good way to kind of represent the calculations we should do on images. But there's
01:45:47.480 | actually much more sophisticated things you could do, you know, if you're a computer and you could
01:45:51.960 | figure these things out better than a human can. And so a VIT actually figures out things that are
01:45:56.920 | even better than convolutions. And so when you fine tune ImageNet using a very, you know, a VIT
01:46:05.080 | that's been pre-trained on lots of data, then that's why it ends up being better than a ResNet. So
01:46:11.560 | that's why, you know, the things I'm showing you are not the things that contain transformers and
01:46:22.200 | diffusion because to make that work would require pre-training on a really, really large data set
01:46:29.880 | for a really, really long amount of time. So anyway, so we might only come to transformers,
01:46:38.040 | well not in a very long time, but when we do them in NLP in vision, maybe we'll cover them briefly,
01:46:47.240 | you know, they're very interesting to use as pre-trained models. The main thing to know about
01:46:52.440 | them is, yeah, a VIT, you know, which is a really successful and when pre-trained on lots of data,
01:46:59.080 | which they all are nowadays, is a very successful architecture. But like literally the VIT paper
01:47:03.960 | says, oh, we wondered what would happen if we take a totally plain 1D transformer, you know,
01:47:11.000 | and convert it and make it work on images with as few changes as possible. So everything we've
01:47:17.800 | learned about attention today and MLPs applies directly because they haven't changed anything.
01:47:25.480 | And so one of the things you might realize that means is that you can't use a VIT that was trained
01:47:34.040 | on 224 by 224 pixel images on 128 by 128 pixel images because, you know, all of these
01:47:41.960 | self-attention things are the wrong size, you know, and specifically the problem is actually the,
01:47:53.560 | actually it's not really the attention, let me take that back. All of the
01:48:02.120 | position embeddings are the wrong size. And so actually that's something I, sorry,
01:48:06.600 | I forgot to mention, is that in transformers the first thing you do is you always take
01:48:15.640 | your, you know, these pixels and you add to them a positional embedding.
01:48:24.840 | And that's done, I mean, it can be done lots of different ways, but the most popular way is
01:48:30.440 | identical to what we did for the time step embedding is the sinusoidal embedding. And so
01:48:36.040 | that's specific, you know, to how many pixels there are in your image. So yeah, that's an example,
01:48:47.320 | it's one of the things that makes VITs a little tricky. Anyway, hopefully, yeah, you get the idea
01:48:52.520 | that we've got all the pieces that we need. Okay, so with that discussion,
01:49:08.280 | I think that's officially taken us over time. So maybe we should do the conditional next time.
01:49:20.040 | Do you know what actually it's tiny? Let's just quickly do it now. You guys got time?
01:49:24.840 | Yeah. Okay. So let's just, yeah, let's finish by doing a conditional model. So for a conditional
01:49:31.640 | model, we're going to basically say I want something where I can say draw me the number,
01:49:38.120 | sorry, draw me a shirt, or draw me some pants, or draw me some sandals. So we're going to pick one
01:49:44.040 | of the 10 fashion MNIST classes and create an image of a particular class. To do that,
01:49:53.960 | we need to know what class each thing is. Now, we already know what class each thing is
01:50:05.080 | because it's the Y label, which way back in the beginning of time, we set, okay,
01:50:14.360 | it's just called the label. So that tells you what category it is.
01:50:19.240 | So we're going to change our collation function. So we call noisify as per usual. That gives us
01:50:29.480 | our noised image, our time step, and our noise. But we're also going to then add to that tuple
01:50:40.600 | what kind of fashion item is this. And so the first tuple will be noised image,
01:50:48.440 | noise, and label, and then the dependent variable as per usual is the noise.
01:50:56.280 | And so what's going to happen now when we call our unit, which is now a conditioned unit model,
01:51:00.920 | is the input is now going to contain not just the activations and the time step,
01:51:08.520 | but it's also going to contain the label. Okay, that label will be a number between zero and nine.
01:51:14.120 | So how do we convert the number between zero and nine into a vector which represents that number?
01:51:21.080 | Well, we know exactly how to do that in n.embedding. Okay, so we did that lots in part one.
01:51:27.320 | So let's make it exactly, you know, the same size as our time embedding. So n number of
01:51:42.920 | activations in the embedding. It's going to be the same as our time step embedding.
01:51:50.920 | And so that's convenient. So now in the forward we do our time step embedding as usual.
01:51:55.240 | We'll pass the labels into our conditioned embedding.
01:52:00.520 | The time embedding we will put through the embedding let LPN before, and then we're just
01:52:06.520 | going to add them together. And that's it, right? So this now represents a combination of the time
01:52:12.680 | and the fashion item plus. And then everything else is identical in both parts. So all we've
01:52:21.400 | added is this one thing. And then we just literally sum it up. So we've now got a joint
01:52:28.360 | embedding representing two things. And then, yeah, and then we train it.
01:52:33.320 | And, you know, interestingly, it looks like the loss, well, it ends up a bit the same,
01:52:41.640 | but you don't often see 0.031. It is a bit easier for it to do a conditional embedding
01:52:48.360 | model because you're telling it what it is. It just makes it a bit easier. So then to do
01:52:53.160 | conditional sampling, you have to pass in what type of thing do you want from these labels.
01:53:06.600 | And so then we create a vector just containing that number repeated over many times there on
01:53:19.080 | the batch. And we pass it to our model. So our model has now learned how to denoise something
01:53:27.160 | of type C. And so now if we say like, oh, trust me, this noise contains is a noised image of type C,
01:53:34.600 | it should hopefully denoise it into something of type C. That's all there is to it. There's no
01:53:41.960 | magic there. So, yeah, that's all we have to do to change the sampling. So we didn't have to change
01:53:48.760 | DDIM step at all, right? Literally all we did was we added this one line of code and we added it
01:53:55.960 | there. So now we can say, okay, let's say class ID zero, which is t-shirt top. So we'll pass that to
01:54:04.280 | sample. And there we go. Well, everything looks like t-shirts and tops. Yeah, okay. I'm glad we
01:54:16.760 | didn't leave that till next time because we can now say we have successfully replicated
01:54:25.320 | everything in stable diffusion, except for being able to create whole sentences,
01:54:32.520 | which is what we do with clip. Getting really close. Yes. Well, except the clip requires
01:54:39.720 | all of NLP. So I guess we'll, we might do that next or depending on how research goes.
01:54:53.160 | All right. We still need a latent diffusion part. Oh, good point. Latents. Okay. We'll definitely
01:55:00.600 | do that next time. So let's see. Yeah. So we'll do a VAE and latent diffusion,
01:55:08.040 | which isn't enough for one lesson. So maybe some of the research I'm doing will end up in the next
01:55:14.600 | lesson as well. Yes. Okay. Thanks for the reminder. Although we've already kind of done auto-encoders,
01:55:22.120 | so VAEs are going to be pretty, pretty easy. Well, thank you, Tanishka and Johnno. Fantastic
01:55:30.280 | comments, as always. Glad your internet/power reappeared, Johnno. Back up.
01:55:42.600 | All right. Thanks, gang. Cool. Thanks, everybody. That was great.