back to indexStanford CS25: V1 I Transformers in Vision: Tackling problems in Computer Vision
Chapters
0:0
0:34 General Visual Representation
4:8 The Visual Task Adaptation Benchmark
7:26 Self-Supervised Pre-Training
7:58 Semi-Supervised Training
21:22 Synthetic Images
26:33 Applying Transformers to Vision
26:49 Embedding Space
42:5 Early Convolutions
45:28 Patch Size
46:24 Inference Speed
59:31 Scaling the Data Set
00:00:00.000 |
Today, I'm going to talk to you about vision transformers, 00:00:20.560 |
and specifically also on the vision part of things, 00:00:24.320 |
because I think the majority of what you have seen 00:00:37.000 |
and you're going to soon see what that means and why, 00:00:56.680 |
because if you have a good understanding of what you see, 00:01:00.240 |
then you can much quicker understand what's going on 00:01:04.640 |
And eventually, I have now a little kid since a year, 00:01:09.720 |
and so I really want that when he's grown up, 00:01:16.080 |
It doesn't need to be nice and pretty like in movies, 00:01:18.080 |
just maybe an arm or whatever, that my kid could teach, 00:01:36.000 |
It's not all that's required, but it's one part, 00:01:45.080 |
and one good example of a general visual representation 00:01:49.920 |
and I'm going to show you what I mean by that. 00:01:58.120 |
and I give you five images of each class, okay? 00:02:04.040 |
and I'm sure that by now you all know which class it is. 00:02:09.320 |
I'm not going to ask because I don't actually see you. 00:02:11.760 |
If I was in the room, I would do the raised hands, 00:02:17.160 |
Okay, this is fine. We have seen millions of flowers 00:02:26.960 |
Some people may have never seen it sometimes, 00:02:29.520 |
like when you fly or maybe on TV or in the Internet or so, 00:02:35.680 |
Three classes, class A, B, C, five images of each, 00:02:41.720 |
This might be a little bit less trivial than the flower, 00:02:44.720 |
but I think I've spent enough time talking that by now, 00:02:48.520 |
most of you should know that this is class B. 00:02:51.160 |
Shows a, what is it, basketball court, right? 00:03:00.200 |
but still, I give you images of class A and B. 00:03:05.320 |
because you need to use your brain a little bit more, 00:03:10.680 |
and now I should do a little bit of small talk 00:03:15.960 |
like you see that there is like spheres, boxes, and whatnot, 00:03:26.880 |
and class B is always, what is it, five objects, 00:03:29.640 |
no matter what they are, what they look like. 00:03:33.080 |
Okay, I think by now, you more or less understand 00:03:37.680 |
what I mean when I mean a good visual representation, 00:03:46.640 |
in your brain, in your eyes such that you can quickly see 00:03:53.760 |
with just a few examples, and then generalize from that, 00:04:08.800 |
which we call the Visual Task Adaptation Benchmark. 00:04:10.920 |
It's kind of formalization of the little game 00:04:20.280 |
or anybody who participates in the benchmark does, 00:04:24.800 |
We don't really care what data, what model, how, what not. 00:04:30.440 |
Then we come with this landscape of all possible visual tasks 00:04:35.280 |
that kind of make sense, which is a vague statement, 00:04:41.160 |
and this is kind of the task that you have just seen. 00:04:44.840 |
They were actually taken out of this Task Adaptation Benchmark, 00:04:49.160 |
and we have, for a first step, made 19 such tasks 00:04:53.000 |
where we try to cover broad types of visual tasks, 00:05:01.200 |
but also of very specialized images like satellite image, 00:05:04.480 |
also non-classification tasks that involve counting, 00:05:09.360 |
but that can be expressed in this simple classification API, 00:05:13.200 |
but that logically requires some more thinking. 00:05:15.840 |
Some things like distance, we have something with cars 00:05:19.760 |
and with distance of the closest car and things like that. 00:05:27.000 |
and then with the model that you came to this benchmark, 00:05:32.000 |
you can do some adaptation step on each of the datasets, 00:05:40.720 |
a model of this dataset, which is very small. 00:05:44.040 |
It just has seen a few examples for each class 00:05:56.120 |
we judge how good of a general visual representation 00:06:01.040 |
does your model and adaptation algorithm have, 00:06:03.320 |
and now just for some nomenclature, this preparation, 00:06:08.560 |
we have words that we often use pre-training. 00:06:12.560 |
like upstream data, upstream training, something, 00:06:15.640 |
so I may use this word interchangeably with pre-training, 00:06:23.720 |
and the adaptation, in principle, it's whatever you want, 00:06:29.440 |
but for our work, we almost always just use very simple, 00:06:36.720 |
In general, we try to do things as simple as possible. 00:06:39.600 |
It still works well, and so sometimes I even just say, 00:06:44.280 |
That means moving from this pre-training to the transfer. 00:06:47.120 |
All right, so so far for the settings, so far so good? 00:06:52.440 |
Good. Then the question is, how do we get there, 00:06:58.600 |
and we spend a lot of time thinking about this 00:07:07.520 |
which doesn't mean we're going to cover everything, 00:07:10.120 |
so I'm not going to go, like, through the outline exactly, 00:07:17.040 |
field transformer only comes a little bit later. 00:07:25.280 |
is that we spend some time trying self-supervised pre-training 00:07:30.200 |
and in vision only recently has become popular, 00:07:42.880 |
That's the VTAP score for this few-shot VTAP, 00:07:47.000 |
and self-supervised learning performs like this bar. 00:07:49.680 |
We tried multiple methods and multiple models and so on. 00:07:55.560 |
Then we moved on to semi-supervised training, 00:08:00.080 |
so a few labeled examples and a ton of unlabeled examples. 00:08:19.960 |
- Yeah, so then semi-supervised is that blue bar 00:08:22.520 |
which is a lot higher than this other blue bar, 00:08:33.240 |
Then I'm not going to spend more time on this 00:08:43.600 |
if we just scale up fully-supervised pre-training, 00:08:47.280 |
then we get really much better representations 00:08:51.800 |
and here I want to briefly spend some time on that one 00:09:04.880 |
for semi-supervised or unsupervised learning, right? 00:09:10.160 |
there's almost always some extra information, 00:09:18.440 |
that you could use as some weak source of information 00:09:25.080 |
there's some team that actually does this for production, 00:09:28.080 |
and they have collected already a large dataset 00:09:31.240 |
with some pipeline that from the surrounding signals 00:09:38.240 |
and we wanted to figure out how far can we go 00:09:43.760 |
Then, long story short, you need a couple of ingredients. 00:09:50.680 |
This is one of the curves of just pre-training 00:09:57.640 |
The gist is that if I zoom into this little box, 00:10:00.360 |
I see this here, and this is the metric for the training, 00:10:06.280 |
Then I see after spending eight GPU weeks of compute, 00:10:13.920 |
one GPU for eight weeks or eight GPUs for one week 00:10:21.560 |
But this looks flat. A reasonable person would say, 00:10:24.000 |
"Yeah, there's no progress for a week on eight GPUs. 00:10:26.760 |
This is flat. I'm going to stop and try something else," 00:10:31.600 |
and this is what the exact same spot looks like 00:10:35.880 |
and you can clearly see the things are progressing, right? 00:10:39.480 |
So it may not always be obvious, and you need patience. 00:10:52.760 |
The x-axis is the number of images available. 00:10:55.000 |
In vision, there is this image in the dataset, 00:10:57.000 |
which is a very common, super common dataset for pre-training, 00:11:02.560 |
There's another one which has 10 times more images 00:11:04.560 |
that's still public, and then there is one subset 00:11:11.240 |
so the y-axis is measure of accuracy on some tasks, 00:11:22.000 |
The blue dot is the standard ResNet 50 that everybody uses. 00:11:29.040 |
but if you go to even more data, it looks like, 00:11:30.960 |
oh, okay, this doesn't really seem that useful, 00:11:34.280 |
and this is what most people have been doing for a long time, 00:11:37.280 |
and a lot of people, even in Google, were like, 00:11:39.960 |
yeah, I tried this internal checkpoint on these tons of data. 00:11:45.800 |
However, what we found out, and in hindsight, 00:11:48.000 |
it's kind of obvious, is that you actually need to scale 00:11:53.760 |
Here, this blue dot is a gigantic ResNet that is slow as hell, 00:11:57.520 |
but when you scale this up together with the data, 00:11:59.720 |
you keep getting benefit with adding more data, 00:12:05.720 |
be patient could also be quite scale up your patience. 00:12:14.040 |
so here there is a few short transfer learning. 00:12:22.400 |
on the y-axis is the accuracy on one of these tasks, 00:12:33.240 |
you don't really see benefit or small benefit 00:12:43.760 |
you start getting better and better and better 00:12:52.000 |
Second benefit that we did not anticipate really at all, 00:12:55.560 |
but then found out is that these models are super robust 00:13:07.120 |
like a chair in the bathtub and things like that, 00:13:13.920 |
Here, the pink dots are basically how existing models, 00:13:17.480 |
and x-axis is, again, how large is the model, 00:13:19.920 |
and pink dot is existing ones from the literature, 00:13:31.840 |
like in this case, out-of-distribution robustness. 00:13:38.280 |
Scale up everything, be patient, and get huge benefit. 00:13:45.040 |
but there is a question from a student in the class. 00:13:49.400 |
Do you want to unmute yourself and ask it yourself? 00:13:59.840 |
what work has been done characterizing the parameters 00:14:04.040 |
Like, the reason why I'm motivating this question is, 00:14:06.960 |
it seems like we do this tremendous amount of pre-training, 00:14:12.000 |
if we just have smarter initialization schemes. 00:14:19.560 |
And they've come to conclude that I think not. 00:14:33.080 |
You know, that everything is in a nice range, 00:14:35.200 |
such that it can have nice input/output functions, 00:14:38.320 |
and so on, and that your optimizer can do steps 00:14:41.160 |
that make reasonable change to the input/output function, 00:15:05.280 |
remembering similarity to things they've seen in training. 00:15:11.160 |
they have more memory, and they have seen more things, 00:15:13.840 |
so they should be better on more newer things, 00:15:16.480 |
because there's more similar things they have seen. 00:15:24.400 |
But I don't have the immediate pointer to a paper 00:15:29.320 |
at the top of my head now to answer your question. 00:15:36.200 |
so has posted on the chat and is raising his hand. 00:15:40.160 |
Maybe in this order, you wanna ask your question first? 00:15:46.000 |
So I just have a quick clarification on this chart right here, 00:16:02.720 |
all the way to the 300 million image dataset for bit L? 00:16:16.000 |
And then the different points are random restarts, 00:16:27.640 |
And as you go to the right, the model gets larger. 00:16:30.320 |
And so you can see that for this little data, 00:16:32.840 |
going to larger model doesn't really help you much 00:16:42.200 |
- Right, that makes a lot of sense, thank you. 00:16:54.320 |
What is the intuition for the upstream performance 00:17:12.920 |
that just seems like an odd looking training curve. 00:17:19.000 |
- Yeah, this is old school computer vision thing, 00:17:26.240 |
In computer vision, it used to be very common 00:17:28.440 |
to have the learning rate in a kind of staircase pattern. 00:17:31.520 |
So it's constant for a while, and then you stop, 00:17:40.680 |
And nowadays, people don't use this much anymore. 00:17:42.880 |
And this work was like three years ago, I think, 00:17:48.960 |
And nowadays, people use more continuously changing 00:17:51.800 |
learning rate schedule, and then you don't really have 00:17:55.960 |
But if you would overlay it, it would be like 00:17:57.760 |
more continuously, but going roughly the same. 00:18:05.560 |
learning rate schedule, where also you don't see 00:18:07.240 |
this effect, because learning rate continuously decreases. 00:18:12.680 |
- And then this is what, because you asked for, 00:18:18.640 |
Actually here, if you're like here, you could say, 00:18:26.400 |
Maybe you could have started the decay earlier, 00:18:28.920 |
and earlier, and earlier, and then you would get the same, 00:18:35.840 |
And you do land at much worse place in the end 00:18:55.960 |
- Yeah, it's fine, we can coordinate that with this. 00:19:05.000 |
So basically what you're trying to do is multitask learning 00:19:08.320 |
with convolutional neural networks/LSTMs, right? 00:19:13.320 |
But you're doing multitask learning, correct? 00:19:21.920 |
- Because like, initially, like you showed like different, 00:19:33.400 |
And this pre-training, I didn't mention it yet. 00:19:36.320 |
I just said, I don't care what you do in the pre-training, 00:19:38.960 |
just pre-train somehow, and give me the model. 00:19:41.600 |
And then I test it on multiple tasks independently. 00:19:49.080 |
which in our case means fine-tune it just on the task, 00:19:55.520 |
Like later we moved to just learning a linear regression 00:20:03.240 |
what we do is just regular supervised learning, 00:20:14.800 |
a couple labels or not, but it usually doesn't have. 00:20:30.360 |
like the discussion rather than started about this, 00:20:33.960 |
it's like memorization, or it's more memorizing the data 00:20:42.080 |
that you can pre-train on a synthetic language 00:20:45.280 |
that's, it doesn't have any semantic meaning, 00:20:52.600 |
And that actually gives you almost the same boost 00:20:56.280 |
in your downstream transfer as a normal pre-training. 00:21:04.600 |
the structure seems to make a lot of contribution, 00:21:11.040 |
it's a different case, maybe to have people done, 00:21:13.640 |
maybe some synthetic pre-training data set for image. 00:21:24.600 |
and like not even rendering of some realistic things, 00:21:27.440 |
but just completely patterns, waves, and shapes and so on, 00:21:33.760 |
And then it shows that they get almost the same performance 00:21:38.080 |
they actually do this with vision transformers. 00:21:41.520 |
But yeah, they never go further or it is not clear, 00:21:45.600 |
you know, they kind of show that you can almost get 00:21:49.200 |
That is not clear how much further can you go with this. 00:21:56.000 |
but it's just me guessing that not much further, 00:22:07.360 |
Said that you think like the large vision models 00:22:21.720 |
Essentially like when you're doing pre-short learning, 00:22:24.760 |
you just say like, "I'm going to learn a network." 00:22:40.560 |
because this is just some intuitive guess that I have. 00:22:52.480 |
when we do something like prototypical networks 00:22:54.680 |
for the future learning with these pre-trained models, 00:22:57.560 |
we do get worse performance than when we do fine-tuning. 00:23:12.160 |
Okay, yeah, so, ah, right, and I didn't mention, 00:23:23.520 |
in computer vision, with this work, with the big transfer, 00:23:30.600 |
after there was a long period of a couple of years 00:23:46.080 |
Yeah, that's, okay, this is just a little aside, 00:23:51.360 |
that if you are in the setting that I mentioned 00:23:59.520 |
that you don't have images from the other tasks 00:24:05.360 |
Otherwise, you have seen them during training, 00:24:09.240 |
and you're just fooling yourself with good scores. 00:24:12.520 |
And this is a real danger when we get huge amounts of data, 00:24:15.240 |
because, like, ImageNet images can totally be 00:24:24.920 |
and also new duplicates, like when they are shifted, 00:24:27.760 |
rotated, squeezed, color changed a bit, whatnot. 00:24:32.000 |
And we use this to completely remove all images 00:24:34.920 |
from the test data sets that we test on later. 00:24:44.920 |
between the training set of ImageNet and CIFAR, 00:24:50.960 |
So new duplicates are quite widespread problem in vision. 00:24:54.600 |
And this slide is just to say, hey, there are problems, 00:24:58.720 |
we actually took care that in the pre-training, 00:25:01.120 |
as best as we can, we don't have new duplicates. 00:25:12.760 |
And that's how we got to transformers, basically. 00:25:16.440 |
In computer vision, everything was convolutional networks 00:25:20.480 |
And basically there was nothing else, CNN is king. 00:25:23.440 |
However, in language, we saw a transformation recently, 00:25:29.320 |
everywhere LSTM was king, and then came the transformer. 00:25:32.880 |
And in the case when there is a lot of data available, 00:25:35.880 |
suddenly transformer worked much better than LSTM. 00:25:39.400 |
For little data, that was still not the case exactly. 00:25:45.600 |
so we are now in this regime where we have tons of data 00:26:06.360 |
because I don't want to point fingers too much, 00:26:09.120 |
but they were all not really using transformers 00:26:14.920 |
It was always like, get something out of a ResNet first, 00:26:21.560 |
or high-level feature maps or things like that, 00:26:30.240 |
And so we came up with the simplest and most natural, 00:26:33.080 |
I believe, way of applying transformers to vision, 00:26:36.440 |
which is you take the image, you cut it into pieces, 00:26:48.320 |
and you project it into your embedding space, 00:27:18.960 |
You can, and people later did, go on and say, 00:27:29.480 |
This is just the simplest way to do it first. 00:27:40.080 |
and then give them to exactly the BERT transformer 00:27:45.960 |
And just like in language, we add this class token, 00:27:49.400 |
or I think the language is like end-of-sentence token 00:27:53.960 |
And we add the position embeddings to the tokens 00:27:59.960 |
And then we feed all of this to a transformer encoder, 00:28:02.760 |
which has a MLP head, which reads out this class token, 00:28:12.720 |
And that's it. That is the vision transformer. 00:28:25.280 |
And then just same story as before, scale everything up. 00:28:28.360 |
Compute, data set, model size, patients, everything. 00:28:41.880 |
The gray area is actually what were all of the bit dots before. 00:28:51.000 |
And the bubble is kind of the size of the model, 00:28:56.320 |
And what you can see first is that with little data, 00:29:05.440 |
and just try this, we're like, "Okay, this is a crap idea." 00:29:20.760 |
then we actually start outperforming this ResNet. 00:29:33.840 |
Then we did more controlled studies and everything. 00:29:35.880 |
And one of them is like using subset of the same data set. 00:29:53.320 |
which is a ResNet variant and bits, the vision transformer. 00:30:04.360 |
But as we start having a lot of data, actually, 00:30:12.840 |
and a lot and so on now, in five or 10 years, 00:30:17.000 |
Like 10 years ago, imagine if this one seemed to be huge 00:30:33.200 |
- Because we, yeah, yeah, we have some questions. 00:30:45.880 |
if you want to unmute yourself and ask the questions. 00:30:51.280 |
And I think Dimal already answered part of the question, 00:30:54.080 |
but I was wondering in the input to this transformer, 00:30:58.600 |
into little puzzle pieces and then finding them, 00:31:02.920 |
does the order of feeding these patches in matter? 00:31:13.480 |
And I actually have a slide on something like this, 00:31:21.120 |
if the order is consistent during training, right? 00:31:24.840 |
And you don't shuffle the order again for each new image, 00:31:41.000 |
This is the slide was on my plan to present anyways. 00:31:53.280 |
we had 14 by 14 patches that we cut the image in. 00:31:56.880 |
So it means we have also 14 by 14 position embeddings. 00:32:01.160 |
Although we just see them as one long sequence of, 00:32:04.480 |
150 something, or I don't know, 140 something. 00:32:09.720 |
And now each of these pictures shows the position embedding, 00:32:15.400 |
How similar is it to all the other position embeddings? 00:32:20.560 |
Yellow means perfectly similar, like exactly the same. 00:32:23.680 |
And blue means opposite in terms of cosine similarity. 00:32:27.320 |
So this position embedding is most similar to itself, 00:32:32.360 |
And then the neighboring pixels is how similar is it 00:32:35.520 |
to the position embeddings that correspond originally 00:32:44.760 |
to the embedding from its surrounding patches. 00:33:13.560 |
But it also means that if you take the trained model now 00:33:26.480 |
We did try also to implement, like, position embeddings 00:33:30.360 |
which encode the location as hardcoded by us, 00:33:35.440 |
and other fancy position embeddings like relative ones. 00:33:39.520 |
But basically, none of that really outperformed 00:33:48.200 |
And so we go with that, and so just like that. 00:34:05.400 |
and scaling up the model would be fun as well. 00:34:08.240 |
But it seems like you're reaching an awesome job, right, 00:34:13.400 |
So I'm curious if you have any thoughts on that. 00:34:18.160 |
or is there kind of a best you can sort of do 00:34:21.480 |
where when you're pre-training the data or the parameters, 00:34:32.040 |
where I would like to not jump on it, if you don't mind. 00:34:36.640 |
And then maybe in 10, 15 minutes, we will be there. 00:34:57.000 |
Are there any more questions before we proceed? 00:35:03.120 |
-So what I'm curious to know is how does this VIT 00:35:08.680 |
so, for example, ResNet, with an attention mechanism? 00:35:13.360 |
-Like, how much of this is due to the structure of a transformer 00:35:18.800 |
that a vanilla ConvNet does not have access to? 00:35:22.040 |
-Yeah, so this has been tried many times before, 00:35:27.800 |
was actually from -- I mispronounce his name, 00:35:34.120 |
and some of his colleagues, they called it non-blocker networks. 00:35:37.360 |
This was way -- I think even before the transformer paper, 00:35:59.400 |
you can imagine if you place the attention just on the pixels 00:36:03.360 |
this is way too expensive computation-wise, right? 00:36:07.080 |
If you have two to four by two to four pixels, 00:36:08.920 |
that's like -- yeah, I cannot do this in my head. 00:36:14.320 |
Attending to 40,000 others, that doesn't work, 00:36:27.280 |
but then you don't really get much benefit of scaling 00:36:43.920 |
and that is also kind of a form of attention, 00:36:53.520 |
But yeah, it has been tried many times before, 00:36:57.640 |
or it hasn't been shown to have this scaling benefit 00:37:03.280 |
-So I think I'm missing something critical here, 00:37:09.920 |
to do an attention layer at a low level in the ResNet, 00:37:12.920 |
but why is it any different than doing an attention layer 00:37:28.160 |
Like, you could imagine, not at a high level, 00:37:34.280 |
after you've applied, like, one or two convolutional filters -- 00:37:39.440 |
then you have something the size of the patches. 00:37:43.080 |
-That's still 50 by 50 at the early layers, and that's -- 00:37:52.320 |
-But it's still 2,500 tokens attending to 2,500 tokens, 00:38:12.880 |
where we do try something almost like what you said, 00:38:23.640 |
but, like, the full transformer encoder on top of it, 00:38:32.080 |
And this is this process, and we call them hybrid, 00:38:35.280 |
but it's almost literally what you said, actually, 00:38:42.000 |
and then stick the whole transformer encoder. 00:38:53.360 |
so for the little compute, it seems to work well. 00:38:55.920 |
But then the scaling behavior of the pure ResNet 00:39:00.680 |
I think we later tried also hybrid further to the right, 00:39:03.280 |
and it was a bit lower, but it was after the paper, 00:39:05.880 |
so it's not on this plot, which I just cut out of the paper. 00:39:15.520 |
then this is a totally reasonable thing to do, 00:39:32.960 |
basically, there's like a short section of paper 00:39:35.640 |
about, like, fine-tuning and, like, higher resolution, 00:39:38.120 |
and in that case, right, like, the pre-trained, 00:39:40.640 |
like, position embeddings, sorry, are, like, skewed, right? 00:39:45.480 |
And it basically says that you guys are, like, interpolating. 00:39:50.480 |
Like, how do you interpolate what's going on? 00:39:52.400 |
-Yeah. Actually, when I checked the slides earlier today, 00:39:55.840 |
I was like, "Oh, it would be cool to have a slide on that." 00:40:00.440 |
And we don't have a nice visualization in the paper, 00:40:02.520 |
either, because it's a bit difficult to explain, 00:40:08.320 |
So if you want to increase the resolution of the image, 00:40:13.520 |
it means you have more patches suddenly, right? 00:40:15.720 |
And then, as you say, the patch embeddings, like, 00:40:18.480 |
what do you even use as position embeddings, right? 00:40:25.480 |
that they learn a very regular structure, right? 00:40:41.680 |
kind of imaging these boxes, they slide apart, 00:40:50.040 |
And that's basically what we do with the position embeddings. 00:40:54.520 |
We create new ones where there are missing ones, 00:41:02.240 |
Or more precisely, we basically see them as a picture, 00:41:06.040 |
in this case, 14 by 14, with 700-something channels, 00:41:14.320 |
like you would resize a picture by interpolation. 00:41:19.000 |
And that way, we get more and new position embeddings 00:41:24.040 |
but they follow the same pattern as the learned ones, 00:41:39.080 |
So when you're creating the embeddings as input, 00:41:50.440 |
Has there been work to do to memorize the other way, 00:41:52.440 |
'cause there's a lot of pixels that are close to each other? 00:42:04.880 |
it's called "Early Convolutions Help Transformers See Better," 00:42:16.240 |
we replace it by a stack of three-by-three convolution 00:42:22.280 |
And then they have also nonlinearities between them, 00:42:32.160 |
So the outcome would then be the same dimensionality 00:42:36.040 |
as after this patch cutting and then projecting. 00:42:40.440 |
supposedly it makes it a bit easier to optimize 00:42:45.160 |
in the sense that more optimized settings are good settings. 00:43:08.880 |
I have played a bit with it and tried to reproduce it. 00:43:16.120 |
but I don't see as much benefit as in the paper yet. 00:43:19.040 |
But that's not to say that the paper is wrong, 00:43:42.640 |
Yeah, I have like three more interesting details 00:43:49.600 |
I have more content, like also the question about, 00:44:05.880 |
is like how should we scale these transformers? 00:44:19.080 |
So we started with the reasonable medium-sized transformer, 00:44:32.760 |
if we go to the right, this point increases the width, 00:44:39.040 |
X-axis is compute relative to this starting point. 00:44:46.240 |
There's the width, which is how wide are the vectors 00:45:08.760 |
or some people call it the one-by-one convolution 00:45:12.840 |
And this seems to scale a bit nicer, this orange part. 00:45:21.840 |
or if we just didn't scale it down, but anyways. 00:45:26.000 |
which does not exist in the transformers from text 00:45:37.520 |
This is the green one, which also seems to scale nicely. 00:45:42.120 |
Then the depth is an interesting one, this yellow one. 00:45:53.640 |
And it scales really badly if you decrease the depth. 00:45:59.040 |
However, the width seems to be a good thing to decrease 00:46:03.160 |
And then the blue is just scaling everything together 00:46:09.960 |
That seems to scale nicely as well as the rest 00:46:14.920 |
and is relatively simple, or at least conceptually. 00:46:23.440 |
And this one I really like is the inference speed, 00:46:26.560 |
because if you have the image size of two to four pixels, 00:46:29.640 |
it actually means you have two to four by two to four pixels. 00:46:32.520 |
So if you have, then you patchify it with 16 by 16 patch, 00:46:37.000 |
for example, patch size, then you have 14 by 14 patches. 00:46:42.520 |
So that is the sequence length is actually 150. 00:46:54.760 |
the self-attention operation is to the fourth power, 00:47:03.000 |
Like everybody who sees all of something to the fourth 00:47:09.760 |
So we checked what does it look like in practice 00:47:25.920 |
And this, what this means, it doesn't look so bad yet. 00:47:37.320 |
actually start going down a lot more than the ResNets. 00:47:52.120 |
But as we go larger, it will likely be a problem, 00:48:01.440 |
Then, this is the last one from the original paper. 00:48:07.560 |
This is looking at the input's receptive field size. 00:48:17.920 |
And here on the x-axis, we see the layer in the network. 00:48:21.600 |
To the right is more towards the output, the classes, 00:48:24.040 |
and to the left is more towards the input, the patches. 00:48:35.320 |
And does look means that the peak of the self-attention 00:48:45.800 |
because we can use multi-head self-attention. 00:48:48.560 |
And so what this shows is that in the early layers, 00:48:53.360 |
but also a lot of heads that look very nearby them, 00:48:59.600 |
we only are left with heads that, on average, look further. 00:49:07.000 |
There is not immediately action to take about this, 00:49:09.840 |
but it's interesting to see that earlier layers, 00:49:12.640 |
they learn a mixture of looking to a local neighborhood 00:49:24.280 |
So that is about the original vision transformers. 00:49:35.320 |
I have a couple of options that I can talk about, 00:49:38.760 |
which is one project that was further scaling updates, 00:49:52.720 |
There is another project about how to train vision transformers 00:50:08.360 |
I talk all about these benefits of a really large model 00:50:13.640 |
Okay, that's nice. That's how we get a good model. 00:50:17.000 |
But then actually using a model that is massive 00:50:22.960 |
You need, like, multiple TPUs to even use it. 00:50:29.080 |
and usually still go back to small-ish models, 00:50:32.000 |
even though they know, like, larger models should be better. 00:50:37.800 |
That's another project we had, which is about distillation. 00:50:41.520 |
So I would say it's up to you guys what you prefer to do. 00:50:49.200 |
because I think now the original one hour would be over, right? 00:50:56.880 |
and we'll also be recording it so people can, like, 00:50:59.360 |
just, like, go and see it if they miss out something. 00:51:04.240 |
-Yeah, the other thing is two people have their hands raised, 00:51:28.000 |
So if an object lies on the border between the patches, 00:51:32.280 |
does that impact the model's performance in any way? 00:51:45.560 |
So one is we didn't specifically go and test this. 00:51:48.960 |
It would be an interesting thing to test in a very controlled way 00:51:57.520 |
The other thing is that when you have a massive data set, 00:52:01.360 |
like 300 million images, it's an insane amount. 00:52:03.920 |
I used to try to conceptualize how much is image net, 00:52:07.960 |
1 million images, and I think I did the math. 00:52:10.920 |
It's like if you go to an image and look at all of the images, 00:52:17.200 |
you are sitting there for a month or something like that. 00:52:27.160 |
random augmentations, like random crop out of the image. 00:52:30.920 |
So I would say it's the default that you see objects 00:52:34.040 |
that don't fall on a patch during the training already. 00:52:40.600 |
this is the standard model, like how the patches are. 00:52:44.360 |
When we have 14 by 14, they look roughly this size also. 00:52:50.240 |
Then an object is usually scattered across many patches, 00:52:59.280 |
People don't take a picture where the object of interest 00:53:04.400 |
So that's the default that you see during pre-training. 00:53:13.640 |
Then the other answer to the question is like, OK, 00:53:16.480 |
maybe if you did some nicer thing than this very crude 00:53:20.720 |
patch cutting, like for example, this stack of convolutions 00:53:31.480 |
So you mentioned that we're using transformers, 00:53:49.800 |
I was just thinking, are these sort of properties 00:53:54.400 |
that you probably [INAUDIBLE] and especially when 00:54:02.720 |
So why is it that we would prefer [INAUDIBLE] 00:54:14.920 |
Is that we say that transformers lack locality bias, or prior, 00:54:21.280 |
And why is this even something that we want, right? 00:54:25.120 |
Wouldn't we want our models to know about locality 00:54:27.560 |
if they are about pictures in the first place? 00:54:32.680 |
So that's why I gave the context in the beginning. 00:54:35.520 |
This is all about what happens when you scale things up. 00:54:39.760 |
And specifically, in the ideal world, at least in our mind, 00:54:54.000 |
And there will be more and more data just generally there. 00:55:04.720 |
Because what we may think that is good to solve the task 00:55:14.880 |
AlphaGo that made some moves that experts would say, 00:55:22.240 |
And in a similar way, we want to encode as little as possible 00:55:28.400 |
throw massive amounts of data in the difficult task at it, 00:55:31.520 |
that it might think things that are even better that we 00:55:37.960 |
Because we believe that, as I mentioned, I think, already, 00:55:47.560 |
So that's where we want to go and look what's the direction. 00:55:51.520 |
However, if you want to just get something working now 00:55:59.880 |
for some reason, which always use a pre-trained model, 00:56:08.800 |
of your prior intuition and knowledge of what should 00:56:22.280 |
What sort of [INAUDIBLE] like any vision task? 00:56:53.840 |
is powerful enough to learn about this concept itself 00:57:01.320 |
If it's not useful to solve the task, then if we had put it in, 00:57:05.840 |
there is no way for the model not to do this, right? 00:57:19.480 |
the from left to right direction of text, like in RMS. 00:57:26.440 |
And works much better if you throw a lot of data at it. 00:57:29.600 |
And it recovers that plus some more or a more flexible variant 00:57:42.840 |
as smart to design the thing, the model in the way that 00:57:49.160 |
Let's rather give it all the flexibility and all the data 00:57:55.200 |
I mean, it is a philosophy of approaching it. 00:58:05.640 |
I'm not saying this is the only true way, right? 00:58:18.720 |
And Lucas, we want to be mindful of your time 00:58:21.240 |
as well, because it is evening where you are. 00:58:30.240 |
So you could quickly go over the last few bits, 00:58:41.080 |
Those two that are still very, very tight to transformers 00:58:44.360 |
and answer some questions that happened before. 00:58:46.880 |
Like the first question was like, OK, are we saturating? 00:59:04.280 |
when we use them, we just notice they have really nice scaling 00:59:08.920 |
to scale up without paying massive compute as much 00:59:12.920 |
as ResNet, just from gut feeling from us having experience 00:59:20.120 |
if we scale vision transformer just as far up 00:59:25.640 |
And we spent quite a lot of our blood into making this happen. 00:59:36.680 |
that this 300 million data set is just one out of many 00:59:43.480 |
had the 3 billion, like 10 times larger data set 00:59:52.360 |
And this is just showing, yes, just scaling up the data set 01:00:00.120 |
Then the next thing is we needed to figure out 01:00:03.080 |
how to use less memory on device, like on GPU or TPU, 01:00:10.280 |
we fitted the model as large as we could fit. 01:00:13.200 |
So we did a lot of clicks that I will skip for now 01:00:24.520 |
that I mentioned before, like the width of the MIP on x-axis, 01:00:29.720 |
and then the different plots are different layers for the depth. 01:00:37.800 |
And then boom, one step further and two steps further, 01:00:51.640 |
Then, yeah, some learning rate stuff, and it is really cool. 01:00:54.360 |
I recommend people to look at square root learning rate 01:00:56.840 |
schedule, which is cool, and often just mentioned 01:01:17.840 |
This is actually plus 2% on what we had before, 01:01:20.200 |
which is very significant in this high percentage range 01:01:35.800 |
And for example, it's just 10 images per image net class, 01:01:40.560 |
which means 10,000 images total because 1,000 classes. 01:01:46.280 |
We get 85% of 1 accuracy, which is what you typically 01:01:58.600 |
It makes actually view shot work significantly better. 01:02:04.240 |
Well, this actually has an interesting message. 01:02:22.080 |
and the base vision transformer in the large one. 01:02:33.880 |
But still, you need to see a lot fewer images 01:03:01.400 |
sorry, I had the order of the slides mixed up in my head. 01:03:05.120 |
But then another threat was that besides further scaling up 01:03:11.040 |
into this direction of less hand engineering of things 01:03:21.360 |
transform in general, what is the obviously most hand 01:03:29.440 |
more generic than that and less smart than that, basically? 01:03:34.240 |
And we ended up by replacing it, essentially, 01:03:36.480 |
with just a multi-layer perceptron that, however, 01:03:46.120 |
So they would skip the structure or the safety of time. 01:03:49.680 |
And we're coming back to this plot, where the question was, 01:03:56.040 |
We, again, have this bit resonate here in black. 01:03:59.040 |
And the full green line is the vision transformer. 01:04:05.800 |
So it is exactly the same numbers as from before. 01:04:09.000 |
However, now we also throw in this mixer architecture, 01:04:11.800 |
which we believe is even more flexible and less 01:04:16.240 |
And as you see, with less data, it's even worse. 01:04:22.520 |
be surpassing the transformer, or it may be random noise. 01:04:29.000 |
Because it's the only point where this happens. 01:04:35.960 |
for example, from the previous paper that I mentioned here, 01:04:40.160 |
and try to extend these lines to the right to see what happens. 01:04:54.720 |
And that, first of all, yes, the vision transformer 01:04:59.880 |
We don't have such experiment with the ResNet, 01:05:07.520 |
But it also seems that the mixer, what we believe 01:05:11.560 |
actually is consistently above the transformer now, 01:05:19.200 |
So we're now right at the time when I should stop, right? 01:05:37.640 |
to model sizes for Earth or the natural language. 01:05:43.840 |
Like, especially when we're going from smaller models 01:05:46.080 |
to much bigger models, are they comparable at all 01:05:53.880 |
what is the [INAUDIBLE] models for these two different tasks? 01:05:57.280 |
Yeah, actually, a colleague of mine has a slide, which I hate. 01:06:01.160 |
But he loves-- it's the model number of parameters 01:06:07.240 |
And the question is, how do you measure model size? 01:06:15.160 |
However, the language models, number of parameters, 01:06:19.200 |
like a huge chunk of it is in the dictionary, 01:06:21.480 |
for example, which for us just doesn't exist. 01:06:23.880 |
It is linear embedding, which is trivial number of parameters. 01:06:29.240 |
So in terms of number of parameters, it's much smaller. 01:06:39.120 |
this maybe in terms of compute, like how much floating point 01:06:46.400 |
And in terms of this, it's in the same ballpark. 01:06:50.120 |
However, last time I checked, which is quite a few months 01:06:55.640 |
like four times more or five times more in the vision model, 01:07:02.200 |
So that's the two ways of measuring model size. 01:07:09.560 |
And I think it's actually an interesting research topic, 01:07:11.920 |
like how to properly measure and order models 01:07:26.520 |
I think it's just there is less interest in it, 01:07:34.360 |
Like in Google, there are many, many more groups 01:07:37.160 |
doing research with language than with vision. 01:07:40.720 |
And I think we are one of the few groups that 01:07:45.640 |
and are interested in scaling up things in vision so much. 01:07:48.920 |
Whereas in language, it seems there are a lot of groups 01:07:56.040 |
It's not that we don't want to go beyond that, 01:08:07.760 |
Right, so we are actually over time at this point. 01:08:10.840 |
So anyone who has to leave, please feel free to do so. 01:08:13.720 |
And before we do that, Lucas, thank you so much for joining, 01:08:21.600 |
And we know it's in the evening, so thank you 01:08:24.080 |
for taking your free time to come and talk to us here.