back to indexLesson 21: Deep Learning Foundations to Stable Diffusion
Chapters
0:0 A super cool demo with miniai and CIFAR-10
2:55 The notebook
7:12 Experiment tracking and W&B callback
16:9 Fitting
17:15 Comments on experiment tracking
20:50 FID and KID, metrics for generated images
23:35 FID notebook (18_fid.ipynb)
31:7 Get the FID from an existing model
37:22 Covariance matrix
42:21 Matrix square root
46:17 Why it is called Fréchet Inception Distance (FID)
47:54 Some FID caveats
50:13 KID: Kernel Inception Distance
55:30 FID and KID plots
57:9 Real FID - The Inception network
61:16 Fixing (?) UNet feeding - DDPM_v3
68:49 Schedule experiments
74:52 Train DDPM_v3 and testing with FID
79:1 Denoising Difussion Implicit Models - DDIM
86:12 How does DDIM works?
90:15 Notation in Papers
92:21 DDIM paper
113:49 Wrapping up
00:00:00.000 |
Hello, Jono. Hello, Tanishk. Are you guys ready for lesson 21? 00:00:07.720 |
I don't know what I would have said if you had said no. So good. 00:00:12.760 |
I'm actually particularly excited because I had a little bit of a peak preview of something 00:00:16.640 |
that Jono has been working on, which I think is a super cool demo of what's possible with 00:00:30.320 |
Great, thanks, Jeremy. Yeah, so as you'll see, when it's back to Jeremy to talk through 00:00:35.920 |
some of the experiments and things we've been doing, we've been using the fashion reminisce 00:00:40.240 |
dataset at a really small scale and really rapidly try out these different ideas and 00:00:44.840 |
see some maybe nuances or things that we'd like to explore further. And so as we were 00:00:49.640 |
doing that, I started to think that maybe it was about time to explore just ramping 00:00:54.000 |
up the level, like seeing if we can go to the next slightly larger datasets, slightly 00:00:58.120 |
harder difficulty, just to double check that these ideas still hold for longer training 00:01:06.960 |
That's a really good idea because I feel pretty confident that the learnings from fashion 00:01:11.720 |
reminisce are going to move across most of the time these things seem to, but sometimes 00:01:17.720 |
they don't and it can be very hard to predict. So it seems like a very wise choice. 00:01:23.000 |
Yeah. And so we'll keep wrapping up, but as a next step, one above fashion enmost, I thought 00:01:30.280 |
I'd look at this data called CIFAR10. And so CIFAR10 dataset is a very popular dataset 00:01:35.920 |
originally for things like image classification, but also now for any paper on generative modeling. 00:01:43.080 |
It's kind of like the smallest dataset that you'll see in these papers. And so yeah, if 00:01:48.600 |
you look at the classification results, for example, pretty much every classification 00:01:51.480 |
paper since they started tracking has reported results on CIFAR10 as well as their larger 00:01:57.600 |
datasets. And likewise with image generation, very, very popular, all of the recent diffusion 00:02:04.040 |
papers will usually report CIFAR10, add their image net, and then whatever large massive 00:02:13.200 |
We were somewhat notable in 2018 for managing to train. So for CIFAR10, 94% classification 00:02:20.640 |
is kind of the benchmark. So there was a competition a few years ago where we managed to get to 00:02:25.000 |
that point at a cost of like 26 cents worth of AWS time, I think, which won a big global 00:02:34.760 |
competition. So I actually hate CIFAR10, but we had some real fun with it a few years ago. 00:02:43.120 |
Yeah. And it's good. It's a nice dataset for quickly testing things out, but we'll talk 00:02:47.560 |
about why we also like us as a group don't like it at all. And we'll pretty soon move 00:02:55.320 |
So one of the things you'll notice in this notebook, I'm basically using all of the same 00:02:59.560 |
code that Jeremy is going to be looking at and explaining. So I won't go into too much, 00:03:03.800 |
but the data sets also on HuggingFace. So we can load it just like we did the fashion 00:03:08.040 |
Mnest. The images are three channel rather than single channel. So the shape of the data 00:03:18.680 |
is slightly different to what we've been working with. That's weird. Yeah. So we have instead 00:03:26.280 |
of the single channel image, we have a three channel red, green, and blue image. And this 00:03:31.520 |
And you've got then two images in your batch. So that's batch by channel by height by width, 00:03:39.160 |
That was a little confused by the 32 by 32 is it? 00:03:42.520 |
Oh, yeah. It's fine. I got it now. Batch size in the arbitrary. And so if you plot these, 00:03:48.200 |
one of the things, if you look at this, okay, I can see these are different classes. Like 00:03:51.680 |
I know this is an airplane, a frog, an airplane, but it's actually a puzzle with an airplane 00:03:55.560 |
on the cover, a bird, a horse, a car. That one you squint, you can tell it's a dare, 00:04:00.920 |
but only if you really know what you're looking for. And so when we started to talk about 00:04:04.880 |
generating these images, this is actually quite frustrating. Like this, if I generated 00:04:09.320 |
this, I'd say this might be the model doing a really bad job. But it's actually that this 00:04:13.840 |
is a boat, this is a dog. It's just that this is what the data looks like. 00:04:17.560 |
And so I've actually got something that can help you out. I'll show later today, which 00:04:22.000 |
is something like this. It's really actually hard to see whether it's good because the 00:04:26.360 |
images are bad. It can be helpful to have a metric to generate that can see how good 00:04:31.840 |
samples are. So I'll be showing a metric for that later today. 00:04:35.600 |
Yeah. And that'll be great. And I hope to have like automated. But anyway, I just wanted 00:04:39.960 |
to flag like for visually inspecting these, it's not great. And so we don't really like 00:04:43.760 |
CIFAR 10 because it's hard to tell, but still a good one to test with. So the Noisify and 00:04:50.520 |
everything I'm following what Jeremy is going to be showing exactly, the code works without 00:04:54.560 |
any changes because we're adding random noise in the same shape as our data. So even though 00:04:58.920 |
our data now has three channels, the Noisify function still works fine. If we try and visualise 00:05:04.120 |
the noise based images because we're adding noise in the red pink blue channels, and some 00:05:08.840 |
of that's quite extreme values. Yeah, it looks slightly different, looks all crazy RGB. But 00:05:14.920 |
you can see, for example, this frog doesn't have as much noise and it's vaguely visible. 00:05:19.840 |
But it is, it's a many impossible task to look at this and tell what image is hiding 00:05:24.840 |
So I think this is really neat that you could use the same Noisify. Yeah, and it still, 00:05:35.400 |
it still works thanks to, it's not just that shape thing, but I guess just thanks to kind 00:05:39.840 |
of PyTorch's broadcasting kind of stuff. This often happens, you can kind of change the 00:05:45.960 |
dimensions of things that just keeps working. 00:05:48.240 |
Exactly. And we've been paying attention to those broadcasting rules and the right dimensions 00:05:52.560 |
and so on. Cool. So I'm going to use the same sort of approach to learning unit, except 00:05:59.560 |
that now obviously I need to specify three input channels and three output channels because 00:06:03.520 |
we're working with three channel images. But I did want to explore for this demo, like, 00:06:08.440 |
okay, how could I maybe justify wanting to do this kind of experiment tracking thing 00:06:13.240 |
that I'll talk about. And so I'm bumping up the size of the model substantially. I've 00:06:17.160 |
gone from, this is the default settings that we were using for FashionM Nest, but the Diffuser's 00:06:21.160 |
default unit has what, many, 20 times as many parameters, 274 million versus 15 million. 00:06:30.260 |
So we're going to try a larger model. We're going to try some longer training. And so 00:06:34.480 |
I could just do the same training that we've always done just in the notebook, set up a 00:06:41.800 |
learner with progress CV to kind of plot the loss, subtract some metrics. But yeah, I don't 00:06:49.640 |
know about you, but once it's beyond a few minutes training, I quickly get a patient 00:06:54.320 |
and I have to wait for it to finish before we can sample. So I'm doing the DDPM sample, 00:06:58.000 |
but I have to, I actually interrupted the training to say, I just want to get a look 00:07:01.400 |
at what it looks like initially and to plot some samples. And again, the sampling function 00:07:06.160 |
works without any modification, but I'm passing in my size to be a three channel image. Yeah. 00:07:12.560 |
And so this is like, we could do it like this, but at some point I would like to A, keep 00:07:17.920 |
track of what experiments I've tried and B, be able to see things as it's going over time, 00:07:24.040 |
including like, I'd love to see what the samples look like if you generated after the first 00:07:27.480 |
epoch, after the second epoch. And so that's where my little callback that I've been playing 00:07:34.560 |
So just before you do that, I'll just mention like, I mean, there are simple ways you could 00:07:39.360 |
do that, right? Like, you know, one popular way a lot of people do is that they'll save 00:07:43.800 |
some sample images as files every epoch or two, or we could like, the same way that we 00:07:50.080 |
have an updating plot, as we train with fast progress, we could have an updating set of 00:08:00.120 |
sample images. So there's a few ways we could solve that. That wouldn't handle the tracking 00:08:09.440 |
that you mentioned of like looking over time at how different changes have improved things 00:08:13.160 |
or made them worse, whatever that would, I guess, would require you kind of like saving 00:08:17.040 |
multiple versions of a notebook or keeping some kind of research journal or something. 00:08:24.600 |
It is. And all of that's doable, but I also find like, I'm a little bit lazy sometimes. 00:08:28.640 |
Maybe I don't write down what I'm trying or yeah, I've saved untitled member 37 notebooks. 00:08:34.760 |
So yeah, the idea that I wanted to show here is just that there are lots of other solutions 00:08:39.600 |
for this kind of experiment tracking and logging. And one that I really like is called Weights 00:08:44.080 |
and Biases. And so I'll explain what's going on in the code here that I'm running a training 00:08:50.160 |
with this additional Weights and Biases callback. And what it's doing is it's allowing me to 00:08:54.320 |
log whatever I'd like. So I can log samples at a different-- 00:09:00.120 |
Okay, so you're switching to a website here called wnb.ai. So that's where your callback 00:09:06.760 |
is sending information to. Yeah, so Weights and Biases accounts are free 00:09:11.720 |
for personal and academic use. And it's very, very, like, I don't think anyone hates Weights 00:09:15.840 |
and Biases. But it's a very nice service. You sign in and you log in on your computer 00:09:20.280 |
or you get an authentication token. And then you're able to log these experiments and you 00:09:25.440 |
can log into different projects. And what it gives you is for each experiment, anything 00:09:31.160 |
that you call Weights and Biases.log at any step in the training, that's getting logged 00:09:36.720 |
and sent to their server and stored somewhere where you can later access it and display 00:09:40.920 |
it. They have these plots that you can visualize easily. And you can also share them very easily 00:09:46.940 |
in these reports that integrate this data sort of interactive thing. And why that's 00:09:53.640 |
nice is that later you can go and look at-- So this is now the project that I'm logging 00:09:57.720 |
into. You can log multiple runs with different settings. And for each of those, you have 00:10:04.160 |
all of these things that you've tracked, like your training, loss, and validation. But you 00:10:09.240 |
can also track your learning rate if you're doing a learning rate schedule. And you can 00:10:14.120 |
save your model as an artifact and it'll get saved on their server so you can see exactly 00:10:19.080 |
what run reproduced, what model. It logs the code. If you set that to-- You can save code 00:10:25.600 |
equals true. And then it creates a copy of your whole Python environment, what libraries 00:10:29.260 |
were installed, what code you ran. So being able to come back later and say, oh, these 00:10:35.280 |
images here, these look really good. I can go back and see, oh, that was this experiment 00:10:40.600 |
here. I can check what settings I used. In the initialization, you can log whatever configuration 00:10:49.640 |
details you'd like in any comments. And yeah, there's other frameworks for this. 00:10:56.440 |
Yeah, in some ways, it's kind of-- initially, when I first saw Weights and Biases, it felt 00:11:02.080 |
a bit weird to me actually sending your information off to an external website because, I mean, 00:11:08.920 |
before Weights and Biases existed, the most popular way to do this was something called 00:11:12.880 |
TensorBoard, which Google provides, which is actually a lot like this, but it's a little 00:11:17.940 |
server that runs on your computer. And so like when you log things, it just puts it 00:11:24.880 |
into this little database on your computer, which is totally fine. But I guess actually, 00:11:33.880 |
there are some benefits to having somebody else run this service instead of running your 00:11:40.200 |
own little TensorBoard or whatever server. One is that you can have multiple people working 00:11:47.500 |
on a project collaborating. So I've done that before, where we will each be sending different 00:11:52.520 |
sets of hyperparameters, and then they'll end up in the same place. Or if you want to 00:11:58.200 |
be really antisocial, you can interrupt your romantic dinner and look at your phone to 00:12:03.440 |
see how your training's going. So yeah, I'm not going to say it's always the best approach 00:12:11.260 |
to doing things, but I think there's definitely benefits to using this kind of service. And 00:12:16.080 |
it looks like you're showing us that you can also create reports for sharing this, which 00:12:21.600 |
Yeah, yeah. So I like for working with other people or you want to show somebody the final 00:12:32.040 |
results and being able to, yeah, pull together the results from some different runs or just 00:12:40.160 |
say, oh, by the way, here's a set of examples from my two most recent. And things track 00:12:47.720 |
to different steps. What do you think of this? And being able to have this place where everyone 00:12:53.840 |
can go and they can inspect the different loss curves. For any run, they can say, oh, 00:12:58.440 |
what was the batch size for this? Let me go look at the info there. OK, I didn't log it, 00:13:05.000 |
but I logged time in the epochs and the learning rate. So yeah, I find it quite nice, especially 00:13:10.160 |
in a team or if you're doing lots and lots of experiments to be able to have this permanent 00:13:15.360 |
record that somebody else deals with and may host the storage and the tracking. Yeah, it's 00:13:22.720 |
Wait, and this is all the code you had to write? That's amazing. 00:13:26.640 |
Yeah. So this is using the callback system. The way Weights and Bices works is that you 00:13:32.280 |
start an experiment with this 1db.init, and you can specify any configurational settings 00:13:38.840 |
that you used there. And then anything you need to log is 1db.log, and you pass in whatever 00:13:45.000 |
the name of your value is, again, logging the loss and then the value. And once you've 00:13:50.360 |
done 1db.finish, and that syncs everything up and sends it to the server. 00:13:53.760 |
Oh, this is wild, the way you've inherited from Metric CV, and you replaced that underscore 00:13:58.680 |
log that we previously we used to allow fast progress to do the logging, and you've replaced 00:14:03.400 |
it to allow Weights and Bices to the logging. So yeah, it's really sweet. 00:14:08.520 |
Yeah, yeah. So this is using the callback system. I wanted to do the things that Metric 00:14:14.320 |
CV normally does, which is tracking different metrics that you pass in. So this will still 00:14:18.480 |
do that. And I just offload to the super like the original Metric CV method for things like 00:14:23.960 |
the after batch. But in addition to that, I'd also like to log the Weights and Bices. 00:14:29.440 |
And so before I fit, I initialize the experiments, every batch, I'm going to log the loss after 00:14:36.280 |
every epoch, and the default metrics callback is going to accumulate the metrics and so 00:14:42.680 |
on. And then it's going to call this underscore log function. So I chose to modify that to 00:14:46.680 |
say, I'm going to log my training loss, it's training, I'm going to log my validation loss, 00:14:51.760 |
if I'm doing validation, and I'd like to log some samples. And Weights and Bices is quite 00:14:56.320 |
flexible in terms of what you can log, you can create images or videos or whatever. But 00:15:02.480 |
it also takes a method with figure. And so I'm generating samples and plotting them with 00:15:09.440 |
show image and splitting back that map of the figure, which I can then log and that 00:15:14.040 |
becomes these pretty pictures. And that you can see over time, like every, every time 00:15:19.280 |
that log function runs, which is after every epoch, you can go in and see what the images 00:15:23.760 |
look like. So maybe we can make your code even simpler in the future. If we had show 00:15:29.480 |
images, maybe it could have like a optional return fake parameter that returns the figure, 00:15:35.520 |
and then we could replace those four lines of code with one, I suspect. 00:15:39.960 |
Yeah. Yeah. And I mean, this, I just sort of threw this together. It's quite early still. 00:15:45.640 |
You could also what I've done in the past is usually just create a PIL image where you 00:15:50.400 |
can, you know, make a grid or overlay a text or whatever else you'd like, and then just 00:15:55.960 |
log that as 1db.image. And otherwise, like apart from that, I'm just passing in this 00:16:01.440 |
callback as an extra callback to my set of callbacks for the learner instead of a metric 00:16:07.920 |
callback. And so when I call that fit, I still get my little progress bar, I still get this 00:16:13.400 |
printed out version because my log function still also prints those metrics just for debugging. 00:16:20.320 |
But instead of having to like watch the progress in the notebook, I can set this running disconnect 00:16:24.480 |
from the server, go have dinner, and then I can check on my phone or whatever. What 00:16:28.520 |
do the samples look like? And okay, cool. They're starting to look like less than random 00:16:34.560 |
nonsense, but still not necessarily recognizable. Maybe we need to train for longer. That can 00:16:39.480 |
be the next experiment. What I should probably do next is think of some extra metrics, but 00:16:44.200 |
Jeremy's going to talk about that. So for now, that's pretty much all I had to show 00:16:47.840 |
is just to say, yeah, it's worth as you move to these longer, you know, 10 minutes, one 00:16:52.400 |
hour, 10 hours, these experiments, it's worth setting up a bit of infrastructure for yourself 00:16:56.920 |
so that you know what were the settings I used. And maybe you're saving the model so 00:17:01.200 |
you have the artifact as a result. And yeah, I like this Wix devices approach, but there's 00:17:06.080 |
lots of others. The main thing is that you're doing something to track these experiments 00:17:10.280 |
beyond just, you know, creating for any different versions of your notebook. 00:17:14.600 |
I love it. One thing I was going to note that, I don't know if many people know, but like 00:17:19.040 |
Wix and Biasis can also save the exact code that you used to run for that run. So like 00:17:25.320 |
if you make any changes to your code, and then you know that you don't know which version 00:17:29.200 |
of your code you use for this particular experiment, so then you can figure out exactly what code 00:17:33.680 |
you use. So it's all completely reproducible. And so I love, you know, weights and biases, 00:17:38.440 |
all these different features it has. And I use weights and biases all the time for my 00:17:41.920 |
own research, like almost daily, like I had to, you know, put it on just last night, chuck 00:17:46.400 |
on it today morning. So it's like, I use it all the time for my own research. And it's 00:17:51.040 |
and yeah, like I use it, especially to just know like, oh, this run had this particular 00:17:54.840 |
config. And then like, yeah, the models go straight into weights and biases. And then 00:17:59.120 |
if I want to run a model on the test set, I literally actually take it off of weights 00:18:03.320 |
and biases like downloaded for weights and biases and run it on the test set. So I use 00:18:07.320 |
it all the time. And also just having the ability to have everything reproducible and 00:18:10.960 |
know exactly what you were doing is very convenient, instead of having to like manually track it 00:18:15.640 |
in some sort of like, I guess, a big Excel sheet or some sort of journal or something 00:18:20.080 |
like that. Sometimes this is, you know, this is a lot more convenient, I feel so yeah, 00:18:25.480 |
lest we get into too much billing for weights and biases, I'm going to put a slightly alternative 00:18:31.760 |
point of view, which is I don't use it or any experiment tracking framework myself, 00:18:40.760 |
which is not to say maybe I could get some benefits by doing so, but I fairly intentionally 00:18:46.400 |
don't because I don't want to make it easy for myself to try 1000 different hyper parameters 00:18:52.960 |
or do kind of like, you know, directed, you know, sampling of things I like to be very, 00:19:02.520 |
like directed, you know. And so that's, that's kind of the workflow I'm looking for is one 00:19:10.240 |
that allows that to happen, right? Constantly going back and refactoring and thinking what 00:19:14.360 |
did I learn and how do I change things from here and never kind of doing like 17 learning 00:19:20.320 |
rates and six architectures and whatever. Now, obviously, that's not something that 00:19:26.120 |
John O is doing at the moment. I don't be so easy for him to get on if you want to. 00:19:32.440 |
I can normally a script that just does a hundred runs with different models and different tasks 00:19:37.520 |
and then I can look at my weaknesses and say filter by the best loss, which is very tempting. 00:19:42.400 |
So I would say to people like, yeah, definitely be aware that these tools exist. And I definitely 00:19:47.600 |
agree that as we do this, which is early 2023, weights and biases is by far the best one 00:19:54.360 |
I've seen. It has by far the best integration with fast AI. And as of today, if shadow is 00:20:00.400 |
pushed yet, it has by far the best integration with mini AI. I think also fast AI is the 00:20:08.160 |
best library for using with weights and biases. It works in both ways. So yeah, no, it's there. 00:20:16.920 |
Consider using it, but also consider not going crazy on on experiments because, you know, 00:20:27.040 |
I think experiments have their place clearly, but also definitely thought out hypotheses, 00:20:35.440 |
testing them, changing your code is overall the approach that I think is best. 00:20:40.880 |
Well, thank you, John. I think that's awesome. I got some fun stuff to share as well, or 00:20:53.840 |
at least I think it's fun. And what I wanted to share is like, well, the first of all, 00:21:00.800 |
I should say we had said, we all had said that we were going to look at units this week. 00:21:08.840 |
We are not going to look at units this week, but we have good reason, which is that we 00:21:15.280 |
had said we're going to go from foundations to stable diffusion. That was also a lie because 00:21:20.920 |
we're actually going beyond stable diffusion. And so we're actually going to start showing 00:21:26.240 |
today some new research directions. I'm going to describe the process that I'm using at 00:21:31.800 |
the moment to investigate some new research directions. And we're also going to be looking 00:21:35.680 |
at some other people's research directions that have gone beyond stable diffusion over 00:21:41.560 |
the past few months. So we will get to units, but we haven't quite finished, you know, as 00:21:52.600 |
it turns out, the training and sampling yet. Now, one challenge that I was having as I 00:22:02.440 |
started experimenting with new things was started getting to the point where actually 00:22:09.000 |
the generated images looked pretty good and it felt like, you know, almost like being 00:22:20.760 |
a parent, you know, each time a new set of images would come out, I would want to convince 00:22:24.880 |
myself that these were the most beautiful. And so, yeah, when they're crap, it's obvious 00:22:33.200 |
they're crap, you know, but when they're starting to look pretty good, it's very easy to convince 00:22:36.080 |
yourself you're improving. So I wanted to have a metric which could tell me how good 00:22:43.120 |
they were. Now, unfortunately, there is no such metric. There's no metric that actually 00:22:50.040 |
says do these images, would these images look to a human being like pictures of clothes? 00:22:58.080 |
Because only talking to a person can do that. But there are some metrics which give you 00:23:02.800 |
an approximation of that. And as it turns out, these metrics are not actually a replacement 00:23:13.960 |
for human beings looking at things, but they're a useful addition. So, and I certainly found 00:23:22.840 |
them useful. So I'm going to show you the two most common, well, there's really the 00:23:26.280 |
one most common metric, which is called FID, and I'm going to show another one called KID 00:23:32.120 |
or KID. So let me describe and show how they work. And I'm going to demonstrate them using 00:23:46.080 |
the model we trained in the last lesson, which was in DDPM2. And you might remember, we trained 00:23:55.040 |
one with mixed precision, and we saved it as fashion DDPM MP for mixed precision. Okay, 00:24:04.760 |
so this is all the usual imports and stuff. This is all the usual stuff. But there's a 00:24:12.360 |
slight difference this time, which is that we're going to try to get the FID for a model 00:24:18.000 |
we've already trained. So basically, to get the model we've already trained to get its 00:24:24.400 |
FID, we can just torch.load it, right, and then .cuda to pop it on the GPU. So I'm going 00:24:31.200 |
to call that the S model, which is the model for samples, the samples model. And this is 00:24:35.320 |
just a copied and pasted DBPM from the last time. So that's for sampling. So we're going 00:24:40.480 |
to do sampling from that model. And so once we've sampled from the model, we're then going 00:24:48.120 |
to try and calculate this score called the FID. Now, what the FID is going to do is it's 00:24:56.720 |
not going to say how good are these images. It's going to say how similar are they to 00:25:05.360 |
real images. And so the way we're going to do that is we're going to actually look specifically 00:25:13.740 |
at four of the images that we generated in these samples. We're going to look at some 00:25:21.560 |
statistics of some of the activations. So what we're going to do, we've generated these samples, 00:25:31.440 |
and we're going to create a new data leader, which contains no training batches, and it 00:25:39.520 |
contains one validation batch, which contains the samples. It doesn't actually matter what 00:25:44.760 |
the dependent variable is, so I just put in the same dependent variable that we already 00:25:48.720 |
had. And then what we're going to do is we're going to use that to extract some features 00:25:58.520 |
from a model. Now, what do we mean by that? So if you remember back to notebook 14, we 00:26:06.160 |
created this thing called summary. And summary shows us at different blocks of our model, 00:26:15.080 |
there are various different output shapes. In this case, it's a batch size of 102.4. 00:26:20.120 |
And so after the first block, we had 16 channels, 28 by 28, and then we had 32 channels, 48 00:26:27.400 |
by 14 and so forth until just before the final linear layer, we had the 1024 batches, and 00:26:38.640 |
we had 512 channels with no height and width. Now, the idea of fit and kid is that the distribution 00:26:49.000 |
of these 512 channels for a real image has a particular kind of like signature, right? 00:26:58.480 |
It looks a particular way. And so what we're going to do is we're going to take our samples, 00:27:04.200 |
we're going to run it through a model that's learned to predict, you know, fashion classes, 00:27:12.600 |
and we're going to grab this layer, right? And then we're going to average it across 00:27:18.760 |
a batch, right, to get 512 numbers. And that's going to represent the mean of each of those 00:27:26.600 |
channels. So those channels might represent, for example, does it have a pointed color? 00:27:35.480 |
Does it have, you know, smooth fabric? Does it have sharp heels and so forth, right? And 00:27:46.240 |
you could recognize that something's probably not a normal fashion image if it says, "Oh, 00:27:51.640 |
yes, it's got sharp heels and flowing fabric." It's like, "Oh, that doesn't sound like anything 00:27:58.640 |
we recognize," right? So there are certain kind of like sets of means of these activations 00:28:05.280 |
that don't make sense. So this is a metric for... it's not a metric for an individual 00:28:13.840 |
image necessarily, but it's across a whole lot of images. So if I generate a bunch of 00:28:19.520 |
fashion images, and I want to say, does this look like a bunch of fashion images? If I 00:28:23.680 |
look at the mean, like maybe X percent have this feature and X percent have that feature. 00:28:28.440 |
So if I'm looking at those means, as like comparing the distribution within all these 00:28:31.760 |
images I generated, do roughly the same amount have sharp colors as those in the trend? 00:28:39.480 |
Yeah, and it's actually gonna get even more sophisticated than that. But let's just start 00:28:44.480 |
at that level, which is this features.bin. So the basic idea here is that we're going 00:28:51.040 |
to take our samples and we're going to pass them through a pre-trained model that has 00:28:57.080 |
learned to predict what type of fashion something is. And of course, we train some of those 00:29:04.720 |
in this notebook. And specifically, we trained a nice 20 epoch one in the data augmentation 00:29:13.160 |
section, which had a 94.3% accuracy. And so if we pass our samples through this model, 00:29:22.960 |
we would expect to get some, you know, useful features. One thing that I found made this 00:29:29.680 |
a bit complicated, though, is that this model was trained using data that had gone through 00:29:36.200 |
this transformation of subtracting the mean and dividing by the standard deviation. And 00:29:44.200 |
that's not what we're creating in our samples. And so, generally speaking, samples in most 00:29:55.000 |
of these kinds of diffusion models tend to be between negative one and one. So I actually 00:30:01.560 |
added a new section to the very bottom of this notebook, which simply replaces the transform 00:30:08.520 |
with something that goes from negative one to one and just creates those data loaders 00:30:14.400 |
and then trains something that can classify fashion. And I save this as not data aug, but 00:30:22.560 |
data aug two. So this is just exactly the same as before, but it's a fashion classifier 00:30:28.920 |
where the inputs are expected to be between minus one and one. Having said that, it turns 00:30:35.720 |
out that our samples are not between minus one and one. But actually, if you go back 00:30:45.360 |
and you look at DDPM2, we just use TF dot to tensor, and that actually makes images 00:30:54.520 |
that are between zero and one. So actually, that's a bug. Okay, so our images have a bug, 00:31:04.640 |
which is they go between zero and one. So we'll look at fixing that in a moment. But 00:31:07.840 |
for now, we're just trying to get the fit of our existing model. So let's do that. 00:31:13.320 |
So what we need to do is we need to take the output of our model, and we need to multiply 00:31:21.360 |
by two, so that'll be between zero and two, and subtract one. So that'll change our samples 00:31:26.520 |
to be between minus one and one, and we can now pass them through our pre-trained fashion 00:31:32.480 |
classifier. Okay, so now, how do we get the output of that pooling layer? Because that's 00:31:40.480 |
actually what we want to remind you. We want the output of this layer. So just to kind 00:31:55.200 |
of flex our PyTorch muscles, I'm going to show a couple of ways to do it. So we're going 00:32:02.440 |
to load the model I just trained, the data or to model. And what we could do is, of course, 00:32:08.740 |
we could use a hook. And we have a hooks callback. So we could just create a function, which 00:32:17.080 |
just depends the output. So very straightforward. Okay, so that's what we want. We want the 00:32:22.960 |
output. And specifically, it's, so we've got these are all sequentials. So we can just 00:32:32.320 |
go through and go, oh, one, two, three, four, five, the layer that we want. Okay, and so 00:32:37.760 |
that's the module that we want to hook. So once we've hooked that, we can pass that as 00:32:43.480 |
a callback. And we can then, it's a bit weird calling fit, I suppose, because we're saying 00:32:49.720 |
train equals false, but we're just basically capturing. This is just to put make one batch 00:32:54.640 |
go through and grab the outputs. So this means now in our hook, there's never gonna be thinking 00:33:02.400 |
called out P, because we put it there. And we can grab, for example, a few of those to 00:33:09.000 |
have a look. And yep, here, we've got a 64 by 512 set of features. Okay, so that's one 00:33:15.600 |
way we can do it. Another way we could do it is that actually sequential models are what's 00:33:24.640 |
called in Python collections, they have certain certain API that they're expected to support. 00:33:31.960 |
And out of something a collection can do like a list is you can call Dell to delete something. 00:33:39.120 |
So we can delete this layer and this layer and be left with just these layers. And once 00:33:48.880 |
we do that, that means we can just call capture Preds, because now they don't have the last 00:33:53.260 |
two layers. So we can just delete layers eight and seven call capture Preds. And one nice 00:34:00.440 |
thing about this is it's going to give us the entire 10,000 images in the test set. 00:34:08.560 |
So that's what I ended up deciding to do. There's lots of other ways I played around with which 00:34:12.320 |
worked, but I decided to show these two as being two good, pretty good techniques. Okay, 00:34:17.960 |
so now we've got what do 1000 real images look like at the end of the pooling layer. 00:34:25.240 |
So now we need to do the same for our sample. So we'll load up our fashion DDPM MP, we'll 00:34:34.240 |
sample, let's just grab 256 images for now, make them go between minus one and one, make 00:34:41.840 |
sure they look okay. And as I described before, create a data loaders where the validation 00:34:47.440 |
set just has one batch, which contains our samples and call capture Preds. Okay, so that's 00:35:00.000 |
going to give us our features. And the reason why is because we're passing the sample to 00:35:09.920 |
model and model is the classifier, which we've deleted the last two layers from. So that's 00:35:19.720 |
going to give us our 256 by 512. So now we can get the means now. That's not really enough 00:35:34.080 |
to tell us whether something is looks like real images. So maybe I should draw here. 00:35:44.340 |
So we started out with our batch of 256, 256, and our channels of 512. And we squished them 00:36:11.680 |
by taking their mean. So it's now just 256 a vector. So this is the, so wrong way around. 00:36:27.880 |
We squished them this way, 512, because this is the main for each channel. Okay. And we 00:36:42.320 |
did exactly the same thing for the much bigger, you know, full set of real images. So this 00:36:49.400 |
is our samples and this is our real. But when we squish it, that was 10,000 by 512, we get 00:37:02.560 |
again 512. So we could now compare these two, right? But, you know, you could absolutely 00:37:14.280 |
have some samples that don't look anything like images, but have similar averages for 00:37:21.440 |
each channel. So we do a second thing, which is we create a covariance matrix. Now, if 00:37:30.320 |
you've forgotten what this is, you should go back to our previous lesson where we looked 00:37:33.600 |
at it, but just remind you a covariance matrix says, in this case, we do it across the channels. 00:37:41.200 |
So it's going to be 512 by 512. So it's going to take each of these columns, and it says 00:37:51.840 |
in each cell, so here's cell one, one, basically it says, what's the difference between it, 00:38:01.680 |
basically it's saying, what's the difference between each row, each element here and the 00:38:06.120 |
mean of the whole column, multiplied by exactly the same thing for a different column. Now, 00:38:13.880 |
on the diagonal, it's the same column twice. So that means that these in the diagonal is 00:38:19.400 |
just the variance, right? But more interestingly, the ones in the off diagonal, like here, is 00:38:27.520 |
actually saying, what's the relationship between column one and column two, right? So if column 00:38:34.660 |
one and column two are uncorrelated, then this would be zero, right? If they were identical, 00:38:41.760 |
right, then it would be the same as the variance in here. So it's how correlated are they. 00:38:49.320 |
And why is this interesting? Well, if we do the same, exactly the same thing for the reals, 00:38:55.320 |
that's going to give us another 512 by 512. And it's going to say things like, so let's 00:39:01.200 |
say this first column was kind of like that, you know, doesn't have pointy heels, right? 00:39:08.920 |
And sorry, heels, spell. And the second one might be, doesn't have flowing fabric, right? 00:39:17.720 |
And this is where we say, okay, if, you know, generally speaking, you would expect these 00:39:23.300 |
to be negatively correlated, right? So over here in the reals, this is probably going 00:39:30.640 |
to have a negative, right? Whereas if over here it was like zero or even worse if it's 00:39:36.760 |
positive, it'd be like, oh, those are probably not real, right? Because it's very unlikely 00:39:41.560 |
you're going to have images that have both, where pointy heels are positively associated 00:39:46.280 |
with a flowing fabric. So we're basically looking for two data sets where their covariance 00:39:56.560 |
matrices are kind of the same and their means are also kind of the same. All right. So there 00:40:08.760 |
are ways of comparing these, you know, basically comparing two sets of data to say, are they, 00:40:24.000 |
you know, from the same distribution? And you can broadly think of it as being like, oh, 00:40:28.440 |
do they have pretty similar covariance matrices? So they have pretty similar mean vectors. 00:40:35.400 |
And so this is basically what the fresh A inception distance does. Does that make sense 00:40:45.600 |
Yes. It's, when he's striking me now, I was from the similarity as to when we were talking 00:40:53.280 |
about like the style loss and those kinds of things. How do we get the types of features 00:40:58.600 |
that occur together without worrying about like which I data is in the data, the grams 00:41:04.080 |
matrices or whatever. Yeah. Now the particular way of comparing. So, okay. So I've got the 00:41:15.880 |
means and I've got the covariances for my samples. And I've actually just created this 00:41:24.080 |
little calc stats, right? So I always, I'm showing you how I build things, not just things 00:41:28.760 |
that are built, right? So I always create things step by step and check their shapes, 00:41:32.680 |
right? And then I paste them into our merge the cells, copy the cells and merge them into 00:41:38.440 |
functions. So here's something that gets the means and the covariance matrix. So then I 00:41:47.000 |
basically do recall that both for my sample features and for my features of the actual 00:41:53.920 |
data set or the test set and the data set. Now, what I now do with that, if they have 00:42:01.600 |
those, now they have those features, I can calculate this thing called the fresh A inception 00:42:05.080 |
distance, which is here. And basically what happens is we multiply together the two covariance 00:42:14.160 |
matrices and that's now going to make them like bigger, right? So we now need to basically 00:42:22.640 |
scale that down again. Now, if we were working with, you know, non-matrices, you know, if 00:42:31.080 |
you kind of like multiply two things together, then to kind of bring it back down to the 00:42:37.240 |
original scale, you know, you could kind of like take the square root, right? So particularly 00:42:41.800 |
if it was by itself, you took the square root, you get back to the original. And so we need 00:42:45.960 |
to do exactly the same thing to renormalize these matrices. The problem is that we've 00:42:52.320 |
got matrices and we need to take the matrix square root. Now the matrix square root, you 00:43:02.040 |
might not have come across this before, but it exists and it's the thing where the matrix 00:43:09.800 |
square root of the matrix A times itself is A. Now, I'm going to slightly cheat because 00:43:19.600 |
we've used the float square root before and we did not re-implement it from scratch because 00:43:25.120 |
it's in the Python standard library and also it wouldn't be particularly interesting. But 00:43:29.320 |
basically the way you can calculate the float square root from scratch is by using, there's 00:43:36.360 |
lots of ways, but you know, the classic way that you might have done it in high school 00:43:39.680 |
is to use Newton's method, which is where you basically can solve if you're trying to 00:43:45.040 |
calculate A equals root x, then you're basically saying A squared equals x, which means you're 00:43:57.360 |
saying A squared minus x equals zero. And that's an equation that you can solve and you can 00:44:05.560 |
solve it by basically taking the derivative and taking a step along the derivative a bunch 00:44:09.880 |
of times. You can basically do the same thing to calculate the matrix square root. And so 00:44:21.880 |
here it is, right? It's the Newton method, but because it's for matrices it's slightly 00:44:26.960 |
more complicated, so it's a short method and I'm not going to go through it, but it's basically 00:44:32.240 |
the same deal. You go through up to 100 iterations and you basically do something like traveling 00:44:40.320 |
along that kind of derivative and then you say, okay, well, the result times itself ought 00:44:49.440 |
to equal the original matrix. So let's subtract the matrix times itself from the original 00:44:54.960 |
matrix and see whether the absolute value is small and if it is, we've calculated it. 00:45:00.920 |
Okay. So that's basically how we do a matrix square root. So we do, that's that. And so 00:45:07.400 |
now that we have strictly speaking, implemented from scratch, we're allowed to use the one 00:45:10.520 |
that already exists. iTorch doesn't have one, sadly, so we have to use the one from SciPy, 00:45:17.520 |
SciPy.minelk. So this is basically going to give us a measure of similarity between the 00:45:28.280 |
two covariance matrices. And then we, here's the measure of similarity between the two 00:45:36.840 |
mean matrices, which is just the sum of squared errors. And then basically for reasons that 00:45:42.300 |
aren't really interesting, but it's just normalizing, we subtract what's called the trace, which 00:45:46.640 |
is the sum of the diagonal elements, and we subtract two times the trace of this thing. 00:45:53.920 |
And that's called the Frechet Inception Distance. So a bit hand wavy on the math, because I 00:45:59.400 |
don't think it's particularly relevant to anything, but it gives you a number which 00:46:04.720 |
represents how similar is, you know, this for the samples to this for some real data. 00:46:17.640 |
Now it's weird, it's called Frechet Inception Distance when we've done nothing to do with 00:46:22.400 |
Inception. Well, the reason why is that people do not normally use the fast.ai part two custom 00:46:30.960 |
fashioned MNIST data org 2.pickle. They normally use a more famous model. They normally use 00:46:38.200 |
the Inception model, which was an image net winning model from Google Brain from a few 00:46:44.080 |
years ago. There's no reason whatsoever that Inception is a good model to use for this, 00:46:51.120 |
it just happens to be the one which the original paper used. And as a result, everybody now 00:46:56.880 |
uses that not because they sheep, but because you want to be able to compare your results 00:47:02.440 |
with other papers results, perhaps. We actually don't. We actually want to compare our results 00:47:08.760 |
from our other results, and we're going to get a much more accurate metric if we use 00:47:16.200 |
a model that's good specifically at recognizing fashion. So that's why we're using this. So 00:47:23.840 |
very, very few people bother to use this. Most people just hip install Python fit or 00:47:30.000 |
whatever it's called and use Inception, but it's actually better to use. Now, unless you're 00:47:35.000 |
comparing to papers, it's better to use a model that you've trained on your data and 00:47:38.680 |
you know is good at that. So I guess this is not a fit, it's a... Well, maybe fit now stands 00:47:46.960 |
for fashion, no fashion, MNIST. I don't know what it stands for. I should have did something. 00:47:54.920 |
I wanted to bring up two other caveats of FID, especially then like in papers is like, 00:48:02.920 |
the other thing is that FID is dependent on the number of samples that you use. So as 00:48:08.800 |
the number of samples they use for measuring FID, it's more accurate if you use more samples 00:48:15.580 |
and it's less accurate if you use less samples. Well, that's actually biased. So if you use 00:48:20.760 |
less samples, it's too high specifically. Yeah. So in papers, you'll see them report 00:48:28.880 |
how many samples they used. And so even then comparing to other papers and comparing between 00:48:34.360 |
different models and different things, you want to make sure that you're comparing with 00:48:37.440 |
the same amount of samples. Otherwise, it might just be high because they just use less 00:48:41.400 |
number of samples or something like this. So you want to make sure that's comparable. 00:48:45.640 |
And then the other thing that is because I guess it's a kind of a side effect of using 00:48:50.680 |
the Inception network in these papers is the fact that all of these are at a size 299 by 00:48:58.000 |
299, which is like the size that the Inception model was trained. So actually, when you're 00:49:03.760 |
applying this Inception network for measuring this distance, you're going to be resizing 00:49:08.680 |
your images to 299 by 299, which in different cases that may not make much sense. So like 00:49:15.520 |
in our case, we're working with 32 by 32 or 28 by 28 images. These are very small images 00:49:21.840 |
and if you resize it to 299, or in other cases, this is now kind of an issue with some of 00:49:27.960 |
these latest models, you have these large 512 by 512 or 1024 by 1024 images. And then 00:49:35.720 |
you're, you know, kind of shrinking these images to 299 by 299. And you're losing a 00:49:41.160 |
lot of that detail and quality in those images. So actually, it's kind of become a problem 00:49:46.840 |
with some of these latest papers, when you look at the FID scores and how they're comparing 00:49:50.720 |
them. And then visually, when you see them, you can kind of notice, oh, yeah, these are 00:49:54.460 |
much better images, but the FID score doesn't capture that as well, because you're actually 00:49:59.120 |
using these much smaller images. So there are a bunch of different caveats. And so FID, 00:50:04.680 |
you know, it's very good for like, yeah, it's nice and simple and automated for this sort 00:50:08.720 |
of comparison, but you have to be aware of all these different caveats of this metric 00:50:13.840 |
So excellent segue, because we're going to look at exactly those two things right now. 00:50:20.320 |
And in fact, there is a metric that compares the two distributions in a way that is not 00:50:27.720 |
biased. So it's not necessarily higher or lower if you use more or less samples, and 00:50:33.760 |
it's called the KID or KID, which is the kernel inception distance. It's actually significantly 00:50:40.800 |
simpler to calculate than the fresh A inception distance. And basically, what you do is you 00:50:48.320 |
create a bunch of groups, a bunch of partitions, and you go through each of those partitions 00:50:54.880 |
and you grab a few of your X's at a time and a few of your Y's at a time. And then you 00:51:02.040 |
calculate something called the MMD, which is here, which is basically that, again, the 00:51:11.560 |
details don't really matter. We basically do a matrix product and we actually take the 00:51:20.480 |
Q of it. This K is the kernel. And we basically do that for the first sample by its, compared 00:51:27.640 |
to itself, the second compared to itself, and the first compared to the second. And 00:51:33.560 |
we then normalize them in various ways and add the two with themselves together and subtract 00:51:41.520 |
the, with the other one. And this one actually does not use the stats. It doesn't use the 00:51:50.320 |
means and covariance metrics. It uses the features directly. And the actual final result 00:52:00.720 |
is basically the mean of this calculated across different little batches. Yeah, again, the 00:52:08.320 |
math doesn't really matter as to exactly why all these are exactly what they are, but it's 00:52:15.440 |
going to give you, again, a measure of the similarity of these two distributions. At 00:52:21.560 |
first I was confused as to why more people weren't using this because people don't tend 00:52:26.560 |
to use this and it doesn't have this, a nasty bias problem. And now that I've been using 00:52:30.720 |
it for a while, I know why, which is that it has a very high variance, which means when 00:52:35.440 |
I call it multiple times with just like samples with different random seeds, I get very different 00:52:41.560 |
values. And so I actually haven't found this used at all. So we left in the situation, 00:52:51.560 |
which is, yeah, we don't actually have a good unbiased metric. And I think that's the truth 00:52:57.340 |
of where we are, the best practices. And even if we did, all I would tell you is like how 00:53:04.580 |
similar distributions are to each other. It doesn't actually tell you whether they look 00:53:07.700 |
any good, really. So that's why pretty much all good papers, they have a section on human 00:53:14.140 |
testing. But I've definitely found this fairly useful for me for like comparing fashion images, 00:53:22.040 |
which particularly like humans are good at looking at like faces that are reasonably 00:53:26.000 |
high resolution and be like, "Oh, that eye looks kind of weird," but we're not good at 00:53:29.880 |
looking at 28 by 28 fashion images. So it's particularly helpful for stuff that our brains 00:53:35.560 |
aren't good at. So I basically wrapped this up into a class, which I call image eval for 00:53:40.600 |
evaluating images. And so what you're going to do is you're going to pass in a pre-trained 00:53:46.520 |
model for a classifier and your data loaders, which is the thing that we're going to use 00:53:54.880 |
to basically calculate the real images. So that's going to be the data loaders that were 00:54:04.440 |
in this learn, so the real images. And so what it's going to do in this class, then again, 00:54:12.520 |
this is just copying and pasting the previous lines of code and putting them into a class. 00:54:15.800 |
This is going to be then something that we call capture preds on to get our features 00:54:21.600 |
for the real images, and then we can also calculate the stats for the real images. And 00:54:27.680 |
so then we can call fit by calling calc fit, which is the thing we already had, passing 00:54:34.780 |
in the stats for the real images and calculate the stats for the features from our samples, 00:54:42.640 |
where the features, the thing that we've seen before, we pass in our samples, any random 00:54:48.560 |
Y value is fine, so I just have a single tensor there and call capture preds. So we can now 00:54:54.600 |
create an image eval object passing in our classifier, passing in our data loaders with 00:55:03.240 |
the real data, any other callbacks you want. And if we call fit, it takes about a quarter 00:55:09.200 |
of a second and 33.9 is the fit for our samples. So something that I think, okay, then kid, 00:55:19.800 |
kid's very going to be a very different scale. It's only 0.05, so kids are generally much 00:55:24.200 |
smaller than fits. So I'm mainly going to be looking at fits. And so here's what happens 00:55:31.320 |
if we call fit on sample zero and then sample 50 and then sample 100 and so forth, all the 00:55:38.760 |
way up to 900. And then we also do samples 975, 990, and 999. And so you can see over 00:55:46.760 |
time, our samples fits improved. So that's a good little test. There's something curious 00:55:53.640 |
about the fact that they stopped improving about here. So that's interesting. I've not 00:55:59.120 |
seen anybody plot this graph before. I don't know if Jono or Tanishk, if you guys have, 00:56:03.280 |
I feel like it's something people should be looking at because it's really telling you 00:56:09.260 |
it's your sampling, making consistent improvements. 00:56:12.920 |
And to clarify, this is like the predicted de-noised sample at the different stages during 00:56:19.860 |
If I was to stop something now and just go straight to the predicted X error, what would 00:56:24.960 |
So I just want to check our samples. Yeah, we preset, we add the X naught hat at each 00:56:30.920 |
time. Yep. Yep, exactly. Same for kid. And I was hoping that they would look the same 00:56:38.580 |
and they do. So that's encouraging that kid and fit are basically measuring the same thing. 00:56:43.480 |
And then something else that I haven't seen people do, but I think it's a very good idea 00:56:47.040 |
is to take the fit of an actual batch of data. Okay. And so that tells us how good we could 00:56:53.040 |
get. Now that's a bit unfair because I think the different sizes, our data is 512, our 00:57:02.440 |
sample is 256, but anyway, it's a pretty huge difference. 00:57:09.240 |
And then, yeah, the second thing that Tanishk talked about, which I thought I'd actually 00:57:12.600 |
show is what does it take to get a real fit to use the Inception Network? So I didn't 00:57:20.920 |
particularly feel like re-ellementing the Inception Network. So I guess I'm cheating 00:57:24.240 |
here. I'm just going to grab it from itorchfit. But there's absolutely no reason to study 00:57:29.800 |
the Inception Network because it's totally obsolete at this point. And as Tanishk mentioned, 00:57:37.120 |
it wants 299 by 299 images, which actually you can just call resize input to have that 00:57:43.160 |
done for you. It also expects three-channel images. So what I did is I created a wrapper 00:57:52.520 |
for an Inception v3 model that when you call forward, it takes your batch and replicates 00:58:05.000 |
the channel three times. So that's basically creating a three-channel version of a black 00:58:12.160 |
and white image just by replicating it three times. So with that wrapping, and again, this 00:58:18.480 |
is good flexing of your PyTorch muscles. Try to make sure you can replicate this, that 00:58:26.400 |
you can get an Inception model working on your batch and MNIST samples. And yeah, then 00:58:36.120 |
from there, we can just pass that to our image eval instead. And so on our samples, that 00:58:43.760 |
gives us 63.8. And on a real batch of data, it gives 27.9. And I find this a good sign 00:58:51.720 |
that this is much less effective than our real-fashioned MNIST classifier because that's 00:58:57.920 |
only a difference of a ratio of three or so. The fact that our FID for real data using 00:59:08.000 |
a real classifier was 6.6, I think that's pretty encouraging. Yeah, so that is that. 00:59:18.080 |
And we now have a FID. More specifically, we now have an image eval. Did you guys have 00:59:26.420 |
any questions or comments about that before we keep going? 00:59:31.960 |
Let's begin that pretty much every other FIDGC reported is going to be set up for CIFAR-10, 00:59:41.360 |
tiny 32 by 32 pixels resized up to 299 and fed through Inception that was trained on 00:59:46.400 |
imaging, not CIFAR-10. Yeah, it's bearing in mind that, once again, this is a slightly 00:59:51.880 |
weird metric. And even things like the types of image, like the imagery sizing algorithms 00:59:57.140 |
in PyTorch Intensiflow might be slightly different. Or if you saved your images as JPEGs and 01:00:02.640 |
then reloaded them, your FID might be twice as bad. 01:00:05.880 |
Yeah, it makes a big difference. Yeah, exactly. 01:00:09.880 |
So just to reiterate, the takeaway from all of this that I get is that it's really useful. 01:00:16.080 |
Everything's the same, like using the same backbone model, using the same approach, the 01:00:20.440 |
same number of samples, then you can compare it to other samples. But yeah, for one set 01:00:27.000 |
of experiments, a FIDGC might be good, because it's the way everything's set up. And for 01:00:31.200 |
another, that might be terrible. So if you want to compare to a paper or whatever is 01:00:36.880 |
So I'm going to maybe the approach is that like, if you're doing your own experiments, 01:00:42.640 |
these sorts of metrics are good. But then if you're going to compare to other models, 01:00:46.040 |
it's best to rely on human studies if you're comparing to other models. And that, yeah, 01:00:50.600 |
I think that's kind of the sort of approach or mindset that we should be having when it 01:00:57.840 |
Yeah, or both, you know. But yeah, so we're going to see this is going to be very useful 01:01:02.680 |
for us. And we're just going to be using the same, pretty much all the time, we're going 01:01:09.000 |
to use the same number of samples, and we're going to use the same fashion emitter-specific 01:01:16.000 |
So the first thing I wanted to do was fix our bug. And to remind you, the bug was that 01:01:21.760 |
we had, we were feeding into our unit in DDPM v2 and the original DDPM images that go from 01:01:31.560 |
zero to one. And yeah, that's that's wrong. That's like nobody does that. Everybody feeds 01:01:37.880 |
in images that are from minus one to one. So that's very easy to fix. You just... 01:01:44.840 |
It drew me just to ask, like, why is that a bug? Why is it a bug? I mean, it's like everybody 01:01:50.680 |
knows it's a bug because that's what everybody does. Like, I've never seen anybody do anything 01:01:54.960 |
else and it's very easy to fix. So I fixed it by adding this to DDPM v2 and I reran it 01:02:04.040 |
and it didn't work. It made it worse. And this was the start of, you know, a few horrible 01:02:14.400 |
days of pain because, like, when you, you know, fix a bug and it makes things worse, 01:02:22.240 |
that generally suggests there's some other bug somewhere else that somehow is offset 01:02:26.600 |
your first bug. And so I had to go, you know, I basically went back through every other 01:02:31.240 |
notebook at every cell and I did find at least one bug elsewhere, which is that we hadn't 01:02:39.760 |
been shuffling our training sets the whole time. So I fixed that, but it's got absolutely 01:02:45.160 |
nothing to do with this. And I ended up going through everything from scratch three times, 01:02:49.760 |
rerunning everything three times, checking every intermediate output three times. So 01:02:53.160 |
days of, you know, depressing and annoying work and made no progress at all. At which 01:03:00.840 |
point I then asked Johnno's question to myself more carefully and provided a less driven 01:03:09.240 |
response to myself, which was, well, I don't know why everybody does this, actually. So 01:03:18.440 |
I asked to Nishkan Johnno and I was like, oh, in Patreon, I was like, have you guys seen 01:03:22.720 |
any math, papers, whatever that's based on this particular input range? And yeah, you 01:03:35.920 |
guys are both like, no, I haven't. It's just, it's just what everybody does. So at that 01:03:43.920 |
point, it raised the possibility that like, okay, maybe what everybody does is not the 01:03:50.840 |
right thing to do. And is there any reason to believe it is the right thing to do? Given 01:03:57.840 |
that it seemed like fixing the bug made it worse, maybe not. But then it's like, well, 01:04:05.680 |
okay, we are pretty confident from everything we've learned and discussed that having centered 01:04:11.000 |
data is better than uncentered data. So having data that go from zero to one clearly seems 01:04:16.560 |
weird. So maybe the issue is not that we've changed the center, but that we've scaled 01:04:20.480 |
it down so that rather than having a range of two, it's got a range of one. So at that 01:04:25.480 |
point, you know, I did something very simple, which was I did this, I subtracted 0.5. So 01:04:38.600 |
now rather than going from 0 to 1, it goes from minus 0.5 to 0.5. And so the theory here 01:04:45.200 |
then was, okay, if our hypothesis is correct, which is that the negative one to one range 01:04:50.400 |
has no foundational reason for being. And we've accidentally hit on something, which 01:04:56.400 |
is that a range of one is better than a range of two. And this should be better still, because 01:05:01.000 |
this is a range of one and it's centered properly. And so this is DDPMv3. And I ran that. And 01:05:08.360 |
yes, it appeared to be better. And this is great because now I've got fit. I was able 01:05:14.480 |
to run fit on DDPMv2 and on DDPMv3, and it was dramatically, dramatically, dramatically 01:05:20.280 |
better. And in fact, I was running a lot of other experiments at the time, which we will 01:05:28.040 |
talk about soon. And like all of my experiments are totally falling apart when I fix the bug. 01:05:33.600 |
And once I did this, all the things that I thought weren't working suddenly started working. 01:05:42.960 |
So this is often the case, I guess, is that bugs can highlight accidental discoveries. 01:05:52.800 |
And the trick is always to be careful enough to recognize when that's happened. Some people 01:06:01.480 |
might remember the story. This is how the noble gases were discovered. A chemistry experiment 01:06:07.700 |
went wrong and left behind some strange bubbles at the bottom of the test tube. And most people 01:06:13.800 |
would just be like, huh, whoops, bubbles. But people who are careful enough actually 01:06:19.080 |
went, no, there shouldn't be bubbles there. Let's test them carefully. It's like they 01:06:22.400 |
don't react. Again, most people would be like, oh, that didn't work. The reaction failed. 01:06:28.160 |
But if you're really careful, you'll be like, oh, maybe the fact they don't react is the 01:06:33.160 |
interesting thing. So yes, being careful is not fair for the journey. 01:06:39.040 |
When you say things like it didn't work or it was worse, when you first showed us this 01:06:42.600 |
thing, I kind of said, the images looked fine. The fit was slightly worse. But it was okay. 01:06:49.520 |
And if you trained it longer, it eventually got better mostly. There were some things 01:06:53.040 |
that sampling occasionally went wrong. One image in a hundred or something like that. 01:06:58.120 |
But it was like, this isn't like everything completely fell apart. He's just the truth. 01:07:02.800 |
Women's was slightly worse than expected. And if you were doing the run and gun, try 01:07:08.960 |
a bunch of things, it's like, oh, well, I just doubled my training time and set a few 01:07:12.880 |
runs going and looked at the weights and biases stats later. And oh, that seems like it's 01:07:15.720 |
better now. We just needed to train for longer. And we have internet GQs and lots of money. 01:07:21.600 |
You would notice this. So yeah, it wasn't like, yeah, the fact that you picked up on 01:07:27.640 |
it showed that you had this deep intuition for where it should be at this stage in training 01:07:32.040 |
versus where it was, what the samples should look like. And you had the fit as well to 01:07:36.120 |
say like, okay, I would have expected a fit of nine and I'm getting 14. What's up here? 01:07:42.080 |
And that was enough to start asking these questions and we all jumped on. We all started 01:07:48.080 |
Yeah, I mean, definitely. I drive people crazy that I work with. I don't know why you guys 01:07:53.720 |
aren't crazy yet, but with this kind of like, no, I need to know exactly why this is not 01:08:00.920 |
exactly what we expected. But yeah, this is why. To find that when something's mysterious 01:08:08.080 |
and weird, it means that there's something you didn't understand and that's an opportunity 01:08:11.400 |
to learn something new. So that's what we did. And so that was quite exciting because 01:08:20.720 |
yeah, going -0.5 to 0.5 made the fit better still. And I was definitely in, yeah, I moved 01:08:29.040 |
from this frame of mind from like total depression. I was so mad. I still remember when I spoke 01:08:36.360 |
to Giotto, I was just so upset. And then I suddenly like, oh my gosh, we're actually 01:08:41.920 |
onto something. So I started experimenting more and a bit more confidence at this point, 01:08:48.640 |
I guess. And one thing I started looking at was our schedule. We'd always been copying 01:08:55.960 |
and pasting this standard, again, set of stuff. And I started questioning everything. Why 01:09:02.080 |
is this the standard? Like, why are these numbers here? And we don't see any particular 01:09:08.600 |
reason why those numbers were there. And I thought, well, we should maybe experiment 01:09:14.100 |
with them. So to make it easier, I created a little function that would return a schedule. 01:09:20.920 |
Now you could create a new class for a schedule, but something that's really cool is there's 01:09:25.480 |
a thing in Python called Simplement namespace, which is a lot like a struct in C, basically 01:09:30.620 |
lets you wrap up a little bunch of keys and values as if it's an object. So I created 01:09:37.840 |
this little simple namespace, which contains our alphas, our alphabars, and our sigmas for 01:09:46.280 |
our normal beta max, xs.02 namespace. This is what we always do. And then, yeah, there's 01:09:54.680 |
another paper which mentions an alternative approach, which is cosine schedule, which 01:10:01.760 |
is where you basically set alphabar equal to T as a fraction of big T times pi/2 cosine 01:10:13.560 |
of that squared. And if you make that your alphabar, you can then basically reverse back 01:10:21.000 |
out to calculate what alpha must have been. And so we can create a schedule for this cosine 01:10:27.440 |
schedule as well. And yeah, this cosine schedule is, I think, pretty recognized as being better 01:10:37.840 |
than this linear schedule. And so I thought, okay, it'll be interesting to look at how 01:10:44.600 |
they compare. And in fact, really all that matters is the alphabar. The alphabar is the 01:10:54.960 |
total amount of noise that you're adding. So in DDPM, when we do noisify, it's alphabar 01:11:07.320 |
It's the amount of the image and 1 minus alphabar, that's where it's the amount of noise. 01:11:15.160 |
Yeah. So I just printed those out, plotted those, for the normal linear schedule and 01:11:21.600 |
this cosine schedule. And you can really see the linear schedule. It really sucks badly. 01:11:27.480 |
It's got a lot of time steps where it's basically about zero. And that's something we can't really 01:11:39.360 |
do anything with, you know, whereas the cosine schedule is really nice and smooth and there's 01:11:47.000 |
not many steps which are nearly zero or nearly one. So I thought, so I was kind of inclined 01:11:53.000 |
to try using the cosine schedule, but then I thought, well, it'd be easy enough to get 01:11:57.360 |
rid of this big flat bit by just increasing, by just decreasing beta max. That'd be another 01:12:01.960 |
thing we can do. So I tried, oh, sorry, first of all, I should mention that the other thing 01:12:06.880 |
that's really important is the slope of these curves, because that's how much things are 01:12:11.480 |
stepping during the sampling process. And so here's the slope of the lin and the cosine. 01:12:18.120 |
And you can see the cosine slope, really nice, right? You have this nice smooth curve, whereas 01:12:24.640 |
the linear is just a disaster. So yeah, if I change beta max to 0.01, that actually gets 01:12:34.320 |
you nearly the same curve as the cosine. So I thought that was very interesting. It kind 01:12:41.480 |
of made me think like, why on earth does everybody always use 0.02 as the default? And so we 01:12:47.160 |
actually talked to Robin, who is one of the two lead authors on the stable diffusion paper. 01:12:53.600 |
And we talked about all of these things and he said, oh yeah, we noticed not exactly this, 01:12:58.520 |
but we experimented with everything. And we noticed that when we decreased beta max, we 01:13:03.560 |
got better results. And so actually stable diffusion uses beta max at 0.012. I think 01:13:10.200 |
that might be a little bit higher than they should have picked, but it's certainly a lot 01:13:12.520 |
better than the normal default. So it's interesting talking to Robin to see all of these kinds 01:13:17.600 |
of experiments and things that we tried out, they had been there as well and noticed the 01:13:27.120 |
But the inputs range as well, they have this magical factor of 0.1802 wherever they scale 01:13:33.960 |
the latency by. And if you ask why they're like, oh yeah, we wanted to delay this to 01:13:36.960 |
be like roughly uniform range or whatever, but that's also like, that's reducing the 01:13:44.160 |
I think exactly. We independently discovered this idea. Yeah, exactly. Yeah, exactly. So 01:13:54.640 |
we'll be talking more about like what's actually going on with that maybe next lesson. Anyway, 01:14:03.160 |
so here's the curves as well. They're also pretty close. So at this point I was kind 01:14:06.440 |
of thinking, well, I'd like to like change as little as possible. So I'm going to keep 01:14:10.600 |
using a linear schedule, but I'm just going to change beta max to 0.01 for my next version 01:14:16.360 |
of GDPM. So that's what I've got here, linear schedule, beta max, 0.01. And so that I wouldn't 01:14:22.360 |
really have to change any of my code. I'd end up just put those in the same variable 01:14:25.520 |
names that I've always used. So then noisify is exactly the same as it always has been. 01:14:31.680 |
So now I just repeat everything that I've done before. So now would I show a batch of 01:14:37.800 |
data? I can already see that there's more actually recognizable images, which I think 01:14:45.440 |
is very encouraging. Previously, like almost all of them had been pure noise, which is 01:14:50.480 |
not a good sign. So, okay. So now I just train it exactly the same as GDPM v2. And so save 01:15:00.800 |
this as fashion GDPM 3. Oh, and then the other things I've done here is, you know, this did 01:15:08.040 |
turn out to work pretty well. I actually decided let's keep going even further. So I actually 01:15:13.040 |
doubled all of my channels from before, and I also increased the number of epochs by 3 01:15:19.960 |
because things are going so well. I was like, how well could they go? So we've got a bigger 01:15:23.400 |
model trained for longer. It takes a few minutes. That's what the 25 here is the number of epochs. 01:15:31.880 |
So samples exactly the same as it always has been. So create 512 samples and here they are. 01:15:41.400 |
And they definitely look to me, you know, great. Like they, I'm not sure I could recognize 01:15:49.160 |
whether these are real samples or generated samples. But luckily, you know, we can test 01:15:55.760 |
them so we can load up our data org2, delete the last two layers, pass that to image val 01:16:04.720 |
and get a fit for our samples. And it's eight. And then I chose 512 for a reason, because 01:16:12.960 |
that's our batch size. So then I can compare that like with like for the fit for the actual 01:16:17.920 |
data at 6.6. So this is like hugely exciting to me. We've got down to a fit that is nearly 01:16:26.060 |
as good as real images. So I feel like this is, you know, in terms of image quality for 01:16:37.040 |
small unconditional sampling, I feel like we're done, you know, pretty much. And so 01:16:46.960 |
at this point, I was like, OK, well, can we make it faster? You know, at the same quality. 01:16:51.960 |
And I just wanted to experiment with a few things, like really obvious ideas. And in 01:16:55.560 |
particular, I thought we're calling this a thousand times, which means we're calling 01:17:09.160 |
this a thousand times, just running the model. And that's slow. And most of the time you 01:17:15.600 |
just move a tiny bit. So the model is pretty much the same. It's, you know, the noise being 01:17:21.040 |
predicted is pretty much the same. So I just did something really obvious, which is I decided 01:17:26.720 |
let's only call the model every third time, you know, and maybe also just the last 50 01:17:34.000 |
to help with fine tune. I don't know if that's necessary other than that. It's exactly the 01:17:39.280 |
same. So now this is basically three times faster. And yeah, samples look basically the 01:17:47.600 |
same. So the feed is nine point seven eight versus eight point one. And this is like within 01:17:52.880 |
the normal variance of feed. So I don't know, like you'd have to run this a few times or 01:17:57.360 |
use bigger samples. But this is basically saying like, yeah, you probably don't need 01:18:01.800 |
to call the model a thousand times. I did something else slightly weird, which is I 01:18:08.080 |
basically said like, oh, let's create a different like schedule for how often we call the model, 01:18:14.200 |
which is I created this thing called sample app. It basically said when you're for the 01:18:18.520 |
first few time steps, just do it every 10 and then for the next two, every nine and 01:18:23.800 |
then actually every eight and so forth. And just for the last hundred, do it every one. 01:18:28.480 |
So that makes it even faster. Um, samples look good. This is, you know, it's definitely 01:18:35.720 |
worse though now, you know, but it's still not bad. So, um, yeah, I kind of felt like, 01:18:42.380 |
all right, this is encouraging and this, this stuff before we fixed the minus one to one 01:18:47.840 |
thing was they looked really bad, you know, um, that's why I was thinking that my code 01:18:53.040 |
is full of bugs. Um, so at this point I'm thinking, okay, okay, okay. We can create 01:18:57.560 |
extremely high quality samples using DDPM. What's the like, you know, best paper out 01:19:04.720 |
there for doing it faster. And the most popular paper for doing it faster is DDIM. So I thought 01:19:12.920 |
we might switch to this next. So we're now at the point where we're not actually going 01:19:18.580 |
to retrain our model at all, right? If you noticed with these different sampling approaches, 01:19:26.000 |
I didn't retrain the model at all. We're just saying, okay, we've got a model. The model 01:19:30.560 |
knows how to estimate the noise in an image. How do we use that to call it multiple times 01:19:37.440 |
to denoise using iterative refinement as Jono calls it. Um, and so DDIM is a, another way 01:19:50.640 |
of doing that. So, um, what we're going to do, so I'm going to show you how I built my 01:20:00.200 |
own DDIM from scratch. And, um, I kind of cheated, which is I, there's already an existing 01:20:08.280 |
one in diffusers. So I decided I will use that first, make sure everything works, and 01:20:15.520 |
then I'll try and re-implement it from scratch myself. Um, so that's kind of like when there's 01:20:20.720 |
an existing thing that works, you know, that's what I like to do. And it's been really good 01:20:25.400 |
to have my own DDIM from scratch because now I can modify it, you know, and I've made it 01:20:30.780 |
much more concise code than the diffusers version. So, um, now, um, we had created this 01:20:41.320 |
class called unit, which passed the tuple of X's through as individual parameters and returned 01:20:52.920 |
the dot sample. But not surprisingly, the, um, given that this comes from diffusers and 01:20:58.840 |
we want to use the diffusers schedulers, um, the diffusers, um, schedulers assume this has 01:21:06.560 |
not happened. It wants the X, you know, as a tuple and it expects to find the thing called 01:21:11.560 |
dot sample. So here's something crazy. When we save this thing, this pickle, it doesn't 01:21:21.560 |
really know anything about the code, right? It just knows that it's from a class called 01:21:28.360 |
unit. So we can actually lie. We can say, oh yeah, that class called unit. It's actually 01:21:35.760 |
the same as you connect 2D model with no other changes and Python doesn't know or care, right? 01:21:43.120 |
So this, we can now load up this model and it's going to use this unit. Okay. So this 01:21:49.200 |
is where it's useful to understand how Python works behind the scenes, right? It's, it's, 01:21:53.720 |
it's a very simple programming language. Um, so we've now got a model which we've trained, 01:22:00.600 |
but it's not, it's just going to, you know, use the dot sample on it. That means we can 01:22:04.320 |
use it directly with the diffusers schedulers. So we'll start by actually repeating what 01:22:10.040 |
we already know how to do, which is use a DDPM scheduler. So we have to tell it what 01:22:13.680 |
beta we used to train. Um, and so we can grab some random data and so it could say, okay, 01:22:22.440 |
we're going to start at time step 999. So let's create a batch of data and then predict 01:22:29.800 |
the noise. And then this is the way the diffusers thing works. As you call scheduler.step and 01:22:37.560 |
that's the thing which does, um, those lines. That's the thing that calculates X, T given 01:22:48.480 |
noise. So that's what scheduler.step does. So that's why you pass in X, T and the time 01:22:54.700 |
step and the noise. Um, and that's going to give you a new set. And so I ran that as usual 01:23:03.200 |
first cell by cell to make sure I understood how it all worked. I then copied those cells 01:23:07.240 |
and merged them together and chucked them in a loop. So this is now going to go through 01:23:12.560 |
all the time steps, use a progress bar to see how we're going, get the noise, call step 01:23:19.780 |
and append. So this is just DDPM, but using diffusers and not surprisingly, um, it consists, 01:23:28.240 |
you know, basically the same results as, you know, nice results, very nice results that 01:23:33.480 |
we got from our own DDPM. And so we can now use the same code we've used before to create 01:23:40.040 |
our image evaluator. And, um, I decided, yeah, we're now going to go right up to 2048 images 01:23:48.880 |
at a time. So it's now, this is the size I found it's big enough that it's recently stable. 01:23:54.860 |
And so we now down to 3.7 for our feed, where else the data itself has a feed of 1.9. So 01:24:02.760 |
again, it's showing that our DDPM is basically very nearly unrecognizably different from real 01:24:11.600 |
data using its distribution of those activations. So then we can switch to DDIM by just saying 01:24:20.160 |
DDIM scheduler. And so with DDIM, you can say, I don't want to do all thousand steps. 01:24:25.120 |
I just want to do 333 steps to every third. So that's basically a bit like, um, a bit 01:24:35.680 |
like this sample skip of doing every third. But DDIM as we'll see, does it in a smarter 01:24:42.320 |
way. Um, and so here's exactly the same code basically as before, but I put it into a little 01:24:51.000 |
function. Okay. So I can basically pass in my model, the size, the scheduler. Um, and 01:24:58.840 |
then there's a parameter called ADA, which is basically how much noise to add. Um, so 01:25:04.840 |
uh, just add all the noise. Um, and so this is now going to take three times. This three 01:25:10.920 |
times faster. And yeah, the fit's basically the same. That's encouraging. So they weren't 01:25:16.680 |
added 200 steps. If it's basically the same, 100 steps. And at this point, okay, the fit's 01:25:25.240 |
getting worse. Um, and then 50 steps. We're still 25 steps. We're still, that's interesting. 01:25:38.160 |
Like when you get down to 25 steps, like what does it look like? And you can see that they're 01:25:42.360 |
kind of like, they're too smooth. You know, they don't have interesting, you know, fabric 01:25:49.880 |
swirls so much or buckles or logos or patterns as much, you know, as the, these ones, they've 01:25:58.920 |
got a lot more texture to them. So that's kind of what tends to happen. So you can still 01:26:03.440 |
like get something out pretty fast, but that's, that's kind of how they suffer. So, okay. 01:26:12.480 |
So how does DDIM work? Well, DDIM, it's nice. It's actually, in my opinion, it makes things 01:26:19.300 |
a lot easier than DDPM. So there's basically an equation from the paper, which Tanishka 01:26:28.880 |
will explain shortly. But basically what you do is I've, I've actually grabbed the sample 01:26:38.200 |
function from, from here and I split it out into two bits. One bit is the bit that says 01:26:48.920 |
what are the time steps, creates that random starting point, loops through, finds what 01:26:56.300 |
my current A bar is, gets the noise, and then basically does the same as shed.step, calls 01:27:04.400 |
some function, right? And then that's been pulled out. So this allows me to now create 01:27:10.120 |
my own different steps. So I go to the DDIM step and basically all I did was I took this 01:27:22.820 |
equation and I turned it into code. Actually this, this one is a second equation from the 01:27:28.600 |
paper. Now it's a bit confusing, which is that the notation here is different. DDPM, 01:27:36.800 |
what it calls alpha bar, this paper calls alpha. So you've got to look out for that. 01:27:43.280 |
So basically you'll see, I basically go, I've got here XT, XT minus, okay, one minus alpha 01:27:50.800 |
bar is, we've got to call that beta bar. So beta bar dot square root times noise. This 01:27:58.040 |
here is the, this is the neural net, so this here is the noise. Okay. And here I've got 01:28:12.840 |
my next XT is, oh sorry, yes, here's my A bar T1 square root times this. And you can 01:28:23.040 |
see here it says predicted X naught. So here's my predicted X naught plus beta bar T1 minus 01:28:31.240 |
sigma squared square root. Again, here's noise. That's the same thing as here. Okay. And then 01:28:42.480 |
plus a bit of random noise, which we only do if you're not at the last step. So yeah, 01:28:48.720 |
so I can call that. So I just did it for, so rather than saying a hundred steps, I said 01:28:56.120 |
skip every, skip 10 steps. So do 10 steps at a time. So it's basically going to be a 01:29:00.240 |
hundred steps. And so you can see here actually, this is happened to do a bit better for my 01:29:05.640 |
hundred steps. It's not bad at all. So yeah, I mean, this has been getting to this point, 01:29:19.000 |
it's been a bit of a lifesaver to be honest, because, you know, I can now run a tooth, 01:29:24.720 |
you know, two batch of 2048 samples. I can sample them in under a minute, which doesn't 01:29:31.520 |
feel painful. And so, you know, now at a point where I can actually get a pretty good measure 01:29:38.400 |
of how I'm doing, get a pretty reasonable amount of time and I can, you know, easily 01:29:46.160 |
compare it. And I got to admit, you know, the difference between a fit of five and eight 01:29:51.440 |
and 11, I can't necessarily tell the difference. So for fashion, I think fit is better than 01:29:57.640 |
my eyes for this, as long as I use a consistent sample size. So yeah, Tanisha, did you want 01:30:06.320 |
to talk a bit about, you know, the ideas of why we do this or where it comes from or what 01:30:14.880 |
Can I say a little bit before we do that, which is just that what you have there, Jeremy, 01:30:19.040 |
which is like a screenshot from the paper and then the code that as close as possible 01:30:24.480 |
tries to follow that, like the difference that makes for people is huge. Like I've got 01:30:30.120 |
a little research team that I'm doing some, you know, contract work with. And the fact 01:30:35.040 |
that like, it's called Alpha in the data and paper and Alpha elsewhere. And then in the 01:30:38.520 |
code that they were copying and pasting from it was called A and B for Alpha and B done. 01:30:44.160 |
And it's like you can get things kind of working by copying and pasting into things. And it's 01:30:48.720 |
all just sort of kind of works. But just spending their time that actually take two screenshots 01:30:52.640 |
from the equation 14 and 16 from the paper and put them in there and rewrite the code 01:30:57.720 |
so that it, you know, with some comments and things to say, like, this is what this is, 01:31:00.840 |
this is that part from the equation. It's like, you know, the look of pain on their face 01:31:06.520 |
when I said, oh, by the way, did you notice that like, it's called Alpha there and Alpha 01:31:09.400 |
by there? They're like, yes, how could they do that? You know, it's just like, you could 01:31:12.880 |
just tell how many hours have been spent, you know, like grinding text and saying what's 01:31:17.960 |
Yeah, and building this stuff in notebooks is such a good idea. Like we're doing MIDI 01:31:23.920 |
AI because the, you know, the next engineer to come along and work on that can see the 01:31:31.960 |
equation right there and you can add rows and stuff. So I think, you know, NVDev works 01:31:39.000 |
particularly well for this, this kind of development. 01:31:43.520 |
Yeah, before, before, before I talk about this, I just wanted to briefly, in the context 01:31:51.720 |
of all of these different notations, I recently created this meme, which I thought was, was 01:31:58.320 |
relevant in terms of like, each paper basically has a different diffusion model of me, different 01:32:03.320 |
notation. So it's just like this, but they all try to come up with their own universal 01:32:07.440 |
notation and it's just, just keeps proliferating. 01:32:14.360 |
Yes, exactly. We need to implement diffusion models in APLs. So yeah, the paper that Jeremy 01:32:25.400 |
had implemented was this denoising diffusion implicit model paper. And if you look at the 01:32:33.400 |
paper again, you can see like the notation could be again, a little bit intimidating, 01:32:38.680 |
but when we walk through it, we'll see it's not too bad actually. So I'll just bring up, 01:32:44.680 |
I guess, some of the important equations and also comparing and contrasting, you know, DDPM 01:32:53.120 |
and the notation of DDPM and the equations with DIM. 01:32:57.280 |
Not only is it not too bad, I actually discovered it's making life a lot. The DDIM notation 01:33:03.640 |
and equations are a lot easier to work with than DDPM. So I found my life is better since 01:33:12.160 |
Yes, yes. I think a lot of people prefer to use DDIM as well. Yeah. So yeah, basically 01:33:20.280 |
in, let's see here. So yeah, so in both DDIM and in both DDIM and DDPM, we have this same 01:33:32.120 |
sort of equation. This equation is exactly the same. This is telling us the predicted 01:33:40.920 |
denoised image. So we predict our, but basically we predict the, you can see my pointer, right? 01:33:50.260 |
By the way, so the little double-headed arrow in the top right, does that, if you click 01:33:55.080 |
that, do you get more room for us to see what's going on? 01:33:59.960 |
Yeah, I see. Yeah, that works much better. Yeah. So we have our predicted noise. So our 01:34:11.600 |
model is predicting the noise in the image. It is also passed in the time step, but this 01:34:17.880 |
is just emitted. It basically kind of is given in the XT, but our model also takes in the 01:34:23.320 |
time step. And so it's predicting the noise in this XT, our noisy image. And we are trying 01:34:29.940 |
to remove the noise. That's what this whole term here is, remove the noise. So because 01:34:39.240 |
our noise that we're predicting is unit variance noise. So we have to scale the variance of 01:34:44.080 |
our noise appropriately to remove it from our noisy image. So we have to scale the noise 01:34:51.080 |
and subtract it out of the original image. And that's how we would get our predicted 01:35:01.000 |
And I think we have to write this one before by looking at the equation for XT in the Noisify 01:35:10.680 |
function and rearranging it to solve for X naught. And that's what you get. 01:35:17.960 |
Yes, that's basically what this is. So basically the idea is, okay, instead of noisifying it 01:35:25.400 |
where we're starting out with X0 and some noise and get an XT, we're doing the opposite 01:35:30.780 |
where we have some noise and we have XT. So how can we get X0? So that's what this equation 01:35:36.000 |
is. So that's the predicted X0 or our predicted clean image. And this equation is the same 01:35:43.000 |
for both DDPM and DDIM, but these distributions are what's different between DDPM and DDIM. 01:35:50.840 |
So we have these distribution, which tells us, okay, so if we have XT, which is our current 01:35:58.320 |
noisy image, and X0, which is our clean image, can we find out what some sort of intermediate 01:36:05.360 |
noisy image is in that process? And that's XT minus one. So we have a distribution for 01:36:12.760 |
that. And so that tells us how to get such an image. And so this is in the DDPM paper, 01:36:19.000 |
they did to define some distribution and explain the math behind it. But yeah, basically, they 01:36:25.200 |
have some equations. So you have, again, a Gaussian distribution with some sort of mean 01:36:32.000 |
and variance, but it's, again, some form of you have this sort of interpolation between 01:36:38.840 |
your original clean image and your noisy image. And that gives you your intermediate, slightly 01:36:47.800 |
less noisy image. So that's what this is giving. Given a clean image and a noisy image, you're 01:36:54.240 |
slightly less noisy image. And so the sampling procedure that we do with DDPM basically is 01:37:03.440 |
predict the noisy, predict the X0, and then plug it into this distribution to give you 01:37:10.140 |
your slightly less noisy image. So maybe it's worth drawing that out. So like if we had, 01:37:19.280 |
let's say, some sort of, like, I don't know, I'm just making some sort of, I don't know, 01:37:24.120 |
maybe a lot of some sort of better, yeah, some sort of. So then, in this case, I'm showing 01:37:30.480 |
a one-dimensional example. Let's say you have some sort of a point. So it's kind of a one-dimensional 01:37:35.920 |
example that's still in the sort of 2D space. But let's say you have any point on this, 01:37:42.200 |
it represents an actual image that you want to sample from, right? So this is where your 01:37:46.680 |
distribution of actual images would lie. And you want to estimate this. So when this sort 01:37:53.400 |
of algorithm that we've been seeing here says that, okay, if we take some random point, 01:38:00.160 |
this is some random point that we choose, you know, when we start out. And what we did 01:38:05.240 |
is we learned this function, the score function, to take us to this manifold, but it's only 01:38:10.400 |
going to be accurate in some space. So it's going to be accurate, you know, it's only, 01:38:15.000 |
it would be accurate in some area. So we get an estimate of the score function and it tells 01:38:20.040 |
us the direction to move in. And it's going to give us the direction to predict our denoised 01:38:26.880 |
image, right? So basically, like, let's say, let's say this, let's say you actually, your 01:38:33.600 |
score function, sorry, so let's say your score function is actually in reality, some curve, 01:38:38.520 |
okay? So it's in reality, some curve that points to your, oops, it points here. So that's your 01:38:43.640 |
score function. And you know, the value here, that's what score function basically means 01:38:48.280 |
your gradient. Yeah. Yes, yes, it's a gradient. So, you know, we have, again, doing some form 01:38:54.720 |
of, in this case, I guess you would say gradient ascent, because you're not really minimizing 01:39:00.120 |
the score, you're maximizing it. You want, sorry, you're maximizing your, the likelihood 01:39:05.360 |
of that data point being an actual data point. You want to go towards it. So you're doing 01:39:10.840 |
the sort of gradient ascent process. So you're following the gradient to get to that. So 01:39:17.920 |
when we estimate epsilon theta and predict our noise, what we're doing is we're getting 01:39:24.280 |
the score value here. And then so we can, you know, follow that. And we follow it to 01:39:30.720 |
some point. I'm being kind of exaggerating here. But this point will now represent our 01:39:38.920 |
x zero hat. Yeah. So, yeah, our x or hat. So, and in reality, you know, that's not maybe 01:39:47.980 |
that's not going to be some point that is an actual point, it wouldn't be next to the 01:39:52.920 |
distribution. So, you know, it's, it's not going to be a very good estimate of a clean 01:39:58.240 |
image at the beginning. But, you know, we only have that estimate at the beginning at 01:40:02.920 |
this point, and we have to follow it all the way to to some place. So this is where we 01:40:07.880 |
follow it to. And then we want to find some sort of x t minus one. So that's what our 01:40:16.160 |
next point is. And so that's what our second distribution tells us. And it basically takes 01:40:22.900 |
us all the way, takes us all the way back to maybe some point here. And now we can re-estimate 01:40:30.080 |
the the the score function or our gradient over there, you know, do this prediction of 01:40:35.200 |
noise. And, you know, it may be more accurate of a of a score function. And maybe we go 01:40:40.880 |
somewhere here. And then we re-estimate and get another point and then we follow it. And 01:40:47.240 |
so that's kind of this iterative process where we're trying to follow this the score function 01:40:52.740 |
to your own point. And in order to do so, we first have to estimate our x zero hat and 01:40:58.360 |
then and then basically add back some noise and to get, you know, a little bit again, 01:41:07.120 |
a new estimate and keep falling and add back a little bit more noise and keep estimating. 01:41:11.400 |
So that's what we're doing here in these two steps. We have our x zero hat, and then we 01:41:16.920 |
have this distribution. And that's how we do it regular DDPM. And it's that I think that's 01:41:23.160 |
the maybe where the sort of broken it up in two steps is a bit clearer. And I don't think 01:41:30.040 |
the DDPM paper really clarifies that really talks about it too much. But the DDPM paper 01:41:36.740 |
also really hammers that point home, I think, and especially in their in their update equation. 01:41:42.720 |
So the other so that's the deep DDPM, but then with DDPM, just the one thing is that 01:41:52.320 |
you look at your prediction, use that to make a step that you also add back some additional 01:41:56.920 |
noise that's always fixed. Right. There's no parameter to control how much extra noise 01:42:01.960 |
you add back at each step. Right, exactly. So oops. So yeah, so then you're, let's see 01:42:10.200 |
here. Yeah, so yeah, basically, you won't be exactly at this point, you could be, you 01:42:14.400 |
know, you're in that general vicinity, and, you know, adding that noise also helps with, 01:42:19.520 |
you know, you don't want to fall into, you know, specific modes where it's like, oh, 01:42:26.280 |
you know, this is the most likely data point, you want to add some noise where you can like 01:42:30.200 |
explore other data points as well. So yeah, there's some, you know, the noise also can 01:42:36.080 |
help and that's something you really can't control with DDPM. And this is something that 01:42:40.920 |
DDIM explores a little bit further is in terms of the noise and even trying to get rid of 01:42:45.600 |
the noise altogether in DDIM. So with the DDIM paper, the main difference is literally 01:42:54.160 |
this one equation, that's all all really it is in terms of changing this distribution 01:43:01.360 |
where you predict the less noisy equation, the less noisy, sorry, the less noisy image. 01:43:10.640 |
And basically, as you can see, it's just a, you have this additional parameter now, which 01:43:18.360 |
is sigma. And the sigma controls how much noise, like we were just mentioning, is going 01:43:25.280 |
to be part of this process. And you can actually, for example, if you want, you could set sigma 01:43:29.800 |
to zero. And then you can see here, now you have a variance, that would be zero. And so 01:43:34.440 |
this becomes a completely deterministic process. So if you want, this could be completely deterministic. 01:43:41.840 |
So that's one aspect of it. And then, yeah, so the other, the other aspects of that, right, 01:43:52.880 |
the reason is called DDIM is just not DDPM, because it's like not probabilistic anymore. 01:43:57.840 |
It can be made deterministic. So the name was changed for that reason. But the other thing 01:44:04.280 |
is like, you would think that you've kind of changed the model altogether with a new 01:44:08.920 |
distribution altogether. And so you say, oh, wouldn't you have to like trade a different 01:44:14.160 |
model for this purpose? But it turns out the math works out where the same model objective 01:44:19.840 |
would work well with this distribution as well. 01:44:23.120 |
And in fact, I think that's what they were setting out from the very beginning is what 01:44:27.840 |
kind of other models can we get with the same objective. And so this is what they're able 01:44:34.600 |
to do is you can make some, you can have this new parameter that you can introduce, in this 01:44:40.040 |
case, kind of controlling the stochasticity of the model. And it still can be, you can 01:44:50.000 |
still use the same exact trained model that you had. 01:44:55.320 |
So what this means is that this actually is just a new sampling algorithm, and not anything 01:45:00.480 |
new with the training itself. And this is just, yeah, just like we talked about a new 01:45:05.680 |
way of sampling the model. And then, so yeah, this is how, you know, given now this equation, 01:45:13.800 |
then you can rewrite your x t minus one term. And again, we're doing the same sort of thing 01:45:18.360 |
where we split it up into predicting the x zero, and then adding back to go back to your 01:45:27.400 |
x t. And also, if you need to add a little bit of noise back in, like Jonathan was saying, 01:45:35.080 |
you can do so, you have this extra term here, and the sigma controls that term. 01:45:41.440 |
And again, like we said, you have to be, again, looking at the DDI equation versus the DDPM 01:45:48.080 |
equation, you have to be careful of the alphas here are referring to alpha bars in the DDPM 01:45:53.960 |
equation. So that's the other caveat. So yeah, and this you have this sigma t set to this 01:46:02.840 |
particular value will give you back DDPM. So sometimes instead, instead, they will write, 01:46:10.240 |
basically, Jeremy mentioned this sort of, I guess, eta, which is equal to basically, 01:46:20.520 |
yeah, so it's just basically, eta is sigma is equal to eta times this coefficient. So 01:46:28.840 |
sorry, let me just go back. So basically, yeah, in reality, you take, you have eta here, 01:46:43.600 |
so it's like, yeah, this is where eta would go. So if it's one, it becomes regular DDPM 01:46:48.800 |
and if it's zero, of course, that's a determinant case. So this is where the eta that, you know, 01:46:53.320 |
all these API's and in the code that we have, also the code that Jeremy was showing, they 01:46:59.240 |
have eta equals to one, which of course, which they say is corresponding to regular DDPM. 01:47:06.280 |
This is actually where the eta would go in the equation. So finally, it's like, yeah, 01:47:14.200 |
you could pass in sigma, right? Like if you weren't trying to match it in clean use papers, 01:47:18.080 |
you could just say, oh, well, we have this parameter sigma that controls the amount of 01:47:20.760 |
noise. So let's just take in a big more scale as an argument. But for convenience, they said, 01:47:25.760 |
let's create this new thing, eta, where zero means sigma is equal to zero, which if you 01:47:31.320 |
look at the equation that works, one means we match the amount of noise that's in like 01:47:37.600 |
vanilla DDPM. And so then that gives you like a nice slice. So you could say eta equals 01:47:42.240 |
two or eight equals 0.7 or whatever. But it's like, but a meaningful unit of one equals the 01:47:48.480 |
same as this previous reference work. Well, it's also convenient because it's sigma t, 01:47:54.360 |
which is to say different time steps, unless you choose eta equals zero, in which case it 01:48:00.240 |
doesn't matter. Different time steps probably want different amounts of noise. And so here's 01:48:07.040 |
a reasonable way of scaling that noise. Then the last thing of importance, which is of 01:48:16.240 |
course, one of the reasons that we were exploring this as the first place is to be able to do 01:48:21.560 |
this sort of rapid sampling. So the basic idea here is that you can define a similar 01:48:32.440 |
distribution where again, the math works out similarly, where now you have, let's say you 01:48:40.080 |
have some subset of diffusion steps. So in this case, it uses tau variable. So for example, 01:48:46.080 |
if you say, let's say subset of diffusion steps. So if it's like 10 diffusion steps, 01:48:52.760 |
then tau one would just be zero, then tau two would be 10. You just keep going all the 01:48:57.920 |
way up to say a thousand, but you've only got the, sorry, tau two would be a hundred. 01:49:02.720 |
And then you go all the way up to a thousand. And so you'd get 10 diffusion steps. So that's 01:49:08.200 |
what they're referring to when they have this, I guess this tau variable here. And so you 01:49:16.880 |
can do these sorts of similar equation and similar derivation to show that this distribution 01:49:24.400 |
here again, meets the sort of objective that you use for training. And you can now use 01:49:31.160 |
this for a faster sampling, where basically all you have to do is you have to just select 01:49:37.400 |
the appropriate alpha bar. And sorry, this one I've written out. So this one, actually 01:49:43.480 |
the alpha bar is the regular alpha bar that we've talked about. But basically, sorry, 01:49:47.960 |
it's a little bit confusing switching between different notations. But basically, you have 01:49:54.760 |
this distribution and then you just have to select the appropriate alpha bars and it follows 01:50:01.000 |
the math the same in terms of you have appropriate sampling process. So yeah, and I guess that's, 01:50:11.600 |
it makes it a lot simpler in terms of doing this, I guess, accelerated sampling. Yeah, 01:50:21.080 |
I guess with any other note, maybe other comments that maybe you guys had, or was this? 01:50:27.240 |
Yeah, the key for me is that in this equation, we just have one, we only need one parameter, 01:50:39.360 |
which is the alpha bar or alpha depending which notation is raised and everything else 01:50:43.480 |
is calculated from that. And so we don't have the, what DDPM calls the alpha or beta anymore. 01:50:54.720 |
And that's more convenient for doing this kind of smaller number of steps, because we can 01:51:01.960 |
just jump straight from time step to alpha bar. And we can also then, as particularly 01:51:10.240 |
convenient with the cosine schedule, because you can calculate the inverse of the cosine 01:51:16.240 |
schedule function, which means you can also go from an alpha bar to a T. So it's really 01:51:20.960 |
easy to like, say like, oh, what would alpha bar be 10 time steps previously to this one, 01:51:27.680 |
you know, it's just, you could just call a function. We don't need, yeah, we don't need 01:51:31.440 |
anything else. And so actually the original cosine schedule paper has to fuss around with 01:51:38.960 |
various like kind of epsilon style small numbers that they add to things to avoid getting weird 01:51:47.920 |
numerical problems. And so yeah, when we only deal with alpha bar, all that stuff also goes 01:51:53.640 |
away. So yeah, so looking, if you're looking at the DDIM code, you know, it's simpler code 01:52:04.600 |
with less parameters than our DDPM code. And of course, it's dramatically faster. And it's 01:52:11.400 |
also more flexible because we've got this eta thing we can play with. 01:52:14.800 |
Yes. Yeah, that's the other thing. It's like this idea of like, yeah, controlling stochasticity. 01:52:21.800 |
I think that's something that's interesting to explore. And we've been exploring that 01:52:26.040 |
a bit now. And I think we'll continue to explore that in terms of deterministic versus stochasticity. 01:52:32.480 |
So it's worth talking about this, the sigma in the middle equation you've got there. So 01:52:37.280 |
you've got the sigma t eta t adding the random noise. And intuitively, it makes sense that 01:52:42.880 |
if you're adding random noise there, you would need to have less, you want to move less back 01:52:48.640 |
towards xt, which is your noisy image. So that's why, you know, you've got the 1 minus alpha 01:52:55.460 |
t minus 1 minus sigma squared. And then you're taking the square root of that. So basically, 01:53:02.160 |
that's just sigma, the square root of the squared. So you're subtracting sigma t from 01:53:08.520 |
the direction pointing to xt and adding it to the random noise, or vice versa. So yes, 01:53:14.920 |
you know, that everything's there for a reason, you know, yes. And the predicted x0, that entire 01:53:28.840 |
And it remains the same in pretty much any diffusion model methodology. 01:53:32.360 |
Well, as long as we'll be talking about actually some places where it's going to change probably 01:53:39.760 |
Well, yeah, I guess it's another thing where you're predicting the noise. Yes. Yes. Yes. 01:53:44.980 |
If you're predicting the noise, yes, there'll be. 01:53:49.000 |
So I think, you know, we'll probably, yeah, let's wrap it up here so that we leave ourselves 01:53:58.060 |
plenty of time to cover the kind of new research directions next lesson more in more detail. 01:54:03.560 |
But as I mentioned, in terms of where we're at, just like we hit a kind of like, okay, 01:54:09.960 |
we can really predict classes for Fashioned MNIST a few weeks ago, where I think we're 01:54:15.600 |
there now, and we can do stable diffusion sampling and units, except for the unit architecture, 01:54:24.520 |
or unconditional generation now, we basically can do Fashioned MNIST almost so it's unrecognizably 01:54:32.440 |
different to the real samples, and DDIM is the scheduler that the original stable diffusion 01:54:43.080 |
So yeah, you know, we're actually about to go beyond stable diffusion for our sampling 01:54:53.140 |
and unit training now. So I think we've, yeah, definitely meeting our stretch goals so far, 01:55:03.440 |
and all from scratch, with weights and biases, experiment logging. 01:55:11.600 |
And you know, if you wanted to have fun, there's no reason you couldn't like have a little 01:55:17.400 |
call back that instead logs things into a SQL like database, and then you could write 01:55:21.600 |
a little front end to show your experiments, you know, that'd be fun as well. 01:55:27.920 |
Yeah, I mean, you could do also send you a text message when the loss gets good enough. 01:55:34.160 |
Alright, well, thanks, guys. That was really fun.