back to indexVariational Autoencoder - Model, ELBO, loss function and maths explained easily!
Chapters
0:0 Introduction
0:41 Autoencoder
2:35 Variational Autoencoder
4:20 Latent Space
6:6 Math introduction
8:45 Model definition
12:0 ELBO
16:5 Maximizing the ELBO
19:49 Reparameterization Trick
22:41 Example network
23:55 Loss function
00:00:00.000 |
Welcome to my video about the variational autoencoder. 00:00:02.840 |
In this video, I will be introducing the model, how it works, the architecture, 00:00:10.360 |
Why should you learn about the variational autoencoder? 00:00:13.160 |
Well, it's one of the building blocks of the stable diffusion, 00:00:16.960 |
and if you can understand the maths behind the variational autoencoder, 00:00:20.240 |
you have covered more than 50% of the maths that you need for the stable diffusion. 00:00:25.640 |
At the same time, I will also try to simplify the math as much as possible 00:00:29.400 |
so that everyone with whatever background can follow the video. 00:00:34.720 |
Before we go into the details of variational autoencoder, 00:00:37.360 |
we need to understand what is an autoencoder. 00:00:41.880 |
So the autoencoder is a model that is made of two smaller models. 00:00:46.520 |
The first is the encoder, the second the decoder, 00:00:49.240 |
and they are joined together by this bottleneck Z. 00:00:52.200 |
The goal of the encoder is to take some input 00:00:55.160 |
and convert it into a lower dimensional representation, let's call it Z, 00:01:00.240 |
and then if we take this lower dimensional representation 00:01:05.600 |
we hope that the model will reproduce the original data. 00:01:10.960 |
Because we want to compress the original data into a lower dimension. 00:01:19.040 |
We can have an analogy with file compression. 00:01:21.360 |
For example, if you have a picture, let's call it zebra.jpg, 00:01:24.400 |
and you zip the file, you will end up with a zebra.zip file. 00:01:28.800 |
And if you unzip the file or decompress the file, 00:01:34.200 |
The difference between the autoencoder and the compression 00:01:39.520 |
and the neural network will not reproduce the exact original input, 00:01:44.920 |
but will try to reproduce as much as possible of the original input. 00:01:54.760 |
that is, the lower representation of the data should be as small as possible, 00:01:58.680 |
and the reconstructed input should be as close as possible to the original input. 00:02:05.640 |
The problem with autoencoders is that the code learned by the model doesn't make sense. 00:02:11.280 |
That is, the model just learns a mapping between input data and a code Z, 00:02:17.680 |
but doesn't learn any semantic relationship between the data. 00:02:21.040 |
For example, if we watch at the code learned for the picture of the tomato, 00:02:25.200 |
it's very similar to the code learned for the picture of the zebra, 00:02:28.880 |
or the cat is very similar to the code learned for the pizza, for example. 00:02:33.280 |
So the model didn't capture any relationship between the data 00:02:36.840 |
or any semantic relationship between the data. 00:02:39.160 |
And this is why we introduced the variational autoencoder. 00:02:42.120 |
In the variational autoencoder, we learn a latent space. 00:02:47.800 |
which represents a multivariate distribution over this data. 00:02:52.240 |
And we hope that this multivariate distribution, so this latent space, 00:02:56.720 |
captures also the semantic relationship between the data. 00:02:59.800 |
So for example, we hope that all the food pictures have a similar representation in this latent space, 00:03:05.440 |
and also all the animals have a similar representation, 00:03:08.160 |
and all the cars and all the buildings, for example the stadium, 00:03:13.200 |
And the most important thing that we want to do with this variational autoencoder 00:03:17.640 |
is we want to be able to sample from this latent space to generate new data. 00:03:23.440 |
So what does it mean to sample the latent space? 00:03:26.440 |
Well, for example, when you use Python to generate a random number between 1 and 100, 00:03:32.280 |
you're actually sampling from a random distribution, 00:03:37.560 |
because every number has equal probability of being chosen. 00:03:40.440 |
We can sample from the latent space to generate a new random vector, 00:03:44.280 |
give it to the decoder and generate new data. 00:03:46.880 |
For example, if we sample from this latent space, 00:03:49.560 |
which is the latent space of a variational autoencoder that was trained on food pictures, 00:03:54.600 |
and we happen to sample something that was exactly in between of these three pictures, 00:04:00.040 |
we hope to get something that also in its meaning is similar to these three pictures. 00:04:05.680 |
So for example, in between the picture of egg, floor and basil leaves, 00:04:10.160 |
we hope to find pasta with basil, for example. 00:04:13.280 |
Which means that the model has captured somehow the relationship between the data it was trained upon, 00:04:23.560 |
Because we model our data as it is coming from a variable X, 00:04:32.360 |
but this variable X is conditioned on another random variable Z that is not visible to us, 00:04:42.760 |
And we will model this hidden variable as a multivariate Gaussian with means and variance. 00:04:50.040 |
I know that this all sounds very abstract, so let me give you a more concrete example. 00:05:01.320 |
we have some people who since the childhood are born and lived all their life in this cave. 00:05:09.680 |
And these people never left the cave, so they only stayed in this area of the cave. 00:05:17.840 |
These people, since childhood, have seen these pictures on the cave 00:05:22.840 |
that are projected from these 3D objects through this fire. 00:05:27.320 |
So they are the shadow of these 3D objects here. 00:05:30.080 |
But these people, they don't know that these pictures actually are casted from these 3D objects. 00:05:37.440 |
For them, the horse is something black that moves like this. 00:05:40.480 |
The bird is something black that moves like this. 00:05:43.360 |
So we need to think that we are just like these people. 00:05:49.960 |
But this data actually comes from something that we cannot observe, 00:05:53.880 |
that is of a higher representation of this data, 00:05:59.320 |
And we want to learn something about this abstract representation. 00:06:03.000 |
Before we go into the maths of variational autoencoder, 00:06:08.440 |
Because the math is going to be a little hard to follow for some people 00:06:13.920 |
The point is, in order to understand the variational autoencoder, 00:06:19.480 |
Not only the numerical math, but also the concept. 00:06:22.440 |
So what I will try to do is to give you the necessary background to understand the math, 00:06:29.360 |
But at the same time, I will also try to convey some general information, 00:06:34.360 |
some high-level representation of what is happening 00:06:39.160 |
Also, I believe that VA is the most important component of stable diffusion models. 00:06:43.400 |
So concepts like ELBO that we will see in the following slides 00:06:49.640 |
it will make it easy for you to understand the stable diffusion. 00:06:52.800 |
Plus, in 2023, I think you shouldn't be memorizing things, 00:06:56.600 |
so just memorizing the architecture of models, 00:06:58.880 |
because ChatGPT can do that faster and better than you. 00:07:04.560 |
You can't be a machine and compete with a machine. 00:07:06.600 |
I also believe that you should try to learn things not only out of curiosity, 00:07:11.640 |
but because that's the true engine of innovation and creativity. 00:07:17.120 |
So let's start by introducing some math concepts 00:07:23.560 |
Don't be scared if you are not familiar with these concepts, 00:07:26.120 |
because I will try to give a higher representation of what is happening. 00:07:32.440 |
you will still understand what is happening on a higher level. 00:07:36.480 |
We need what is the expectation of a random variable, which is this. 00:07:40.120 |
We need the chain rule of probability, which is this, 00:07:44.800 |
All of these three concepts are usually taught in a bachelor's class, 00:07:50.560 |
And another concept that is not taught in a bachelor, 00:07:53.240 |
but I will introduce now, is the Kullback-Leiber divergence. 00:07:56.720 |
This is a very important concept in machine learning, 00:07:59.200 |
and it's a divergence measure that allows you to measure 00:08:02.760 |
the distance between two probability distributions. 00:08:11.720 |
how far are these two probability distributions. 00:08:15.600 |
But at the same time, this is not a distance metric, 00:08:20.120 |
So when you have a distance metric, usually from, for example, 00:08:25.720 |
if A to B is one meter apart, then B to A is also one meter apart. 00:08:29.680 |
But this doesn't happen with Kullback-Leiber divergence. 00:08:42.760 |
and it's equal to zero if and only if the two distributions are same. 00:08:50.160 |
Now, we saw before that we want to model our data 00:08:53.600 |
as coming from a random distribution that we call X, 00:08:56.720 |
which is conditioned on a hidden variable or latent variable called Z. 00:09:02.880 |
marginalize over the joint probability using this relationship here. 00:09:09.760 |
because we need to integrate over all latent variable Z. 00:09:15.680 |
It means that in theory, we can calculate it. 00:09:18.440 |
But in practice, it is so slow and so computationally expensive 00:09:23.280 |
So something intractable is like trying to guess your neighbor's Wi-Fi password. 00:09:27.960 |
In theory, you can do it by generating all possible passwords 00:09:31.800 |
But in practice, it will take you thousands of years. 00:09:34.440 |
So this relationship, we can also write it like this 00:09:44.840 |
but we need this ground truth of this, which we don't have, 00:09:48.400 |
because this is the probability distribution over the latent space 00:09:57.480 |
So this looks like a chicken and egg problem, 00:10:00.640 |
because we are trying to find this using this, 00:10:12.320 |
Usually, when you cannot find something that you want, 00:10:23.160 |
And we think that it's parametrized by some parameters theta 00:10:31.320 |
However, what if we could find something that is a surrogate, 00:10:45.200 |
We start with the log likelihood of our data, 00:10:57.080 |
of a probability distribution, which is always equal to one. 00:11:00.360 |
And we can bring this quantity inside the integral, 00:11:03.760 |
because it doesn't depend on the variable that is integrated. 00:11:11.240 |
We can see that this integral is actually an expectation. 00:11:17.920 |
we can apply the equation given by the chain rule of probability. 00:11:22.560 |
We can multiply the numerator and the denominator 00:11:40.720 |
And finally, we can see that the second expectation 00:11:45.840 |
And we know that it's always greater than or equal to zero. 00:11:48.880 |
Now, let me expand this relationship that we have found. 00:11:55.160 |
is equal to this quantity plus this KL divergence. 00:12:11.040 |
that is always greater than or equal to zero. 00:12:17.480 |
without knowing nothing about the quantities involved? 00:12:32.720 |
which is always greater than or equal to zero. 00:12:34.720 |
Without knowing nothing about your base salary 00:12:41.480 |
that your total compensation is always greater than 00:12:46.360 |
Now, this expression here has the same structure 00:12:52.360 |
So we can infer the same for the first expression. 00:12:56.520 |
That is, the first quantity is always greater than 00:13:00.800 |
without caring what happens to the third quantity. 00:13:03.240 |
So this also means that this is a lower bound for this. 00:13:21.840 |
we are going to automatically maximize this quantity. 00:13:32.560 |
And then we can see that the second expectation 00:13:43.360 |
with the probability distribution we see here. 00:14:13.160 |
and at the same time, you are minimizing this. 00:14:23.640 |
and your company has profit, revenue, and cost. 00:14:32.280 |
and at the same time, you maximize your cost. 00:14:38.720 |
we are actually maximizing this first quantity here, 00:14:42.960 |
we are minimizing the second quantity we see here. 00:14:46.240 |
Now, let's look at what do these quantities mean. 00:14:49.200 |
And for that, I took this picture from a paper 00:14:57.760 |
and we can see that this is a log likelihood. 00:15:16.440 |
So this is what we want our Z space to look like, 00:15:21.760 |
we want our Z space to be a multivariate Gaussian, 00:15:25.000 |
and this is the learned distribution by the model. 00:15:34.000 |
the model actually is minimizing the distance 00:15:59.560 |
Now, the problem is when you maximize something 00:16:24.360 |
we take the gradient and adjust the weights of the model 00:16:29.560 |
And this is also what happens when we train our models. 00:16:32.720 |
For example, imagine we have a function that is convex, 00:16:38.840 |
and our minimum is here, our initial weights are here, 00:16:48.560 |
to where the direction of growth of the function, 00:16:52.600 |
so the function is growing in this direction, 00:16:58.840 |
And the problem is we are not calculating the true gradient 00:17:12.440 |
stochastic gradient descent and not just gradient descent? 00:17:17.760 |
you need to evaluate the function over all the data set, 00:17:28.960 |
you get a distribution over the possible gradient. 00:17:32.760 |
So for example, when we use stochastic gradient descent 00:17:36.480 |
and we evaluate the gradient of our loss function, 00:17:44.840 |
And someone proved that if you do it long enough, 00:17:47.960 |
so if you do it over the entire training set, 00:17:55.280 |
Now, the fact that it is a stochastic gradient descent, 00:18:05.680 |
in stochastic gradient descent, is small enough 00:18:08.840 |
so that we can use stochastic gradient descent. 00:18:14.400 |
if we do the same job with this one, we get an estimator. 00:18:19.880 |
a stochastic quantity, that has a high variance, 00:18:25.340 |
So if we look at the paper by Kingma and Welling, 00:18:28.180 |
they show that there is an estimator for the elbow, 00:18:32.580 |
and this estimator, however, exhibits a very high variance. 00:18:38.700 |
imagine we are trying to minimize our function. 00:18:41.380 |
If we use an estimator that has high variance, 00:18:45.700 |
suppose we are here and the minimum of the model is here. 00:18:49.300 |
If we are lucky, when we calculate the gradient, 00:18:58.060 |
However, if we are unlucky because it has high variance, 00:19:01.620 |
the model may return a very different gradient 00:19:04.540 |
than what we expect, for example, in this direction, 00:19:06.900 |
and then we will move to the opposite direction, 00:19:09.180 |
which is this one, so it will take us far from the minimum, 00:19:14.740 |
So we cannot use an estimator that has high variance. 00:19:24.700 |
But because of it being with a high variance, 00:19:53.740 |
the source of randomness outside of the model, 00:19:56.860 |
and we will call it reparameterization trick. 00:19:59.780 |
So the reparameterization trick means basically 00:20:02.260 |
that we take the stochastic component outside of z, 00:20:14.100 |
combine it with the parameters learned by the model, 00:20:19.420 |
which is the mean and the sigma of our multivariate Gaussian 00:20:24.420 |
and then we will run backpropagation through it. 00:20:34.300 |
and we can see here that when we run backpropagation, 00:20:41.580 |
but before, this node here was random, was stochastic, 00:20:46.940 |
so we couldn't run backpropagation through it 00:20:55.820 |
However, if we take the randomness outside of this node 00:21:09.780 |
will also calculate the gradient along this path, 00:21:17.220 |
We will choose a random source that is fixed. 00:21:24.180 |
in case we are using a multivariate Gaussian, 00:21:27.300 |
and so now we can actually calculate the backpropagation. 00:21:30.060 |
Plus, this estimator that we found has lower variance. 00:21:35.900 |
in which we replaced the stochastic quantity here, 00:21:43.260 |
which is actually coming from our noise source, 00:21:47.740 |
We combine it with the parameters learned from the model 00:21:56.980 |
This is also called the Monte Carlo estimator. 00:22:00.300 |
We also can prove that this new estimator is unbiased. 00:22:06.060 |
it will actually converge to the true gradient. 00:22:10.580 |
So if we take the gradient of this estimator, 00:22:16.540 |
we can see that we can write this quantity here like this, 00:22:25.900 |
doesn't depend on the parameters of this estimation, 00:22:33.260 |
and then we can write this quantity inside this one here 00:23:00.140 |
So now I want to combine all this knowledge together 00:23:02.860 |
to simulate what the network will actually do. 00:23:06.060 |
So imagine we have a picture here of something. 00:24:04.100 |
but I will try to simplify the meaning behind it. 00:24:11.180 |
We can see it here and it's made of two components. 00:24:23.180 |
is from what we want our distribution to look like. 00:24:26.420 |
And the second one is the quality of the reconstruction, 00:24:36.980 |
how our image is different from the original image, 00:24:40.620 |
so the reconstructed sample from the original sample. 00:24:45.820 |
allows to calculate the KL divergence between the prior, 00:24:51.460 |
and what is actually the Z space learned by the model. 00:25:15.460 |
plus the sigma learned by the model multiplied, 00:25:42.340 |
we should force our model to learn a positive quantity, 00:25:48.180 |
So we just pretend that we are learning log sigma squared, 00:25:51.420 |
and then we want to transform into sigma squared, 00:26:04.540 |
of what is the ELBO, so the ELBO is something 00:26:09.460 |
And I also wanted to show you the derivation of this ELBO 00:26:14.780 |
because this is the same problems that we will face 00:26:17.100 |
when we will talk about the stable diffusion. 00:26:19.060 |
And this part here, I took from the original paper 00:26:27.300 |
If you're wondering why we got this particular formula here 00:26:37.180 |
to have a better understanding on how to derive it yourself. 00:26:42.540 |
and hopefully you learned everything there is to know about, 00:26:46.460 |
at least from a theoretical point of view, about the VAE. 00:26:49.900 |
In my next video, I want to also make a practical example 00:26:53.460 |
on how to code a VAE and how to train a network, 00:26:56.180 |
and then how to sample from this latent space. 00:27:00.980 |
I'm pretty sure that you will have a deep understanding of the VAE, 00:27:08.540 |
Thank you for watching, and welcome back to my channel.