back to index

Lesson 9B - the math of diffusion


Chapters

0:0 Introduction
2:19 Data distribution
6:38 Math behind lesson 9’s “Magic API”
18:50 CLIP (Contrastive Language–Image Pre-training)
27:4 Forward diffusion (markov process with gaussian transitions)
36:11 Likelihood vs log likelihood
42:16 Denoising diffusion probabilistic model (DDPM)
48:4 Conclusion

Whisper Transcript | Transcript Only Page

00:00:00.000 | Hello, everyone. My name is Waseem. I am an entrepreneur in residence at FAST.ai and I'm
00:00:10.920 | currently at the FAST.ai headquarters at the moment in Australia, although I'm originally
00:00:17.240 | from South Africa, from Cape Town. And I'm joined here today by Tanishq. Tanishq works
00:00:25.080 | at Stability AI. And we've been working together with a couple other people on diffusion models,
00:00:33.080 | generative kind of modeling. And that's been super fun. So, Tanishq, do you want to maybe,
00:00:39.400 | you know, introduce yourself as well? Yeah, so my name is Tanishq. I am a PhD student
00:00:46.600 | at UC Davis, but I also work at Stability AI. And I've been exploring and playing around
00:00:52.560 | with diffusion models for the past several months. And so, it's been great to also explore
00:01:00.920 | that with the FAST.ai community as well in these last few weeks as well. Awesome. Cool.
00:01:07.920 | Cool. So this talk is for me trying to understand the math behind diffusion. So, you know, if
00:01:18.440 | you've done the FAST.ai courses before, you know that you don't need to understand the
00:01:22.720 | math to be effective with any of these models. In fact, you don't even need the math to do
00:01:28.120 | research, to do novel research and contribute to these. But for me, it was all about, it
00:01:34.920 | came out of interest. And, you know, I thought it was kind of, it's kind of beautiful how
00:01:42.140 | what diffusion models were discovered. And I think a large part of that was thanks to
00:01:48.320 | some some really clever math. And so I wanted to understand that. I'm not, I don't have
00:01:55.280 | a math background. And so I want to help kind of describe how I think about it and how I,
00:02:07.920 | how you can kind of interpret, you know, all of these notations and things. Cool. Yeah,
00:02:16.480 | so I can just dive into it, I think. So the first bit of math that we see in this paper
00:02:25.720 | is Q of X superscript zero. And they call this the data distribution.
00:02:35.200 | Do you want to mention what the paper, exactly which paper this is?
00:02:38.640 | Good, good question. So this paper is the 2015 paper. Do you remember the authors of
00:02:45.280 | that paper, Tanish? I think it's Jasa, Sol Dikstein, who now works at Google, I think.
00:02:53.800 | And it's from Surya Ganguly's lab. So cool. Yeah, so this was the paper, as far as I understand
00:03:01.080 | that introduced this idea of diffusion. Yeah, 2015 by those authors. They start out by defining
00:03:11.600 | this data distribution and they use this notation. And already, like a lot of people, you know,
00:03:17.080 | myself included, find this quite confusing. But let's go through what's described here.
00:03:23.760 | So they have an X. And, you know, in math, X is often used as the input variable, much
00:03:36.840 | like Y, which is then used often as the output variable. Yeah, and the fact that it has a
00:03:50.920 | superscript also implies something. So the fact that we have X superscript zero implies
00:04:00.360 | that there might be a sequence of Xs. And, you know, I think it's useful to get comfortable
00:04:06.960 | with this idea of simple compact notations implying a lot more than, you know, might be
00:04:16.760 | obvious at first glance. So X implies that it means something about this quantity. It's
00:04:22.720 | an input variable. And, you know, the zero implies that there might be other things that
00:04:26.960 | you might have. You might have an X1, an X2, and so on, but we might see that. And then
00:04:36.600 | the third part is you have Q. And Q is what we call a probability density function. So
00:04:49.840 | the first part here is probability. And the question is, you know, what does Q have to
00:04:55.200 | do with probabilities? Well, it's because usually we use the letter P to describe probability
00:05:03.040 | density functions of interest. And then because Q is right after that, it's another common
00:05:08.480 | one. So it's kind of like how you use X and Y. We use P and Q. And the fact that we use
00:05:14.400 | Q here instead of P is because it suggests that there might be a P that we'll introduce.
00:05:21.160 | And maybe P is the thing that we're modeling and Q is kind of supplementary to that. Does
00:05:28.040 | that sound right, Tanishk? Yeah, yeah. And I think it's also helpful to kind of maybe
00:05:33.520 | think about like X0 in a more practical, concrete way. Of course, if we're working with images
00:05:40.400 | then X0 would be, you know, that's what's representing the images. So it's also useful
00:05:45.840 | to think about it from kind of that concrete practical approach as well.
00:05:50.120 | Right. So X0 might be, you know, an MNIST digit. And then we got Q. So Q, I'll just
00:06:02.240 | use this to mean Q is some function. So we look at it as a box. And it takes in X0 and
00:06:19.440 | it gives us the probability that this X0, which is an image, looks like an MNIST digit. So
00:06:28.360 | in this case, you know, this would be 0.9 or maybe even, yeah, it's 0.98. So this is quite
00:06:35.600 | high probability that this is an MNIST digit. Hi, this is Jeremy. Can I jump in for a moment?
00:06:41.840 | Please do. Oh, thank you. I just wanted to double check. This looks a lot like the magic
00:06:47.920 | API that we had at the start of lesson one that you feed in a digit and it gives you
00:06:52.000 | back a probability. Is that basically what Q is doing here?
00:06:56.080 | Absolutely, yeah. It's a magic API. That's a good way to think of it. We don't know,
00:07:02.080 | we couldn't write down what Q is, but we imagine that somebody has it somewhere. Yeah, so this
00:07:10.160 | is a concrete example. And like, if you had to do something to this image, you might get
00:07:17.400 | a smaller number here. So another thing worth mentioning here is probability density functions.
00:07:26.960 | So these are these magic APIs that, you know, give us a number, tells us how likely the
00:07:32.080 | thing is. They, you don't often see them, they don't often make it all the way to your
00:07:38.960 | code. In fact, they very rarely will appear in your code. But it turns out that there
00:07:45.720 | are very useful ways or tools to work with random quantities, because they allow you
00:07:52.840 | to represent random quantities as functions, just ordinary functions. And because they're
00:07:59.800 | functions, you have a whole century's worth of math to analyze and understand them. So
00:08:08.440 | you'll often find probability density functions in papers, and eventually they work out to
00:08:13.960 | really simple equations or formulas that end up in your code. Do you want to add anything
00:08:23.040 | Tanish? I think that sounds all correct. Of course, I think you probably will go over
00:08:29.880 | some examples of probability density functions, especially relevant to this one. But yeah,
00:08:35.700 | it's useful to think about the, also the sorts of functions you may have in a simplified
00:08:41.240 | case. And that's what we probably are going to talk about next, right?
00:08:45.360 | Yeah, yeah, that's exactly what we talked about. So we have this QXO0, and then we introduce
00:08:51.520 | another one. And like you said, this is going to turn out to have a really nice simple form.
00:08:59.840 | But before that, the next thing we define is QXT, given XT minus one. So we'll say what
00:09:10.760 | we define this to be, but to begin with, this is another probability density function. And
00:09:18.320 | this bar over here means it's a conditional probability density function, which you can
00:09:23.920 | think of as you are given the thing on the right to calculate probabilities over the
00:09:31.880 | thing on the left. In this case, you can think of it as something that takes images. So maybe
00:09:43.880 | another magic API and produces other images. But we don't know what these look like yet
00:09:56.760 | because we haven't defined over here. And this, we would call XT minus one, which could
00:10:02.880 | be X0. And this would be XT, which in the X0 case would be X1. Something worth noting
00:10:15.120 | here is this notation can be a little bit confusing because we've said Q is one thing
00:10:22.760 | earlier and now we think Q is another thing. So this year, I'm going to need your help
00:10:27.840 | on this one, Tanisha. I think people would usually, in the strictest sense, define the
00:10:38.360 | first one like this, maybe, and the second one with a subscript. And this notation that
00:10:56.600 | we see here on the left is just a shortcut where they wanted to save the space of writing
00:11:06.960 | that and included that, implied it by what was in the practice. Is that true?
00:11:13.840 | Yeah. I think they used the variables Q and then, of course, later on, we see P to describe,
00:11:23.360 | as we'll see, different aspects of the diffusion model and the different processes of the diffusion
00:11:30.600 | model, which we'll see. So I think that's why they use the same variables to kind of
00:11:36.680 | demonstrate this is corresponding to this process and the other variable corresponds
00:11:41.160 | to the other process of the diffusion model. So we'll obviously go over that. So I think
00:11:46.120 | that's where those variables or those letters are being used in that matter. But if you
00:11:53.640 | do want to make it more specific, more clear, yeah, I think that that notation is fine as
00:11:59.640 | well.
00:12:00.640 | Okay. Yeah, that makes sense. Okay. So let's describe what this Q does to the image on
00:12:10.280 | the left to produce the one on the right. So I'll start over here so we have more space.
00:12:21.880 | I'll write it out first and then we can go into the details. Okay. So kind of like the
00:12:44.240 | bar, you can think of this semicolon as grouping things together. And so you have the things
00:12:52.680 | on the left and the things on the right. My understanding is these two things on the right
00:12:59.760 | are the parameters of the model, sorry, of the probability. And the thing on the left
00:13:09.960 | is actually, Denise, could you help me understand what the thing on the left is? Do you know?
00:13:16.560 | Right. Well, so this is, again, like a probability distribution. And the thing on the left is
00:13:22.200 | saying this is a probability distribution for this particular variable. So that's just
00:13:27.240 | representing what it is a probability distribution for. And then the stuff on the right are the
00:13:32.480 | parameters for this probability distribution. So that's kind of what's going on here. So
00:13:39.600 | like, yeah, anytime you have a normal distribution and it's describing some variable, you'll
00:13:44.920 | have that sort of notation where it's the normal distribution of some variable. And
00:13:49.960 | then these are the parameters that describe that normal distribution.
00:13:58.040 | Right. So just to clarify, the bit after the semicolon is the bit that we're kind of used
00:14:04.400 | to seeing to describe a normal distribution, which is the mean and variance of the normal
00:14:12.040 | distribution. So we're going to be sampling random numbers from that normal distribution
00:14:19.920 | according to that mean and that variance. Is that right?
00:14:23.600 | Yes, that's correct. Yeah.
00:14:27.280 | So we need to describe a bit more there about normal distribution. We kind of skip past
00:14:36.520 | that. So we have this fancy N and fancy letters in math for distributions usually refer to
00:14:44.760 | well-known distributions. And the N here stands for normal, which is also known as the Gaussian
00:14:56.040 | distribution. And it's probably the most well-known probability distribution that you can find.
00:15:04.560 | And when I say well-known, I mean that these things pop up everywhere. You know, you can
00:15:12.440 | do in all sorts of fields, measuring all sorts of things, turns out that they follow roughly
00:15:19.200 | something that looks like this distribution. And because they have pop up so much, you
00:15:25.840 | know, people studied them, studied all of their properties, and we understand them really
00:15:31.720 | well now. The reason that they used often in cases like this is because they turns out
00:15:39.760 | they have really useful properties and they're easy to work with. Some reasons are they're
00:15:46.360 | described by just two parameters. So the mean, called the mean and the covariance. Another
00:15:55.000 | property is that they have kind of, you know, what people would call sun tails, which kind
00:16:02.400 | of means that they only, you only need to describe their behavior in a small region
00:16:07.440 | of space. You can kind of just ignore the rest. Yeah. Do you mind drawing a quick example
00:16:17.760 | of a normal distribution? That's a good point. So we have, let's say our random variable
00:16:24.440 | is just one kind of dimensional. So just a single number of floats. This is sort of what
00:16:32.000 | the normal distribution would look like. And in this case, that would be our mean. And
00:16:41.480 | the variance would sort of describe the width over here, which in this case, you'd use a
00:16:51.120 | small sigma because you're doing a single variable. In our case, we used a capital sigma,
00:16:58.880 | which is the symbol for multiple variables or multiple dimensions. And yeah, I also didn't
00:17:08.320 | say that this is the Greek letter mu. So capital sigma, mu, and lowercase sigma. I just wanted
00:17:18.640 | to note that typically the lowercase sigma represents the standard deviation, which is
00:17:25.840 | the square root of the variance. So for example, sometimes you may see in papers sigma squared,
00:17:35.920 | and that's just the variance, but they will write it sometimes as sigma squared instead.
00:17:41.920 | So it depends on the notation. So sigma is the standard deviation often, and sigma squared
00:17:49.800 | would then be the variance. Cool. Yeah, we can also show with our example what this would
00:17:57.800 | look like. So we start out with a MNIST digit, put it through this magic API, and what would
00:18:10.800 | we get out? Okay, so something we didn't describe is what does this I mean?
00:18:21.800 | Did you want me to talk about that, Wesleyan? Yes, please. Okay, sure. Because I think this
00:18:27.040 | is something which actually-- can I borrow your pen? It actually came up in the lesson
00:18:32.980 | we were doing kind of in an interesting way. So in that lesson--
00:18:38.040 | Do you want to get in the video? Ah, no, they know what I look like. Oh, well, okay, I'm
00:18:43.760 | in the video now. Yeah, in the video-- Hi, Tanish, nice to see you. Yeah, so in the lesson,
00:18:53.000 | we did this thing for clip, I don't know if you remember, where we had the various pictures
00:19:01.880 | down here. I'm so embarrassed, you're better at the graphics tablet than I am, and it's
00:19:06.000 | my graphics tablet. And we had the various sentences along here. And we said, oh, it
00:19:11.800 | would be kind of cool to take the dot product of their embeddings. Because if their dot
00:19:17.440 | products are high, that means they're similar to each other. And if we subtracted the means
00:19:25.920 | from those first, then you've got the dot-- and instead of having images down here, what
00:19:35.280 | if we had the exact same vectors on each side? Then what you've got down here is basically
00:19:49.920 | x minus its average, if it's a check that first, squared. And that is the variance.
00:20:00.480 | So that's the variance for each one of these vectors. But what's interesting, as you pointed
00:20:08.720 | out, is that normally, at high school, when we look at a normal distribution, it looks
00:20:13.840 | like this, right? But you're not just doing one normal distribution. You've got a whole
00:20:18.260 | bunch of normal distributions for all of your different pixels. They're the pixels, right,
00:20:24.200 | Tanish? Normal distribution of every pixel. So there's a whole bunch of them. And so one
00:20:29.680 | of them might have a normal distribution that's there. And another one might have a normal
00:20:33.080 | distribution that's here. And another one might have a normal distribution that's here.
00:20:38.320 | And it's more than that, though, because it's possible that one pixel tends to be higher
00:20:45.960 | when another pixel tends to be higher, or one pixel tends to be higher when another
00:20:50.160 | pixel's lower. So it actually has kind of created this surface in n-dimensional space
00:20:58.200 | where n is the number of pixels. So if you now, like, look at, like, OK, well, what happens
00:21:03.040 | if we multiply this by this, just like we did in CLIP, right? Then if this number is
00:21:10.380 | high, then it's saying that when this variable is high, where this pixel's high, this pixel
00:21:15.360 | tends to be high, and vice versa. Or if it's low, it's saying when this pixel tends to
00:21:19.800 | be high, this one tends to be low. Or, interesting to us, what happened-- oopsie, Daisy, sorry
00:21:25.760 | -- what happens if this is zero? That says that if this is high, then this could be anything.
00:21:37.760 | Or if this is high, this could be anything. There's no relationship between them. So statistically,
00:21:42.120 | we would say that these two pixels are independent. And so now, that basically means we could
00:21:50.000 | do that for all of these. We could say, oh, these are all zeros. And what that says is
00:21:57.880 | that, oh, every pixel is independent of every other pixel. Now, of course, in real pictures,
00:22:03.160 | that's not how real pixels work. But that's the assumption we're making. Because if we
00:22:08.640 | start with a very special matrix called i, which is 1, 1, 1, 1, 0, 0, 0. If we take this
00:22:22.520 | very special matrix-- it's very special because I can multiply it by something, say beta. And
00:22:30.360 | if I multiply it by a matrix, I get back the original matrix. If I multiply it by a scalar,
00:22:35.240 | I'm going to get beta, beta, beta, dah, dah, dah, dah, and lots of zeros. And so if I multiply
00:22:42.240 | something by this matrix, then I'm just multiplying it by beta. But what's interesting about this
00:22:50.120 | is that this is what Wasim wrote. Wasim wrote i times beta, i times beta t. So what he's
00:22:58.440 | saying is, oh, we've now got a covariance matrix where for each individual pixel, it's
00:23:07.400 | like pixel number 1, beta 1, pixel number 2, beta 2. This is the variances of each one.
00:23:14.520 | And the covariances, the relationship between the pixels is 0. They're expected to be independent.
00:23:22.600 | So that's where we're going from statistics you do in high school to statistics you do
00:23:28.320 | at university. It's like suddenly covariance is now a matrices, not individual numbers.
00:23:34.320 | Does that sound about right to you, Tanishk?
00:23:36.360 | Yeah, that's a great explanation of it, yes.
00:23:41.160 | Awesome. Cool. So now let's try to describe what this would do to MNIST digits. So let's
00:23:52.240 | put back our mean equation and our covariance. So mean and our covariance. And let's look
00:24:11.400 | at how this behaves at the edges. So it's really hard to understand this. I don't think
00:24:19.220 | anybody can just look at this and know what it means. What we typically do is we try to
00:24:25.760 | describe it at the edges. And so we'll start with what happens if that's 0. And we'll work
00:24:33.440 | with x0 as well instead of xt minus 1, which would mean an MNIST digit. So if beta is 0
00:24:42.520 | then we get our x0, you know, square root 1 minus 0, which is 1 and square root of 1
00:24:51.320 | is 1. So that kind of falls away. So we just have a mean of our previous image. And this
00:24:59.040 | is just variance of 0. So we have a normal distribution with a mean of our previous image
00:25:06.800 | of variance of 0, which means we have the same image.
00:25:18.000 | Yeah, just to clarify, when you have variance of 0, that means that there's really no noise
00:25:26.200 | or anything. It's just at that mean and, you know, your distribution is just saying that's
00:25:31.400 | the only point that you can get from it. So yeah, that just becomes the same image because
00:25:35.800 | yeah, there's no noise or variance because the variance is 0.
00:25:41.800 | Yeah, exactly. And then when our beta is 1, we still have this and then we have, you know,
00:25:51.880 | square root 1 minus 1 and that becomes 0. So this whole thing becomes 0. And this thing
00:26:02.800 | becomes i times beta t, which is, you know, i. And if it's just i, then as Jeremy described,
00:26:11.480 | it would, you know, imply a variance of 1. And so our image through this function would
00:26:25.040 | just be pure noise. So let, you know, mean of 0, standard deviation of 1, and it would
00:26:33.880 | just be a bunch of noise and kind of somewhere in between that, we have to say over here,
00:26:44.940 | you know, what would it produce? It would be some mixture. So, you know, like maybe a light,
00:26:53.160 | the lighter pixels of 8 and some noise, maybe a bit darker. And we can kind of draw this
00:27:06.720 | and you would have seen this in the previous lecture. You can draw the sequence of things
00:27:18.440 | that become progressively more noisy in very small steps, all the way until it becomes
00:27:26.160 | pure noise. This is what we call the forward diffusion process. And we can now describe
00:27:41.000 | some of these things. So this would be a sample from our data distribution qx0. This would
00:27:52.480 | be the function for the conditional probability density function that takes, so of x1 given
00:28:05.080 | x0 and so on. And the way that the terminology that we would use or that mathematicians used
00:28:20.440 | to describe this is they would call it a Markov process with Gaussian transitions. And this
00:28:40.720 | can sound quite scary, but we've just described exactly what this is. So when we say process,
00:28:47.520 | it usually means, you know, something where there's a sequence involved. When we say Markov,
00:28:55.860 | it means that the thing at time t depends only on the thing at t minus 1. The transition
00:29:06.920 | is this function. How do you actually go from t minus 1 to t? And Gaussian is the fact that
00:29:16.240 | that transition is the normal distribution. Does that sound right?
00:29:22.320 | Yes. Just to also clarify a couple of things, when we say that, you know, we're sampling
00:29:28.760 | from the data distribution, what that is referring to is trying to find some random sample or
00:29:38.160 | some random data point that maximizes that likelihood or that has a high likelihood.
00:29:45.040 | So when we say that, you know, we're looking at that API, that magic API we were talking
00:29:50.400 | about, and we're trying to get some, you know, some data points that have a high value from
00:29:56.920 | that API. And, you know, for some distributions, it's very simple and we know how it works
00:30:03.600 | like a Gaussian distribution. If we know the parameters of that Gaussian distribution,
00:30:08.100 | it's very easy to be able to do that sampling. And then of course, in other cases, it's not
00:30:13.120 | very easy, it's not, it's quite difficult to do that sampling. So then we have to figure
00:30:18.000 | alternative ways of doing that sampling. But that's why in this case, with the forward
00:30:22.520 | distribution, we just have these simple Gaussian transitions. And we already know the parameters
00:30:28.400 | of those Gaussian transitions, so we can easily do that sampling. And going back also to that,
00:30:36.200 | I think it's worthwhile to also kind of show and think about maybe how this is again done
00:30:41.680 | practically. Because one of the nice properties of Gaussian distributions as a whole is that
00:30:48.480 | you can, you know, simply take some normal noise at with a mean of zero and variance
00:30:55.680 | of one. So that's, I think they usually typically call that a unit, unit distribution is just
00:31:03.120 | like, yeah, normal of zero, one. And then if you want to get to some other point with
00:31:08.640 | a mean of whatever value you specify and a variance of whatever value you specify, you
00:31:14.560 | can simply take that normal distribution, scale it by the, you multiply it by the variance,
00:31:24.760 | and then you add your, your mean. So then there's a simple equation that you can take
00:31:29.560 | to get the, to get any particular mean and variance. So that's how you would get the samples
00:31:41.280 | for these other distributions that we have defined throughout the forward distribution.
00:31:47.720 | So, you know, for example, when you're coding this up, of course, a lot of these softwares,
00:31:54.200 | they will have a way of getting a sample from this normal distribution of zero, one. And
00:32:02.360 | then you just use that equation then to get it at the desired mean and variance. So that's
00:32:07.400 | how it kind of happens under the hood when you're, when you're kind of describe this
00:32:11.680 | with code.
00:32:13.400 | That's really helpful. Yeah. And this idea of, we can't really sample from this thing.
00:32:23.600 | That's exactly, you know, the problem that generative kind of modeling is trying to solve.
00:32:28.920 | Like how do you represent this in such a way that you can easily sample from it? And so
00:32:35.320 | it turns out that if you have one of these processes, you know, where you have many,
00:32:41.680 | many steps, so let's say a thousand steps, a thousand of these steps going to the right
00:32:48.120 | and they're all very small steps that eventually go to noise, somebody, you know, maybe in
00:32:57.080 | the 1950s, I think, discovered that you can represent the process of going backwards in
00:33:09.720 | exactly the same functional form with just different parameters. So what that means is
00:33:15.600 | if we say P is the thing that goes backwards. So, you know, the previous one, given the
00:33:23.600 | current one, the P has the same functional form. So it's also, the transitions are also
00:33:36.000 | normal, but the mean is, you know, some unknown. So we'll use a square and the variance is
00:33:43.000 | some unknown. So we use a triangle. Is that correct?
00:33:51.160 | Yeah, that's correct. And just going back to our previous point about P versus Q, here
00:33:56.520 | we can see that the Q was describing the sort of forward process going, you know, yeah,
00:34:02.440 | the sort of steps that we're doing. And then the P is describing what we're going in the
00:34:07.320 | reverse way. So that's why, you know, these papers are using, you know, Q for one process
00:34:15.880 | and then P for another. That's what they're kind of indicating, at least in the diffusion
00:34:19.360 | model literature.
00:34:22.720 | And P is kind of like X, you know, it's the one we want to figure out. So like Q is kind
00:34:28.440 | of like Y and P is kind of like X. That's how I like to think of that. And so, you know,
00:34:34.720 | we have this functional form and the next question is how can we use this or, you know,
00:34:42.040 | we just don't know what these parameters are. How can we figure out what those are? And
00:34:48.280 | this goes back, you know, to early kind of statistics literature where you can fit this
00:34:56.560 | model using by maximizing what's called the likelihood function. So we can try different
00:35:05.080 | parameters until we have one that maximizes the likelihood. It turns out that we can't
00:35:15.080 | quite do this exactly because you would need to calculate some integral and that integral
00:35:24.320 | is over very high dimensional values, continuous values. So you can't actually calculate this.
00:35:32.280 | I think you can think of it because, you know, we're having these thousands of steps that
00:35:38.080 | we're trying to go in this reverse process. And so, you know, you have these thousands
00:35:42.560 | of steps that there are going to be many possible values for each step. So it's kind of hard
00:35:47.040 | to evaluate it over all these thousands of steps and all the possible values for all
00:35:51.600 | these different steps. So I think that's kind of where the challenges arise and that's what
00:35:56.720 | it makes it difficult because you have to evaluate it over these multiple steps and
00:36:02.680 | try to find these functions for all these different steps. So that's kind of where the
00:36:07.200 | challenge is.
00:36:13.040 | And so you might see people talk not about the likelihood function, but about the log-likelihood.
00:36:23.800 | And correct me if I'm wrong here, Tanish, but I think the log here is a bit of a computational
00:36:30.720 | trick almost. So I think it has a few properties. The first is that it's always increasing.
00:36:38.000 | You know, people would call this monotonic. You know, it looks always kind of increasing.
00:36:46.560 | And because it's always increasing, you get the same parameters if you optimize the log-likelihood
00:36:53.720 | versus you optimize the likelihood. It also takes products to sums. And that's helpful
00:37:05.460 | because we have joint distributions, you know, which turn out to be products. So it turns
00:37:10.160 | out we have a lot of products here and they become sums, which is easy to work with. And
00:37:14.840 | the last thing is that, you know, this normal distribution has exponential functions and
00:37:25.240 | those disappear with the log. So this is the much friendlier thing to optimize.
00:37:37.480 | Cool. And then there's one more step. You know, we still can't optimize the log-likelihood
00:37:44.880 | of the thing that this eventually describes. But again, and this is kind of the beauty
00:37:51.600 | of math is that somebody figured out a long time ago that there's a way to optimize some
00:37:57.200 | other quantity called the ALBO for short, which stands for Evidence Lower Bound.
00:38:10.520 | And the evidence is just another name for the likelihood. And the lower bound means
00:38:30.120 | it's, you know, the lower bound of the evidence. And if you optimize that, it's almost as good
00:38:36.720 | as optimizing the thing that we really want to. But this one we can calculate very, very
00:38:44.920 | easily. And so you can use this as a loss function to train two neural networks that
00:39:04.640 | predict our square from earlier, which was our mean, and our triangle, which is our variance
00:39:13.000 | of this reverse process. And once you have that, you go all the way back here. So then
00:39:20.840 | you have these values. You can start with pure noise and keep calling these neural networks
00:39:30.080 | sampling from those normal distributions, kind of applying that iteratively over many
00:39:37.980 | steps and you recover this data distribution. One thing that's important to clarify here
00:39:47.920 | is that you can recover the whole distribution, but you can't necessarily take a single image,
00:39:57.280 | get it to pure noise and then convert it back. So this operates sort of at the distribution
00:40:03.240 | level. So you can take this kind of magic API, you can reconstruct that whole API. And
00:40:13.680 | if you can do that, then you can generate images, MNIST digits or cats or dogs or whatever
00:40:22.120 | you want to.
00:40:27.320 | I want to just clarify one thing about this process of the kind of the loss function.
00:40:33.200 | So this sort of evidence, lower bound loss function, the kind of approach that it's taking
00:40:38.400 | is that we have this forward process, right? We can go from the original images and figure
00:40:47.040 | out these sorts of intermediate distributions going all the way finally to noise. With this
00:40:53.400 | sort of evidence, lower bound loss function, what we're really kind of doing is trying
00:41:00.120 | to match our distribution that we're trying to optimize to those distributions that we
00:41:06.200 | saw in the forward process. So that's what we're trying to do. We're trying to match
00:41:11.560 | that sort of those distributions and there's a specific type of function and it's able
00:41:18.440 | to do that. It's called a KL divergence. That's the sort of function that can compare probability
00:41:24.920 | distributions. And again, because we're dealing with Gaussians, you can calculate that analytically
00:41:31.920 | and a lot of the math becomes very simple. So that's again, with the whole Gaussians,
00:41:40.040 | we know them quite well and the math is very simple. So that allows us to do this sort
00:41:44.320 | of comparison between these distributions very easily and optimize that. And so we want
00:41:50.000 | to kind of minimize the difference between the distributions we see in the forward process
00:41:55.240 | and the distributions we're trying to determine for the reverse process.
00:42:01.960 | Perfect. Then there's one more thing, I think, one more kind of major step to get closer to
00:42:12.320 | the form that you would have seen in Jeremy's lesson. So there was a 2020 paper. The initials
00:42:22.200 | of that model is DDPM. Tanish, do you know what this stands for?
00:42:28.080 | Yeah, that's for denoising, diffusion, probabilistic model.
00:42:33.720 | Okay, cool. And what they did was they said, let's assume that this variance is just a
00:42:46.800 | constant so we don't learn it. And we assume also that the step size from earlier, you
00:42:56.320 | know, the variance of the noise that we add at each step is also a constant. We don't
00:43:02.200 | learn that. So we're just predicting the mean and these are set to some really convenient
00:43:08.200 | values. Then the last turns out to be that you predict the noise. So you can restructure
00:43:24.480 | this whole thing as you need to train a network that takes in images. So here's your network
00:43:37.320 | and it tells you what of this image is noise. Thanks to these simplifying assumptions. And
00:43:50.720 | even though they're assumptions, turns out you can train much more, you know, models
00:43:57.080 | that produce much better images. Now, I think this relates to something from the, you know,
00:44:06.120 | the lesson that Jeremy gave. Tanish, do you remember that there was something about the
00:44:12.000 | gradient or something like that?
00:44:14.680 | Yes, yes. So this idea of, you know, adding noise and learning to remove noise, the idea
00:44:25.240 | is that kind of by, you know, again, you have this sort of this image that you have noise,
00:44:35.040 | right? And by, sorry, let me think about the best way to say this. Oh, yeah, sorry. Okay,
00:44:46.080 | let me just turn it over. So I'll just start. Yeah, so like Jeremy will say in the lesson,
00:44:55.760 | what we want to do is we want to figure out the gradient of this likelihood function.
00:45:03.600 | So this is just kind of a different way about thinking about this. If we had some information
00:45:07.960 | about this gradient, then we could, for example, you know, use that information to produce
00:45:16.800 | like we talked about, kind of this optimization, kind of produce images with high likelihood.
00:45:22.280 | So the idea is that we can add noise to the images that we have. So those are samples
00:45:30.120 | that we have. And that kind of takes us away from, you know, the regular images that we
00:45:36.480 | know that we have. And, you know, that kind of decreases the likelihood, right? So we
00:45:40.800 | have those images and we're adding noise that decreases the likelihood. And we want to kind
00:45:45.040 | of learn how to get back to high likelihood images and kind of use that to provide some
00:45:51.920 | sort of estimate of our gradient. So this sort of denoising process actually allows us
00:45:57.680 | to do that. So there are actually theorems also, I think, from the 1950s that demonstrate
00:46:04.000 | that especially in the case of this sort of Gaussian noise that we're working with, this
00:46:09.080 | denoising process is equivalent to learning what is known as the score function. And the
00:46:16.440 | score function is the gradient of the log of the likelihood. So again, they have this
00:46:24.040 | log here, which, again, makes the math nicer and easier to work with. But the general idea
00:46:29.440 | is the same because as we talked about, log is a monotonic function. So again, the general
00:46:35.200 | ideas are the same, but the score function specifically refers to the gradient of the
00:46:41.600 | log likelihood. So this sort of denoising process allows us to learn the score function. So
00:46:51.920 | that's when we're doing this noise, predicting that we had this whole probabilistic framework
00:46:56.920 | using that sort of likelihood framework. And it came back down to just predicting the noise.
00:47:02.960 | And that's what the DPM paper showed in 2020. But it turns out that is equivalent to calculating
00:47:10.320 | out this sort of score function and using that information to be able to sample from
00:47:16.760 | our distribution. So that's kind of how these two approaches connect. So there's a lot of
00:47:22.360 | literature talking about maybe the sort of probabilistic likelihood perspectives of diffusion
00:47:27.840 | models. And there's also a lot of literature talking about this score-based perspective.
00:47:33.760 | But this hopefully allows you to think about the similarities and how these two approaches
00:47:39.280 | connect with each other. Yeah, awesome. Yeah. And that's kind of the beauty, I think, of
00:47:46.080 | the math side of things here is that you find all of these relationships between different
00:47:52.240 | fields and also like between different centuries, basically. And that allows you to do really
00:47:58.840 | kind of powerful and unexpected things. Okay, so you can just do a quick recap of where
00:48:09.520 | we got to. So we started out with our data distribution, which we want to model. We said
00:48:16.920 | we'll define this forward diffusion process, which is a way of kind of adding noise to
00:48:23.040 | this model. And because we add it in this specific way, thanks to some discovery in
00:48:32.120 | the 1950s, the reverse process has the same form. And then we already know how to train
00:48:45.280 | a neural network for this using the elbow. And then a couple of years later came the
00:48:55.400 | discovery, simplifying assumptions that in the end, all we do is predict the noise. And
00:49:02.240 | I just remembered we take actually the MSE of this noise prediction, the mean squared
00:49:08.840 | error, which is a nice, very simple framing of the model. And Antoni spoke about another
00:49:15.560 | way to derive all of this, which is the score function approach, the gradient of the log-like
00:49:20.880 | period. Okay, cool. Yeah, I highly recommend checking
00:49:28.440 | out the course lesson as well, if you haven't. You know, if you don't understand this, there's
00:49:35.760 | no need to be intimidated. You can still do be very effective without ever using math,
00:49:43.680 | you can be very effective at deep learning, as fast AI has shown us, and you can do novel
00:49:48.480 | research as well. For me, this is, it's interesting. And, you know, it's even beautiful, in a way.
00:49:58.360 | So I recommend checking it out, but don't feel intimidated. You can find the course
00:50:04.360 | lesson links in the fast AI forum. We'll add those links as well in the description of
00:50:10.200 | this video. We'll also have a topic in the forum for this lesson. You can have discussions
00:50:16.800 | there, post any comments, add any, you know, relevant links to the math. And then we have
00:50:23.380 | another lesson, you know, video by Jono, which I really recommend checking out. He's a, you
00:50:30.600 | know, he's a great teacher and he was, I think he was the first person to do a full course
00:50:36.440 | on unstable diffusion. Yeah, Jono's video is kind of a deep dive into some of the code
00:50:41.840 | a little bit more and into some of the concepts a little bit more. So I feel like between these
00:50:45.440 | three videos, it's a good overview. You know, I think, I mean, just to clarify, you don't
00:50:54.880 | need to understand all the math that was described in this video. That's not to say you want
00:50:59.160 | me to understand math. We'll be covering lots of math in these lessons. But we'll be covering
00:51:06.520 | just the math you need to understand and build on the code. And we'll be covering it over
00:51:13.840 | many, many more hours than this rather rapid overview.
00:51:18.520 | Perfect. Cool. And yeah, thank you so much, Denise. I had a lot of fun. And thank you
00:51:25.640 | so much, Westy. And that was awesome. Awesome. Cool. Bye-bye.
00:51:31.880 | [BLANK_AUDIO]