back to indexLesson 9B - the math of diffusion
Chapters
0:0 Introduction
2:19 Data distribution
6:38 Math behind lesson 9’s “Magic API”
18:50 CLIP (Contrastive Language–Image Pre-training)
27:4 Forward diffusion (markov process with gaussian transitions)
36:11 Likelihood vs log likelihood
42:16 Denoising diffusion probabilistic model (DDPM)
48:4 Conclusion
00:00:00.000 |
Hello, everyone. My name is Waseem. I am an entrepreneur in residence at FAST.ai and I'm 00:00:10.920 |
currently at the FAST.ai headquarters at the moment in Australia, although I'm originally 00:00:17.240 |
from South Africa, from Cape Town. And I'm joined here today by Tanishq. Tanishq works 00:00:25.080 |
at Stability AI. And we've been working together with a couple other people on diffusion models, 00:00:33.080 |
generative kind of modeling. And that's been super fun. So, Tanishq, do you want to maybe, 00:00:39.400 |
you know, introduce yourself as well? Yeah, so my name is Tanishq. I am a PhD student 00:00:46.600 |
at UC Davis, but I also work at Stability AI. And I've been exploring and playing around 00:00:52.560 |
with diffusion models for the past several months. And so, it's been great to also explore 00:01:00.920 |
that with the FAST.ai community as well in these last few weeks as well. Awesome. Cool. 00:01:07.920 |
Cool. So this talk is for me trying to understand the math behind diffusion. So, you know, if 00:01:18.440 |
you've done the FAST.ai courses before, you know that you don't need to understand the 00:01:22.720 |
math to be effective with any of these models. In fact, you don't even need the math to do 00:01:28.120 |
research, to do novel research and contribute to these. But for me, it was all about, it 00:01:34.920 |
came out of interest. And, you know, I thought it was kind of, it's kind of beautiful how 00:01:42.140 |
what diffusion models were discovered. And I think a large part of that was thanks to 00:01:48.320 |
some some really clever math. And so I wanted to understand that. I'm not, I don't have 00:01:55.280 |
a math background. And so I want to help kind of describe how I think about it and how I, 00:02:07.920 |
how you can kind of interpret, you know, all of these notations and things. Cool. Yeah, 00:02:16.480 |
so I can just dive into it, I think. So the first bit of math that we see in this paper 00:02:25.720 |
is Q of X superscript zero. And they call this the data distribution. 00:02:35.200 |
Do you want to mention what the paper, exactly which paper this is? 00:02:38.640 |
Good, good question. So this paper is the 2015 paper. Do you remember the authors of 00:02:45.280 |
that paper, Tanish? I think it's Jasa, Sol Dikstein, who now works at Google, I think. 00:02:53.800 |
And it's from Surya Ganguly's lab. So cool. Yeah, so this was the paper, as far as I understand 00:03:01.080 |
that introduced this idea of diffusion. Yeah, 2015 by those authors. They start out by defining 00:03:11.600 |
this data distribution and they use this notation. And already, like a lot of people, you know, 00:03:17.080 |
myself included, find this quite confusing. But let's go through what's described here. 00:03:23.760 |
So they have an X. And, you know, in math, X is often used as the input variable, much 00:03:36.840 |
like Y, which is then used often as the output variable. Yeah, and the fact that it has a 00:03:50.920 |
superscript also implies something. So the fact that we have X superscript zero implies 00:04:00.360 |
that there might be a sequence of Xs. And, you know, I think it's useful to get comfortable 00:04:06.960 |
with this idea of simple compact notations implying a lot more than, you know, might be 00:04:16.760 |
obvious at first glance. So X implies that it means something about this quantity. It's 00:04:22.720 |
an input variable. And, you know, the zero implies that there might be other things that 00:04:26.960 |
you might have. You might have an X1, an X2, and so on, but we might see that. And then 00:04:36.600 |
the third part is you have Q. And Q is what we call a probability density function. So 00:04:49.840 |
the first part here is probability. And the question is, you know, what does Q have to 00:04:55.200 |
do with probabilities? Well, it's because usually we use the letter P to describe probability 00:05:03.040 |
density functions of interest. And then because Q is right after that, it's another common 00:05:08.480 |
one. So it's kind of like how you use X and Y. We use P and Q. And the fact that we use 00:05:14.400 |
Q here instead of P is because it suggests that there might be a P that we'll introduce. 00:05:21.160 |
And maybe P is the thing that we're modeling and Q is kind of supplementary to that. Does 00:05:28.040 |
that sound right, Tanishk? Yeah, yeah. And I think it's also helpful to kind of maybe 00:05:33.520 |
think about like X0 in a more practical, concrete way. Of course, if we're working with images 00:05:40.400 |
then X0 would be, you know, that's what's representing the images. So it's also useful 00:05:45.840 |
to think about it from kind of that concrete practical approach as well. 00:05:50.120 |
Right. So X0 might be, you know, an MNIST digit. And then we got Q. So Q, I'll just 00:06:02.240 |
use this to mean Q is some function. So we look at it as a box. And it takes in X0 and 00:06:19.440 |
it gives us the probability that this X0, which is an image, looks like an MNIST digit. So 00:06:28.360 |
in this case, you know, this would be 0.9 or maybe even, yeah, it's 0.98. So this is quite 00:06:35.600 |
high probability that this is an MNIST digit. Hi, this is Jeremy. Can I jump in for a moment? 00:06:41.840 |
Please do. Oh, thank you. I just wanted to double check. This looks a lot like the magic 00:06:47.920 |
API that we had at the start of lesson one that you feed in a digit and it gives you 00:06:52.000 |
back a probability. Is that basically what Q is doing here? 00:06:56.080 |
Absolutely, yeah. It's a magic API. That's a good way to think of it. We don't know, 00:07:02.080 |
we couldn't write down what Q is, but we imagine that somebody has it somewhere. Yeah, so this 00:07:10.160 |
is a concrete example. And like, if you had to do something to this image, you might get 00:07:17.400 |
a smaller number here. So another thing worth mentioning here is probability density functions. 00:07:26.960 |
So these are these magic APIs that, you know, give us a number, tells us how likely the 00:07:32.080 |
thing is. They, you don't often see them, they don't often make it all the way to your 00:07:38.960 |
code. In fact, they very rarely will appear in your code. But it turns out that there 00:07:45.720 |
are very useful ways or tools to work with random quantities, because they allow you 00:07:52.840 |
to represent random quantities as functions, just ordinary functions. And because they're 00:07:59.800 |
functions, you have a whole century's worth of math to analyze and understand them. So 00:08:08.440 |
you'll often find probability density functions in papers, and eventually they work out to 00:08:13.960 |
really simple equations or formulas that end up in your code. Do you want to add anything 00:08:23.040 |
Tanish? I think that sounds all correct. Of course, I think you probably will go over 00:08:29.880 |
some examples of probability density functions, especially relevant to this one. But yeah, 00:08:35.700 |
it's useful to think about the, also the sorts of functions you may have in a simplified 00:08:41.240 |
case. And that's what we probably are going to talk about next, right? 00:08:45.360 |
Yeah, yeah, that's exactly what we talked about. So we have this QXO0, and then we introduce 00:08:51.520 |
another one. And like you said, this is going to turn out to have a really nice simple form. 00:08:59.840 |
But before that, the next thing we define is QXT, given XT minus one. So we'll say what 00:09:10.760 |
we define this to be, but to begin with, this is another probability density function. And 00:09:18.320 |
this bar over here means it's a conditional probability density function, which you can 00:09:23.920 |
think of as you are given the thing on the right to calculate probabilities over the 00:09:31.880 |
thing on the left. In this case, you can think of it as something that takes images. So maybe 00:09:43.880 |
another magic API and produces other images. But we don't know what these look like yet 00:09:56.760 |
because we haven't defined over here. And this, we would call XT minus one, which could 00:10:02.880 |
be X0. And this would be XT, which in the X0 case would be X1. Something worth noting 00:10:15.120 |
here is this notation can be a little bit confusing because we've said Q is one thing 00:10:22.760 |
earlier and now we think Q is another thing. So this year, I'm going to need your help 00:10:27.840 |
on this one, Tanisha. I think people would usually, in the strictest sense, define the 00:10:38.360 |
first one like this, maybe, and the second one with a subscript. And this notation that 00:10:56.600 |
we see here on the left is just a shortcut where they wanted to save the space of writing 00:11:06.960 |
that and included that, implied it by what was in the practice. Is that true? 00:11:13.840 |
Yeah. I think they used the variables Q and then, of course, later on, we see P to describe, 00:11:23.360 |
as we'll see, different aspects of the diffusion model and the different processes of the diffusion 00:11:30.600 |
model, which we'll see. So I think that's why they use the same variables to kind of 00:11:36.680 |
demonstrate this is corresponding to this process and the other variable corresponds 00:11:41.160 |
to the other process of the diffusion model. So we'll obviously go over that. So I think 00:11:46.120 |
that's where those variables or those letters are being used in that matter. But if you 00:11:53.640 |
do want to make it more specific, more clear, yeah, I think that that notation is fine as 00:12:00.640 |
Okay. Yeah, that makes sense. Okay. So let's describe what this Q does to the image on 00:12:10.280 |
the left to produce the one on the right. So I'll start over here so we have more space. 00:12:21.880 |
I'll write it out first and then we can go into the details. Okay. So kind of like the 00:12:44.240 |
bar, you can think of this semicolon as grouping things together. And so you have the things 00:12:52.680 |
on the left and the things on the right. My understanding is these two things on the right 00:12:59.760 |
are the parameters of the model, sorry, of the probability. And the thing on the left 00:13:09.960 |
is actually, Denise, could you help me understand what the thing on the left is? Do you know? 00:13:16.560 |
Right. Well, so this is, again, like a probability distribution. And the thing on the left is 00:13:22.200 |
saying this is a probability distribution for this particular variable. So that's just 00:13:27.240 |
representing what it is a probability distribution for. And then the stuff on the right are the 00:13:32.480 |
parameters for this probability distribution. So that's kind of what's going on here. So 00:13:39.600 |
like, yeah, anytime you have a normal distribution and it's describing some variable, you'll 00:13:44.920 |
have that sort of notation where it's the normal distribution of some variable. And 00:13:49.960 |
then these are the parameters that describe that normal distribution. 00:13:58.040 |
Right. So just to clarify, the bit after the semicolon is the bit that we're kind of used 00:14:04.400 |
to seeing to describe a normal distribution, which is the mean and variance of the normal 00:14:12.040 |
distribution. So we're going to be sampling random numbers from that normal distribution 00:14:19.920 |
according to that mean and that variance. Is that right? 00:14:27.280 |
So we need to describe a bit more there about normal distribution. We kind of skip past 00:14:36.520 |
that. So we have this fancy N and fancy letters in math for distributions usually refer to 00:14:44.760 |
well-known distributions. And the N here stands for normal, which is also known as the Gaussian 00:14:56.040 |
distribution. And it's probably the most well-known probability distribution that you can find. 00:15:04.560 |
And when I say well-known, I mean that these things pop up everywhere. You know, you can 00:15:12.440 |
do in all sorts of fields, measuring all sorts of things, turns out that they follow roughly 00:15:19.200 |
something that looks like this distribution. And because they have pop up so much, you 00:15:25.840 |
know, people studied them, studied all of their properties, and we understand them really 00:15:31.720 |
well now. The reason that they used often in cases like this is because they turns out 00:15:39.760 |
they have really useful properties and they're easy to work with. Some reasons are they're 00:15:46.360 |
described by just two parameters. So the mean, called the mean and the covariance. Another 00:15:55.000 |
property is that they have kind of, you know, what people would call sun tails, which kind 00:16:02.400 |
of means that they only, you only need to describe their behavior in a small region 00:16:07.440 |
of space. You can kind of just ignore the rest. Yeah. Do you mind drawing a quick example 00:16:17.760 |
of a normal distribution? That's a good point. So we have, let's say our random variable 00:16:24.440 |
is just one kind of dimensional. So just a single number of floats. This is sort of what 00:16:32.000 |
the normal distribution would look like. And in this case, that would be our mean. And 00:16:41.480 |
the variance would sort of describe the width over here, which in this case, you'd use a 00:16:51.120 |
small sigma because you're doing a single variable. In our case, we used a capital sigma, 00:16:58.880 |
which is the symbol for multiple variables or multiple dimensions. And yeah, I also didn't 00:17:08.320 |
say that this is the Greek letter mu. So capital sigma, mu, and lowercase sigma. I just wanted 00:17:18.640 |
to note that typically the lowercase sigma represents the standard deviation, which is 00:17:25.840 |
the square root of the variance. So for example, sometimes you may see in papers sigma squared, 00:17:35.920 |
and that's just the variance, but they will write it sometimes as sigma squared instead. 00:17:41.920 |
So it depends on the notation. So sigma is the standard deviation often, and sigma squared 00:17:49.800 |
would then be the variance. Cool. Yeah, we can also show with our example what this would 00:17:57.800 |
look like. So we start out with a MNIST digit, put it through this magic API, and what would 00:18:10.800 |
we get out? Okay, so something we didn't describe is what does this I mean? 00:18:21.800 |
Did you want me to talk about that, Wesleyan? Yes, please. Okay, sure. Because I think this 00:18:27.040 |
is something which actually-- can I borrow your pen? It actually came up in the lesson 00:18:32.980 |
we were doing kind of in an interesting way. So in that lesson-- 00:18:38.040 |
Do you want to get in the video? Ah, no, they know what I look like. Oh, well, okay, I'm 00:18:43.760 |
in the video now. Yeah, in the video-- Hi, Tanish, nice to see you. Yeah, so in the lesson, 00:18:53.000 |
we did this thing for clip, I don't know if you remember, where we had the various pictures 00:19:01.880 |
down here. I'm so embarrassed, you're better at the graphics tablet than I am, and it's 00:19:06.000 |
my graphics tablet. And we had the various sentences along here. And we said, oh, it 00:19:11.800 |
would be kind of cool to take the dot product of their embeddings. Because if their dot 00:19:17.440 |
products are high, that means they're similar to each other. And if we subtracted the means 00:19:25.920 |
from those first, then you've got the dot-- and instead of having images down here, what 00:19:35.280 |
if we had the exact same vectors on each side? Then what you've got down here is basically 00:19:49.920 |
x minus its average, if it's a check that first, squared. And that is the variance. 00:20:00.480 |
So that's the variance for each one of these vectors. But what's interesting, as you pointed 00:20:08.720 |
out, is that normally, at high school, when we look at a normal distribution, it looks 00:20:13.840 |
like this, right? But you're not just doing one normal distribution. You've got a whole 00:20:18.260 |
bunch of normal distributions for all of your different pixels. They're the pixels, right, 00:20:24.200 |
Tanish? Normal distribution of every pixel. So there's a whole bunch of them. And so one 00:20:29.680 |
of them might have a normal distribution that's there. And another one might have a normal 00:20:33.080 |
distribution that's here. And another one might have a normal distribution that's here. 00:20:38.320 |
And it's more than that, though, because it's possible that one pixel tends to be higher 00:20:45.960 |
when another pixel tends to be higher, or one pixel tends to be higher when another 00:20:50.160 |
pixel's lower. So it actually has kind of created this surface in n-dimensional space 00:20:58.200 |
where n is the number of pixels. So if you now, like, look at, like, OK, well, what happens 00:21:03.040 |
if we multiply this by this, just like we did in CLIP, right? Then if this number is 00:21:10.380 |
high, then it's saying that when this variable is high, where this pixel's high, this pixel 00:21:15.360 |
tends to be high, and vice versa. Or if it's low, it's saying when this pixel tends to 00:21:19.800 |
be high, this one tends to be low. Or, interesting to us, what happened-- oopsie, Daisy, sorry 00:21:25.760 |
-- what happens if this is zero? That says that if this is high, then this could be anything. 00:21:37.760 |
Or if this is high, this could be anything. There's no relationship between them. So statistically, 00:21:42.120 |
we would say that these two pixels are independent. And so now, that basically means we could 00:21:50.000 |
do that for all of these. We could say, oh, these are all zeros. And what that says is 00:21:57.880 |
that, oh, every pixel is independent of every other pixel. Now, of course, in real pictures, 00:22:03.160 |
that's not how real pixels work. But that's the assumption we're making. Because if we 00:22:08.640 |
start with a very special matrix called i, which is 1, 1, 1, 1, 0, 0, 0. If we take this 00:22:22.520 |
very special matrix-- it's very special because I can multiply it by something, say beta. And 00:22:30.360 |
if I multiply it by a matrix, I get back the original matrix. If I multiply it by a scalar, 00:22:35.240 |
I'm going to get beta, beta, beta, dah, dah, dah, dah, and lots of zeros. And so if I multiply 00:22:42.240 |
something by this matrix, then I'm just multiplying it by beta. But what's interesting about this 00:22:50.120 |
is that this is what Wasim wrote. Wasim wrote i times beta, i times beta t. So what he's 00:22:58.440 |
saying is, oh, we've now got a covariance matrix where for each individual pixel, it's 00:23:07.400 |
like pixel number 1, beta 1, pixel number 2, beta 2. This is the variances of each one. 00:23:14.520 |
And the covariances, the relationship between the pixels is 0. They're expected to be independent. 00:23:22.600 |
So that's where we're going from statistics you do in high school to statistics you do 00:23:28.320 |
at university. It's like suddenly covariance is now a matrices, not individual numbers. 00:23:41.160 |
Awesome. Cool. So now let's try to describe what this would do to MNIST digits. So let's 00:23:52.240 |
put back our mean equation and our covariance. So mean and our covariance. And let's look 00:24:11.400 |
at how this behaves at the edges. So it's really hard to understand this. I don't think 00:24:19.220 |
anybody can just look at this and know what it means. What we typically do is we try to 00:24:25.760 |
describe it at the edges. And so we'll start with what happens if that's 0. And we'll work 00:24:33.440 |
with x0 as well instead of xt minus 1, which would mean an MNIST digit. So if beta is 0 00:24:42.520 |
then we get our x0, you know, square root 1 minus 0, which is 1 and square root of 1 00:24:51.320 |
is 1. So that kind of falls away. So we just have a mean of our previous image. And this 00:24:59.040 |
is just variance of 0. So we have a normal distribution with a mean of our previous image 00:25:06.800 |
of variance of 0, which means we have the same image. 00:25:18.000 |
Yeah, just to clarify, when you have variance of 0, that means that there's really no noise 00:25:26.200 |
or anything. It's just at that mean and, you know, your distribution is just saying that's 00:25:31.400 |
the only point that you can get from it. So yeah, that just becomes the same image because 00:25:35.800 |
yeah, there's no noise or variance because the variance is 0. 00:25:41.800 |
Yeah, exactly. And then when our beta is 1, we still have this and then we have, you know, 00:25:51.880 |
square root 1 minus 1 and that becomes 0. So this whole thing becomes 0. And this thing 00:26:02.800 |
becomes i times beta t, which is, you know, i. And if it's just i, then as Jeremy described, 00:26:11.480 |
it would, you know, imply a variance of 1. And so our image through this function would 00:26:25.040 |
just be pure noise. So let, you know, mean of 0, standard deviation of 1, and it would 00:26:33.880 |
just be a bunch of noise and kind of somewhere in between that, we have to say over here, 00:26:44.940 |
you know, what would it produce? It would be some mixture. So, you know, like maybe a light, 00:26:53.160 |
the lighter pixels of 8 and some noise, maybe a bit darker. And we can kind of draw this 00:27:06.720 |
and you would have seen this in the previous lecture. You can draw the sequence of things 00:27:18.440 |
that become progressively more noisy in very small steps, all the way until it becomes 00:27:26.160 |
pure noise. This is what we call the forward diffusion process. And we can now describe 00:27:41.000 |
some of these things. So this would be a sample from our data distribution qx0. This would 00:27:52.480 |
be the function for the conditional probability density function that takes, so of x1 given 00:28:05.080 |
x0 and so on. And the way that the terminology that we would use or that mathematicians used 00:28:20.440 |
to describe this is they would call it a Markov process with Gaussian transitions. And this 00:28:40.720 |
can sound quite scary, but we've just described exactly what this is. So when we say process, 00:28:47.520 |
it usually means, you know, something where there's a sequence involved. When we say Markov, 00:28:55.860 |
it means that the thing at time t depends only on the thing at t minus 1. The transition 00:29:06.920 |
is this function. How do you actually go from t minus 1 to t? And Gaussian is the fact that 00:29:16.240 |
that transition is the normal distribution. Does that sound right? 00:29:22.320 |
Yes. Just to also clarify a couple of things, when we say that, you know, we're sampling 00:29:28.760 |
from the data distribution, what that is referring to is trying to find some random sample or 00:29:38.160 |
some random data point that maximizes that likelihood or that has a high likelihood. 00:29:45.040 |
So when we say that, you know, we're looking at that API, that magic API we were talking 00:29:50.400 |
about, and we're trying to get some, you know, some data points that have a high value from 00:29:56.920 |
that API. And, you know, for some distributions, it's very simple and we know how it works 00:30:03.600 |
like a Gaussian distribution. If we know the parameters of that Gaussian distribution, 00:30:08.100 |
it's very easy to be able to do that sampling. And then of course, in other cases, it's not 00:30:13.120 |
very easy, it's not, it's quite difficult to do that sampling. So then we have to figure 00:30:18.000 |
alternative ways of doing that sampling. But that's why in this case, with the forward 00:30:22.520 |
distribution, we just have these simple Gaussian transitions. And we already know the parameters 00:30:28.400 |
of those Gaussian transitions, so we can easily do that sampling. And going back also to that, 00:30:36.200 |
I think it's worthwhile to also kind of show and think about maybe how this is again done 00:30:41.680 |
practically. Because one of the nice properties of Gaussian distributions as a whole is that 00:30:48.480 |
you can, you know, simply take some normal noise at with a mean of zero and variance 00:30:55.680 |
of one. So that's, I think they usually typically call that a unit, unit distribution is just 00:31:03.120 |
like, yeah, normal of zero, one. And then if you want to get to some other point with 00:31:08.640 |
a mean of whatever value you specify and a variance of whatever value you specify, you 00:31:14.560 |
can simply take that normal distribution, scale it by the, you multiply it by the variance, 00:31:24.760 |
and then you add your, your mean. So then there's a simple equation that you can take 00:31:29.560 |
to get the, to get any particular mean and variance. So that's how you would get the samples 00:31:41.280 |
for these other distributions that we have defined throughout the forward distribution. 00:31:47.720 |
So, you know, for example, when you're coding this up, of course, a lot of these softwares, 00:31:54.200 |
they will have a way of getting a sample from this normal distribution of zero, one. And 00:32:02.360 |
then you just use that equation then to get it at the desired mean and variance. So that's 00:32:07.400 |
how it kind of happens under the hood when you're, when you're kind of describe this 00:32:13.400 |
That's really helpful. Yeah. And this idea of, we can't really sample from this thing. 00:32:23.600 |
That's exactly, you know, the problem that generative kind of modeling is trying to solve. 00:32:28.920 |
Like how do you represent this in such a way that you can easily sample from it? And so 00:32:35.320 |
it turns out that if you have one of these processes, you know, where you have many, 00:32:41.680 |
many steps, so let's say a thousand steps, a thousand of these steps going to the right 00:32:48.120 |
and they're all very small steps that eventually go to noise, somebody, you know, maybe in 00:32:57.080 |
the 1950s, I think, discovered that you can represent the process of going backwards in 00:33:09.720 |
exactly the same functional form with just different parameters. So what that means is 00:33:15.600 |
if we say P is the thing that goes backwards. So, you know, the previous one, given the 00:33:23.600 |
current one, the P has the same functional form. So it's also, the transitions are also 00:33:36.000 |
normal, but the mean is, you know, some unknown. So we'll use a square and the variance is 00:33:43.000 |
some unknown. So we use a triangle. Is that correct? 00:33:51.160 |
Yeah, that's correct. And just going back to our previous point about P versus Q, here 00:33:56.520 |
we can see that the Q was describing the sort of forward process going, you know, yeah, 00:34:02.440 |
the sort of steps that we're doing. And then the P is describing what we're going in the 00:34:07.320 |
reverse way. So that's why, you know, these papers are using, you know, Q for one process 00:34:15.880 |
and then P for another. That's what they're kind of indicating, at least in the diffusion 00:34:22.720 |
And P is kind of like X, you know, it's the one we want to figure out. So like Q is kind 00:34:28.440 |
of like Y and P is kind of like X. That's how I like to think of that. And so, you know, 00:34:34.720 |
we have this functional form and the next question is how can we use this or, you know, 00:34:42.040 |
we just don't know what these parameters are. How can we figure out what those are? And 00:34:48.280 |
this goes back, you know, to early kind of statistics literature where you can fit this 00:34:56.560 |
model using by maximizing what's called the likelihood function. So we can try different 00:35:05.080 |
parameters until we have one that maximizes the likelihood. It turns out that we can't 00:35:15.080 |
quite do this exactly because you would need to calculate some integral and that integral 00:35:24.320 |
is over very high dimensional values, continuous values. So you can't actually calculate this. 00:35:32.280 |
I think you can think of it because, you know, we're having these thousands of steps that 00:35:38.080 |
we're trying to go in this reverse process. And so, you know, you have these thousands 00:35:42.560 |
of steps that there are going to be many possible values for each step. So it's kind of hard 00:35:47.040 |
to evaluate it over all these thousands of steps and all the possible values for all 00:35:51.600 |
these different steps. So I think that's kind of where the challenges arise and that's what 00:35:56.720 |
it makes it difficult because you have to evaluate it over these multiple steps and 00:36:02.680 |
try to find these functions for all these different steps. So that's kind of where the 00:36:13.040 |
And so you might see people talk not about the likelihood function, but about the log-likelihood. 00:36:23.800 |
And correct me if I'm wrong here, Tanish, but I think the log here is a bit of a computational 00:36:30.720 |
trick almost. So I think it has a few properties. The first is that it's always increasing. 00:36:38.000 |
You know, people would call this monotonic. You know, it looks always kind of increasing. 00:36:46.560 |
And because it's always increasing, you get the same parameters if you optimize the log-likelihood 00:36:53.720 |
versus you optimize the likelihood. It also takes products to sums. And that's helpful 00:37:05.460 |
because we have joint distributions, you know, which turn out to be products. So it turns 00:37:10.160 |
out we have a lot of products here and they become sums, which is easy to work with. And 00:37:14.840 |
the last thing is that, you know, this normal distribution has exponential functions and 00:37:25.240 |
those disappear with the log. So this is the much friendlier thing to optimize. 00:37:37.480 |
Cool. And then there's one more step. You know, we still can't optimize the log-likelihood 00:37:44.880 |
of the thing that this eventually describes. But again, and this is kind of the beauty 00:37:51.600 |
of math is that somebody figured out a long time ago that there's a way to optimize some 00:37:57.200 |
other quantity called the ALBO for short, which stands for Evidence Lower Bound. 00:38:10.520 |
And the evidence is just another name for the likelihood. And the lower bound means 00:38:30.120 |
it's, you know, the lower bound of the evidence. And if you optimize that, it's almost as good 00:38:36.720 |
as optimizing the thing that we really want to. But this one we can calculate very, very 00:38:44.920 |
easily. And so you can use this as a loss function to train two neural networks that 00:39:04.640 |
predict our square from earlier, which was our mean, and our triangle, which is our variance 00:39:13.000 |
of this reverse process. And once you have that, you go all the way back here. So then 00:39:20.840 |
you have these values. You can start with pure noise and keep calling these neural networks 00:39:30.080 |
sampling from those normal distributions, kind of applying that iteratively over many 00:39:37.980 |
steps and you recover this data distribution. One thing that's important to clarify here 00:39:47.920 |
is that you can recover the whole distribution, but you can't necessarily take a single image, 00:39:57.280 |
get it to pure noise and then convert it back. So this operates sort of at the distribution 00:40:03.240 |
level. So you can take this kind of magic API, you can reconstruct that whole API. And 00:40:13.680 |
if you can do that, then you can generate images, MNIST digits or cats or dogs or whatever 00:40:27.320 |
I want to just clarify one thing about this process of the kind of the loss function. 00:40:33.200 |
So this sort of evidence, lower bound loss function, the kind of approach that it's taking 00:40:38.400 |
is that we have this forward process, right? We can go from the original images and figure 00:40:47.040 |
out these sorts of intermediate distributions going all the way finally to noise. With this 00:40:53.400 |
sort of evidence, lower bound loss function, what we're really kind of doing is trying 00:41:00.120 |
to match our distribution that we're trying to optimize to those distributions that we 00:41:06.200 |
saw in the forward process. So that's what we're trying to do. We're trying to match 00:41:11.560 |
that sort of those distributions and there's a specific type of function and it's able 00:41:18.440 |
to do that. It's called a KL divergence. That's the sort of function that can compare probability 00:41:24.920 |
distributions. And again, because we're dealing with Gaussians, you can calculate that analytically 00:41:31.920 |
and a lot of the math becomes very simple. So that's again, with the whole Gaussians, 00:41:40.040 |
we know them quite well and the math is very simple. So that allows us to do this sort 00:41:44.320 |
of comparison between these distributions very easily and optimize that. And so we want 00:41:50.000 |
to kind of minimize the difference between the distributions we see in the forward process 00:41:55.240 |
and the distributions we're trying to determine for the reverse process. 00:42:01.960 |
Perfect. Then there's one more thing, I think, one more kind of major step to get closer to 00:42:12.320 |
the form that you would have seen in Jeremy's lesson. So there was a 2020 paper. The initials 00:42:22.200 |
of that model is DDPM. Tanish, do you know what this stands for? 00:42:28.080 |
Yeah, that's for denoising, diffusion, probabilistic model. 00:42:33.720 |
Okay, cool. And what they did was they said, let's assume that this variance is just a 00:42:46.800 |
constant so we don't learn it. And we assume also that the step size from earlier, you 00:42:56.320 |
know, the variance of the noise that we add at each step is also a constant. We don't 00:43:02.200 |
learn that. So we're just predicting the mean and these are set to some really convenient 00:43:08.200 |
values. Then the last turns out to be that you predict the noise. So you can restructure 00:43:24.480 |
this whole thing as you need to train a network that takes in images. So here's your network 00:43:37.320 |
and it tells you what of this image is noise. Thanks to these simplifying assumptions. And 00:43:50.720 |
even though they're assumptions, turns out you can train much more, you know, models 00:43:57.080 |
that produce much better images. Now, I think this relates to something from the, you know, 00:44:06.120 |
the lesson that Jeremy gave. Tanish, do you remember that there was something about the 00:44:14.680 |
Yes, yes. So this idea of, you know, adding noise and learning to remove noise, the idea 00:44:25.240 |
is that kind of by, you know, again, you have this sort of this image that you have noise, 00:44:35.040 |
right? And by, sorry, let me think about the best way to say this. Oh, yeah, sorry. Okay, 00:44:46.080 |
let me just turn it over. So I'll just start. Yeah, so like Jeremy will say in the lesson, 00:44:55.760 |
what we want to do is we want to figure out the gradient of this likelihood function. 00:45:03.600 |
So this is just kind of a different way about thinking about this. If we had some information 00:45:07.960 |
about this gradient, then we could, for example, you know, use that information to produce 00:45:16.800 |
like we talked about, kind of this optimization, kind of produce images with high likelihood. 00:45:22.280 |
So the idea is that we can add noise to the images that we have. So those are samples 00:45:30.120 |
that we have. And that kind of takes us away from, you know, the regular images that we 00:45:36.480 |
know that we have. And, you know, that kind of decreases the likelihood, right? So we 00:45:40.800 |
have those images and we're adding noise that decreases the likelihood. And we want to kind 00:45:45.040 |
of learn how to get back to high likelihood images and kind of use that to provide some 00:45:51.920 |
sort of estimate of our gradient. So this sort of denoising process actually allows us 00:45:57.680 |
to do that. So there are actually theorems also, I think, from the 1950s that demonstrate 00:46:04.000 |
that especially in the case of this sort of Gaussian noise that we're working with, this 00:46:09.080 |
denoising process is equivalent to learning what is known as the score function. And the 00:46:16.440 |
score function is the gradient of the log of the likelihood. So again, they have this 00:46:24.040 |
log here, which, again, makes the math nicer and easier to work with. But the general idea 00:46:29.440 |
is the same because as we talked about, log is a monotonic function. So again, the general 00:46:35.200 |
ideas are the same, but the score function specifically refers to the gradient of the 00:46:41.600 |
log likelihood. So this sort of denoising process allows us to learn the score function. So 00:46:51.920 |
that's when we're doing this noise, predicting that we had this whole probabilistic framework 00:46:56.920 |
using that sort of likelihood framework. And it came back down to just predicting the noise. 00:47:02.960 |
And that's what the DPM paper showed in 2020. But it turns out that is equivalent to calculating 00:47:10.320 |
out this sort of score function and using that information to be able to sample from 00:47:16.760 |
our distribution. So that's kind of how these two approaches connect. So there's a lot of 00:47:22.360 |
literature talking about maybe the sort of probabilistic likelihood perspectives of diffusion 00:47:27.840 |
models. And there's also a lot of literature talking about this score-based perspective. 00:47:33.760 |
But this hopefully allows you to think about the similarities and how these two approaches 00:47:39.280 |
connect with each other. Yeah, awesome. Yeah. And that's kind of the beauty, I think, of 00:47:46.080 |
the math side of things here is that you find all of these relationships between different 00:47:52.240 |
fields and also like between different centuries, basically. And that allows you to do really 00:47:58.840 |
kind of powerful and unexpected things. Okay, so you can just do a quick recap of where 00:48:09.520 |
we got to. So we started out with our data distribution, which we want to model. We said 00:48:16.920 |
we'll define this forward diffusion process, which is a way of kind of adding noise to 00:48:23.040 |
this model. And because we add it in this specific way, thanks to some discovery in 00:48:32.120 |
the 1950s, the reverse process has the same form. And then we already know how to train 00:48:45.280 |
a neural network for this using the elbow. And then a couple of years later came the 00:48:55.400 |
discovery, simplifying assumptions that in the end, all we do is predict the noise. And 00:49:02.240 |
I just remembered we take actually the MSE of this noise prediction, the mean squared 00:49:08.840 |
error, which is a nice, very simple framing of the model. And Antoni spoke about another 00:49:15.560 |
way to derive all of this, which is the score function approach, the gradient of the log-like 00:49:20.880 |
period. Okay, cool. Yeah, I highly recommend checking 00:49:28.440 |
out the course lesson as well, if you haven't. You know, if you don't understand this, there's 00:49:35.760 |
no need to be intimidated. You can still do be very effective without ever using math, 00:49:43.680 |
you can be very effective at deep learning, as fast AI has shown us, and you can do novel 00:49:48.480 |
research as well. For me, this is, it's interesting. And, you know, it's even beautiful, in a way. 00:49:58.360 |
So I recommend checking it out, but don't feel intimidated. You can find the course 00:50:04.360 |
lesson links in the fast AI forum. We'll add those links as well in the description of 00:50:10.200 |
this video. We'll also have a topic in the forum for this lesson. You can have discussions 00:50:16.800 |
there, post any comments, add any, you know, relevant links to the math. And then we have 00:50:23.380 |
another lesson, you know, video by Jono, which I really recommend checking out. He's a, you 00:50:30.600 |
know, he's a great teacher and he was, I think he was the first person to do a full course 00:50:36.440 |
on unstable diffusion. Yeah, Jono's video is kind of a deep dive into some of the code 00:50:41.840 |
a little bit more and into some of the concepts a little bit more. So I feel like between these 00:50:45.440 |
three videos, it's a good overview. You know, I think, I mean, just to clarify, you don't 00:50:54.880 |
need to understand all the math that was described in this video. That's not to say you want 00:50:59.160 |
me to understand math. We'll be covering lots of math in these lessons. But we'll be covering 00:51:06.520 |
just the math you need to understand and build on the code. And we'll be covering it over 00:51:13.840 |
many, many more hours than this rather rapid overview. 00:51:18.520 |
Perfect. Cool. And yeah, thank you so much, Denise. I had a lot of fun. And thank you 00:51:25.640 |
so much, Westy. And that was awesome. Awesome. Cool. Bye-bye.