back to index

Stanford CS25: V5 I Transformers in Diffusion Models for Image Generation and Beyond


Whisper Transcript | Transcript Only Page

00:00:00.000 | All right, thank you so much for joining us today.
00:00:04.560 | Today, I'm very honored to welcome Syek Paul from Hugging Face,
00:00:09.720 | who works a lot on diffusion models, image generation, and so forth.
00:00:15.760 | So his day-to-day includes contributing to diffusers, diffusers library,
00:00:21.560 | training and babysitting diffusion models, and working on applied ideas.
00:00:26.560 | He's interested in subject-driven generation, preference alignment,
00:00:30.560 | and evaluation of diffusion models.
00:00:33.040 | And when he's not working, he can be found playing the guitar
00:00:36.800 | and binge-watching ICML tutorials and suits.
00:00:40.560 | So without further ado, I'll hand it off to Syek.
00:00:43.920 | I guess I'm here to depart and deviate from the usual theme that's followed at CS25.
00:00:50.040 | It's not a lot about visual modality, especially the generative aspects of them.
00:00:55.280 | So I guess I'm happy in a way that I'm here to depart and deviate from that theme.
00:00:59.040 | So, yeah.
00:01:01.760 | I wanted to start with a couple of disclaimers.
00:01:05.760 | And this talk is definitely not going to be an exhaustive overview of all the possible methods.
00:01:11.760 | So it might be the case that I didn't cover your work, which you thought to be very seminal.
00:01:18.800 | So I apologize in advance.
00:01:20.880 | And then I'm not going to cover what is diffusion or flow matching in details,
00:01:25.280 | but I'll give a very quick overview just to set the context and tone.
00:01:29.840 | And then the architectures I'll discuss today will be fairly agnostic to diffusion or flow matching.
00:01:37.680 | And since I work on image and videos primarily, I will take my examples with images.
00:01:42.800 | But just know that these architectures are fairly well known to generalize to other continuous modalities such as audios.
00:01:51.600 | And then, of course, I'll share my slide after the talk.
00:01:57.600 | And here's how I want to approach this talk.
00:02:00.560 | This is the rough overview of all the things that I want to cover.
00:02:04.000 | As I mentioned, I'll give you a brief introduction to diffusion models or flow matching as well.
00:02:10.720 | And then I'll try to set the context by discussing the early architectures for diffusion,
00:02:17.040 | the early architectures for image generation in this field.
00:02:20.160 | And then we'll head straight to DITS and, as I like to call, their friends.
00:02:26.000 | And I'll conclude with some thoughts.
00:02:28.320 | And in those thoughts, I'll discuss some of the promising directions that I've become lately interested in.
00:02:34.160 | And then I'm sure there will be time for Q&A.
00:02:36.320 | I really wanted to kind of fascinate you with all the cool examples in the text-to-image arena.
00:02:45.920 | But I guess at this point in time, we all know these examples.
00:02:49.680 | Because I think, like, these days we are becoming more and more interested in native multimodal generation, not just images.
00:02:58.720 | But nonetheless, I just wanted to start off by giving you a couple of examples from the text-to-image arena.
00:03:05.520 | My favorite is the last one, "tiny astronaut hatching from an egg on the moon."
00:03:10.000 | I think human imagination really took off there.
00:03:14.560 | But yeah, you can see, apart from DALI-3, all these are open models.
00:03:18.880 | And the photorealism aspects of these models are quite impressive.
00:03:23.840 | So, yeah.
00:03:24.640 | And then I want to also give you an infographic of how I like to think about diffusion models in general.
00:03:34.880 | So, I like to think about diffusion models as the following.
00:03:37.680 | So, when you take a random noise vector and what happens as you denoise it over a period of time
00:03:44.160 | so that it becomes a photorealistic image.
00:03:46.640 | And if you take a look closely, you will notice that it's an iterative process, unlike GANs.
00:03:53.680 | GANs are one-shot in nature, but diffusion models are iterative.
00:03:56.960 | It's sequential in nature.
00:03:58.560 | So, in a way, we are essentially denoising the random noise that we had started off with
00:04:04.000 | until and unless we are done with the kind of image that we are looking for.
00:04:09.680 | So, just know that it's iterative in nature.
00:04:11.440 | That's the main takeaway from this slide.
00:04:13.520 | And when you condition, when you start conditioning the denoising process with text,
00:04:19.520 | for example, you can condition in lots of ways.
00:04:21.760 | But let's say, I think, text is one of the more liberating conditions.
00:04:26.320 | As you start conditioning the denoising process with text, you get abstract creatures like this.
00:04:32.080 | But yeah, you will start feeling very liberated with text images, I like to believe.
00:04:41.280 | And then, let's take a step back and start developing a mental model of what it takes to have
00:04:47.440 | a fairly state-of-the-art text-to-image model.
00:04:51.440 | Like, let's say I want to get from this text prompt a cat looking like a tiger and from all the way up to the
00:04:56.960 | image on the top.
00:04:59.040 | What does it take?
00:05:00.560 | What are the components that we are looking for?
00:05:04.160 | How should they be connected?
00:05:05.440 | So, I want to kind of give you a connected graph of how the different components should be
00:05:12.800 | connected so that we can backtrack them and start developing more intuition.
00:05:18.560 | So, of course, when you have text, you need to have some sort of embedding so that you can sort of work
00:05:23.840 | with those embeddings and we have text encoders for them.
00:05:26.320 | And like state-of-the-art diffusion models, they usually rely on more than one text encoders.
00:05:31.680 | For example, stable diffusion three, it relies on three text encoders, not one, not two, but three.
00:05:37.120 | So, you have your text prompt, you pass it off to text encoders, and you then have the embeddings.
00:05:43.440 | And as we saw in the earlier slide, we are starting off with some random noise drawn from a Gaussian
00:05:49.040 | distribution. So, we are starting with some noisy latents and then you have your text embeddings.
00:05:55.200 | And then you have some time step which you do the math with the scheduler.
00:06:02.560 | I'm going to come to the scheduler component in a minute, but you have your conditions at this point
00:06:07.520 | in time. You have your text embeddings, you have your noisy latents, and you have some time step.
00:06:11.760 | And then you have your core diffusion network, and you pass all these inputs to your core diffusion
00:06:22.240 | network. And it's a sequential process. You run the diffusion network over a period of time as we
00:06:29.200 | saw in the earlier slide. And then it's going to give you some refined latent.
00:06:33.200 | And then you give it to some decoder model, and you have your image out. Now, I would like to call
00:06:40.080 | out two broad classes of diffusion models. This overview resembles the pixels, the latent space
00:06:46.400 | diffusion models. But there's another class of diffusion models, which is called the pixel space
00:06:50.080 | diffusion model. But the recent or the state of the diffusion models, they are all latent space based.
00:06:55.440 | Because pixel space is quite prohibitive and computationally intensive in nature. That's why
00:07:01.200 | it's more common to see latent space diffusion models. So that's why you see noisy latents there,
00:07:08.320 | and not raw pixels. And if you were to sort of generalize this to other continuous modalities,
00:07:14.560 | the basic idea here is how do you represent the raw modality data points, and how do you compute
00:07:21.840 | intermediate representations of them? So if you were to do this on the audio space, you can think of
00:07:27.200 | some similar analogies. But long story cut short, for a fairly well-performing text-to-image system,
00:07:34.240 | these are roughly the components that you need. So yeah.
00:07:37.120 | And then, let's now start developing some notations, so that we can start discussing about
00:07:46.080 | how these models are trained, and finally how you should perform inference with these models.
00:07:50.720 | So again, I'm giving these examples with images, but these are fairly agnostic, modality agnostic.
00:08:00.720 | Apart from text, of course, these won't work on discrete tokens. So let's say I have some original image,
00:08:06.800 | and I'm drawing some noise from a standard Gaussian distribution. And then, let's say I'm also drawing
00:08:12.400 | some time step from a uniform distribution, and then I have a particular noise schedule, and also some terms
00:08:19.360 | that are controlling the noise schedule. And then, let's say I have some conditioning vector, and by
00:08:24.240 | conditioning vector, I essentially mean some text embeddings or some other form of structural
00:08:29.680 | controls, like let's say depth maps, segmentation maps, and so on. And then, you have your diffusion model
00:08:36.400 | that you will learn. And during training, what we basically do is, we compute some intermediate
00:08:44.240 | representations of the images, or let's say if you are working on the pixel space, you add small amount of
00:08:51.120 | noise to your clean images, and you make the model predict what was the amount of noise that was added. This is one
00:08:58.720 | very popular and widely adopted parameterization of the diffusion network training. It's called the
00:09:03.520 | epsilon objective, if you will. And that's basically it. So, we basically make the model learn what was
00:09:10.320 | the amount of noise that was added. So, that's training. And during sampling, we repeat the noise
00:09:18.240 | prediction part in a sequential manner until and unless we arrive at an image that we feel good about. So, we start
00:09:26.240 | from a random noise, and we denoise it over a period of time with some condition, which can be text embeddings, for example.
00:09:34.960 | And then we sort of repeat the denoising process over a period of time until and unless we arrive at a data point that we feel good about.
00:09:43.040 | So, yeah. And for flow matching, which is becoming more and more common these days,
00:09:48.000 | for example, flux, if you have heard of the model flux or stable diffusion three, these are all flow
00:09:52.960 | matching based. In flow matching, the paths become more straight. We try to connect noise and clean data
00:10:02.480 | through a straight path. That's why you see the linear interpolation equation in the first point. And we try to
00:10:08.800 | predict the target velocity field with a neural network. So, the key takeaway here is we try to connect
00:10:16.960 | noise and clean data through a straight line. But in diffusion models, the path is not assumed to be a
00:10:23.120 | straight line. So, that's why a lot of simplification you will see in flow matching.
00:10:30.240 | But long story cut short, in this talk, we are more interested in the network component of things. We are
00:10:36.320 | not interested in how you should add the noise, what particular noise schedules you should follow,
00:10:41.760 | and so on. We are more focused on the parameterization side of things in this talk. So, yeah.
00:10:52.400 | And let's now start discussing the components we might need. The components that are kind of expected
00:11:01.520 | in a diffusion model. What are the core requirements that we want to see in a diffusion model to have
00:11:08.560 | so that it performs within some bound of expectations.
00:11:15.120 | So, we have to figure a way out to deal with the noisy inputs. As we saw noise, we can't really get rid of
00:11:22.640 | the noise because that kind of gives us the foundation. And usually, in pixel space, your shapes will look
00:11:30.880 | something like this if your channels first. And if you are working on the latent space,
00:11:38.800 | it will, of course, be compressed a bit. And we usually always deal with the latent space
00:11:43.680 | diffusion model. So, it helps to have the kind of shapes that you would expect to see in your models
00:11:50.880 | that will flow through. And then you will have to deal with the conditions, right? Because all the
00:11:56.880 | text image models that you see, text is your condition. So, you'll have to deal with your
00:12:02.000 | conditions. And, of course, time step. So, the amount of noise that gets added to your latents,
00:12:08.000 | it should depend on the time step. Like, let's say, if your time step is 10, the amount of noise that
00:12:13.600 | will get added to your data points, it will be different from the time step if the time step were to be
00:12:20.480 | somewhere like 1,000 or 100. So, that's why time step is also a very crucial condition. Because
00:12:27.600 | it basically tells the model that which point in the trajectory it's in. Should it denoise less? Should
00:12:36.400 | it denoise more aggressively? And so on. And then you have got other forms of conditions such as class,
00:12:43.440 | maybe tiger, bird, or just like natural language. And then we have to also model the dependencies quite
00:12:51.680 | a bit. How should the noisy inputs interact with your conditions? Of course, you might want to think
00:12:58.480 | about these things through the lens of cross-attention and so on. But we'll get to that in a moment.
00:13:05.360 | And then how should the final outputs be produced? Should we just flat out decode everything if we were
00:13:11.920 | to deal with transformer-based architectures? Should we be up-sampling if we were to deal with pure
00:13:17.280 | convolutional architectures? And so on. And usually, for the diffusion network, the input shape exactly
00:13:25.840 | equals the output shape until and unless you are dealing with a separate parameterization of diffusion.
00:13:31.520 | There are other parameterizations, but it's way more common to see
00:13:34.720 | the inputs that are flowing through the diffusion model will equal the outputs that the diffusion model
00:13:41.360 | is supposed to produce. And then, yes, here's a bit of a history. So, DDPM, all the early work
00:13:50.720 | in this area, DDPM latent diffusion models, they all used unit-based architectures. And I think,
00:13:58.720 | historically speaking, units have kind of dominated this area for quite a bit, like StyleGAN, they all had
00:14:06.240 | a unit-based architecture, right? And then, until SDXL, which is a fairly recent, I would like to say
00:14:12.560 | recent because it's like from 2023. And if you are trying to develop a chronology of all the architectures
00:14:18.800 | that have come around, I would like to argue that SDXL is fairly decent. And also, based on its usage,
00:14:25.200 | I think it's decent. So, until SDXL, all the works relied on unit. So, I think it makes sense to sort of discuss
00:14:35.360 | the unit-based architecture in this paradigm and then start seeing why we really need to transition to
00:14:44.720 | transformers. So, yeah. And, of course, the unit for diffusion, it's giant. It's also one of the reasons
00:14:56.000 | why you would want to probably get rid of it, but that's not the end of the story. But let's try to see
00:15:02.160 | what are the different components that are involved in our giant unit. So, you, of course, need to have
00:15:08.320 | an input convolutional stem that directly operates at the inputs that are coming at it.
00:15:14.720 | And then, you have a bunch of down blocks, which is basically comprised of custom residual blocks
00:15:21.840 | made of normalization layers and convolutional layers. And then, you have got custom transformer
00:15:28.480 | blocks. Again, normalization projections and regular transformer blocks. And then, you have got
00:15:34.960 | convolutional layers for upsampling. So, basically, you are… So, when you are operating on the latent space,
00:15:42.320 | you would want to upsample to a higher resolution. And then, you would have some middle block like you
00:15:50.000 | would have in a standard unit architecture where you would not have any resolution changes. And then,
00:15:57.440 | you would have a series of up blocks, which will basically downsample the upsampled, you know, outputs
00:16:04.640 | that you had in your down blocks. And finally, you will produce the output that will have the same shape
00:16:10.480 | as your input. And it basically resembles the same kind of blocks that you would have in your down
00:16:18.160 | block counterpart. But instead of doing upsampling, it will have downsampling layers.
00:16:24.160 | So, yeah. And there are some miscellaneous things that you should worry about. For example,
00:16:29.680 | time step embeddings, how you should embed your time steps, and additional embeddings. For example,
00:16:35.040 | if it's a class conditional model, the way to embed them would be different. If it's a text image model,
00:16:42.400 | the way to embed the text embeddings would be different. Basically, the way you modulate
00:16:46.640 | all your conditional embeddings, that will change depending on the kind of conditions that you are
00:16:51.280 | dealing with. And here's basically how a down block of the unit architecture looks like. I mean,
00:16:58.800 | it's so, so giant that I had to transpose it. So, it's kind of very painful to even think in the head.
00:17:04.960 | So, if you are someone that works with architectures quite a bit, I think we'll have to solemnly agree
00:17:10.320 | that this is bad. This is bad news already. I mean, I can't even imagine it in my head. So,
00:17:16.000 | and here's a bit of a zoomed-in look of what goes into that down block component of the unit. You have
00:17:25.120 | got a bunch of custom ResNet blocks, as I was discussing. I've got a bunch of custom transformer
00:17:30.720 | blocks, as I was discussing, but nothing too brutal. Nothing that you haven't seen already. So,
00:17:35.760 | it's just a bunch of composition of those blocks. And I have also tried mentioning the resolution
00:17:41.440 | changes in each of the stages. So, yeah. And putting it all together, it looks something like
00:17:49.200 | this. I mean, I have tried shortening it quite a bit. But I mean, if I had to imagine it in my head,
00:17:55.200 | it's going to be extremely painful. And the blocks are just, oh, damn. It's very prohibitive in nature,
00:18:01.200 | just for the task of image generation. So, yeah. That's like the complete perspective. I know it's
00:18:09.040 | not as complete as you might expect. But it's also hard to kind of fit all the unit in a single screen.
00:18:15.760 | So, yeah. And then of course, the natural next thing that one would have tried would be to try to
00:18:25.680 | replace the convolutional layers with MLP blocks to try to simplify quite a bit. And yeah, folks from
00:18:30.880 | Google tried it in the UWIT architecture. So, that's there. That was the precursor to the pure
00:18:40.720 | transformer-based architectures, but not quite. It still had its fair share of complicacies and the
00:18:47.200 | architectural painfulness, as you saw in the unit-based design. So, yeah.
00:18:53.360 | Now, I'm going to now try to motivate why we really need a pure transformer-based architecture.
00:19:01.920 | Now, the first point is probably very obvious. We would want to benefit from the advancements that
00:19:07.600 | are happening in the transformer-based architectures, like all the divine benevolence, as Noam Shazir likes
00:19:13.200 | to call it. You would want to have Swigloo. You would want to have QK normalization. You would want to
00:19:18.240 | have parallel MLP layers and so on. So, that's one. Of course, good scaling properties and so on. And then,
00:19:26.000 | let's say you want to connect the pure transformer-based diffusion architecture or the backbone
00:19:33.760 | with some other backbones, let's say a pure LLM-based backbone, the integration becomes very easy.
00:19:40.000 | And then, it allows you to get rid of the giant unit, which I guess is my main motivation. But I
00:19:47.440 | hope I was able to convince you fairly strongly that why we need a change, like a paradigm shift in the
00:19:56.320 | architectures that are primarily inspired from the unit design.
00:20:00.800 | And also, this is side to the sore eyes already. I mean, this is not uncommon. We all know that this
00:20:10.320 | is the standard forward pass in a vision transformer network. They should feel very familiar at this point
00:20:16.320 | in time. Now, the point I am trying to make here is this doesn't have to change a whole lot. If we were to
00:20:24.160 | sort of extrapolate this to image generation. And I think you will agree with me that this is not
00:20:30.640 | changing a lot. Like all the core components are there. The patchification is there. The positional
00:20:38.320 | embeddings are there. Of course, the way to class embed things, that's different. The Y embedder bit,
00:20:45.360 | you would still need to have a component to embed your time steps. That's different. But the rest of the
00:20:50.240 | components, it's still there. Like you are still iterating through the blocks. You have your final
00:20:55.840 | layer to finally decode your outputs and so on. And then non-patchification layer. So like this is still
00:21:01.920 | very similar to how you would do it in a standard bit. But of course, you have to, you know, account for the
00:21:09.680 | generation head also at the same time. So my point is, you are not changing a whole lot
00:21:17.840 | in the standard bit forward pass. Now taking a closer look, I think this should also kind of feel
00:21:26.320 | very familiar apart from a few mods here and there. Like we have got some scale and shift parameters,
00:21:32.560 | which I'm going to discuss in a bit. But rest of the other blocks, like you have got the same layer norm,
00:21:38.560 | you have got the patchification layer, you have got some embeddings and so on. And then you have got a linear and
00:21:43.600 | reshape operation. So most of it should feel familiar, but the ones that are apparently a little foreign, I'm going to
00:21:52.080 | discuss them now. So yeah. Let's start with the time step bit. Like how do we actually embed time steps?
00:21:59.680 | I'm, I've been talking about time steps for quite a bit now. But let's now see how do we actually embed the time
00:22:06.960 | steps. And time steps are really important. And I'm also going to show you the expected shapes, the output
00:22:14.000 | shapes. And in this case, this is batch size comma the hidden dimension of the transform blocks that you are expecting.
00:22:21.760 | So how do we embed time steps? And time steps can range in between zero to thousand, where zero meaning like no
00:22:28.240 | noise and thousand should mean it's fully noised.
00:22:31.280 | So each T is embedded into sinusoidal frequency to make them phase aware. Like at any point in time,
00:22:41.200 | the network is seeing extremely low frequencies and extremely higher frequencies. And it must be aware of the kind
00:22:47.120 | of phase it should be, it should be modulating into. So that's why sinusoidal frequencies are really
00:22:54.080 | helpful. And then after that, how the network should weigh these different frequencies. And in order to
00:23:01.120 | model those weights, we basically pass it through a very shallow MLP. And then how do you embed class
00:23:08.960 | labels? You just take an nn.embedding layer as simple as that. And it's the standard patchification.
00:23:15.520 | You do it with a convolutional stem. And for positional encodings, they use the standard sine cosine
00:23:21.760 | scheme. And the final conditioning is you first embed the time steps and then you embed the class
00:23:31.040 | labels and you basically sum them up. And then you have your final condition that goes into the transformer
00:23:37.520 | blocks. And this is very important to note that C remains fixed across all the blocks. So that's very
00:23:43.920 | important to note. And you would probably think in order to model the conditioning with the actual
00:23:56.240 | inputs, the noisy latency, you would probably want to use cross-attention, but that's not the case, as we
00:24:03.200 | will see in a few slides later. So let's take a step back and try to think of how we can inject the
00:24:14.640 | conditioning bit into the transformer blocks. So we have something called adaptive layer norm, which is very
00:24:23.200 | important in order to be able to model the stylistic aspects that you are getting out of your images.
00:24:31.440 | And it's basically this. So you have your standard layer norm and then you have an additional set of
00:24:36.880 | parameters, which we call modulation parameters, which basically operate on the condition space. And remember,
00:24:44.960 | the condition is basically a summation of the time step embeddings as well as the class embeddings. So that's your condition right there.
00:24:55.440 | And skipping the regular transformer bits, like the QKV, the multi-head self-attention and the MLP layers,
00:25:02.640 | we know the equations that govern the computations that take place within a standard transformer encoder block, right?
00:25:09.920 | So this is fairly well known. Now the part that's not known at this point in time, hopefully,
00:25:15.840 | is how do we modulate the conditioning bit? How do we actually inject the conditioning in the transformer blocks?
00:25:23.760 | So this is how we do it. Instead of doing any cross-attention, which would have been a fairly natural choice,
00:25:29.760 | I guess, we actually do not do cross-attention. We are still doing self-attention and then we are basically
00:25:36.720 | modulating the conditioning along with self-attention, as you can see in the bottom half of your equations.
00:25:44.320 | And these modulation parameters are learned from the modality that you are training these things on.
00:25:49.520 | And then in order to get your final outputs, you basically have a single-layer decoder,
00:25:58.960 | and then you basically unpatchify it to get the same shape as your inputs.
00:26:08.560 | Now, a note on init, because the initialization is fairly important. It's all standard vision transformer init,
00:26:17.280 | but with two key modifications. Each transformer block is initialized as identity block, taking inspiration
00:26:26.000 | from the early works in ImageNet training, wherein if you have a bunch of ResNet blocks,
00:26:31.760 | you usually initialize the beta parameter in the batch normalization layer as zero. It helps with
00:26:36.880 | training stability and stabilization, and turns out that's the case here as well.
00:26:40.240 | Now, coming to the adaptive layer normalization thing, it's very important and it's also more compute
00:26:48.880 | efficient. And as I was mentioning, cross-attention would have been the natural choice in order to
00:26:56.960 | kind of model the dependency between the conditions and the noise inputs, but that's
00:27:02.160 | not at all the case, if you take a look at the graph. In fact, adaptive layer norm performs the best and it
00:27:08.000 | beats cross-attention big time. And it's also not because the conditions are fairly simple to model,
00:27:14.240 | but it's because when you are operating with continuous modalities like images,
00:27:19.120 | it's not that trivial to model dependencies with cross-attention. And when you are
00:27:26.000 | conditions are simple like class embeddings, it doesn't make sense to model them with cross-attention.
00:27:32.720 | It's also because it's a waste of compute, and the graph kind of confirms it.
00:27:40.720 | And later works have also explored like a more compute efficient variant of adaptive layer norm,
00:27:45.920 | which we are going to get to in a couple slides later.
00:27:48.320 | And how you are modulating these conditions, like
00:27:51.920 | in this case, we are basically operating on a summation of the time step embeddings and the
00:27:58.240 | class embeddings. Now, how you are modulating it across and throughout your different transformer blocks,
00:28:04.960 | that becomes very important, as we will see in a couple blocks later.
00:28:10.560 | And as expected, it scales fairly graciously with more compute and turns out that you can basically
00:28:17.360 | apply all the unit-based training techniques to a diffusion transformer. So, that's pretty cool.
00:28:25.920 | And it scales pretty graciously.
00:28:27.760 | It performs also well when compared to other equi-sized unit counterparts.
00:28:34.720 | And of course, at this point in time, you must be thinking no one really does class conditional in the space of image generation.
00:28:46.160 | So, yeah, that's where we are headed next.
00:28:48.880 | Now, I want to try to motivate what it would take to enable text to image generation in the standard
00:29:01.200 | diffusion transformer architecture, because I think it makes sense to approach the problem in that sense.
00:29:08.240 | like, what are the components that are missing in a standard diffusion transformer so that it becomes a tool for text to image generation as well.
00:29:16.160 | And I think PixArt Alpha is one of the early works that explored it.
00:29:20.960 | So, yeah, we are going to definitely see it in details.
00:29:25.520 | Now, one natural question would be how to embed the input natural language text prompts.
00:29:32.480 | And the answer would be simple.
00:29:34.000 | You would need a text encoder.
00:29:36.400 | And that's exactly what the PixArt Alpha work does.
00:29:39.840 | And then how to learn your contextual dependencies.
00:29:45.200 | Now we have, instead of classes, we have got natural language text on top of time steps.
00:29:54.160 | We could do self-attention on noisy latents.
00:29:57.120 | And then we could do cross-attention in between your noisy latents and text.
00:30:01.360 | And mind you, this text is not just class.
00:30:04.320 | This text is some natural language description, like a baby astronaut hatching out of an egg on the moon.
00:30:11.440 | So this is natural language description we are talking about.
00:30:13.920 | This is not simply class labels.
00:30:16.000 | So that could make sense.
00:30:18.720 | It could make sense to have self-attention on noisy latents,
00:30:23.040 | to model the local dependencies within the patches.
00:30:26.160 | And then also to cross-attention in between the noisy patches and the text embeds.
00:30:32.320 | And then we'll have to figure a way out to modulate the time steps as well.
00:30:37.520 | And as we saw in the DIT work, that it's important to modulate the time step embeddings throughout your transformer blocks.
00:30:47.440 | And then, if you have got access to a class conditional diffusion transformer, and if it's compatible, it might also make sense to kind of initialize some blocks from it.
00:30:58.400 | Because why waste compute, right?
00:31:01.680 | It might help with training stabilization, and so on.
00:31:05.200 | Now, that's exactly what Pixart Alpha does.
00:31:12.000 | So if you were paying attention at this point in time, you would have realized that implementing these things is not extremely challenging, but it helps to know that that's exactly what Pixart Alpha does.
00:31:29.120 | So it uses a text encoder to embed the prompts.
00:31:32.240 | It uses self-attention on noisy latents.
00:31:35.040 | It uses cross-attention to model the dependency between the noisy latents and the text embeddings.
00:31:42.000 | And it also initializes from a class conditional DIT model in order to accelerate training.
00:31:46.960 | So it kind of helps to know that you can still think about these things and see it getting implemented in practice.
00:31:54.960 | So yeah.
00:31:55.360 | And here's some discussion around the use of text encoder.
00:32:01.840 | So Pixart Alpha used Flan T5XXL in order to really get that text rendering ability, and some of the concurrent works like Imagine, they showed that if you scale the text encoder, and if you ever wanted to render text in your images, having a better text encoder actually helps.
00:32:22.880 | And also then there's this problem of long prompt following ability, because models like SDXL, they rely on clip for embedding text, and clip has a very short context length.
00:32:35.280 | I think it's 77, but with the T5XXL, you get a way longer context length.
00:32:41.040 | So you basically get to have longer prompts, you get to describe your prompt in a bit more detail.
00:32:47.680 | So that's there.
00:32:50.080 | And also exploring the space of text encoder, it's still a kind of good research problem to take a look at.
00:32:56.720 | And if you were wondering, why not just use a standard language model in place of T5?
00:33:02.320 | Well, it's not that difficult to actually use it.
00:33:04.720 | At many works like Lumina, they explore it quite a bit.
00:33:08.560 | And if you were wondering, diffusion models are already compute bound.
00:33:12.480 | So why add the baggage of adding another heavy model?
00:33:16.720 | Well, I think it's okay to use that because computing prompt embedding, it's a one-step process.
00:33:22.320 | So it's okay to have a memory bound model like a large language model in the mix.
00:33:27.680 | Cool.
00:33:31.360 | And here we will again see the return of adaptive layer norm.
00:33:36.720 | And remember, if you forgot, remember that for each diffusion transformer block, we were operating on a summation of timestep embeddings and class embeddings.
00:33:50.000 | Right?
00:33:50.320 | And for Pixart Alpha, we basically, we already have a way to compute our text embeddings, which we do not want to touch.
00:34:01.360 | Because those embeddings are already computed with a dedicated rich text encoder.
00:34:07.200 | So we maybe do not want to touch those embeddings and modulate them, unlike diffusion transformers, where we modulated the class embeddings as well.
00:34:16.080 | And instead of having adaptive layer normalization blocks in every diffusion transformer encoder block, we basically maintain tables.
00:34:27.360 | We basically maintain embedding tables and we sum them up.
00:34:31.360 | So instead of doing another matmul for each and every transformer encoder block, we basically get away with addition.
00:34:41.040 | Addition with another embedding table.
00:34:44.640 | And it helps us reduce the compute quite a bit.
00:34:47.360 | It helps us reduce 27% of the original diffusion transformer computation.
00:34:52.560 | And I think the idea is fairly elegant.
00:34:54.880 | It also gives another perspective to think about how you can reduce parameters and still sort of maintain performance.
00:35:03.680 | This was good.
00:35:05.040 | Cool.
00:35:07.840 | And I must say, this is impressive performance for a fairly compact model.
00:35:11.920 | Like it's only 0.6 billion parameter model and it's already like breaking all the charts.
00:35:17.840 | Like when it came around in 2023, it was fairly, fairly good, fairly good.
00:35:22.880 | I think I have the general scores.
00:35:27.520 | No, we don't.
00:35:28.720 | But the last chart here, it basically shows you the human preference rating.
00:35:33.440 | Quality wise, it's overall image quality and alignment wise.
00:35:38.560 | It's the alignment between the text and the generated images.
00:35:42.480 | And in both of those aspects, PixArt Alpha performed fairly well.
00:35:51.280 | Now, before I jump to the time and memory complexity of these models, because we are still using the quadratic vanilla attention thing,
00:36:04.480 | I want to see if there are any questions at this point in time.
00:36:08.720 | So, I'm going to open the floor for questions, if that's okay, Steven.
00:36:11.920 | That's okay.
00:36:13.120 | Yeah, go ahead.
00:36:17.200 | You mentioned that the state-of-the-art diffusion models require more than one text encoder.
00:36:25.440 | Why is that?
00:36:26.160 | What's the benefit of having multiple different types of encodings for the same text?
00:36:32.880 | Oh, yeah, for sure.
00:36:34.080 | So, the question is, why would you want to have multiple text encoders?
00:36:38.320 | And why does it help improve the image generation performance?
00:36:43.040 | So, you saw in class conditional dates, you modulate not just the time step embeddings, but also the class embeddings.
00:36:53.120 | So, the conditional embeddings that we were modulating along with self-attention were a summation of class embeddings and time step embeddings.
00:37:00.640 | Now, for text image models, that is not as trivial.
00:37:05.680 | So, apart from your text embeddings, you also kind of have your time step embeddings.
00:37:11.120 | But what about modulating the other condition that you computed from your text input?
00:37:19.200 | So, turns out that the richer the representations are, the diverse the representations are, the better it is for the generation backbone.
00:37:27.440 | Now, when I said many models do use more than one text encoders, they usually have a combination of Clip and T5.
00:37:36.320 | Now, Clip is an entirely different model and T5 is an entirely different model.
00:37:39.840 | With Clip, you embed some kind of contrastive nature in the text embeddings.
00:37:44.880 | And with T5, you have a completely different nature in your text embeddings.
00:37:50.800 | So, the more the diverse are these embeddings, the better it is for the generative performance.
00:37:56.800 | But there's no systematic study of this.
00:37:59.840 | That's what the general belief is.
00:38:03.520 | Yeah, but there are works that get away just by using a single language model.
00:38:09.440 | So, maybe the language model, maybe the language models inherit both as a virtue of their good pre-training.
00:38:19.840 | I think most of them are using diffusion transformers because of their efficiencies and also they are easy to adapt.
00:38:41.200 | Like, if you wanted to embed another form of control, like let's say you wanted to additionally prompt the model with some stylistic reference from images,
00:38:51.760 | it's easier to do that on a diffusion transformer as we will see in a couple slides later.
00:38:57.120 | So, in the original diffusion transformer block,
00:39:16.000 | you initialize all your modulation parameters as basically an affine layer.
00:39:20.800 | So, you can do that, right?
00:39:23.040 | I mean, you can chunk it and then you have your different modulation parameters.
00:39:27.040 | And this is where your matrix multiplication lies.
00:39:34.400 | And notice how I am computing the modulation parameters, right?
00:39:38.320 | But here I am only computing it through an embedding table.
00:39:43.040 | There's no affine transformation happening.
00:39:45.920 | And I end up adding it to my timestep embeddings, that's it.
00:39:52.080 | So, that's how I avoid the matmul.
00:39:54.560 | And that's largely why I am able to sort of reduce the computation by 27%.
00:40:01.840 | I'll take one last question before I jump to the next section.
00:40:07.200 | If there's any.
00:40:08.560 | I've got some over Zoom as well as online.
00:40:12.240 | I'll ask some, let me see.
00:40:16.160 | Someone's wondering, is there still any point of GANs for image and video generation?
00:40:20.800 | Or have they been fully replaced basically with diffusion models?
00:40:24.560 | I won't say they have been fully replaced because you, as we saw, diffusion models are still sequential in nature.
00:40:31.120 | So, if you're looking for really ultra real-time generation, I think GANs are still the way to go.
00:40:36.560 | And many companies are, in fact, using them.
00:40:38.880 | So, if you're looking for really cool one-shot generation, I think GANs is the way.
00:40:42.800 | And even if, so there's this literature around timestep distillation that basically looks at reducing the number of timesteps or inference steps that you need in order to produce a good quality image.
00:40:54.480 | So, there you need a GAN loss, actually.
00:40:56.160 | So, GANs are not going to be completely replaced yet.
00:40:59.920 | Okay, because I thought, I thought GANs had a lot of issues like mode collapse.
00:41:03.520 | Oh yeah, they have, but if you have got specific use cases and if you have got a fairly good data set, you can still train a good GAN.
00:41:14.800 | I can ask one more if there's time from online.
00:41:16.960 | Sure, go ahead.
00:41:17.280 | Image generation generally requires a lot of data to train, especially for diffusion models.
00:41:24.560 | Are there techniques or architectural choices for low data regimes?
00:41:28.240 | That's a good one.
00:41:31.840 | The short answer is no.
00:41:33.360 | You need, I mean, I think it's kind of correlated with your use case a bit.
00:41:41.120 | Like for medical imaging, probably you do not need a whole lot of diversity, but if you need a fairly well and generative model, I think you need to train it on a lot of data to inherit all the biases that you are looking for.
00:41:55.840 | Like if you want to train on the distribution of natural images and you want the model to always produce realistic images, I think you need to have a lot of data.
00:42:06.880 | At least diverse data.
00:42:08.880 | Because nowadays for like LLMs, you know, there's like zero shot, few shot learning after you preach.
00:42:16.160 | Oh, I'm going to talk a bit about in context learning for diffusion models.
00:42:20.000 | Okay.
00:42:20.320 | But, but you still, you are assuming you have access to a pre-trained model, right?
00:42:26.240 | Right, right.
00:42:26.960 | So there you go.
00:42:30.640 | Alright, I'm going to, I'm going to start this leg of the talk by discussing the quadratic time and memory complexity argument a bit, because now we are in the image generation territory and when you are trying to generate like really high resolution images, like 4K images,
00:42:54.640 | So, vanilla attention becomes extremely prohibitive even if you were to operate on the latent space.
00:43:01.920 | Like take a look at the dimensions of the latents.
00:43:03.920 | Uh, let's say your, uh, number of latent channels is 16 and your latent dimensions are 512, 512.
00:43:11.920 | Like it's, it's still way too large, way larger than what we are used to seeing in the VLM space.
00:43:19.200 | Right.
00:43:20.080 | And I have some dummy computations here.
00:43:22.720 | This is of course not using flash attention.
00:43:24.640 | If I were to give you some dummy estimates, it would be like 190 GBs in floating point 16.
00:43:34.400 | And this is all like reasonable defaults.
00:43:36.800 | I mean, I have 24 attention heads, I've got batch size one, the sequence length is flattened out.
00:43:45.040 | I would need 190 GB.
00:43:47.920 | This is of course not using flash attention.
00:43:50.480 | But you kind of get an idea of the prohibitive computation space of the quadratic attention that we usually use in diffusion worlds.
00:43:59.440 | If we were to deal with really, uh, high quality and high resolution, uh, images.
00:44:05.520 | So what could we do?
00:44:07.680 | Like two simple things that we could do is operate on an even more compressed space.
00:44:12.800 | Like this, this 1, 16, 512, 512, this space is already compressed enough.
00:44:19.040 | But could we even increase the compression ratio further?
00:44:22.080 | Right.
00:44:23.520 | And then the second, second obvious thing could be to use some form of linear attention.
00:44:28.560 | that doesn't do the n cross n, uh, multiplication.
00:44:31.440 | So that brings me to my next architecture, which is the SANA architecture.
00:44:38.160 | And it uses both.
00:44:39.360 | It uses an, an even more compressed latent space.
00:44:43.680 | And it also uses a linear variant, uh, of the, of the, of the attention mechanism.
00:44:49.200 | So let's see.
00:44:50.640 | So a linear attention is on the other hand.
00:44:55.680 | And of course you might expect some performance loss with, with the linear complexity of the attention
00:45:02.400 | mechanism.
00:45:03.440 | And to compensate for the performance lost, we use mix FF.
00:45:07.440 | And I'm gonna, I'm gonna get into the details in a bit, but just wanted to give you a quick overview.
00:45:13.120 | So SANA does self, self-attention, self-linear attention as you will.
00:45:17.840 | But it still does cross-attention to model the dependencies between the noisy latents, as well
00:45:23.680 | as, uh, the text prompts that you are providing it to.
00:45:26.640 | And then there's no n-square computation happening in the self-attention.
00:45:31.360 | And the equation makes that clear.
00:45:33.600 | So we have got shared terms, shared terms computed from the KV projections.
00:45:38.080 | And these are reused for all the queries.
00:45:40.640 | And this way we are basically and effectively, uh, not doing the n cross n multiplications.
00:45:47.120 | And all of the multiplications are upper bound to n.
00:45:49.840 | That's why we are able to reduce from order of n-squared to order of n.
00:45:55.360 | And then, as I was mentioning, there has to be some form of accountability as we are not using the
00:46:02.640 | quadratic attention mechanism, right?
00:46:05.200 | So we use mix FFN blocks.
00:46:07.280 | It's basically a bunch of inverted residual blocks and point-first convolutions in order to model the
00:46:12.800 | local dependencies.
00:46:14.160 | Because it turns out that when you take the softmax out of, out of the picture,
00:46:18.240 | you lose all the notion of locality.
00:46:20.960 | Let's not say all, majority of the locality notions are taken out also.
00:46:26.240 | So that's why you, you need some, some components to also model the locality aspect.
00:46:32.320 | And mix FFN blocks are used for that.
00:46:34.240 | And for the first time, Sana got rid of the positional embedding bit.
00:46:40.160 | So, I mean, it's very, it's very funny to say it nope.
00:46:45.520 | So no positional embeddings.
00:46:47.440 | And mix FFN blocks actually helps because you have a bunch of convolutional blocks in there.
00:46:52.640 | So we are still returning to the convolutional argument, but not so much.
00:46:56.880 | Like it's, it's not fully convolutional.
00:46:59.360 | You just have a few convolutional layers thrown in there.
00:47:02.720 | So it's largely did based, but instead of using like a linear MLP,
00:47:07.680 | you have two convolutional layers to, to, to account for the local interactions.
00:47:14.240 | And it performs fairly well.
00:47:19.680 | For its compact size, it performs fairly well.
00:47:21.920 | I must say that we shouldn't take, take into account the kernel fusion and the fancy, fancy DPM
00:47:29.840 | problem solver that they're using.
00:47:30.880 | But up until the fourth row, I think it's still performing fairly well.
00:47:36.480 | Like it's giving very decent general performance.
00:47:40.000 | It's giving very decent DPG performance.
00:47:43.600 | And these are all metrics that assess, that assess a given image in terms of compositionality,
00:47:50.800 | in terms of their fidelity, in terms of their overall quality, and so on.
00:47:56.240 | And these metrics are fairly well grounded in terms of reality.
00:48:00.400 | So yeah.
00:48:01.920 | And, and this is, this is probably going to be the final flavor of attention that I'm going to discuss
00:48:12.080 | before I move on to other topics.
00:48:14.720 | And this is, I think, a bit new.
00:48:18.560 | And also new in the sense that no one in the VLM space really does it.
00:48:22.640 | So what happens if we were to kind of, you know, model
00:48:26.720 | dependencies of the different modalities in separate spaces?
00:48:30.400 | So let's see.
00:48:31.600 | So, and the, and, and one motivation behind this could be text embeddings.
00:48:38.240 | Let's say you are computing text embeddings with a large language model.
00:48:42.160 | You will end up inheriting a lot of bias from it, right?
00:48:45.840 | Like the unidirectionality bias will be there.
00:48:49.440 | If you're using a standard autoregressive large language model, you will end up inheriting the
00:48:54.640 | unidirectionality bias.
00:48:56.320 | So, so you will have bias.
00:48:59.360 | And they might, you know, creep into your, your generative model in all sorts of different ways.
00:49:05.280 | So, so one idea could be to do QKV projections, but, but separately.
00:49:12.960 | Like you do QKV projections on the text embeddings and you also, you also maintain another set of QKV
00:49:19.680 | projection parameters for your noisy latins.
00:49:23.360 | And then you concatenate them before you compute the attention.
00:49:26.720 | So you operate on a concatenate, concatenated representation before you actually do the attention.
00:49:33.440 | So you project your QKV stuff on your image latins separately from the text embeddings.
00:49:44.960 | You concatenate them before you actually compute attention.
00:49:49.040 | So this is basically it.
00:49:51.360 | So this is MMDIT that was introduced in the Stable Diffusion 3 paper.
00:49:55.120 | And their motivation was to get rid of the different biases that might be there in the text embeddings.
00:50:02.640 | And they also showed how qualitatively different text embeddings can be from the image embedding.
00:50:08.640 | So I guess that kind of gives you a hint about why we might need different text embeddings to not end up
00:50:16.240 | inherit each other's biases and have more diversity in the mix.
00:50:19.600 | So it basically looks like this.
00:50:24.960 | So it might feel a little complicated, but the reason why it's so big is because we have
00:50:31.200 | separate projection matrices for the different modalities that we are modeling.
00:50:35.440 | So on the left-hand side, we have got the captions.
00:50:40.560 | And on the right-hand side, we have got the noisy latents.
00:50:43.280 | And we have got separate QKV projections and adaptive layer non-matrices for the separate modalities.
00:50:49.840 | So that's why it feels big.
00:50:51.840 | But conceptually, it's basically this.
00:50:54.560 | So you have got separate adaptive layer non-matrices for the separate modalities.
00:51:02.720 | You have got separate QKV projection matrices.
00:51:05.040 | You have got separate output projection matrices.
00:51:08.400 | And then you have got separate, everything separate.
00:51:11.920 | Everything is separate from the two modalities that we are interested in.
00:51:16.800 | So in a way, it kind of gives a way to co-evolve the two embeddings from the two different modalities
00:51:28.160 | that we are working with for the given task, which is image generation in this case.
00:51:31.760 | And also, if we are very used to cross-attention, you might want to ask,
00:51:38.800 | if we were to compute attention in this way, how do we do masking in the first place?
00:51:43.360 | And that's actually an active area of research.
00:51:45.440 | We do not know how to do masks holistically if we were to do mmdit variant of attention.
00:51:52.400 | And then modulation happens with both time steps and the conditional embedding that you are operating with.
00:52:02.240 | So Stable Diffusion 3 uses a different set of pulled text embeddings that are usually computed from CLIP and not the T5-based text encoder.
00:52:12.000 | So again, you need some form of diversity to not end up inheriting the bias from the other text encoder that you have.
00:52:21.840 | And Stable Diffusion 3 uses three, but they can be mixed and matched during inference.
00:52:28.000 | They showed that you need the T5 if you need to have really good text rendering capabilities.
00:52:34.480 | But other than that, you can still do a lot with the two clips that they use.
00:52:41.360 | And you can drop the T5 if you are not solely focused on text rendering tasks.
00:52:46.080 | And of course, it matters quite a bit.
00:52:52.960 | You might want to ask, does mmdit matter at all?
00:52:57.040 | And it turns out that it does matter.
00:52:59.200 | Like they tried all forms of different attention variants.
00:53:03.440 | They tried cross-dit, they tried uvid.
00:53:06.240 | And mmdit seems to be the variant that gives you the lowest validation loss.
00:53:10.800 | And they also show that validation loss is fairly well correlated with the kind of image generation
00:53:16.960 | matrix that we care about, such as GenieVal and so on.
00:53:20.640 | And it scales fairly well, but it needs QKNORM relation.
00:53:25.760 | And thanks to the concurrent set of works, they didn't have to reinvent QKNORM from the scratch.
00:53:32.240 | So thanks to the developments in the regular transformer literature, they were able to just use QKNORM to
00:53:40.240 | solve the training and stability issues.
00:53:42.320 | And I think Stability Fusion 3 is basically incomplete without this picture.
00:53:48.400 | And I'm not showing this picture just for fun, but also because it shows how complex is this prompt,
00:53:58.080 | and how well Stability Fusion 3 was able to interpret this prompt and get us a creature like this.
00:54:05.360 | So I must give credits to the authors of Stability Fusion 3 who came up with these kinds of prompts.
00:54:11.520 | But it does fairly well.
00:54:13.840 | So it was among the first kinds of models that shows impressive prompt following ability,
00:54:20.960 | and also while also preserving the details that we care about.
00:54:27.040 | So that was quick.
00:54:28.080 | And mmdit didn't stop here.
00:54:34.000 | There are different flavors of mmdits that I wanted to discuss.
00:54:38.320 | Like in Stability Fusion 3, all the transformer blocks followed the mmdit flavor of attention.
00:54:47.440 | And it's computationally demanding.
00:54:49.520 | I mean, if you have worked with transformer encoders,
00:54:52.720 | you would appreciate the computational intensity of mmdits,
00:54:59.440 | because you are having to kind of maintain separate projection matrices for the different
00:55:05.120 | modalities that you are working with.
00:55:06.960 | So it's computationally extremely demanding.
00:55:08.880 | So maybe we could combine mmdit blocks and regular dit blocks to be able to better utilize the flops.
00:55:17.120 | So conceptually, this becomes this.
00:55:21.280 | So you have some number of mmdit blocks, you concatenate the final representation,
00:55:28.080 | and then you operate on the concatenate space and basically do vanilla dit.
00:55:33.680 | You do not do mmdit.
00:55:35.200 | And you end up utilizing the flops a little bit.
00:55:37.840 | And that's exactly what recent models like Flux from Black Forest Labs, they do.
00:55:43.360 | So another twisty one could be you have modality A, you compute output A by passing it through
00:55:54.560 | a bunch of transformer blocks.
00:55:56.480 | And then you have got modality B, let's say text embeddings.
00:55:59.600 | You pass it through another set of transformer blocks.
00:56:02.560 | You concatenate these outputs, and then you pass them through another set of transformer blocks.
00:56:09.280 | And then you basically have your final output.
00:56:11.920 | And in this way, you can configure all the different blocks in their own manner.
00:56:19.680 | like the transformer blocks for modality A could be different from the transformer blocks that you would
00:56:27.040 | use for modality B. So this way, you have a greater level of flexibility and control.
00:56:32.960 | And that's what the Lumina 2 work did.
00:56:36.080 | And as we can see, for conditions, they pass it through a separate set of transformer blocks.
00:56:42.400 | For noisy latents, they pass it through a separate set of transformer blocks.
00:56:46.640 | And they end up concatenating them and then again, they have another set of transformer blocks.
00:56:52.400 | So transformer blocks on the left-hand side, in the RHS big block, can be different from the transformer
00:57:01.760 | blocks used to model the noisy latents.
00:57:05.360 | And I want to give you a sense of how we can think of simplifying all of this design.
00:57:18.400 | Because it might feel complicated at this point in time, but I don't know.
00:57:24.080 | Maybe we can simplify it quite a bit.
00:57:26.560 | But I want to quickly see if there are any questions at this point.
00:57:30.880 | Do we have any questions over Zoom?
00:57:33.120 | Yeah, we've got some online questions.
00:57:35.200 | Let me see.
00:57:38.240 | Somebody just asked, is GPT-4.0 a completely different architecture?
00:57:41.600 | How does that compare with the diffusion transformer?
00:57:44.880 | That's a good one.
00:57:47.120 | I think it uses a hybrid architecture.
00:57:50.720 | It definitely has an LLM component to it,
00:57:53.440 | which was evolved to generate images.
00:57:56.720 | I'm going to come to this sort of hybrid architectures in a moment.
00:58:00.240 | Okay.
00:58:00.560 | And then someone's wondering for evaluation,
00:58:03.440 | are there good automatic metrics for image and video generation?
00:58:06.720 | Or is it mainly human-based subjective evaluation, similar to creative writing?
00:58:11.520 | I think it's an ensemble of different metrics.
00:58:14.800 | Like you can't evaluate an image on a single metric.
00:58:18.880 | And also it depends on what you are exactly looking for.
00:58:22.320 | Like if you are more interested in compositionality,
00:58:25.600 | the metrics will change.
00:58:26.960 | You are more, if you are more interested in aesthetics,
00:58:29.440 | the metrics will change.
00:58:30.560 | So that depends.
00:58:32.240 | Let me see.
00:58:35.440 | What are your thoughts on the vision models?
00:58:37.840 | Someone noted, there's a common perspective about like image generation models failing to generate things like fingers.
00:58:46.240 | I also see that, you know,
00:58:50.400 | do they still suffer from things like counting and like spatial consistency?
00:58:54.400 | Or are those basic things pretty much like solved?
00:58:57.520 | I think models like flags definitely, they do a whole lot better.
00:59:01.600 | And there's a question of like, what are the major additional challenges of video generation compared to image generation?
00:59:16.800 | And how ways to overcome them?
00:59:18.000 | Well, the first problem is the time.
00:59:21.280 | Because image generation, diffusion-based image generation models are already compute intensive.
00:59:26.720 | And then with video generation models, you have got another dimension of temporality.
00:59:32.080 | So they just become more compute intensive.
00:59:35.440 | And then you have got more cars of dimensionality to basically address.
00:59:41.360 | And then if you are generating more and more frames, they just become more and more compute intensive.
00:59:47.840 | So how do you, how do you kind of make them more efficient is definitely the need of the art.
00:59:52.800 | It also adds things like temporal dependencies.
00:59:55.280 | Yeah, yeah, exactly.
00:59:56.560 | Yeah, another access of dependency to model.
00:59:59.520 | Right, right.
01:00:00.080 | And you are basically doing full 3D attention, which is extremely prohibitive to even think about.
01:00:15.440 | In the video, you don't have to generate frame by frame from scratch.
01:00:18.560 | You already have something more like figuring out and then just do a small change.
01:00:23.120 | Well, I mean, if you, if you take, if you take a 2D image generation model, and if you try to conflate it
01:00:31.920 | such that it can also do spatio-temporal generation, well, it doesn't turn out.
01:00:37.440 | It will work well, at least in the realistic setting.
01:00:40.400 | Like if you wanted to generate cinematic frames, it won't work.
01:00:44.080 | Because the spatio-temporal consistency, it just gets lost.
01:00:47.840 | Oh, but, but, but, but there has to be a limit, right?
01:00:55.040 | Like if you are operating with that many frames as your previous input, you're, you're like,
01:01:01.040 | the queue also becomes extremely prohibitive.
01:01:03.840 | So you have to figure out how, how, how you can compress the temporal dimension effectively.
01:01:09.920 | Like, let's say you have a variable number of frames.
01:01:12.640 | Now, how do you map it on the latent temporal level so that it still has some meaning, uh,
01:01:18.560 | to it while being efficient at the same time?
01:01:20.880 | So when you are doing videos, you no longer just have the spatial compression, but also you have temporal compression.
01:01:29.440 | And how, how do you model the two as you, as you, as you make progress?
01:01:33.200 | So, are there any other questions?
01:01:38.160 | I'll just take one maybe?
01:01:43.920 | Come on, Zoom asks, can you give some intuition about adaptive layer normalization and why it works so well?
01:01:50.080 | Oh yeah. So, so for, for layer norm, so you, you are basically using layer norm for stabilizing training.
01:01:57.760 | I would say like you have more stable representations across the, across the, across the blocks.
01:02:04.800 | But for images, there are certain kind of characteristics that you would want to model beyond the standard representations that you are computing.
01:02:12.320 | And in, and you will have to let those kind of, let those features flow freely into your transform blocks.
01:02:20.320 | And that's why you need the modulation parameters.
01:02:22.480 | Otherwise, the interaction in between the normalization parameters and your regular attention features or whatever MLP features that you are computing,
01:02:32.160 | they don't get to interact, uh, in a way that will benefit the generation performance.
01:02:37.200 | If you were to deal with just understanding or maybe just discriminative performance, it wouldn't have mattered that much, much.
01:02:44.560 | But if you care about stylistic aspects and fidelity, you need to kind of modulate, uh, the additional features that you are getting from the, uh, visual, uh, cues.
01:02:54.320 | Yeah.
01:02:55.440 | Any, any, any, any other questions?
01:02:59.520 | Yeah.
01:03:00.320 | So you mentioned, we don't know yet how to do masking.
01:03:03.440 | For, for mmdit.
01:03:05.200 | Yeah, mmdit.
01:03:06.320 | Yeah.
01:03:06.640 | What's the next best thing?
01:03:09.200 | Is there an alternative to masking?
01:03:11.200 | Maybe that's a hunch, yeah.
01:03:12.320 | Yeah.
01:03:13.840 | So may, why would you need masking?
01:03:16.640 | That's a good reason to ask for mmdit.
01:03:18.960 | Because if you, if you are, so mmdit, mmdit was done in order to benefit, like, in order to get away with the unidirectionality bias that you may have in your text embedding representations.
01:03:31.760 | So that was one.
01:03:33.680 | Second, if you have really long prompts, you do not need to compute masks in the first place, right?
01:03:38.560 | And it also turns out that it's always better to have long prompts, like descriptive prompts, rather than having to have short prompts.
01:03:46.880 | So that is kind of a hand wavy way to answer this question.
01:03:51.120 | But long story cut short, it's really non-trivial to add masks when you are doing mmdit attention.
01:03:57.680 | So, uh, there you go.
01:03:59.200 | So that, go ahead.
01:04:00.160 | It was about doing separately QKV on both text and image, how about on the image side?
01:04:06.320 | Yeah, you, you, so when you are doing the image, image interaction, you could in theory do masks.
01:04:12.960 | But again, how would you frame the problem?
01:04:16.480 | where would you end up adding the masks if you are just doing text to image?
01:04:20.560 | Maybe, maybe if you have another form of condition, conditional control, maybe you would want to mask some of the interactions there.
01:04:28.640 | But if you are just restricted to text to image, where would you add the masks in the first place?
01:04:33.200 | Yeah, but what I'm asking to avoid masking, you know?
01:04:36.240 | Oh, yeah, yeah, yeah, yeah.
01:04:37.680 | I mean, probably that's why they didn't do masking in the first place.
01:04:40.400 | Maybe their motivations, yeah, yeah, exactly, exactly.
01:04:43.760 | I'll maybe take one final question if there's any.
01:04:47.520 | Okay, cool.
01:04:50.400 | So I, I also want to leave you with some ideas for simplification, simplifying the design a bit,
01:04:57.520 | if these things felt a little complicated, because they do not have to be complicated at all.
01:05:03.360 | We can simplify it quite a bit, and let's see how.
01:05:06.880 | So how much is parameter sharing is useful?
01:05:12.000 | Like things like Adiln that we saw and the, and the, and the work on Pixar Alpha that reduce the computation by 27%.
01:05:20.560 | How much is that kind of parameter sharing is useful?
01:05:23.680 | And we saw Adiln already.
01:05:26.240 | Can we also share QKVO and the MLP, like Albert?
01:05:31.280 | This is from back in the days, Albert, where we shared all the projection matrices for, for a couple of layers.
01:05:38.320 | We also shared the MLP matrices and so on.
01:05:40.960 | And do we really need, like self-attention and then cross-attention?
01:05:46.560 | Or do we really need MMD?
01:05:48.080 | I think this will probably answer your question.
01:05:50.400 | So look forward, I look forward to that.
01:05:52.160 | So do, do we really need these things?
01:05:54.640 | Because this kinds, this in a sense complicates the design space quite a bit.
01:05:59.120 | Can we simplify it?
01:06:02.160 | So can we basically do self-attention on this concatenated space?
01:06:06.320 | We do not have to do self-attention on noisy latency.
01:06:09.200 | noisy latency.
01:06:10.080 | We do not have to do cross-attention on noisy latency and text.
01:06:13.680 | We do not have to do MMDit.
01:06:15.360 | Can we basically concatenate all the image tokens and the text tokens and compute self-attention on it?
01:06:21.840 | Can we do it?
01:06:24.880 | And text encoders, like what's the secret sauce?
01:06:27.440 | We are using three.
01:06:28.560 | We are using two.
01:06:29.520 | What helps?
01:06:30.720 | What helps the cause of text to image generation at the end of the day?
01:06:35.040 | It turns out you can simplify things a lot.
01:06:38.160 | So one thing that you can do is you can parameter share quite a bit.
01:06:42.800 | You can parameter share QKVO if you are looking for efficiency.
01:06:46.800 | You can definitely parameter share the adaptive layer norm parameters.
01:06:52.080 | And you can basically do self-attention on a concatenated
01:06:54.960 | representation space of image tokens and text tokens.
01:06:59.040 | So this is, I think, good news, but not so much because Apple didn't open source this work.
01:07:04.880 | So we'll have to do our own.
01:07:06.560 | But this is good to know that you can simplify the design quite a bit.
01:07:10.960 | So the extreme right hand side is the simplified design that I was talking about.
01:07:15.120 | So you are basically operating on a concatenated space.
01:07:20.400 | You are reducing the adaptive layer norm parameters quite a bit.
01:07:24.800 | And then your parameter sharing the QKV projections as well as the MLP layers.
01:07:29.920 | So yeah.
01:07:30.800 | And it turns out to be working well, fairly well, I must say, in practice.
01:07:38.160 | So the green one is the simplified design, and it always gets the lowest amount of loss.
01:07:45.280 | So yeah.
01:07:46.880 | And as I was mentioning, other than sharing the adaptive layer norm parameters, you can also
01:07:54.880 | share the QKV parameters as well as the MLP.
01:07:59.280 | if you can compromise some of the quality a bit, and if you are only targeting efficiency.
01:08:04.160 | And for text encoders, the Apple folks found out that bi-directional clip and text-only LLM is better.
01:08:12.880 | So instead of using a combination of clip and T5, if you use clip and a regular large language model,
01:08:19.680 | it's better for text presentations.
01:08:21.280 | So the performance is pretty good.
01:08:25.200 | The final row is Apple's work.
01:08:28.240 | It's called DIT year.
01:08:29.120 | So it turns out fairly well.
01:08:32.080 | All the bold numbers from DIT.
01:08:36.480 | So it turns out pretty well.
01:08:38.240 | So long story cut short, you can simplify the design quite a bit.
01:08:41.360 | So you do not need mmDIT as it turns out, but we will see, I guess.
01:08:44.800 | Now, one burning question you must be having at this point in time is, text-to-image is liberating,
01:08:53.280 | but how do I inject more control, more sources of control?
01:08:58.800 | Like if I were to do more structural inputs to my text-to-image model, how do I do that?
01:09:05.760 | Maybe I want the model to follow a particular pose.
01:09:09.200 | Maybe I want the model to follow a particular segmentation map and so on.
01:09:13.600 | And how do I do that?
01:09:14.960 | And can I combine multiple structural signals at the same time?
01:09:18.720 | So one could be to learn maybe an auxiliary network that gives you a way to compute
01:09:28.960 | salient representations from your structural image signals, and maybe you then figure a way out to inject
01:09:35.040 | those structural signals into your base diffusion model.
01:09:40.480 | So that could be one.
01:09:41.600 | And control net, these lines of work, they basically follow this philosophy.
01:09:49.680 | And then maybe you could also change the base diffusion transformer model.
01:09:54.800 | You basically increase your input channels to accept, you know, more controls.
01:10:00.640 | And that's what the flux control framework does.
01:10:05.840 | And maybe you could also learn a small adapter network to model the dependencies between all your conditions and your noisy latent tokens.
01:10:17.920 | And also one important call-out here is structural control will always have spatial correspondence,
01:10:25.200 | but other tasks like subject-driven generation or image edits, they may not have direct spatial correspondence.
01:10:32.240 | So what do we do in those cases?
01:10:34.160 | And this is, again, an active area of research.
01:10:38.240 | And for videos, well, rope works fairly well for positional encoding.
01:10:45.520 | And as I was mentioning, the attention computation space becomes way more prohibitive because it's full 3D.
01:10:52.800 | And if you were thinking about some form of factorized attention, it doesn't work.
01:10:56.720 | I have been very explicit about it.
01:10:59.280 | And for efficiency, if you are particularly interested in efficiency-related literature around video models,
01:11:07.200 | I highly suggest checking out the LTX video work.
01:11:10.480 | And for performance, like for fidelity, photorealism, and so on, I think one has to be my favorite.
01:11:17.040 | It's pretty good.
01:11:19.040 | I have a little demo here.
01:11:21.520 | It's not that bad.
01:11:23.600 | So it's still short-form video, but the realism aspect and the fidelity aspect of these videos,
01:11:30.640 | I think they have improved quite a bit.
01:11:32.160 | And some next-generation architectures.
01:11:37.360 | Now I'm coming towards the end of my talk.
01:11:39.280 | I wanted to also give you a flavor and a sense of next-generation architectures.
01:11:45.200 | You probably have this question, how do we enable in-context learning in diffusion models?
01:11:49.360 | Like in language model, it's very common, zero-shot learning and so on.
01:11:53.040 | How do we enable in-context learning in diffusion models?
01:11:56.400 | Current architectures are clearly not sufficient to enable it in a reasonable time.
01:12:01.760 | So the basic idea is you take an LLM model, you add some components so that it also becomes adept at generating images.
01:12:09.440 | Like recent works like Bagel, then there are works like Lada, Mada, they basically follow this area of work.
01:12:16.720 | They are not necessarily diffusion-based, but they share a similar philosophy.
01:12:22.000 | So you basically start from a pre-trained LLM, you add components to it so that it also becomes adept at generating images.
01:12:29.120 | Like you could do autoregression on discrete tokens, for example, and you could also do at the same time diffusion on the continuous tokens.
01:12:36.960 | So that's one line of work that's called Transfusion.
01:12:39.520 | So as I was mentioning, this Playground v3, this Fusedit, this Transfusion.
01:12:45.920 | Fusedit is from our group.
01:12:47.440 | It's the only open source work that you will probably find that tackles this problem of in-context learning and how do you sort of explore these architectures in a holistic manner.
01:13:01.520 | And if you are already feeling inspired, I hope you are feeling inspired to explore these architectures.
01:13:08.480 | I highly welcome you to check out the library that I work on at Hugging Face.
01:13:12.240 | It's called Diffusers.
01:13:13.120 | We have got reasonably clean implementations of all the models that I discussed today,
01:13:18.800 | which should probably inspire you to hack into these things and tweak them.
01:13:26.160 | And there's a bunch of things I didn't get to cover, of course.
01:13:29.200 | I am a little bit over time.
01:13:31.120 | I think one minute over time.
01:13:32.320 | I'm probably going to finish it in the next one minute.
01:13:35.600 | So I apologize.
01:13:37.120 | I didn't cover MOEs, hot topic again.
01:13:40.720 | But MOEs are making their ways in the diffusion community.
01:13:45.120 | It's called HiDream.
01:13:46.080 | I didn't cover training at all.
01:13:48.320 | There are all kinds of shenanigans there.
01:13:51.040 | How do you do data?
01:13:52.080 | How do you do alignment?
01:13:53.200 | How do you do post-training?
01:13:54.400 | How do you do safety mitigation?
01:13:56.240 | How do you do memorization mitigation?
01:13:57.920 | And so on.
01:13:58.480 | And of course, these architectures go well beyond image and video generation.
01:14:03.600 | They find their application in all sorts of stuff, like robotics, gene synthesis, and so on.
01:14:10.880 | And if you are into mechanistic interpretability, interpretation of these beasts are not trivial at all.
01:14:17.680 | So they interest me quite well.
01:14:19.440 | And yeah, that's about it.
01:14:23.440 | Thank you.
01:14:23.840 | Thank you.
01:14:24.400 | Thank you.
01:14:24.720 | Thank you.
01:14:26.720 | Thank you.
01:14:28.720 | Thank you.
01:14:28.720 | Thank you.