back to index

Lesson 21: Deep Learning Foundations to Stable Diffusion


Chapters

0:0 A super cool demo with miniai and CIFAR-10
2:55 The notebook
7:12 Experiment tracking and W&B callback
16:9 Fitting
17:15 Comments on experiment tracking
20:50 FID and KID, metrics for generated images
23:35 FID notebook (18_fid.ipynb)
31:7 Get the FID from an existing model
37:22 Covariance matrix
42:21 Matrix square root
46:17 Why it is called Fréchet Inception Distance (FID)
47:54 Some FID caveats
50:13 KID: Kernel Inception Distance
55:30 FID and KID plots
57:9 Real FID - The Inception network
61:16 Fixing (?) UNet feeding - DDPM_v3
68:49 Schedule experiments
74:52 Train DDPM_v3 and testing with FID
79:1 Denoising Difussion Implicit Models - DDIM
86:12 How does DDIM works?
90:15 Notation in Papers
92:21 DDIM paper
113:49 Wrapping up

Whisper Transcript | Transcript Only Page

00:00:00.000 | Hello, Jono. Hello, Tanishk. Are you guys ready for lesson 21?
00:00:05.720 | Ready.
00:00:06.720 | Yep, I'm excited.
00:00:07.720 | I don't know what I would have said if you had said no. So good.
00:00:12.760 | I'm actually particularly excited because I had a little bit of a peak preview of something
00:00:16.640 | that Jono has been working on, which I think is a super cool demo of what's possible with
00:00:24.480 | very little code with mini AI.
00:00:27.040 | So let me turn it over to Jono.
00:00:30.320 | Great, thanks, Jeremy. Yeah, so as you'll see, when it's back to Jeremy to talk through
00:00:35.920 | some of the experiments and things we've been doing, we've been using the fashion reminisce
00:00:40.240 | dataset at a really small scale and really rapidly try out these different ideas and
00:00:44.840 | see some maybe nuances or things that we'd like to explore further. And so as we were
00:00:49.640 | doing that, I started to think that maybe it was about time to explore just ramping
00:00:54.000 | up the level, like seeing if we can go to the next slightly larger datasets, slightly
00:00:58.120 | harder difficulty, just to double check that these ideas still hold for longer training
00:01:02.960 | runs and different more difficult data.
00:01:06.960 | That's a really good idea because I feel pretty confident that the learnings from fashion
00:01:11.720 | reminisce are going to move across most of the time these things seem to, but sometimes
00:01:17.720 | they don't and it can be very hard to predict. So it seems like a very wise choice.
00:01:23.000 | Yeah. And so we'll keep wrapping up, but as a next step, one above fashion enmost, I thought
00:01:30.280 | I'd look at this data called CIFAR10. And so CIFAR10 dataset is a very popular dataset
00:01:35.920 | originally for things like image classification, but also now for any paper on generative modeling.
00:01:43.080 | It's kind of like the smallest dataset that you'll see in these papers. And so yeah, if
00:01:48.600 | you look at the classification results, for example, pretty much every classification
00:01:51.480 | paper since they started tracking has reported results on CIFAR10 as well as their larger
00:01:57.600 | datasets. And likewise with image generation, very, very popular, all of the recent diffusion
00:02:04.040 | papers will usually report CIFAR10, add their image net, and then whatever large massive
00:02:09.120 | dataset they're training on.
00:02:13.200 | We were somewhat notable in 2018 for managing to train. So for CIFAR10, 94% classification
00:02:20.640 | is kind of the benchmark. So there was a competition a few years ago where we managed to get to
00:02:25.000 | that point at a cost of like 26 cents worth of AWS time, I think, which won a big global
00:02:34.760 | competition. So I actually hate CIFAR10, but we had some real fun with it a few years ago.
00:02:43.120 | Yeah. And it's good. It's a nice dataset for quickly testing things out, but we'll talk
00:02:47.560 | about why we also like us as a group don't like it at all. And we'll pretty soon move
00:02:52.360 | on to something better.
00:02:55.320 | So one of the things you'll notice in this notebook, I'm basically using all of the same
00:02:59.560 | code that Jeremy is going to be looking at and explaining. So I won't go into too much,
00:03:03.800 | but the data sets also on HuggingFace. So we can load it just like we did the fashion
00:03:08.040 | Mnest. The images are three channel rather than single channel. So the shape of the data
00:03:18.680 | is slightly different to what we've been working with. That's weird. Yeah. So we have instead
00:03:26.280 | of the single channel image, we have a three channel red, green, and blue image. And this
00:03:30.160 | is what a batch of data looks like.
00:03:31.520 | And you've got then two images in your batch. So that's batch by channel by height by width,
00:03:36.840 | right?
00:03:37.840 | Back to our channel by height and width.
00:03:39.160 | That was a little confused by the 32 by 32 is it?
00:03:42.520 | Oh, yeah. It's fine. I got it now. Batch size in the arbitrary. And so if you plot these,
00:03:48.200 | one of the things, if you look at this, okay, I can see these are different classes. Like
00:03:51.680 | I know this is an airplane, a frog, an airplane, but it's actually a puzzle with an airplane
00:03:55.560 | on the cover, a bird, a horse, a car. That one you squint, you can tell it's a dare,
00:04:00.920 | but only if you really know what you're looking for. And so when we started to talk about
00:04:04.880 | generating these images, this is actually quite frustrating. Like this, if I generated
00:04:09.320 | this, I'd say this might be the model doing a really bad job. But it's actually that this
00:04:13.840 | is a boat, this is a dog. It's just that this is what the data looks like.
00:04:17.560 | And so I've actually got something that can help you out. I'll show later today, which
00:04:22.000 | is something like this. It's really actually hard to see whether it's good because the
00:04:26.360 | images are bad. It can be helpful to have a metric to generate that can see how good
00:04:31.840 | samples are. So I'll be showing a metric for that later today.
00:04:35.600 | Yeah. And that'll be great. And I hope to have like automated. But anyway, I just wanted
00:04:39.960 | to flag like for visually inspecting these, it's not great. And so we don't really like
00:04:43.760 | CIFAR 10 because it's hard to tell, but still a good one to test with. So the Noisify and
00:04:50.520 | everything I'm following what Jeremy is going to be showing exactly, the code works without
00:04:54.560 | any changes because we're adding random noise in the same shape as our data. So even though
00:04:58.920 | our data now has three channels, the Noisify function still works fine. If we try and visualise
00:05:04.120 | the noise based images because we're adding noise in the red pink blue channels, and some
00:05:08.840 | of that's quite extreme values. Yeah, it looks slightly different, looks all crazy RGB. But
00:05:14.920 | you can see, for example, this frog doesn't have as much noise and it's vaguely visible.
00:05:19.840 | But it is, it's a many impossible task to look at this and tell what image is hiding
00:05:23.840 | out all of that noise.
00:05:24.840 | So I think this is really neat that you could use the same Noisify. Yeah, and it still,
00:05:35.400 | it still works thanks to, it's not just that shape thing, but I guess just thanks to kind
00:05:39.840 | of PyTorch's broadcasting kind of stuff. This often happens, you can kind of change the
00:05:45.960 | dimensions of things that just keeps working.
00:05:48.240 | Exactly. And we've been paying attention to those broadcasting rules and the right dimensions
00:05:52.560 | and so on. Cool. So I'm going to use the same sort of approach to learning unit, except
00:05:59.560 | that now obviously I need to specify three input channels and three output channels because
00:06:03.520 | we're working with three channel images. But I did want to explore for this demo, like,
00:06:08.440 | okay, how could I maybe justify wanting to do this kind of experiment tracking thing
00:06:13.240 | that I'll talk about. And so I'm bumping up the size of the model substantially. I've
00:06:17.160 | gone from, this is the default settings that we were using for FashionM Nest, but the Diffuser's
00:06:21.160 | default unit has what, many, 20 times as many parameters, 274 million versus 15 million.
00:06:30.260 | So we're going to try a larger model. We're going to try some longer training. And so
00:06:34.480 | I could just do the same training that we've always done just in the notebook, set up a
00:06:41.800 | learner with progress CV to kind of plot the loss, subtract some metrics. But yeah, I don't
00:06:49.640 | know about you, but once it's beyond a few minutes training, I quickly get a patient
00:06:54.320 | and I have to wait for it to finish before we can sample. So I'm doing the DDPM sample,
00:06:58.000 | but I have to, I actually interrupted the training to say, I just want to get a look
00:07:01.400 | at what it looks like initially and to plot some samples. And again, the sampling function
00:07:06.160 | works without any modification, but I'm passing in my size to be a three channel image. Yeah.
00:07:12.560 | And so this is like, we could do it like this, but at some point I would like to A, keep
00:07:17.920 | track of what experiments I've tried and B, be able to see things as it's going over time,
00:07:24.040 | including like, I'd love to see what the samples look like if you generated after the first
00:07:27.480 | epoch, after the second epoch. And so that's where my little callback that I've been playing
00:07:33.560 | with comes in.
00:07:34.560 | So just before you do that, I'll just mention like, I mean, there are simple ways you could
00:07:39.360 | do that, right? Like, you know, one popular way a lot of people do is that they'll save
00:07:43.800 | some sample images as files every epoch or two, or we could like, the same way that we
00:07:50.080 | have an updating plot, as we train with fast progress, we could have an updating set of
00:08:00.120 | sample images. So there's a few ways we could solve that. That wouldn't handle the tracking
00:08:09.440 | that you mentioned of like looking over time at how different changes have improved things
00:08:13.160 | or made them worse, whatever that would, I guess, would require you kind of like saving
00:08:17.040 | multiple versions of a notebook or keeping some kind of research journal or something.
00:08:21.840 | That'd be a bit fiddly.
00:08:24.600 | It is. And all of that's doable, but I also find like, I'm a little bit lazy sometimes.
00:08:28.640 | Maybe I don't write down what I'm trying or yeah, I've saved untitled member 37 notebooks.
00:08:34.760 | So yeah, the idea that I wanted to show here is just that there are lots of other solutions
00:08:39.600 | for this kind of experiment tracking and logging. And one that I really like is called Weights
00:08:44.080 | and Biases. And so I'll explain what's going on in the code here that I'm running a training
00:08:50.160 | with this additional Weights and Biases callback. And what it's doing is it's allowing me to
00:08:54.320 | log whatever I'd like. So I can log samples at a different--
00:09:00.120 | Okay, so you're switching to a website here called wnb.ai. So that's where your callback
00:09:06.760 | is sending information to. Yeah, so Weights and Biases accounts are free
00:09:11.720 | for personal and academic use. And it's very, very, like, I don't think anyone hates Weights
00:09:15.840 | and Biases. But it's a very nice service. You sign in and you log in on your computer
00:09:20.280 | or you get an authentication token. And then you're able to log these experiments and you
00:09:25.440 | can log into different projects. And what it gives you is for each experiment, anything
00:09:31.160 | that you call Weights and Biases.log at any step in the training, that's getting logged
00:09:36.720 | and sent to their server and stored somewhere where you can later access it and display
00:09:40.920 | it. They have these plots that you can visualize easily. And you can also share them very easily
00:09:46.940 | in these reports that integrate this data sort of interactive thing. And why that's
00:09:53.640 | nice is that later you can go and look at-- So this is now the project that I'm logging
00:09:57.720 | into. You can log multiple runs with different settings. And for each of those, you have
00:10:04.160 | all of these things that you've tracked, like your training, loss, and validation. But you
00:10:09.240 | can also track your learning rate if you're doing a learning rate schedule. And you can
00:10:14.120 | save your model as an artifact and it'll get saved on their server so you can see exactly
00:10:19.080 | what run reproduced, what model. It logs the code. If you set that to-- You can save code
00:10:25.600 | equals true. And then it creates a copy of your whole Python environment, what libraries
00:10:29.260 | were installed, what code you ran. So being able to come back later and say, oh, these
00:10:35.280 | images here, these look really good. I can go back and see, oh, that was this experiment
00:10:40.600 | here. I can check what settings I used. In the initialization, you can log whatever configuration
00:10:49.640 | details you'd like in any comments. And yeah, there's other frameworks for this.
00:10:56.440 | Yeah, in some ways, it's kind of-- initially, when I first saw Weights and Biases, it felt
00:11:02.080 | a bit weird to me actually sending your information off to an external website because, I mean,
00:11:08.920 | before Weights and Biases existed, the most popular way to do this was something called
00:11:12.880 | TensorBoard, which Google provides, which is actually a lot like this, but it's a little
00:11:17.940 | server that runs on your computer. And so like when you log things, it just puts it
00:11:24.880 | into this little database on your computer, which is totally fine. But I guess actually,
00:11:33.880 | there are some benefits to having somebody else run this service instead of running your
00:11:40.200 | own little TensorBoard or whatever server. One is that you can have multiple people working
00:11:47.500 | on a project collaborating. So I've done that before, where we will each be sending different
00:11:52.520 | sets of hyperparameters, and then they'll end up in the same place. Or if you want to
00:11:58.200 | be really antisocial, you can interrupt your romantic dinner and look at your phone to
00:12:03.440 | see how your training's going. So yeah, I'm not going to say it's always the best approach
00:12:11.260 | to doing things, but I think there's definitely benefits to using this kind of service. And
00:12:16.080 | it looks like you're showing us that you can also create reports for sharing this, which
00:12:20.200 | is also pretty nifty.
00:12:21.600 | Yeah, yeah. So I like for working with other people or you want to show somebody the final
00:12:32.040 | results and being able to, yeah, pull together the results from some different runs or just
00:12:40.160 | say, oh, by the way, here's a set of examples from my two most recent. And things track
00:12:47.720 | to different steps. What do you think of this? And being able to have this place where everyone
00:12:53.840 | can go and they can inspect the different loss curves. For any run, they can say, oh,
00:12:58.440 | what was the batch size for this? Let me go look at the info there. OK, I didn't log it,
00:13:05.000 | but I logged time in the epochs and the learning rate. So yeah, I find it quite nice, especially
00:13:10.160 | in a team or if you're doing lots and lots of experiments to be able to have this permanent
00:13:15.360 | record that somebody else deals with and may host the storage and the tracking. Yeah, it's
00:13:21.720 | quite nice.
00:13:22.720 | Wait, and this is all the code you had to write? That's amazing.
00:13:26.640 | Yeah. So this is using the callback system. The way Weights and Bices works is that you
00:13:32.280 | start an experiment with this 1db.init, and you can specify any configurational settings
00:13:38.840 | that you used there. And then anything you need to log is 1db.log, and you pass in whatever
00:13:45.000 | the name of your value is, again, logging the loss and then the value. And once you've
00:13:50.360 | done 1db.finish, and that syncs everything up and sends it to the server.
00:13:53.760 | Oh, this is wild, the way you've inherited from Metric CV, and you replaced that underscore
00:13:58.680 | log that we previously we used to allow fast progress to do the logging, and you've replaced
00:14:03.400 | it to allow Weights and Bices to the logging. So yeah, it's really sweet.
00:14:08.520 | Yeah, yeah. So this is using the callback system. I wanted to do the things that Metric
00:14:14.320 | CV normally does, which is tracking different metrics that you pass in. So this will still
00:14:18.480 | do that. And I just offload to the super like the original Metric CV method for things like
00:14:23.960 | the after batch. But in addition to that, I'd also like to log the Weights and Bices.
00:14:29.440 | And so before I fit, I initialize the experiments, every batch, I'm going to log the loss after
00:14:36.280 | every epoch, and the default metrics callback is going to accumulate the metrics and so
00:14:42.680 | on. And then it's going to call this underscore log function. So I chose to modify that to
00:14:46.680 | say, I'm going to log my training loss, it's training, I'm going to log my validation loss,
00:14:51.760 | if I'm doing validation, and I'd like to log some samples. And Weights and Bices is quite
00:14:56.320 | flexible in terms of what you can log, you can create images or videos or whatever. But
00:15:02.480 | it also takes a method with figure. And so I'm generating samples and plotting them with
00:15:09.440 | show image and splitting back that map of the figure, which I can then log and that
00:15:14.040 | becomes these pretty pictures. And that you can see over time, like every, every time
00:15:19.280 | that log function runs, which is after every epoch, you can go in and see what the images
00:15:23.760 | look like. So maybe we can make your code even simpler in the future. If we had show
00:15:29.480 | images, maybe it could have like a optional return fake parameter that returns the figure,
00:15:35.520 | and then we could replace those four lines of code with one, I suspect.
00:15:39.960 | Yeah. Yeah. And I mean, this, I just sort of threw this together. It's quite early still.
00:15:45.640 | You could also what I've done in the past is usually just create a PIL image where you
00:15:50.400 | can, you know, make a grid or overlay a text or whatever else you'd like, and then just
00:15:55.960 | log that as 1db.image. And otherwise, like apart from that, I'm just passing in this
00:16:01.440 | callback as an extra callback to my set of callbacks for the learner instead of a metric
00:16:07.920 | callback. And so when I call that fit, I still get my little progress bar, I still get this
00:16:13.400 | printed out version because my log function still also prints those metrics just for debugging.
00:16:20.320 | But instead of having to like watch the progress in the notebook, I can set this running disconnect
00:16:24.480 | from the server, go have dinner, and then I can check on my phone or whatever. What
00:16:28.520 | do the samples look like? And okay, cool. They're starting to look like less than random
00:16:34.560 | nonsense, but still not necessarily recognizable. Maybe we need to train for longer. That can
00:16:39.480 | be the next experiment. What I should probably do next is think of some extra metrics, but
00:16:44.200 | Jeremy's going to talk about that. So for now, that's pretty much all I had to show
00:16:47.840 | is just to say, yeah, it's worth as you move to these longer, you know, 10 minutes, one
00:16:52.400 | hour, 10 hours, these experiments, it's worth setting up a bit of infrastructure for yourself
00:16:56.920 | so that you know what were the settings I used. And maybe you're saving the model so
00:17:01.200 | you have the artifact as a result. And yeah, I like this Wix devices approach, but there's
00:17:06.080 | lots of others. The main thing is that you're doing something to track these experiments
00:17:10.280 | beyond just, you know, creating for any different versions of your notebook.
00:17:14.600 | I love it. One thing I was going to note that, I don't know if many people know, but like
00:17:19.040 | Wix and Biasis can also save the exact code that you used to run for that run. So like
00:17:25.320 | if you make any changes to your code, and then you know that you don't know which version
00:17:29.200 | of your code you use for this particular experiment, so then you can figure out exactly what code
00:17:33.680 | you use. So it's all completely reproducible. And so I love, you know, weights and biases,
00:17:38.440 | all these different features it has. And I use weights and biases all the time for my
00:17:41.920 | own research, like almost daily, like I had to, you know, put it on just last night, chuck
00:17:46.400 | on it today morning. So it's like, I use it all the time for my own research. And it's
00:17:51.040 | and yeah, like I use it, especially to just know like, oh, this run had this particular
00:17:54.840 | config. And then like, yeah, the models go straight into weights and biases. And then
00:17:59.120 | if I want to run a model on the test set, I literally actually take it off of weights
00:18:03.320 | and biases like downloaded for weights and biases and run it on the test set. So I use
00:18:07.320 | it all the time. And also just having the ability to have everything reproducible and
00:18:10.960 | know exactly what you were doing is very convenient, instead of having to like manually track it
00:18:15.640 | in some sort of like, I guess, a big Excel sheet or some sort of journal or something
00:18:20.080 | like that. Sometimes this is, you know, this is a lot more convenient, I feel so yeah,
00:18:25.480 | lest we get into too much billing for weights and biases, I'm going to put a slightly alternative
00:18:31.760 | point of view, which is I don't use it or any experiment tracking framework myself,
00:18:40.760 | which is not to say maybe I could get some benefits by doing so, but I fairly intentionally
00:18:46.400 | don't because I don't want to make it easy for myself to try 1000 different hyper parameters
00:18:52.960 | or do kind of like, you know, directed, you know, sampling of things I like to be very,
00:19:02.520 | like directed, you know. And so that's, that's kind of the workflow I'm looking for is one
00:19:10.240 | that allows that to happen, right? Constantly going back and refactoring and thinking what
00:19:14.360 | did I learn and how do I change things from here and never kind of doing like 17 learning
00:19:20.320 | rates and six architectures and whatever. Now, obviously, that's not something that
00:19:26.120 | John O is doing at the moment. I don't be so easy for him to get on if you want to.
00:19:32.440 | I can normally a script that just does a hundred runs with different models and different tasks
00:19:37.520 | and then I can look at my weaknesses and say filter by the best loss, which is very tempting.
00:19:42.400 | So I would say to people like, yeah, definitely be aware that these tools exist. And I definitely
00:19:47.600 | agree that as we do this, which is early 2023, weights and biases is by far the best one
00:19:54.360 | I've seen. It has by far the best integration with fast AI. And as of today, if shadow is
00:20:00.400 | pushed yet, it has by far the best integration with mini AI. I think also fast AI is the
00:20:08.160 | best library for using with weights and biases. It works in both ways. So yeah, no, it's there.
00:20:16.920 | Consider using it, but also consider not going crazy on on experiments because, you know,
00:20:27.040 | I think experiments have their place clearly, but also definitely thought out hypotheses,
00:20:35.440 | testing them, changing your code is overall the approach that I think is best.
00:20:40.880 | Well, thank you, John. I think that's awesome. I got some fun stuff to share as well, or
00:20:53.840 | at least I think it's fun. And what I wanted to share is like, well, the first of all,
00:21:00.800 | I should say we had said, we all had said that we were going to look at units this week.
00:21:08.840 | We are not going to look at units this week, but we have good reason, which is that we
00:21:15.280 | had said we're going to go from foundations to stable diffusion. That was also a lie because
00:21:20.920 | we're actually going beyond stable diffusion. And so we're actually going to start showing
00:21:26.240 | today some new research directions. I'm going to describe the process that I'm using at
00:21:31.800 | the moment to investigate some new research directions. And we're also going to be looking
00:21:35.680 | at some other people's research directions that have gone beyond stable diffusion over
00:21:41.560 | the past few months. So we will get to units, but we haven't quite finished, you know, as
00:21:52.600 | it turns out, the training and sampling yet. Now, one challenge that I was having as I
00:22:02.440 | started experimenting with new things was started getting to the point where actually
00:22:09.000 | the generated images looked pretty good and it felt like, you know, almost like being
00:22:20.760 | a parent, you know, each time a new set of images would come out, I would want to convince
00:22:24.880 | myself that these were the most beautiful. And so, yeah, when they're crap, it's obvious
00:22:33.200 | they're crap, you know, but when they're starting to look pretty good, it's very easy to convince
00:22:36.080 | yourself you're improving. So I wanted to have a metric which could tell me how good
00:22:43.120 | they were. Now, unfortunately, there is no such metric. There's no metric that actually
00:22:50.040 | says do these images, would these images look to a human being like pictures of clothes?
00:22:58.080 | Because only talking to a person can do that. But there are some metrics which give you
00:23:02.800 | an approximation of that. And as it turns out, these metrics are not actually a replacement
00:23:13.960 | for human beings looking at things, but they're a useful addition. So, and I certainly found
00:23:22.840 | them useful. So I'm going to show you the two most common, well, there's really the
00:23:26.280 | one most common metric, which is called FID, and I'm going to show another one called KID
00:23:32.120 | or KID. So let me describe and show how they work. And I'm going to demonstrate them using
00:23:46.080 | the model we trained in the last lesson, which was in DDPM2. And you might remember, we trained
00:23:55.040 | one with mixed precision, and we saved it as fashion DDPM MP for mixed precision. Okay,
00:24:04.760 | so this is all the usual imports and stuff. This is all the usual stuff. But there's a
00:24:12.360 | slight difference this time, which is that we're going to try to get the FID for a model
00:24:18.000 | we've already trained. So basically, to get the model we've already trained to get its
00:24:24.400 | FID, we can just torch.load it, right, and then .cuda to pop it on the GPU. So I'm going
00:24:31.200 | to call that the S model, which is the model for samples, the samples model. And this is
00:24:35.320 | just a copied and pasted DBPM from the last time. So that's for sampling. So we're going
00:24:40.480 | to do sampling from that model. And so once we've sampled from the model, we're then going
00:24:48.120 | to try and calculate this score called the FID. Now, what the FID is going to do is it's
00:24:56.720 | not going to say how good are these images. It's going to say how similar are they to
00:25:05.360 | real images. And so the way we're going to do that is we're going to actually look specifically
00:25:13.740 | at four of the images that we generated in these samples. We're going to look at some
00:25:21.560 | statistics of some of the activations. So what we're going to do, we've generated these samples,
00:25:31.440 | and we're going to create a new data leader, which contains no training batches, and it
00:25:39.520 | contains one validation batch, which contains the samples. It doesn't actually matter what
00:25:44.760 | the dependent variable is, so I just put in the same dependent variable that we already
00:25:48.720 | had. And then what we're going to do is we're going to use that to extract some features
00:25:58.520 | from a model. Now, what do we mean by that? So if you remember back to notebook 14, we
00:26:06.160 | created this thing called summary. And summary shows us at different blocks of our model,
00:26:15.080 | there are various different output shapes. In this case, it's a batch size of 102.4.
00:26:20.120 | And so after the first block, we had 16 channels, 28 by 28, and then we had 32 channels, 48
00:26:27.400 | by 14 and so forth until just before the final linear layer, we had the 1024 batches, and
00:26:38.640 | we had 512 channels with no height and width. Now, the idea of fit and kid is that the distribution
00:26:49.000 | of these 512 channels for a real image has a particular kind of like signature, right?
00:26:58.480 | It looks a particular way. And so what we're going to do is we're going to take our samples,
00:27:04.200 | we're going to run it through a model that's learned to predict, you know, fashion classes,
00:27:12.600 | and we're going to grab this layer, right? And then we're going to average it across
00:27:18.760 | a batch, right, to get 512 numbers. And that's going to represent the mean of each of those
00:27:26.600 | channels. So those channels might represent, for example, does it have a pointed color?
00:27:35.480 | Does it have, you know, smooth fabric? Does it have sharp heels and so forth, right? And
00:27:46.240 | you could recognize that something's probably not a normal fashion image if it says, "Oh,
00:27:51.640 | yes, it's got sharp heels and flowing fabric." It's like, "Oh, that doesn't sound like anything
00:27:58.640 | we recognize," right? So there are certain kind of like sets of means of these activations
00:28:05.280 | that don't make sense. So this is a metric for... it's not a metric for an individual
00:28:13.840 | image necessarily, but it's across a whole lot of images. So if I generate a bunch of
00:28:19.520 | fashion images, and I want to say, does this look like a bunch of fashion images? If I
00:28:23.680 | look at the mean, like maybe X percent have this feature and X percent have that feature.
00:28:28.440 | So if I'm looking at those means, as like comparing the distribution within all these
00:28:31.760 | images I generated, do roughly the same amount have sharp colors as those in the trend?
00:28:37.160 | Yeah, that's a very good point, too.
00:28:39.480 | Yeah, and it's actually gonna get even more sophisticated than that. But let's just start
00:28:44.480 | at that level, which is this features.bin. So the basic idea here is that we're going
00:28:51.040 | to take our samples and we're going to pass them through a pre-trained model that has
00:28:57.080 | learned to predict what type of fashion something is. And of course, we train some of those
00:29:04.720 | in this notebook. And specifically, we trained a nice 20 epoch one in the data augmentation
00:29:13.160 | section, which had a 94.3% accuracy. And so if we pass our samples through this model,
00:29:22.960 | we would expect to get some, you know, useful features. One thing that I found made this
00:29:29.680 | a bit complicated, though, is that this model was trained using data that had gone through
00:29:36.200 | this transformation of subtracting the mean and dividing by the standard deviation. And
00:29:44.200 | that's not what we're creating in our samples. And so, generally speaking, samples in most
00:29:55.000 | of these kinds of diffusion models tend to be between negative one and one. So I actually
00:30:01.560 | added a new section to the very bottom of this notebook, which simply replaces the transform
00:30:08.520 | with something that goes from negative one to one and just creates those data loaders
00:30:14.400 | and then trains something that can classify fashion. And I save this as not data aug, but
00:30:22.560 | data aug two. So this is just exactly the same as before, but it's a fashion classifier
00:30:28.920 | where the inputs are expected to be between minus one and one. Having said that, it turns
00:30:35.720 | out that our samples are not between minus one and one. But actually, if you go back
00:30:45.360 | and you look at DDPM2, we just use TF dot to tensor, and that actually makes images
00:30:54.520 | that are between zero and one. So actually, that's a bug. Okay, so our images have a bug,
00:31:04.640 | which is they go between zero and one. So we'll look at fixing that in a moment. But
00:31:07.840 | for now, we're just trying to get the fit of our existing model. So let's do that.
00:31:13.320 | So what we need to do is we need to take the output of our model, and we need to multiply
00:31:21.360 | by two, so that'll be between zero and two, and subtract one. So that'll change our samples
00:31:26.520 | to be between minus one and one, and we can now pass them through our pre-trained fashion
00:31:32.480 | classifier. Okay, so now, how do we get the output of that pooling layer? Because that's
00:31:40.480 | actually what we want to remind you. We want the output of this layer. So just to kind
00:31:55.200 | of flex our PyTorch muscles, I'm going to show a couple of ways to do it. So we're going
00:32:02.440 | to load the model I just trained, the data or to model. And what we could do is, of course,
00:32:08.740 | we could use a hook. And we have a hooks callback. So we could just create a function, which
00:32:17.080 | just depends the output. So very straightforward. Okay, so that's what we want. We want the
00:32:22.960 | output. And specifically, it's, so we've got these are all sequentials. So we can just
00:32:32.320 | go through and go, oh, one, two, three, four, five, the layer that we want. Okay, and so
00:32:37.760 | that's the module that we want to hook. So once we've hooked that, we can pass that as
00:32:43.480 | a callback. And we can then, it's a bit weird calling fit, I suppose, because we're saying
00:32:49.720 | train equals false, but we're just basically capturing. This is just to put make one batch
00:32:54.640 | go through and grab the outputs. So this means now in our hook, there's never gonna be thinking
00:33:02.400 | called out P, because we put it there. And we can grab, for example, a few of those to
00:33:09.000 | have a look. And yep, here, we've got a 64 by 512 set of features. Okay, so that's one
00:33:15.600 | way we can do it. Another way we could do it is that actually sequential models are what's
00:33:24.640 | called in Python collections, they have certain certain API that they're expected to support.
00:33:31.960 | And out of something a collection can do like a list is you can call Dell to delete something.
00:33:39.120 | So we can delete this layer and this layer and be left with just these layers. And once
00:33:48.880 | we do that, that means we can just call capture Preds, because now they don't have the last
00:33:53.260 | two layers. So we can just delete layers eight and seven call capture Preds. And one nice
00:34:00.440 | thing about this is it's going to give us the entire 10,000 images in the test set.
00:34:08.560 | So that's what I ended up deciding to do. There's lots of other ways I played around with which
00:34:12.320 | worked, but I decided to show these two as being two good, pretty good techniques. Okay,
00:34:17.960 | so now we've got what do 1000 real images look like at the end of the pooling layer.
00:34:25.240 | So now we need to do the same for our sample. So we'll load up our fashion DDPM MP, we'll
00:34:34.240 | sample, let's just grab 256 images for now, make them go between minus one and one, make
00:34:41.840 | sure they look okay. And as I described before, create a data loaders where the validation
00:34:47.440 | set just has one batch, which contains our samples and call capture Preds. Okay, so that's
00:35:00.000 | going to give us our features. And the reason why is because we're passing the sample to
00:35:09.920 | model and model is the classifier, which we've deleted the last two layers from. So that's
00:35:19.720 | going to give us our 256 by 512. So now we can get the means now. That's not really enough
00:35:34.080 | to tell us whether something is looks like real images. So maybe I should draw here.
00:35:44.340 | So we started out with our batch of 256, 256, and our channels of 512. And we squished them
00:36:11.680 | by taking their mean. So it's now just 256 a vector. So this is the, so wrong way around.
00:36:27.880 | We squished them this way, 512, because this is the main for each channel. Okay. And we
00:36:42.320 | did exactly the same thing for the much bigger, you know, full set of real images. So this
00:36:49.400 | is our samples and this is our real. But when we squish it, that was 10,000 by 512, we get
00:37:02.560 | again 512. So we could now compare these two, right? But, you know, you could absolutely
00:37:14.280 | have some samples that don't look anything like images, but have similar averages for
00:37:21.440 | each channel. So we do a second thing, which is we create a covariance matrix. Now, if
00:37:30.320 | you've forgotten what this is, you should go back to our previous lesson where we looked
00:37:33.600 | at it, but just remind you a covariance matrix says, in this case, we do it across the channels.
00:37:41.200 | So it's going to be 512 by 512. So it's going to take each of these columns, and it says
00:37:51.840 | in each cell, so here's cell one, one, basically it says, what's the difference between it,
00:38:01.680 | basically it's saying, what's the difference between each row, each element here and the
00:38:06.120 | mean of the whole column, multiplied by exactly the same thing for a different column. Now,
00:38:13.880 | on the diagonal, it's the same column twice. So that means that these in the diagonal is
00:38:19.400 | just the variance, right? But more interestingly, the ones in the off diagonal, like here, is
00:38:27.520 | actually saying, what's the relationship between column one and column two, right? So if column
00:38:34.660 | one and column two are uncorrelated, then this would be zero, right? If they were identical,
00:38:41.760 | right, then it would be the same as the variance in here. So it's how correlated are they.
00:38:49.320 | And why is this interesting? Well, if we do the same, exactly the same thing for the reals,
00:38:55.320 | that's going to give us another 512 by 512. And it's going to say things like, so let's
00:39:01.200 | say this first column was kind of like that, you know, doesn't have pointy heels, right?
00:39:08.920 | And sorry, heels, spell. And the second one might be, doesn't have flowing fabric, right?
00:39:17.720 | And this is where we say, okay, if, you know, generally speaking, you would expect these
00:39:23.300 | to be negatively correlated, right? So over here in the reals, this is probably going
00:39:30.640 | to have a negative, right? Whereas if over here it was like zero or even worse if it's
00:39:36.760 | positive, it'd be like, oh, those are probably not real, right? Because it's very unlikely
00:39:41.560 | you're going to have images that have both, where pointy heels are positively associated
00:39:46.280 | with a flowing fabric. So we're basically looking for two data sets where their covariance
00:39:56.560 | matrices are kind of the same and their means are also kind of the same. All right. So there
00:40:08.760 | are ways of comparing these, you know, basically comparing two sets of data to say, are they,
00:40:24.000 | you know, from the same distribution? And you can broadly think of it as being like, oh,
00:40:28.440 | do they have pretty similar covariance matrices? So they have pretty similar mean vectors.
00:40:35.400 | And so this is basically what the fresh A inception distance does. Does that make sense
00:40:43.320 | so far, guys?
00:40:45.600 | Yes. It's, when he's striking me now, I was from the similarity as to when we were talking
00:40:53.280 | about like the style loss and those kinds of things. How do we get the types of features
00:40:58.600 | that occur together without worrying about like which I data is in the data, the grams
00:41:04.080 | matrices or whatever. Yeah. Now the particular way of comparing. So, okay. So I've got the
00:41:15.880 | means and I've got the covariances for my samples. And I've actually just created this
00:41:24.080 | little calc stats, right? So I always, I'm showing you how I build things, not just things
00:41:28.760 | that are built, right? So I always create things step by step and check their shapes,
00:41:32.680 | right? And then I paste them into our merge the cells, copy the cells and merge them into
00:41:38.440 | functions. So here's something that gets the means and the covariance matrix. So then I
00:41:47.000 | basically do recall that both for my sample features and for my features of the actual
00:41:53.920 | data set or the test set and the data set. Now, what I now do with that, if they have
00:42:01.600 | those, now they have those features, I can calculate this thing called the fresh A inception
00:42:05.080 | distance, which is here. And basically what happens is we multiply together the two covariance
00:42:14.160 | matrices and that's now going to make them like bigger, right? So we now need to basically
00:42:22.640 | scale that down again. Now, if we were working with, you know, non-matrices, you know, if
00:42:31.080 | you kind of like multiply two things together, then to kind of bring it back down to the
00:42:37.240 | original scale, you know, you could kind of like take the square root, right? So particularly
00:42:41.800 | if it was by itself, you took the square root, you get back to the original. And so we need
00:42:45.960 | to do exactly the same thing to renormalize these matrices. The problem is that we've
00:42:52.320 | got matrices and we need to take the matrix square root. Now the matrix square root, you
00:43:02.040 | might not have come across this before, but it exists and it's the thing where the matrix
00:43:09.800 | square root of the matrix A times itself is A. Now, I'm going to slightly cheat because
00:43:19.600 | we've used the float square root before and we did not re-implement it from scratch because
00:43:25.120 | it's in the Python standard library and also it wouldn't be particularly interesting. But
00:43:29.320 | basically the way you can calculate the float square root from scratch is by using, there's
00:43:36.360 | lots of ways, but you know, the classic way that you might have done it in high school
00:43:39.680 | is to use Newton's method, which is where you basically can solve if you're trying to
00:43:45.040 | calculate A equals root x, then you're basically saying A squared equals x, which means you're
00:43:57.360 | saying A squared minus x equals zero. And that's an equation that you can solve and you can
00:44:05.560 | solve it by basically taking the derivative and taking a step along the derivative a bunch
00:44:09.880 | of times. You can basically do the same thing to calculate the matrix square root. And so
00:44:21.880 | here it is, right? It's the Newton method, but because it's for matrices it's slightly
00:44:26.960 | more complicated, so it's a short method and I'm not going to go through it, but it's basically
00:44:32.240 | the same deal. You go through up to 100 iterations and you basically do something like traveling
00:44:40.320 | along that kind of derivative and then you say, okay, well, the result times itself ought
00:44:49.440 | to equal the original matrix. So let's subtract the matrix times itself from the original
00:44:54.960 | matrix and see whether the absolute value is small and if it is, we've calculated it.
00:45:00.920 | Okay. So that's basically how we do a matrix square root. So we do, that's that. And so
00:45:07.400 | now that we have strictly speaking, implemented from scratch, we're allowed to use the one
00:45:10.520 | that already exists. iTorch doesn't have one, sadly, so we have to use the one from SciPy,
00:45:17.520 | SciPy.minelk. So this is basically going to give us a measure of similarity between the
00:45:28.280 | two covariance matrices. And then we, here's the measure of similarity between the two
00:45:36.840 | mean matrices, which is just the sum of squared errors. And then basically for reasons that
00:45:42.300 | aren't really interesting, but it's just normalizing, we subtract what's called the trace, which
00:45:46.640 | is the sum of the diagonal elements, and we subtract two times the trace of this thing.
00:45:53.920 | And that's called the Frechet Inception Distance. So a bit hand wavy on the math, because I
00:45:59.400 | don't think it's particularly relevant to anything, but it gives you a number which
00:46:04.720 | represents how similar is, you know, this for the samples to this for some real data.
00:46:17.640 | Now it's weird, it's called Frechet Inception Distance when we've done nothing to do with
00:46:22.400 | Inception. Well, the reason why is that people do not normally use the fast.ai part two custom
00:46:30.960 | fashioned MNIST data org 2.pickle. They normally use a more famous model. They normally use
00:46:38.200 | the Inception model, which was an image net winning model from Google Brain from a few
00:46:44.080 | years ago. There's no reason whatsoever that Inception is a good model to use for this,
00:46:51.120 | it just happens to be the one which the original paper used. And as a result, everybody now
00:46:56.880 | uses that not because they sheep, but because you want to be able to compare your results
00:47:02.440 | with other papers results, perhaps. We actually don't. We actually want to compare our results
00:47:08.760 | from our other results, and we're going to get a much more accurate metric if we use
00:47:16.200 | a model that's good specifically at recognizing fashion. So that's why we're using this. So
00:47:23.840 | very, very few people bother to use this. Most people just hip install Python fit or
00:47:30.000 | whatever it's called and use Inception, but it's actually better to use. Now, unless you're
00:47:35.000 | comparing to papers, it's better to use a model that you've trained on your data and
00:47:38.680 | you know is good at that. So I guess this is not a fit, it's a... Well, maybe fit now stands
00:47:46.960 | for fashion, no fashion, MNIST. I don't know what it stands for. I should have did something.
00:47:54.920 | I wanted to bring up two other caveats of FID, especially then like in papers is like,
00:48:02.920 | the other thing is that FID is dependent on the number of samples that you use. So as
00:48:08.800 | the number of samples they use for measuring FID, it's more accurate if you use more samples
00:48:15.580 | and it's less accurate if you use less samples. Well, that's actually biased. So if you use
00:48:20.760 | less samples, it's too high specifically. Yeah. So in papers, you'll see them report
00:48:28.880 | how many samples they used. And so even then comparing to other papers and comparing between
00:48:34.360 | different models and different things, you want to make sure that you're comparing with
00:48:37.440 | the same amount of samples. Otherwise, it might just be high because they just use less
00:48:41.400 | number of samples or something like this. So you want to make sure that's comparable.
00:48:45.640 | And then the other thing that is because I guess it's a kind of a side effect of using
00:48:50.680 | the Inception network in these papers is the fact that all of these are at a size 299 by
00:48:58.000 | 299, which is like the size that the Inception model was trained. So actually, when you're
00:49:03.760 | applying this Inception network for measuring this distance, you're going to be resizing
00:49:08.680 | your images to 299 by 299, which in different cases that may not make much sense. So like
00:49:15.520 | in our case, we're working with 32 by 32 or 28 by 28 images. These are very small images
00:49:21.840 | and if you resize it to 299, or in other cases, this is now kind of an issue with some of
00:49:27.960 | these latest models, you have these large 512 by 512 or 1024 by 1024 images. And then
00:49:35.720 | you're, you know, kind of shrinking these images to 299 by 299. And you're losing a
00:49:41.160 | lot of that detail and quality in those images. So actually, it's kind of become a problem
00:49:46.840 | with some of these latest papers, when you look at the FID scores and how they're comparing
00:49:50.720 | them. And then visually, when you see them, you can kind of notice, oh, yeah, these are
00:49:54.460 | much better images, but the FID score doesn't capture that as well, because you're actually
00:49:59.120 | using these much smaller images. So there are a bunch of different caveats. And so FID,
00:50:04.680 | you know, it's very good for like, yeah, it's nice and simple and automated for this sort
00:50:08.720 | of comparison, but you have to be aware of all these different caveats of this metric
00:50:12.720 | as well.
00:50:13.840 | So excellent segue, because we're going to look at exactly those two things right now.
00:50:20.320 | And in fact, there is a metric that compares the two distributions in a way that is not
00:50:27.720 | biased. So it's not necessarily higher or lower if you use more or less samples, and
00:50:33.760 | it's called the KID or KID, which is the kernel inception distance. It's actually significantly
00:50:40.800 | simpler to calculate than the fresh A inception distance. And basically, what you do is you
00:50:48.320 | create a bunch of groups, a bunch of partitions, and you go through each of those partitions
00:50:54.880 | and you grab a few of your X's at a time and a few of your Y's at a time. And then you
00:51:02.040 | calculate something called the MMD, which is here, which is basically that, again, the
00:51:11.560 | details don't really matter. We basically do a matrix product and we actually take the
00:51:20.480 | Q of it. This K is the kernel. And we basically do that for the first sample by its, compared
00:51:27.640 | to itself, the second compared to itself, and the first compared to the second. And
00:51:33.560 | we then normalize them in various ways and add the two with themselves together and subtract
00:51:41.520 | the, with the other one. And this one actually does not use the stats. It doesn't use the
00:51:50.320 | means and covariance metrics. It uses the features directly. And the actual final result
00:52:00.720 | is basically the mean of this calculated across different little batches. Yeah, again, the
00:52:08.320 | math doesn't really matter as to exactly why all these are exactly what they are, but it's
00:52:15.440 | going to give you, again, a measure of the similarity of these two distributions. At
00:52:21.560 | first I was confused as to why more people weren't using this because people don't tend
00:52:26.560 | to use this and it doesn't have this, a nasty bias problem. And now that I've been using
00:52:30.720 | it for a while, I know why, which is that it has a very high variance, which means when
00:52:35.440 | I call it multiple times with just like samples with different random seeds, I get very different
00:52:41.560 | values. And so I actually haven't found this used at all. So we left in the situation,
00:52:51.560 | which is, yeah, we don't actually have a good unbiased metric. And I think that's the truth
00:52:57.340 | of where we are, the best practices. And even if we did, all I would tell you is like how
00:53:04.580 | similar distributions are to each other. It doesn't actually tell you whether they look
00:53:07.700 | any good, really. So that's why pretty much all good papers, they have a section on human
00:53:14.140 | testing. But I've definitely found this fairly useful for me for like comparing fashion images,
00:53:22.040 | which particularly like humans are good at looking at like faces that are reasonably
00:53:26.000 | high resolution and be like, "Oh, that eye looks kind of weird," but we're not good at
00:53:29.880 | looking at 28 by 28 fashion images. So it's particularly helpful for stuff that our brains
00:53:35.560 | aren't good at. So I basically wrapped this up into a class, which I call image eval for
00:53:40.600 | evaluating images. And so what you're going to do is you're going to pass in a pre-trained
00:53:46.520 | model for a classifier and your data loaders, which is the thing that we're going to use
00:53:54.880 | to basically calculate the real images. So that's going to be the data loaders that were
00:54:04.440 | in this learn, so the real images. And so what it's going to do in this class, then again,
00:54:12.520 | this is just copying and pasting the previous lines of code and putting them into a class.
00:54:15.800 | This is going to be then something that we call capture preds on to get our features
00:54:21.600 | for the real images, and then we can also calculate the stats for the real images. And
00:54:27.680 | so then we can call fit by calling calc fit, which is the thing we already had, passing
00:54:34.780 | in the stats for the real images and calculate the stats for the features from our samples,
00:54:42.640 | where the features, the thing that we've seen before, we pass in our samples, any random
00:54:48.560 | Y value is fine, so I just have a single tensor there and call capture preds. So we can now
00:54:54.600 | create an image eval object passing in our classifier, passing in our data loaders with
00:55:03.240 | the real data, any other callbacks you want. And if we call fit, it takes about a quarter
00:55:09.200 | of a second and 33.9 is the fit for our samples. So something that I think, okay, then kid,
00:55:19.800 | kid's very going to be a very different scale. It's only 0.05, so kids are generally much
00:55:24.200 | smaller than fits. So I'm mainly going to be looking at fits. And so here's what happens
00:55:31.320 | if we call fit on sample zero and then sample 50 and then sample 100 and so forth, all the
00:55:38.760 | way up to 900. And then we also do samples 975, 990, and 999. And so you can see over
00:55:46.760 | time, our samples fits improved. So that's a good little test. There's something curious
00:55:53.640 | about the fact that they stopped improving about here. So that's interesting. I've not
00:55:59.120 | seen anybody plot this graph before. I don't know if Jono or Tanishk, if you guys have,
00:56:03.280 | I feel like it's something people should be looking at because it's really telling you
00:56:09.260 | it's your sampling, making consistent improvements.
00:56:12.920 | And to clarify, this is like the predicted de-noised sample at the different stages during
00:56:17.860 | sampling, right?
00:56:18.860 | Yes, exactly.
00:56:19.860 | If I was to stop something now and just go straight to the predicted X error, what would
00:56:23.960 | the fit be?
00:56:24.960 | So I just want to check our samples. Yeah, we preset, we add the X naught hat at each
00:56:30.920 | time. Yep. Yep, exactly. Same for kid. And I was hoping that they would look the same
00:56:38.580 | and they do. So that's encouraging that kid and fit are basically measuring the same thing.
00:56:43.480 | And then something else that I haven't seen people do, but I think it's a very good idea
00:56:47.040 | is to take the fit of an actual batch of data. Okay. And so that tells us how good we could
00:56:53.040 | get. Now that's a bit unfair because I think the different sizes, our data is 512, our
00:57:02.440 | sample is 256, but anyway, it's a pretty huge difference.
00:57:09.240 | And then, yeah, the second thing that Tanishk talked about, which I thought I'd actually
00:57:12.600 | show is what does it take to get a real fit to use the Inception Network? So I didn't
00:57:20.920 | particularly feel like re-ellementing the Inception Network. So I guess I'm cheating
00:57:24.240 | here. I'm just going to grab it from itorchfit. But there's absolutely no reason to study
00:57:29.800 | the Inception Network because it's totally obsolete at this point. And as Tanishk mentioned,
00:57:37.120 | it wants 299 by 299 images, which actually you can just call resize input to have that
00:57:43.160 | done for you. It also expects three-channel images. So what I did is I created a wrapper
00:57:52.520 | for an Inception v3 model that when you call forward, it takes your batch and replicates
00:58:05.000 | the channel three times. So that's basically creating a three-channel version of a black
00:58:12.160 | and white image just by replicating it three times. So with that wrapping, and again, this
00:58:18.480 | is good flexing of your PyTorch muscles. Try to make sure you can replicate this, that
00:58:26.400 | you can get an Inception model working on your batch and MNIST samples. And yeah, then
00:58:36.120 | from there, we can just pass that to our image eval instead. And so on our samples, that
00:58:43.760 | gives us 63.8. And on a real batch of data, it gives 27.9. And I find this a good sign
00:58:51.720 | that this is much less effective than our real-fashioned MNIST classifier because that's
00:58:57.920 | only a difference of a ratio of three or so. The fact that our FID for real data using
00:59:08.000 | a real classifier was 6.6, I think that's pretty encouraging. Yeah, so that is that.
00:59:18.080 | And we now have a FID. More specifically, we now have an image eval. Did you guys have
00:59:26.420 | any questions or comments about that before we keep going?
00:59:31.960 | Let's begin that pretty much every other FIDGC reported is going to be set up for CIFAR-10,
00:59:41.360 | tiny 32 by 32 pixels resized up to 299 and fed through Inception that was trained on
00:59:46.400 | imaging, not CIFAR-10. Yeah, it's bearing in mind that, once again, this is a slightly
00:59:51.880 | weird metric. And even things like the types of image, like the imagery sizing algorithms
00:59:57.140 | in PyTorch Intensiflow might be slightly different. Or if you saved your images as JPEGs and
01:00:02.640 | then reloaded them, your FID might be twice as bad.
01:00:05.880 | Yeah, it makes a big difference. Yeah, exactly.
01:00:09.880 | So just to reiterate, the takeaway from all of this that I get is that it's really useful.
01:00:16.080 | Everything's the same, like using the same backbone model, using the same approach, the
01:00:20.440 | same number of samples, then you can compare it to other samples. But yeah, for one set
01:00:27.000 | of experiments, a FIDGC might be good, because it's the way everything's set up. And for
01:00:31.200 | another, that might be terrible. So if you want to compare to a paper or whatever is
01:00:34.840 | very easy.
01:00:36.880 | So I'm going to maybe the approach is that like, if you're doing your own experiments,
01:00:42.640 | these sorts of metrics are good. But then if you're going to compare to other models,
01:00:46.040 | it's best to rely on human studies if you're comparing to other models. And that, yeah,
01:00:50.600 | I think that's kind of the sort of approach or mindset that we should be having when it
01:00:56.840 | comes to this.
01:00:57.840 | Yeah, or both, you know. But yeah, so we're going to see this is going to be very useful
01:01:02.680 | for us. And we're just going to be using the same, pretty much all the time, we're going
01:01:09.000 | to use the same number of samples, and we're going to use the same fashion emitter-specific
01:01:13.960 | classifier.
01:01:16.000 | So the first thing I wanted to do was fix our bug. And to remind you, the bug was that
01:01:21.760 | we had, we were feeding into our unit in DDPM v2 and the original DDPM images that go from
01:01:31.560 | zero to one. And yeah, that's that's wrong. That's like nobody does that. Everybody feeds
01:01:37.880 | in images that are from minus one to one. So that's very easy to fix. You just...
01:01:44.840 | It drew me just to ask, like, why is that a bug? Why is it a bug? I mean, it's like everybody
01:01:50.680 | knows it's a bug because that's what everybody does. Like, I've never seen anybody do anything
01:01:54.960 | else and it's very easy to fix. So I fixed it by adding this to DDPM v2 and I reran it
01:02:04.040 | and it didn't work. It made it worse. And this was the start of, you know, a few horrible
01:02:14.400 | days of pain because, like, when you, you know, fix a bug and it makes things worse,
01:02:22.240 | that generally suggests there's some other bug somewhere else that somehow is offset
01:02:26.600 | your first bug. And so I had to go, you know, I basically went back through every other
01:02:31.240 | notebook at every cell and I did find at least one bug elsewhere, which is that we hadn't
01:02:39.760 | been shuffling our training sets the whole time. So I fixed that, but it's got absolutely
01:02:45.160 | nothing to do with this. And I ended up going through everything from scratch three times,
01:02:49.760 | rerunning everything three times, checking every intermediate output three times. So
01:02:53.160 | days of, you know, depressing and annoying work and made no progress at all. At which
01:03:00.840 | point I then asked Johnno's question to myself more carefully and provided a less driven
01:03:09.240 | response to myself, which was, well, I don't know why everybody does this, actually. So
01:03:18.440 | I asked to Nishkan Johnno and I was like, oh, in Patreon, I was like, have you guys seen
01:03:22.720 | any math, papers, whatever that's based on this particular input range? And yeah, you
01:03:35.920 | guys are both like, no, I haven't. It's just, it's just what everybody does. So at that
01:03:43.920 | point, it raised the possibility that like, okay, maybe what everybody does is not the
01:03:50.840 | right thing to do. And is there any reason to believe it is the right thing to do? Given
01:03:57.840 | that it seemed like fixing the bug made it worse, maybe not. But then it's like, well,
01:04:05.680 | okay, we are pretty confident from everything we've learned and discussed that having centered
01:04:11.000 | data is better than uncentered data. So having data that go from zero to one clearly seems
01:04:16.560 | weird. So maybe the issue is not that we've changed the center, but that we've scaled
01:04:20.480 | it down so that rather than having a range of two, it's got a range of one. So at that
01:04:25.480 | point, you know, I did something very simple, which was I did this, I subtracted 0.5. So
01:04:38.600 | now rather than going from 0 to 1, it goes from minus 0.5 to 0.5. And so the theory here
01:04:45.200 | then was, okay, if our hypothesis is correct, which is that the negative one to one range
01:04:50.400 | has no foundational reason for being. And we've accidentally hit on something, which
01:04:56.400 | is that a range of one is better than a range of two. And this should be better still, because
01:05:01.000 | this is a range of one and it's centered properly. And so this is DDPMv3. And I ran that. And
01:05:08.360 | yes, it appeared to be better. And this is great because now I've got fit. I was able
01:05:14.480 | to run fit on DDPMv2 and on DDPMv3, and it was dramatically, dramatically, dramatically
01:05:20.280 | better. And in fact, I was running a lot of other experiments at the time, which we will
01:05:28.040 | talk about soon. And like all of my experiments are totally falling apart when I fix the bug.
01:05:33.600 | And once I did this, all the things that I thought weren't working suddenly started working.
01:05:42.960 | So this is often the case, I guess, is that bugs can highlight accidental discoveries.
01:05:52.800 | And the trick is always to be careful enough to recognize when that's happened. Some people
01:06:01.480 | might remember the story. This is how the noble gases were discovered. A chemistry experiment
01:06:07.700 | went wrong and left behind some strange bubbles at the bottom of the test tube. And most people
01:06:13.800 | would just be like, huh, whoops, bubbles. But people who are careful enough actually
01:06:19.080 | went, no, there shouldn't be bubbles there. Let's test them carefully. It's like they
01:06:22.400 | don't react. Again, most people would be like, oh, that didn't work. The reaction failed.
01:06:28.160 | But if you're really careful, you'll be like, oh, maybe the fact they don't react is the
01:06:33.160 | interesting thing. So yes, being careful is not fair for the journey.
01:06:39.040 | When you say things like it didn't work or it was worse, when you first showed us this
01:06:42.600 | thing, I kind of said, the images looked fine. The fit was slightly worse. But it was okay.
01:06:49.520 | And if you trained it longer, it eventually got better mostly. There were some things
01:06:53.040 | that sampling occasionally went wrong. One image in a hundred or something like that.
01:06:58.120 | But it was like, this isn't like everything completely fell apart. He's just the truth.
01:07:02.800 | Women's was slightly worse than expected. And if you were doing the run and gun, try
01:07:08.960 | a bunch of things, it's like, oh, well, I just doubled my training time and set a few
01:07:12.880 | runs going and looked at the weights and biases stats later. And oh, that seems like it's
01:07:15.720 | better now. We just needed to train for longer. And we have internet GQs and lots of money.
01:07:21.600 | You would notice this. So yeah, it wasn't like, yeah, the fact that you picked up on
01:07:27.640 | it showed that you had this deep intuition for where it should be at this stage in training
01:07:32.040 | versus where it was, what the samples should look like. And you had the fit as well to
01:07:36.120 | say like, okay, I would have expected a fit of nine and I'm getting 14. What's up here?
01:07:42.080 | And that was enough to start asking these questions and we all jumped on. We all started
01:07:45.760 | to think where this came from.
01:07:48.080 | Yeah, I mean, definitely. I drive people crazy that I work with. I don't know why you guys
01:07:53.720 | aren't crazy yet, but with this kind of like, no, I need to know exactly why this is not
01:08:00.920 | exactly what we expected. But yeah, this is why. To find that when something's mysterious
01:08:08.080 | and weird, it means that there's something you didn't understand and that's an opportunity
01:08:11.400 | to learn something new. So that's what we did. And so that was quite exciting because
01:08:20.720 | yeah, going -0.5 to 0.5 made the fit better still. And I was definitely in, yeah, I moved
01:08:29.040 | from this frame of mind from like total depression. I was so mad. I still remember when I spoke
01:08:36.360 | to Giotto, I was just so upset. And then I suddenly like, oh my gosh, we're actually
01:08:41.920 | onto something. So I started experimenting more and a bit more confidence at this point,
01:08:48.640 | I guess. And one thing I started looking at was our schedule. We'd always been copying
01:08:55.960 | and pasting this standard, again, set of stuff. And I started questioning everything. Why
01:09:02.080 | is this the standard? Like, why are these numbers here? And we don't see any particular
01:09:08.600 | reason why those numbers were there. And I thought, well, we should maybe experiment
01:09:14.100 | with them. So to make it easier, I created a little function that would return a schedule.
01:09:20.920 | Now you could create a new class for a schedule, but something that's really cool is there's
01:09:25.480 | a thing in Python called Simplement namespace, which is a lot like a struct in C, basically
01:09:30.620 | lets you wrap up a little bunch of keys and values as if it's an object. So I created
01:09:37.840 | this little simple namespace, which contains our alphas, our alphabars, and our sigmas for
01:09:46.280 | our normal beta max, xs.02 namespace. This is what we always do. And then, yeah, there's
01:09:54.680 | another paper which mentions an alternative approach, which is cosine schedule, which
01:10:01.760 | is where you basically set alphabar equal to T as a fraction of big T times pi/2 cosine
01:10:13.560 | of that squared. And if you make that your alphabar, you can then basically reverse back
01:10:21.000 | out to calculate what alpha must have been. And so we can create a schedule for this cosine
01:10:27.440 | schedule as well. And yeah, this cosine schedule is, I think, pretty recognized as being better
01:10:37.840 | than this linear schedule. And so I thought, okay, it'll be interesting to look at how
01:10:44.600 | they compare. And in fact, really all that matters is the alphabar. The alphabar is the
01:10:54.960 | total amount of noise that you're adding. So in DDPM, when we do noisify, it's alphabar
01:11:06.320 | that we're actually using.
01:11:07.320 | It's the amount of the image and 1 minus alphabar, that's where it's the amount of noise.
01:11:13.160 | Exactly.
01:11:14.160 | Yeah.
01:11:15.160 | Yeah. So I just printed those out, plotted those, for the normal linear schedule and
01:11:21.600 | this cosine schedule. And you can really see the linear schedule. It really sucks badly.
01:11:27.480 | It's got a lot of time steps where it's basically about zero. And that's something we can't really
01:11:39.360 | do anything with, you know, whereas the cosine schedule is really nice and smooth and there's
01:11:47.000 | not many steps which are nearly zero or nearly one. So I thought, so I was kind of inclined
01:11:53.000 | to try using the cosine schedule, but then I thought, well, it'd be easy enough to get
01:11:57.360 | rid of this big flat bit by just increasing, by just decreasing beta max. That'd be another
01:12:01.960 | thing we can do. So I tried, oh, sorry, first of all, I should mention that the other thing
01:12:06.880 | that's really important is the slope of these curves, because that's how much things are
01:12:11.480 | stepping during the sampling process. And so here's the slope of the lin and the cosine.
01:12:18.120 | And you can see the cosine slope, really nice, right? You have this nice smooth curve, whereas
01:12:24.640 | the linear is just a disaster. So yeah, if I change beta max to 0.01, that actually gets
01:12:34.320 | you nearly the same curve as the cosine. So I thought that was very interesting. It kind
01:12:41.480 | of made me think like, why on earth does everybody always use 0.02 as the default? And so we
01:12:47.160 | actually talked to Robin, who is one of the two lead authors on the stable diffusion paper.
01:12:53.600 | And we talked about all of these things and he said, oh yeah, we noticed not exactly this,
01:12:58.520 | but we experimented with everything. And we noticed that when we decreased beta max, we
01:13:03.560 | got better results. And so actually stable diffusion uses beta max at 0.012. I think
01:13:10.200 | that might be a little bit higher than they should have picked, but it's certainly a lot
01:13:12.520 | better than the normal default. So it's interesting talking to Robin to see all of these kinds
01:13:17.600 | of experiments and things that we tried out, they had been there as well and noticed the
01:13:25.600 | same things.
01:13:27.120 | But the inputs range as well, they have this magical factor of 0.1802 wherever they scale
01:13:33.960 | the latency by. And if you ask why they're like, oh yeah, we wanted to delay this to
01:13:36.960 | be like roughly uniform range or whatever, but that's also like, that's reducing the
01:13:41.480 | range of your inputs to reasonable value.
01:13:44.160 | I think exactly. We independently discovered this idea. Yeah, exactly. Yeah, exactly. So
01:13:54.640 | we'll be talking more about like what's actually going on with that maybe next lesson. Anyway,
01:14:03.160 | so here's the curves as well. They're also pretty close. So at this point I was kind
01:14:06.440 | of thinking, well, I'd like to like change as little as possible. So I'm going to keep
01:14:10.600 | using a linear schedule, but I'm just going to change beta max to 0.01 for my next version
01:14:16.360 | of GDPM. So that's what I've got here, linear schedule, beta max, 0.01. And so that I wouldn't
01:14:22.360 | really have to change any of my code. I'd end up just put those in the same variable
01:14:25.520 | names that I've always used. So then noisify is exactly the same as it always has been.
01:14:31.680 | So now I just repeat everything that I've done before. So now would I show a batch of
01:14:37.800 | data? I can already see that there's more actually recognizable images, which I think
01:14:45.440 | is very encouraging. Previously, like almost all of them had been pure noise, which is
01:14:50.480 | not a good sign. So, okay. So now I just train it exactly the same as GDPM v2. And so save
01:15:00.800 | this as fashion GDPM 3. Oh, and then the other things I've done here is, you know, this did
01:15:08.040 | turn out to work pretty well. I actually decided let's keep going even further. So I actually
01:15:13.040 | doubled all of my channels from before, and I also increased the number of epochs by 3
01:15:19.960 | because things are going so well. I was like, how well could they go? So we've got a bigger
01:15:23.400 | model trained for longer. It takes a few minutes. That's what the 25 here is the number of epochs.
01:15:31.880 | So samples exactly the same as it always has been. So create 512 samples and here they are.
01:15:41.400 | And they definitely look to me, you know, great. Like they, I'm not sure I could recognize
01:15:49.160 | whether these are real samples or generated samples. But luckily, you know, we can test
01:15:55.760 | them so we can load up our data org2, delete the last two layers, pass that to image val
01:16:04.720 | and get a fit for our samples. And it's eight. And then I chose 512 for a reason, because
01:16:12.960 | that's our batch size. So then I can compare that like with like for the fit for the actual
01:16:17.920 | data at 6.6. So this is like hugely exciting to me. We've got down to a fit that is nearly
01:16:26.060 | as good as real images. So I feel like this is, you know, in terms of image quality for
01:16:37.040 | small unconditional sampling, I feel like we're done, you know, pretty much. And so
01:16:46.960 | at this point, I was like, OK, well, can we make it faster? You know, at the same quality.
01:16:51.960 | And I just wanted to experiment with a few things, like really obvious ideas. And in
01:16:55.560 | particular, I thought we're calling this a thousand times, which means we're calling
01:17:09.160 | this a thousand times, just running the model. And that's slow. And most of the time you
01:17:15.600 | just move a tiny bit. So the model is pretty much the same. It's, you know, the noise being
01:17:21.040 | predicted is pretty much the same. So I just did something really obvious, which is I decided
01:17:26.720 | let's only call the model every third time, you know, and maybe also just the last 50
01:17:34.000 | to help with fine tune. I don't know if that's necessary other than that. It's exactly the
01:17:39.280 | same. So now this is basically three times faster. And yeah, samples look basically the
01:17:47.600 | same. So the feed is nine point seven eight versus eight point one. And this is like within
01:17:52.880 | the normal variance of feed. So I don't know, like you'd have to run this a few times or
01:17:57.360 | use bigger samples. But this is basically saying like, yeah, you probably don't need
01:18:01.800 | to call the model a thousand times. I did something else slightly weird, which is I
01:18:08.080 | basically said like, oh, let's create a different like schedule for how often we call the model,
01:18:14.200 | which is I created this thing called sample app. It basically said when you're for the
01:18:18.520 | first few time steps, just do it every 10 and then for the next two, every nine and
01:18:23.800 | then actually every eight and so forth. And just for the last hundred, do it every one.
01:18:28.480 | So that makes it even faster. Um, samples look good. This is, you know, it's definitely
01:18:35.720 | worse though now, you know, but it's still not bad. So, um, yeah, I kind of felt like,
01:18:42.380 | all right, this is encouraging and this, this stuff before we fixed the minus one to one
01:18:47.840 | thing was they looked really bad, you know, um, that's why I was thinking that my code
01:18:53.040 | is full of bugs. Um, so at this point I'm thinking, okay, okay, okay. We can create
01:18:57.560 | extremely high quality samples using DDPM. What's the like, you know, best paper out
01:19:04.720 | there for doing it faster. And the most popular paper for doing it faster is DDIM. So I thought
01:19:12.920 | we might switch to this next. So we're now at the point where we're not actually going
01:19:18.580 | to retrain our model at all, right? If you noticed with these different sampling approaches,
01:19:26.000 | I didn't retrain the model at all. We're just saying, okay, we've got a model. The model
01:19:30.560 | knows how to estimate the noise in an image. How do we use that to call it multiple times
01:19:37.440 | to denoise using iterative refinement as Jono calls it. Um, and so DDIM is a, another way
01:19:50.640 | of doing that. So, um, what we're going to do, so I'm going to show you how I built my
01:20:00.200 | own DDIM from scratch. And, um, I kind of cheated, which is I, there's already an existing
01:20:08.280 | one in diffusers. So I decided I will use that first, make sure everything works, and
01:20:15.520 | then I'll try and re-implement it from scratch myself. Um, so that's kind of like when there's
01:20:20.720 | an existing thing that works, you know, that's what I like to do. And it's been really good
01:20:25.400 | to have my own DDIM from scratch because now I can modify it, you know, and I've made it
01:20:30.780 | much more concise code than the diffusers version. So, um, now, um, we had created this
01:20:41.320 | class called unit, which passed the tuple of X's through as individual parameters and returned
01:20:52.920 | the dot sample. But not surprisingly, the, um, given that this comes from diffusers and
01:20:58.840 | we want to use the diffusers schedulers, um, the diffusers, um, schedulers assume this has
01:21:06.560 | not happened. It wants the X, you know, as a tuple and it expects to find the thing called
01:21:11.560 | dot sample. So here's something crazy. When we save this thing, this pickle, it doesn't
01:21:21.560 | really know anything about the code, right? It just knows that it's from a class called
01:21:28.360 | unit. So we can actually lie. We can say, oh yeah, that class called unit. It's actually
01:21:35.760 | the same as you connect 2D model with no other changes and Python doesn't know or care, right?
01:21:43.120 | So this, we can now load up this model and it's going to use this unit. Okay. So this
01:21:49.200 | is where it's useful to understand how Python works behind the scenes, right? It's, it's,
01:21:53.720 | it's a very simple programming language. Um, so we've now got a model which we've trained,
01:22:00.600 | but it's not, it's just going to, you know, use the dot sample on it. That means we can
01:22:04.320 | use it directly with the diffusers schedulers. So we'll start by actually repeating what
01:22:10.040 | we already know how to do, which is use a DDPM scheduler. So we have to tell it what
01:22:13.680 | beta we used to train. Um, and so we can grab some random data and so it could say, okay,
01:22:22.440 | we're going to start at time step 999. So let's create a batch of data and then predict
01:22:29.800 | the noise. And then this is the way the diffusers thing works. As you call scheduler.step and
01:22:37.560 | that's the thing which does, um, those lines. That's the thing that calculates X, T given
01:22:48.480 | noise. So that's what scheduler.step does. So that's why you pass in X, T and the time
01:22:54.700 | step and the noise. Um, and that's going to give you a new set. And so I ran that as usual
01:23:03.200 | first cell by cell to make sure I understood how it all worked. I then copied those cells
01:23:07.240 | and merged them together and chucked them in a loop. So this is now going to go through
01:23:12.560 | all the time steps, use a progress bar to see how we're going, get the noise, call step
01:23:19.780 | and append. So this is just DDPM, but using diffusers and not surprisingly, um, it consists,
01:23:28.240 | you know, basically the same results as, you know, nice results, very nice results that
01:23:33.480 | we got from our own DDPM. And so we can now use the same code we've used before to create
01:23:40.040 | our image evaluator. And, um, I decided, yeah, we're now going to go right up to 2048 images
01:23:48.880 | at a time. So it's now, this is the size I found it's big enough that it's recently stable.
01:23:54.860 | And so we now down to 3.7 for our feed, where else the data itself has a feed of 1.9. So
01:24:02.760 | again, it's showing that our DDPM is basically very nearly unrecognizably different from real
01:24:11.600 | data using its distribution of those activations. So then we can switch to DDIM by just saying
01:24:20.160 | DDIM scheduler. And so with DDIM, you can say, I don't want to do all thousand steps.
01:24:25.120 | I just want to do 333 steps to every third. So that's basically a bit like, um, a bit
01:24:35.680 | like this sample skip of doing every third. But DDIM as we'll see, does it in a smarter
01:24:42.320 | way. Um, and so here's exactly the same code basically as before, but I put it into a little
01:24:51.000 | function. Okay. So I can basically pass in my model, the size, the scheduler. Um, and
01:24:58.840 | then there's a parameter called ADA, which is basically how much noise to add. Um, so
01:25:04.840 | uh, just add all the noise. Um, and so this is now going to take three times. This three
01:25:10.920 | times faster. And yeah, the fit's basically the same. That's encouraging. So they weren't
01:25:16.680 | added 200 steps. If it's basically the same, 100 steps. And at this point, okay, the fit's
01:25:25.240 | getting worse. Um, and then 50 steps. We're still 25 steps. We're still, that's interesting.
01:25:38.160 | Like when you get down to 25 steps, like what does it look like? And you can see that they're
01:25:42.360 | kind of like, they're too smooth. You know, they don't have interesting, you know, fabric
01:25:49.880 | swirls so much or buckles or logos or patterns as much, you know, as the, these ones, they've
01:25:58.920 | got a lot more texture to them. So that's kind of what tends to happen. So you can still
01:26:03.440 | like get something out pretty fast, but that's, that's kind of how they suffer. So, okay.
01:26:12.480 | So how does DDIM work? Well, DDIM, it's nice. It's actually, in my opinion, it makes things
01:26:19.300 | a lot easier than DDPM. So there's basically an equation from the paper, which Tanishka
01:26:28.880 | will explain shortly. But basically what you do is I've, I've actually grabbed the sample
01:26:38.200 | function from, from here and I split it out into two bits. One bit is the bit that says
01:26:48.920 | what are the time steps, creates that random starting point, loops through, finds what
01:26:56.300 | my current A bar is, gets the noise, and then basically does the same as shed.step, calls
01:27:04.400 | some function, right? And then that's been pulled out. So this allows me to now create
01:27:10.120 | my own different steps. So I go to the DDIM step and basically all I did was I took this
01:27:22.820 | equation and I turned it into code. Actually this, this one is a second equation from the
01:27:28.600 | paper. Now it's a bit confusing, which is that the notation here is different. DDPM,
01:27:36.800 | what it calls alpha bar, this paper calls alpha. So you've got to look out for that.
01:27:43.280 | So basically you'll see, I basically go, I've got here XT, XT minus, okay, one minus alpha
01:27:50.800 | bar is, we've got to call that beta bar. So beta bar dot square root times noise. This
01:27:58.040 | here is the, this is the neural net, so this here is the noise. Okay. And here I've got
01:28:12.840 | my next XT is, oh sorry, yes, here's my A bar T1 square root times this. And you can
01:28:23.040 | see here it says predicted X naught. So here's my predicted X naught plus beta bar T1 minus
01:28:31.240 | sigma squared square root. Again, here's noise. That's the same thing as here. Okay. And then
01:28:42.480 | plus a bit of random noise, which we only do if you're not at the last step. So yeah,
01:28:48.720 | so I can call that. So I just did it for, so rather than saying a hundred steps, I said
01:28:56.120 | skip every, skip 10 steps. So do 10 steps at a time. So it's basically going to be a
01:29:00.240 | hundred steps. And so you can see here actually, this is happened to do a bit better for my
01:29:05.640 | hundred steps. It's not bad at all. So yeah, I mean, this has been getting to this point,
01:29:19.000 | it's been a bit of a lifesaver to be honest, because, you know, I can now run a tooth,
01:29:24.720 | you know, two batch of 2048 samples. I can sample them in under a minute, which doesn't
01:29:31.520 | feel painful. And so, you know, now at a point where I can actually get a pretty good measure
01:29:38.400 | of how I'm doing, get a pretty reasonable amount of time and I can, you know, easily
01:29:46.160 | compare it. And I got to admit, you know, the difference between a fit of five and eight
01:29:51.440 | and 11, I can't necessarily tell the difference. So for fashion, I think fit is better than
01:29:57.640 | my eyes for this, as long as I use a consistent sample size. So yeah, Tanisha, did you want
01:30:06.320 | to talk a bit about, you know, the ideas of why we do this or where it comes from or what
01:30:13.880 | the notation means?
01:30:14.880 | Can I say a little bit before we do that, which is just that what you have there, Jeremy,
01:30:19.040 | which is like a screenshot from the paper and then the code that as close as possible
01:30:24.480 | tries to follow that, like the difference that makes for people is huge. Like I've got
01:30:30.120 | a little research team that I'm doing some, you know, contract work with. And the fact
01:30:35.040 | that like, it's called Alpha in the data and paper and Alpha elsewhere. And then in the
01:30:38.520 | code that they were copying and pasting from it was called A and B for Alpha and B done.
01:30:44.160 | And it's like you can get things kind of working by copying and pasting into things. And it's
01:30:48.720 | all just sort of kind of works. But just spending their time that actually take two screenshots
01:30:52.640 | from the equation 14 and 16 from the paper and put them in there and rewrite the code
01:30:57.720 | so that it, you know, with some comments and things to say, like, this is what this is,
01:31:00.840 | this is that part from the equation. It's like, you know, the look of pain on their face
01:31:06.520 | when I said, oh, by the way, did you notice that like, it's called Alpha there and Alpha
01:31:09.400 | by there? They're like, yes, how could they do that? You know, it's just like, you could
01:31:12.880 | just tell how many hours have been spent, you know, like grinding text and saying what's
01:31:16.960 | wrong here.
01:31:17.960 | Yeah, and building this stuff in notebooks is such a good idea. Like we're doing MIDI
01:31:23.920 | AI because the, you know, the next engineer to come along and work on that can see the
01:31:31.960 | equation right there and you can add rows and stuff. So I think, you know, NVDev works
01:31:39.000 | particularly well for this, this kind of development.
01:31:42.520 | Yeah.
01:31:43.520 | Yeah, before, before, before I talk about this, I just wanted to briefly, in the context
01:31:51.720 | of all of these different notations, I recently created this meme, which I thought was, was
01:31:58.320 | relevant in terms of like, each paper basically has a different diffusion model of me, different
01:32:03.320 | notation. So it's just like this, but they all try to come up with their own universal
01:32:07.440 | notation and it's just, just keeps proliferating.
01:32:10.360 | It's just to me, we should all use AMAL.
01:32:14.360 | Yes, exactly. We need to implement diffusion models in APLs. So yeah, the paper that Jeremy
01:32:25.400 | had implemented was this denoising diffusion implicit model paper. And if you look at the
01:32:33.400 | paper again, you can see like the notation could be again, a little bit intimidating,
01:32:38.680 | but when we walk through it, we'll see it's not too bad actually. So I'll just bring up,
01:32:44.680 | I guess, some of the important equations and also comparing and contrasting, you know, DDPM
01:32:53.120 | and the notation of DDPM and the equations with DIM.
01:32:57.280 | Not only is it not too bad, I actually discovered it's making life a lot. The DDIM notation
01:33:03.640 | and equations are a lot easier to work with than DDPM. So I found my life is better since
01:33:10.440 | I discovered DDIM.
01:33:12.160 | Yes, yes. I think a lot of people prefer to use DDIM as well. Yeah. So yeah, basically
01:33:20.280 | in, let's see here. So yeah, so in both DDIM and in both DDIM and DDPM, we have this same
01:33:32.120 | sort of equation. This equation is exactly the same. This is telling us the predicted
01:33:40.920 | denoised image. So we predict our, but basically we predict the, you can see my pointer, right?
01:33:49.260 | Just want to confirm.
01:33:50.260 | By the way, so the little double-headed arrow in the top right, does that, if you click
01:33:55.080 | that, do you get more room for us to see what's going on?
01:33:58.960 | I'm sorry?
01:33:59.960 | Yeah, I see. Yeah, that works much better. Yeah. So we have our predicted noise. So our
01:34:11.600 | model is predicting the noise in the image. It is also passed in the time step, but this
01:34:17.880 | is just emitted. It basically kind of is given in the XT, but our model also takes in the
01:34:23.320 | time step. And so it's predicting the noise in this XT, our noisy image. And we are trying
01:34:29.940 | to remove the noise. That's what this whole term here is, remove the noise. So because
01:34:39.240 | our noise that we're predicting is unit variance noise. So we have to scale the variance of
01:34:44.080 | our noise appropriately to remove it from our noisy image. So we have to scale the noise
01:34:51.080 | and subtract it out of the original image. And that's how we would get our predicted
01:35:00.000 | denoised image.
01:35:01.000 | And I think we have to write this one before by looking at the equation for XT in the Noisify
01:35:10.680 | function and rearranging it to solve for X naught. And that's what you get.
01:35:17.960 | Yes, that's basically what this is. So basically the idea is, okay, instead of noisifying it
01:35:25.400 | where we're starting out with X0 and some noise and get an XT, we're doing the opposite
01:35:30.780 | where we have some noise and we have XT. So how can we get X0? So that's what this equation
01:35:36.000 | is. So that's the predicted X0 or our predicted clean image. And this equation is the same
01:35:43.000 | for both DDPM and DDIM, but these distributions are what's different between DDPM and DDIM.
01:35:50.840 | So we have these distribution, which tells us, okay, so if we have XT, which is our current
01:35:58.320 | noisy image, and X0, which is our clean image, can we find out what some sort of intermediate
01:36:05.360 | noisy image is in that process? And that's XT minus one. So we have a distribution for
01:36:12.760 | that. And so that tells us how to get such an image. And so this is in the DDPM paper,
01:36:19.000 | they did to define some distribution and explain the math behind it. But yeah, basically, they
01:36:25.200 | have some equations. So you have, again, a Gaussian distribution with some sort of mean
01:36:32.000 | and variance, but it's, again, some form of you have this sort of interpolation between
01:36:38.840 | your original clean image and your noisy image. And that gives you your intermediate, slightly
01:36:47.800 | less noisy image. So that's what this is giving. Given a clean image and a noisy image, you're
01:36:54.240 | slightly less noisy image. And so the sampling procedure that we do with DDPM basically is
01:37:03.440 | predict the noisy, predict the X0, and then plug it into this distribution to give you
01:37:10.140 | your slightly less noisy image. So maybe it's worth drawing that out. So like if we had,
01:37:19.280 | let's say, some sort of, like, I don't know, I'm just making some sort of, I don't know,
01:37:24.120 | maybe a lot of some sort of better, yeah, some sort of. So then, in this case, I'm showing
01:37:30.480 | a one-dimensional example. Let's say you have some sort of a point. So it's kind of a one-dimensional
01:37:35.920 | example that's still in the sort of 2D space. But let's say you have any point on this,
01:37:42.200 | it represents an actual image that you want to sample from, right? So this is where your
01:37:46.680 | distribution of actual images would lie. And you want to estimate this. So when this sort
01:37:53.400 | of algorithm that we've been seeing here says that, okay, if we take some random point,
01:38:00.160 | this is some random point that we choose, you know, when we start out. And what we did
01:38:05.240 | is we learned this function, the score function, to take us to this manifold, but it's only
01:38:10.400 | going to be accurate in some space. So it's going to be accurate, you know, it's only,
01:38:15.000 | it would be accurate in some area. So we get an estimate of the score function and it tells
01:38:20.040 | us the direction to move in. And it's going to give us the direction to predict our denoised
01:38:26.880 | image, right? So basically, like, let's say, let's say this, let's say you actually, your
01:38:33.600 | score function, sorry, so let's say your score function is actually in reality, some curve,
01:38:38.520 | okay? So it's in reality, some curve that points to your, oops, it points here. So that's your
01:38:43.640 | score function. And you know, the value here, that's what score function basically means
01:38:48.280 | your gradient. Yeah. Yes, yes, it's a gradient. So, you know, we have, again, doing some form
01:38:54.720 | of, in this case, I guess you would say gradient ascent, because you're not really minimizing
01:39:00.120 | the score, you're maximizing it. You want, sorry, you're maximizing your, the likelihood
01:39:05.360 | of that data point being an actual data point. You want to go towards it. So you're doing
01:39:10.840 | the sort of gradient ascent process. So you're following the gradient to get to that. So
01:39:17.920 | when we estimate epsilon theta and predict our noise, what we're doing is we're getting
01:39:24.280 | the score value here. And then so we can, you know, follow that. And we follow it to
01:39:30.720 | some point. I'm being kind of exaggerating here. But this point will now represent our
01:39:38.920 | x zero hat. Yeah. So, yeah, our x or hat. So, and in reality, you know, that's not maybe
01:39:47.980 | that's not going to be some point that is an actual point, it wouldn't be next to the
01:39:52.920 | distribution. So, you know, it's, it's not going to be a very good estimate of a clean
01:39:58.240 | image at the beginning. But, you know, we only have that estimate at the beginning at
01:40:02.920 | this point, and we have to follow it all the way to to some place. So this is where we
01:40:07.880 | follow it to. And then we want to find some sort of x t minus one. So that's what our
01:40:16.160 | next point is. And so that's what our second distribution tells us. And it basically takes
01:40:22.900 | us all the way, takes us all the way back to maybe some point here. And now we can re-estimate
01:40:30.080 | the the the score function or our gradient over there, you know, do this prediction of
01:40:35.200 | noise. And, you know, it may be more accurate of a of a score function. And maybe we go
01:40:40.880 | somewhere here. And then we re-estimate and get another point and then we follow it. And
01:40:47.240 | so that's kind of this iterative process where we're trying to follow this the score function
01:40:52.740 | to your own point. And in order to do so, we first have to estimate our x zero hat and
01:40:58.360 | then and then basically add back some noise and to get, you know, a little bit again,
01:41:07.120 | a new estimate and keep falling and add back a little bit more noise and keep estimating.
01:41:11.400 | So that's what we're doing here in these two steps. We have our x zero hat, and then we
01:41:16.920 | have this distribution. And that's how we do it regular DDPM. And it's that I think that's
01:41:23.160 | the maybe where the sort of broken it up in two steps is a bit clearer. And I don't think
01:41:30.040 | the DDPM paper really clarifies that really talks about it too much. But the DDPM paper
01:41:36.740 | also really hammers that point home, I think, and especially in their in their update equation.
01:41:42.720 | So the other so that's the deep DDPM, but then with DDPM, just the one thing is that
01:41:52.320 | you look at your prediction, use that to make a step that you also add back some additional
01:41:56.920 | noise that's always fixed. Right. There's no parameter to control how much extra noise
01:42:01.960 | you add back at each step. Right, exactly. So oops. So yeah, so then you're, let's see
01:42:10.200 | here. Yeah, so yeah, basically, you won't be exactly at this point, you could be, you
01:42:14.400 | know, you're in that general vicinity, and, you know, adding that noise also helps with,
01:42:19.520 | you know, you don't want to fall into, you know, specific modes where it's like, oh,
01:42:26.280 | you know, this is the most likely data point, you want to add some noise where you can like
01:42:30.200 | explore other data points as well. So yeah, there's some, you know, the noise also can
01:42:36.080 | help and that's something you really can't control with DDPM. And this is something that
01:42:40.920 | DDIM explores a little bit further is in terms of the noise and even trying to get rid of
01:42:45.600 | the noise altogether in DDIM. So with the DDIM paper, the main difference is literally
01:42:54.160 | this one equation, that's all all really it is in terms of changing this distribution
01:43:01.360 | where you predict the less noisy equation, the less noisy, sorry, the less noisy image.
01:43:10.640 | And basically, as you can see, it's just a, you have this additional parameter now, which
01:43:18.360 | is sigma. And the sigma controls how much noise, like we were just mentioning, is going
01:43:25.280 | to be part of this process. And you can actually, for example, if you want, you could set sigma
01:43:29.800 | to zero. And then you can see here, now you have a variance, that would be zero. And so
01:43:34.440 | this becomes a completely deterministic process. So if you want, this could be completely deterministic.
01:43:41.840 | So that's one aspect of it. And then, yeah, so the other, the other aspects of that, right,
01:43:52.880 | the reason is called DDIM is just not DDPM, because it's like not probabilistic anymore.
01:43:57.840 | It can be made deterministic. So the name was changed for that reason. But the other thing
01:44:04.280 | is like, you would think that you've kind of changed the model altogether with a new
01:44:08.920 | distribution altogether. And so you say, oh, wouldn't you have to like trade a different
01:44:14.160 | model for this purpose? But it turns out the math works out where the same model objective
01:44:19.840 | would work well with this distribution as well.
01:44:23.120 | And in fact, I think that's what they were setting out from the very beginning is what
01:44:27.840 | kind of other models can we get with the same objective. And so this is what they're able
01:44:34.600 | to do is you can make some, you can have this new parameter that you can introduce, in this
01:44:40.040 | case, kind of controlling the stochasticity of the model. And it still can be, you can
01:44:50.000 | still use the same exact trained model that you had.
01:44:55.320 | So what this means is that this actually is just a new sampling algorithm, and not anything
01:45:00.480 | new with the training itself. And this is just, yeah, just like we talked about a new
01:45:05.680 | way of sampling the model. And then, so yeah, this is how, you know, given now this equation,
01:45:13.800 | then you can rewrite your x t minus one term. And again, we're doing the same sort of thing
01:45:18.360 | where we split it up into predicting the x zero, and then adding back to go back to your
01:45:27.400 | x t. And also, if you need to add a little bit of noise back in, like Jonathan was saying,
01:45:35.080 | you can do so, you have this extra term here, and the sigma controls that term.
01:45:41.440 | And again, like we said, you have to be, again, looking at the DDI equation versus the DDPM
01:45:48.080 | equation, you have to be careful of the alphas here are referring to alpha bars in the DDPM
01:45:53.960 | equation. So that's the other caveat. So yeah, and this you have this sigma t set to this
01:46:02.840 | particular value will give you back DDPM. So sometimes instead, instead, they will write,
01:46:10.240 | basically, Jeremy mentioned this sort of, I guess, eta, which is equal to basically,
01:46:20.520 | yeah, so it's just basically, eta is sigma is equal to eta times this coefficient. So
01:46:28.840 | sorry, let me just go back. So basically, yeah, in reality, you take, you have eta here,
01:46:43.600 | so it's like, yeah, this is where eta would go. So if it's one, it becomes regular DDPM
01:46:48.800 | and if it's zero, of course, that's a determinant case. So this is where the eta that, you know,
01:46:53.320 | all these API's and in the code that we have, also the code that Jeremy was showing, they
01:46:59.240 | have eta equals to one, which of course, which they say is corresponding to regular DDPM.
01:47:06.280 | This is actually where the eta would go in the equation. So finally, it's like, yeah,
01:47:14.200 | you could pass in sigma, right? Like if you weren't trying to match it in clean use papers,
01:47:18.080 | you could just say, oh, well, we have this parameter sigma that controls the amount of
01:47:20.760 | noise. So let's just take in a big more scale as an argument. But for convenience, they said,
01:47:25.760 | let's create this new thing, eta, where zero means sigma is equal to zero, which if you
01:47:31.320 | look at the equation that works, one means we match the amount of noise that's in like
01:47:37.600 | vanilla DDPM. And so then that gives you like a nice slice. So you could say eta equals
01:47:42.240 | two or eight equals 0.7 or whatever. But it's like, but a meaningful unit of one equals the
01:47:48.480 | same as this previous reference work. Well, it's also convenient because it's sigma t,
01:47:54.360 | which is to say different time steps, unless you choose eta equals zero, in which case it
01:48:00.240 | doesn't matter. Different time steps probably want different amounts of noise. And so here's
01:48:07.040 | a reasonable way of scaling that noise. Then the last thing of importance, which is of
01:48:16.240 | course, one of the reasons that we were exploring this as the first place is to be able to do
01:48:21.560 | this sort of rapid sampling. So the basic idea here is that you can define a similar
01:48:32.440 | distribution where again, the math works out similarly, where now you have, let's say you
01:48:40.080 | have some subset of diffusion steps. So in this case, it uses tau variable. So for example,
01:48:46.080 | if you say, let's say subset of diffusion steps. So if it's like 10 diffusion steps,
01:48:52.760 | then tau one would just be zero, then tau two would be 10. You just keep going all the
01:48:57.920 | way up to say a thousand, but you've only got the, sorry, tau two would be a hundred.
01:49:02.720 | And then you go all the way up to a thousand. And so you'd get 10 diffusion steps. So that's
01:49:08.200 | what they're referring to when they have this, I guess this tau variable here. And so you
01:49:16.880 | can do these sorts of similar equation and similar derivation to show that this distribution
01:49:24.400 | here again, meets the sort of objective that you use for training. And you can now use
01:49:31.160 | this for a faster sampling, where basically all you have to do is you have to just select
01:49:37.400 | the appropriate alpha bar. And sorry, this one I've written out. So this one, actually
01:49:43.480 | the alpha bar is the regular alpha bar that we've talked about. But basically, sorry,
01:49:47.960 | it's a little bit confusing switching between different notations. But basically, you have
01:49:54.760 | this distribution and then you just have to select the appropriate alpha bars and it follows
01:50:01.000 | the math the same in terms of you have appropriate sampling process. So yeah, and I guess that's,
01:50:11.600 | it makes it a lot simpler in terms of doing this, I guess, accelerated sampling. Yeah,
01:50:21.080 | I guess with any other note, maybe other comments that maybe you guys had, or was this?
01:50:27.240 | Yeah, the key for me is that in this equation, we just have one, we only need one parameter,
01:50:39.360 | which is the alpha bar or alpha depending which notation is raised and everything else
01:50:43.480 | is calculated from that. And so we don't have the, what DDPM calls the alpha or beta anymore.
01:50:54.720 | And that's more convenient for doing this kind of smaller number of steps, because we can
01:51:01.960 | just jump straight from time step to alpha bar. And we can also then, as particularly
01:51:10.240 | convenient with the cosine schedule, because you can calculate the inverse of the cosine
01:51:16.240 | schedule function, which means you can also go from an alpha bar to a T. So it's really
01:51:20.960 | easy to like, say like, oh, what would alpha bar be 10 time steps previously to this one,
01:51:27.680 | you know, it's just, you could just call a function. We don't need, yeah, we don't need
01:51:31.440 | anything else. And so actually the original cosine schedule paper has to fuss around with
01:51:38.960 | various like kind of epsilon style small numbers that they add to things to avoid getting weird
01:51:47.920 | numerical problems. And so yeah, when we only deal with alpha bar, all that stuff also goes
01:51:53.640 | away. So yeah, so looking, if you're looking at the DDIM code, you know, it's simpler code
01:52:04.600 | with less parameters than our DDPM code. And of course, it's dramatically faster. And it's
01:52:11.400 | also more flexible because we've got this eta thing we can play with.
01:52:14.800 | Yes. Yeah, that's the other thing. It's like this idea of like, yeah, controlling stochasticity.
01:52:21.800 | I think that's something that's interesting to explore. And we've been exploring that
01:52:26.040 | a bit now. And I think we'll continue to explore that in terms of deterministic versus stochasticity.
01:52:31.480 | So yeah.
01:52:32.480 | So it's worth talking about this, the sigma in the middle equation you've got there. So
01:52:37.280 | you've got the sigma t eta t adding the random noise. And intuitively, it makes sense that
01:52:42.880 | if you're adding random noise there, you would need to have less, you want to move less back
01:52:48.640 | towards xt, which is your noisy image. So that's why, you know, you've got the 1 minus alpha
01:52:55.460 | t minus 1 minus sigma squared. And then you're taking the square root of that. So basically,
01:53:02.160 | that's just sigma, the square root of the squared. So you're subtracting sigma t from
01:53:08.520 | the direction pointing to xt and adding it to the random noise, or vice versa. So yes,
01:53:14.920 | you know, that everything's there for a reason, you know, yes. And the predicted x0, that entire
01:53:24.960 | equation we've derived previously.
01:53:28.840 | And it remains the same in pretty much any diffusion model methodology.
01:53:32.360 | Well, as long as we'll be talking about actually some places where it's going to change probably
01:53:38.760 | next week.
01:53:39.760 | Well, yeah, I guess it's another thing where you're predicting the noise. Yes. Yes. Yes.
01:53:44.980 | If you're predicting the noise, yes, there'll be.
01:53:47.000 | Okay.
01:53:48.000 | Yeah.
01:53:49.000 | So I think, you know, we'll probably, yeah, let's wrap it up here so that we leave ourselves
01:53:58.060 | plenty of time to cover the kind of new research directions next lesson more in more detail.
01:54:03.560 | But as I mentioned, in terms of where we're at, just like we hit a kind of like, okay,
01:54:09.960 | we can really predict classes for Fashioned MNIST a few weeks ago, where I think we're
01:54:15.600 | there now, and we can do stable diffusion sampling and units, except for the unit architecture,
01:54:24.520 | or unconditional generation now, we basically can do Fashioned MNIST almost so it's unrecognizably
01:54:32.440 | different to the real samples, and DDIM is the scheduler that the original stable diffusion
01:54:39.480 | paper used.
01:54:43.080 | So yeah, you know, we're actually about to go beyond stable diffusion for our sampling
01:54:53.140 | and unit training now. So I think we've, yeah, definitely meeting our stretch goals so far,
01:55:03.440 | and all from scratch, with weights and biases, experiment logging.
01:55:11.600 | And you know, if you wanted to have fun, there's no reason you couldn't like have a little
01:55:17.400 | call back that instead logs things into a SQL like database, and then you could write
01:55:21.600 | a little front end to show your experiments, you know, that'd be fun as well.
01:55:27.920 | Yeah, I mean, you could do also send you a text message when the loss gets good enough.
01:55:33.160 | Yeah.
01:55:34.160 | Alright, well, thanks, guys. That was really fun.
01:55:38.760 | Thanks, everybody.
01:55:40.200 | Alright, bye.
01:55:41.200 | Jizzle.
01:55:42.200 | Okay, talk to you later, then. We're bye.
01:55:45.120 | [BLANK_AUDIO]