back to indexVision Transformers (ViT) Explained + Fine-tuning in Python
Chapters
0:0 Intro
0:58 In this video
1:12 What are transformers and attention?
1:39 Attention explained simply
4:15 Attention used in CNNs
5:24 Transformers and attention
7:1 What vision transformer (ViT) does differently
7:28 Images to patch embeddings
8:22 1. Building image patches
10:23 2. Linear projection
10:57 3. Learnable class embedding
13:30 4. Adding positional embeddings
16:37 ViT implementation in python with Hugging Face
16:45 Packages, dataset, and Colab GPU
18:42 Initialize Hugging Face ViT Feature Extractor
22:48 Hugging Face Trainer setup
25:14 Training and CUDA device error
26:27 Evaluation and classification predictions with ViT
28:54 Final thoughts
00:00:00.000 |
Vision and language are the two big domains in machine learning. Two distinct disciplines 00:00:06.320 |
with their own problems, best practices and model architectures or at least that used to be the case. 00:00:14.240 |
The vision transformer or VIT marks the first step towards a merger of both fields into a single 00:00:24.240 |
unified discipline. For the first time in the history of machine learning we have a single 00:00:31.920 |
model architecture that is on track to become the dominant model in both language and vision. 00:00:39.760 |
Before the vision transformer, transformers were known as those language models and nothing more. 00:00:47.040 |
But since the introduction of the vision transformer there has been further work 00:00:52.320 |
that has almost solidified their position as state-of-the-art in vision. In this video we're 00:00:59.840 |
going to dive into the vision transformer. We're going to explain what it is, how it works, why it 00:01:07.360 |
works and we're going to look at how we can actually use it and implement it with Python. 00:01:12.400 |
So let's get started with a very quick one-on-one introduction to transformers and the attention 00:01:20.160 |
mechanism. So transformers were introduced in 2017 in the pretty well-known paper called 00:01:27.840 |
"Attention is all you need". Transformers quite literally changed the entire landscape of NLP 00:01:34.640 |
and this was very much thanks to something called the attention mechanism. Now in NLP attention 00:01:43.360 |
allows us to embed contextual meaning into the word level or sub-word level token embeddings 00:01:53.280 |
within a model. So what I mean by that is say you have a sentence and you have two words in 00:01:59.840 |
that sentence that are very much related. Attention allows you to identify that relationship and then 00:02:07.840 |
allow the model to understand those words with respect to one another within that greater 00:02:14.960 |
sentence. Now this starts within the transformer model with tokens just being embedded into a 00:02:22.640 |
vector space based purely on what that token is. So the token for bank will be mapped to a particular 00:02:30.480 |
vector that represents the word bank without any consideration of the words surrounding it. Now 00:02:38.000 |
with these token embeddings what we can do is calculate dot products between their embeddings 00:02:44.480 |
and we will return a high score when they are aligned and a low score when they are not aligned. 00:02:51.360 |
And as we do this within the attention mechanism we can essentially identify which words should 00:02:58.080 |
be placed closer together within that vector space. So for example if we had three sentences 00:03:04.880 |
a plane banks the grassy bank and the bank of England the initial embedding of that token bank 00:03:12.960 |
for all of those sentences is equal. But then through this encoder attention mechanism we 00:03:21.040 |
essentially map the token embedding bank closer to the vector space of the other relevant words 00:03:28.720 |
within each one of those sentences. So in the case of a plane banks what we would have is the word 00:03:36.480 |
bank or banks being moved closer towards words like aeroplane, plane, airport, flight and so on. 00:03:44.800 |
For the phrase a grassy bank we will find that the token embedding for bank gets moved towards 00:03:51.200 |
the embedding space for grass, nature, fields. And for the bank of England we'll find that the 00:03:58.080 |
word bank gets moved towards finance, money and so on. So as we go through these many encoder 00:04:05.680 |
blocks which contain the attention mechanism we are essentially embedding more contextual meaning 00:04:12.000 |
into those initial token embeddings. Now attention did find itself being used occasionally 00:04:19.840 |
in convolutional neural networks which were the past state of the art in computer vision. 00:04:25.600 |
Generally speaking this has produced some benefit but it is somewhat limited. Attention is 00:04:32.560 |
a heavy operation when it comes to having a large number of items to compare because essentially 00:04:39.920 |
with attention you're comparing every item against every other item within your input sequence. So if 00:04:45.920 |
your input sequence is a even relatively large image and you're comparing pixels to pixels with 00:04:53.680 |
your attention mechanism the number of comparisons that you need to do becomes incredibly large very 00:05:01.040 |
very quickly. So in the case of convolutional neural networks attention can only really be 00:05:06.880 |
applied towards the later layers of the models where you basically have less activations being 00:05:14.640 |
compared after a few convolutions. Now that's better than nothing but it does limit the use 00:05:20.880 |
of attention because you can't use it throughout the entire network. Now transform models in NLP 00:05:26.960 |
have not had that limitation and can instead apply attention over many layers literally from the very 00:05:34.240 |
starting point of the model. Now the setup used by BERT, which is a again a very well-known 00:05:40.960 |
transformer model, involves several encoder layers. Now within these encoder layers or encoder blocks 00:05:47.280 |
we have a few different things going on. There is a normalization component, a multi-head attention 00:05:54.400 |
component, which is essentially many attention operations happening in parallel, and a multi-layer 00:06:02.000 |
perceptron layer. Through each of these blocks we're just encoding more and more information 00:06:07.120 |
into our token embeddings and at the end of this process we get these super rich vector embeddings 00:06:14.800 |
and these embeddings are the ultimate output of the core of a transform model including the vision 00:06:22.000 |
transformer. And from there what we tend to find with transform models is that we add another few 00:06:29.120 |
layers onto the end which act as the head of the transformer which essentially encode or take these 00:06:36.400 |
vector embeddings, information rich embeddings, and translate them into predictions for a particular 00:06:43.680 |
task. So you might have a classification head or a NER head or a question answering head and they 00:06:51.120 |
will all be slightly different in some way but at the core what they are doing is translating those 00:06:57.200 |
super rich information embeddings into some sort of meaningful prediction. Now the vision 00:07:02.960 |
transformer actually works in the exact same way, the only difference is how we pre-process things 00:07:09.600 |
before they are fed into the vision transformer. So rather than with BERT and other language 00:07:15.920 |
transformers that consume word or sub-word tokens, the vision transformer consumes image patches. 00:07:24.080 |
Then the remainder transformer works in the exact same way. So let's take a look at how we go from 00:07:30.720 |
images to image patches and then after that into patch embeddings. The high level process for doing 00:07:39.120 |
this is relatively simple. First we split an image into image patches. Two, we process those patches 00:07:46.720 |
through a linear projection layer to get our initial patch embeddings. Then we pre-append something 00:07:53.040 |
called a class embedding to those patch embeddings and finally we sum the patch embeddings and 00:07:59.440 |
something called positional embedding. Now there's a lot of parallels with this process and what we 00:08:04.560 |
see in the language and will relate to those where relevant. So after all these steps we have our 00:08:12.240 |
patch embeddings and we just process them in the exact same way that we would token embeddings with 00:08:17.280 |
a language transformer. But let's dive into each one of these steps in a little more detail. Our 00:08:22.400 |
first step is the transformation of our image into image patches. In NLP we actually do the same 00:08:30.960 |
thing. We take a sentence and we translate it into a list of tokens. So in this respect images are 00:08:39.040 |
sentences and image patches are word or sub-word tokens. Now if we didn't create these image 00:08:46.720 |
patches we could alternatively feed in the full set of pixels from a image. But as I mentioned 00:08:53.760 |
before that basically makes it so that we can't use attention because the calculation or the number 00:09:00.240 |
of computations that we need to do to compare all images would be very restrictive on the size of 00:09:06.880 |
images that we could input. We could only essentially input very very small images. So 00:09:11.360 |
if we consider that attention requires the comparison of everything to everything else 00:09:18.560 |
and we're using pixels here. If we have a 224 by 224 pixel image that means we would have to perform 00:09:27.520 |
224 to the power of 4 comparisons. Which is 2.5 billion comparisons. Which is pretty insane and 00:09:38.240 |
that's for a single attention layer. In transformers we have multiple attention layers. 00:09:43.520 |
So it's already just far too much. If instead we split our 224 by 224 pixel image into image patches 00:09:54.480 |
where we have 14 by 14 pixel patches that would leave us with 256 of these patches. And with that 00:10:04.960 |
a single attention layer requires a much more manageable 9.8 million comparisons. Which is 00:10:13.680 |
a lot easier to do. With that we can have a huge number of attention layers and still not even get 00:10:19.920 |
close to the single attention layer with our full image. Now after building these image patches we 00:10:26.560 |
move on to the linear projection step. For this we use a linear projection layer which is simply 00:10:32.000 |
going to map our image patch arrays into image patch vectors. By mapping these patches to the 00:10:42.720 |
patch embeddings we are reformatting them into the correct dimensionality to be input into 00:10:51.680 |
our vision transformer. But we're not putting these into the vision transformer just yet because 00:10:56.560 |
there's two more steps. Our third step is the learnable embedding or the class token. Now this 00:11:04.640 |
is an idea that comes from BERT. So BERT introduced the use of something called a CLS or classifier 00:11:13.600 |
token. Now the CLS token was a special token pre-appended to every sentence that was input 00:11:21.040 |
into BERT. This CLS token was as with every other token converted into an embedding and passed 00:11:28.160 |
through several encoder layers. Now there are two things that make CLS special. First it does not 00:11:35.040 |
represent a real word so it almost acts as like a blank slate being input into the model. And second 00:11:44.240 |
the CLS token embedding after the many encoder blocks is that embedding that is input into the 00:11:54.880 |
classification head which is used as a part of the pre-training process. So essentially what 00:12:02.080 |
we end up doing there is we end up embedding like a general representation of the full sentence into 00:12:09.200 |
this single token embedding. Because in order for the model to make a good prediction about what 00:12:16.000 |
this sentence is it needs to have a general embedding of the whole sentence in that single 00:12:22.400 |
token. Because it's only that single token embedding that is passed into the classification 00:12:30.000 |
head. Now the vision transformer applies the same logic and it adds something called a learnable 00:12:36.320 |
embedding or a class embedding to the embeddings as they are processed by the first layers of the 00:12:44.080 |
model. And this learnable embedding is practically the same thing as the CLS token in BERT. Now it's 00:12:50.800 |
also worth noting that it is potentially even more important for the vision transformer than it is 00:12:57.520 |
for BERT. Because for BERT the main mode of pre-training is something called mass language 00:13:03.200 |
modeling which doesn't rely on the classification token. Whereas with the vision transformer the 00:13:10.880 |
ideal mode of pre-training is actually a classification task. So in that sense we can 00:13:18.640 |
think of this CLS token or CLS embedding as actually being very critical for the overall 00:13:26.800 |
performance and overall training of the vision transformer. Now the final set that we need to 00:13:32.960 |
apply to our patch embeddings before they are actually fed into the model is we need to add 00:13:40.800 |
something called the positional embeddings. Now positional embeddings are a common thing to be 00:13:47.680 |
added to transformers. And that's because transformers by default don't actually have 00:13:54.000 |
any mechanism for tracking the position of inputs. So there's no order that is being considered. And 00:14:03.680 |
that is difficult because when it comes to language and also vision, but let's think in 00:14:11.200 |
the sense of language for now, the order of words in a sentence is incredibly important. If you mix 00:14:17.280 |
up the order of words as a person it's hard to understand what this sentence is supposed to mean. 00:14:23.760 |
And it can even mean something completely different. So obviously the order of words is 00:14:28.320 |
super important and that applies as well to images. If we start mixing the image patches 00:14:35.120 |
there's a good chance that we won't be able to understand what that image represents anymore. 00:14:40.080 |
And in fact this is what we get with jigsaw puzzles. We get a ton of little image patches 00:14:47.040 |
and we need to put them together in a certain order. And it takes people a long time to figure 00:14:50.640 |
out what that order actually is. So the order of our image patches is obviously quite important, 00:14:58.240 |
but by default transformers don't have a way of handling this. So that's where the positional 00:15:04.000 |
embeddings come in. For the vision transformer, these positional embeddings are learned embeddings 00:15:10.080 |
that are summed with the incoming patch embeddings. Now, as I mentioned, these positional 00:15:17.600 |
embeddings are learned. So during pre-training these are adjusted and what we can actually see 00:15:23.680 |
if we visualize this similarity or the cosine similarity between embeddings is that positional 00:15:30.480 |
embeddings that are close to one another actually have a higher similarity. And in particular 00:15:36.320 |
positional embeddings that exist within the same row and the same column as one another also have 00:15:42.720 |
a higher similarity. So it seems like there's this logical thing going on here with these 00:15:50.480 |
positional embeddings, whereas identifying patches that are within a similar area is pushing them 00:15:56.320 |
into a similar vector space and patches that are in a dissimilar area is pushing them away from each 00:16:02.880 |
other within that vector space. So there's a sense of locality being introduced within these 00:16:08.400 |
positional embeddings. Now, after adding our positional embeddings and patch embeddings together, 00:16:14.800 |
we have our final patch embeddings, which are then fed into our vision transformer and they're 00:16:20.880 |
processed through that sort of encoder attention mechanism that we described before, which is just 00:16:26.160 |
a typical transformer approach. Now, that is the logic behind vision transformer and the 00:16:34.320 |
new innovations that it has brought. Now I want to describe or actually go through an example 00:16:40.720 |
of an implementation of the vision transformer and how we can actually use it. Okay, so we start 00:16:47.200 |
by just installing any prerequisites that we have. So here we've got pip install datasets and 00:16:53.760 |
transformers and also PyTorch. So we run this and then what we want to do is download a dataset that 00:17:02.400 |
we can actually test all of this on and also fine tune with. So we're going to be using the CFAR-10 00:17:08.880 |
dataset. We're going to be getting that from HungFix datasets. So from datasets, import load 00:17:14.560 |
dataset. Let this run and we just run this. One thing just to check here before we go through 00:17:20.560 |
everything is to make sure that we're using GPU. Save and we will have to rerun everything. 00:17:28.320 |
Okay, so after that's downloaded, we'll see that we have 50,000 images with classification labels 00:17:33.440 |
within our training data. And we also download the test split as well. That has 10,000 of these. 00:17:41.120 |
And then what we want to do is we want to just have a look at the labels quickly. So let's see 00:17:46.720 |
what we have in there. So we have 10 labels. That's why it's called CFAR-10. And we want to 00:17:53.200 |
have 10 labels. That's why it's called CFAR-10. And of those, we have these particular classes 00:18:00.400 |
within the dataset. Airplane, automobile, so on and so on. So from there, we can have a look at 00:18:06.240 |
what is within a single item within that dataset. So we have this pill. So Python pill object is 00:18:13.120 |
essentially a image. And then also the label. Now that label corresponds to airplane here in this 00:18:20.480 |
case, because it's number zero. And we can just check that. So run this. This is a Z. We can't 00:18:28.240 |
really see it very well. It's very small, but that is an airplane. And we can actually map the label. 00:18:35.360 |
So zero to labels.names in order to get the actual human readable class label. Okay, cool. So 00:18:43.200 |
what we're going to do is we're going to load the Vision Transformer feature extractor. So 00:18:48.720 |
we're going to be using this model here from the FaceHub. And we can actually see that over here. 00:18:55.120 |
So we have Google VIT Base Patch 16.225 in or IN 21K. Now what that means is we have patches that 00:19:06.560 |
are 16 by 16 pixels. They are being pulled or being built during pre-training at least by a 00:19:14.720 |
224 by 224 pixel image. And this IN 21K is just to say that this has been trained on or pre-trained 00:19:25.040 |
on the ImageNet 21K dataset. So that is the model we'll be using. And we use this feature extractor, 00:19:32.720 |
which is almost like a pre-processor for this particular model. So we can run that and this 00:19:38.560 |
will just download that feature extractor for us. That's pretty quick. And we can see the 00:19:44.480 |
configuration within that feature extractor here. So what is this feature extractor doing exactly? 00:19:50.080 |
It is taking an image. Our image can be any size and in a lot of different formats. And what it's 00:19:58.080 |
going to do is just normalize and resize that image into something that we can then process 00:20:03.200 |
with our vision transformer. So we can see here that it will normalize the pixel values within 00:20:09.360 |
the image and it will resize the image as well. It will resize the image to this here, 224 by 224 00:20:17.680 |
pixels. In terms of normalization, to normalize I'm using these values here for each of the color 00:20:23.040 |
channels. So we have red, green, and blue. And yeah, that's pretty much, that's what it's going 00:20:31.520 |
to be doing. So if we take a look at the first image, we can use the feature extractor here 00:20:36.800 |
on our first image, which is that plane. And we're going to just return tensors 00:20:42.240 |
in using PyTorch because we'll be using PyTorch later on. So we run this and what we return is 00:20:50.800 |
a dictionary containing a single tensor or a single key value pair, which is pixel values, 00:20:55.760 |
which maps to this single tensor here. And we can go down and we can have a look at the shape of 00:21:02.160 |
that. And we see that we have this 224 by 224 pixel image or pixel values tensor. Now that is 00:21:12.560 |
different to the original image because the original image was train zero image or IMG. 00:21:22.560 |
What's the shape of this? I think we can maybe do this. Maybe size. 00:21:35.120 |
Okay. 32 by 32. So it's been resized up to 224 by 224, which is the format that the 00:21:43.200 |
vision transformer needs. Now, when we are doing this, what we're going to want to do later on 00:21:49.280 |
is we're going to be training everything on GPU, not CPU. Now, by default, this tensor here is on 00:21:56.480 |
CPU. We don't want that. We need to be using a GPU where possible. So we say, okay, if a CUDA 00:22:03.600 |
enabled GPU is available, please use GPU. Okay. So we can see here, there is one available. So 00:22:10.000 |
we're on Colab. So that's great. It means everything will be much faster. And the reason 00:22:16.720 |
why we need that is because here, we're going to need to move everything to that device. 00:22:23.040 |
So what we'll do is here, as we use feature extractor here, we're going to say to device. 00:22:32.160 |
That will just move everything to GPU for us. Okay. And then we use this with transform to 00:22:38.240 |
apply that to both the training and the testing data set. Or in reality, we're going to be using 00:22:43.280 |
test data set more as a validation data set. Now, after all that, we're ready to move on to 00:22:49.680 |
the model fine tuning step. So with this, there are a few things we're going to need to define. 00:22:55.760 |
So training and testing data set, we've already done that. It's not a problem. Feature extractor, 00:23:00.000 |
we have already done that as well. Not a problem. The model, we will define that. It's pretty easy. 00:23:05.120 |
Something called a collate function, evaluation metric, and some other training arguments. So 00:23:11.520 |
let's start with the collate function. So here, this is essentially, when we're training with 00:23:18.000 |
the Hug & Face trainer, we need a way to collate all of our data into batches in a way that makes 00:23:25.600 |
sense, which requires this dictionary format. So each record is represented by dictionary, 00:23:33.120 |
and each record contains inputs, which is the pixel values, and also the labels. So we run this. 00:23:39.520 |
We then need to define our evaluation metric, which I'm using accuracy, 00:23:45.680 |
which is, you can read that if you want, but it's pretty straightforward. So we define that. 00:23:55.360 |
And then we have all these training arguments. So these are essentially just the training 00:24:00.800 |
parameters that we're going to use to actually train our model. So we have the batch size that 00:24:06.960 |
we want to use, where we're going to output the model, the number of training epochs that we want 00:24:12.880 |
to use, how often do we want to evaluate the model. So run it on the validation/test data set 00:24:20.160 |
that we have, what learning rate do you want to use, and so on and so on. Rerun that. That just 00:24:25.840 |
sets up the configuration for our training. And then we move on to initializing our model. Again, 00:24:34.240 |
this is just using the same thing that we had before. So when we had that feature extractor, 00:24:39.680 |
we initialized it from pre-trained, and then we had the model name or path, model ID. 00:24:47.040 |
So that is just the VIT patch 16224 that you saw before. One thing that we do need to add here is, 00:24:55.360 |
because we're doing this VIT image classification, we need to specify the number of labels or 00:25:01.680 |
classes that will be output from that classification head, which in this case is 10 of those labels. 00:25:08.400 |
So we define that as well. We move the model to our GPU, and with that, we are ready to initialize 00:25:16.400 |
our trainer with all of those things that we just defined. So we run that, and then to actually 00:25:23.680 |
train the model, we do this. So trainer.train. After that, we can save the model, we can log 00:25:30.800 |
our metrics, save our metrics, and then just save the current state of the trainer at that point. 00:25:37.360 |
So I'm going to run that very briefly and then stop. Okay, so it seems we're getting this error, 00:25:43.280 |
which I think might be because we're trying to move the input tensors to GPU twice. So I think 00:25:50.960 |
the trainer is doing it by default, but earlier on, we added the two device, so we need to remove 00:25:56.640 |
that and run it again. So up here within preprocess, we just remove this, run it again, 00:26:05.200 |
and then just rerun everything. Then pass everything to the trainer, and then try and 00:26:11.600 |
train again. Okay, it looks like we're having a little more luck with it this time. So we can see 00:26:17.200 |
that the model is training. Actually, it doesn't take too long, but what I'm going to do is just 00:26:23.440 |
skip forward. So I'm going to stop this, and what we can do is you can run this to get your 00:26:31.600 |
evaluation metrics and view your evaluation metrics. Your model will be evaluating as it goes 00:26:37.520 |
through your training set, thanks to the trainer. But if you would like to check again, you can just 00:26:45.040 |
use this. But for now, let's just have a look at a specific example. So what we're going to do is 00:26:50.800 |
load this image. I mean, I can't really tell what that image is. I think, so if we come down here, 00:26:58.400 |
it should be a cat, yeah? So run this, we can see that it's actually supposed to be a cat. 00:27:03.120 |
It's very blurry, I can't personally tell. But what we're going to do is load a fine-tuned model. 00:27:09.120 |
So this is the model that has been fine-tuned using this same process. So we can download that 00:27:16.000 |
from Hugging Face Hub. We can also download the feature extractor, which we don't need to do that 00:27:23.840 |
because it is actually using the same feature extractor, but in a real use case scenario, 00:27:31.280 |
you might actually just download everything from a particular model that is hosted within 00:27:36.480 |
the Hugging Face Hub. So this is what you would do, because it's not really fine-tuned. So run that. 00:27:44.640 |
That will just download the fine-tuned model. 00:27:49.120 |
And you can see here, we have the exact same feature extractor configuration there. We process 00:27:59.360 |
our image through the feature extractor, return PySource sensors, and then we say, 00:28:04.160 |
with Torch Node Gradle, which is essentially to make sure that we're not updating the gradients 00:28:09.760 |
of the model like we would during fine-tuning, because we're actually just making a prediction 00:28:14.560 |
here. We don't want to train anything. We use the model, process the inputs, and we extract the 00:28:21.840 |
logits, which is just the output activations. And what we want to do is take the argmax, so 00:28:28.480 |
where the logits is the maximum value is basically highest probability that it is that class being 00:28:35.280 |
predicted. So we extract that, we get the labels, and then we output labels. And if we run that, 00:28:41.680 |
we will see that we get cat. Okay, so it looks like we have fine-tuned a position transformer 00:28:49.040 |
using that same process, and the performance is pretty accurate. Now, before 2021, which really 00:28:57.760 |
not that long ago, transformers were known as just being those language models that they were not 00:29:02.960 |
used in anything else. But now, as we can see, we're actually able to use transformers and get 00:29:09.280 |
really good results within the field of computer vision. And we're actually seeing this use in a 00:29:15.920 |
lot of places. Vision transformer is a key component of the OpenAI's CLIP model, and OpenAI's 00:29:22.560 |
CLIP is a key component of all of the diffusion models that we've seen pop up everywhere, and 00:29:28.400 |
the world is going crazy over them right now. Transformers are also a key component in Tesla 00:29:35.360 |
for self-driving. They are finding use in a huge number of places that would have just been 00:29:43.680 |
incredibly unexpected a year or even two, three years ago. And I think as time progresses, we will 00:29:50.320 |
undoubtedly see more use of transformers within computer vision, and of course, the continued use 00:29:56.240 |
of transformers within the field of language. And they will undoubtedly become more and more unified 00:30:03.920 |
over time. For now, that's it for this video. I hope all of this has been useful and interesting. 00:30:11.280 |
So, thank you very much for watching, and I'll see you again in the next one. Bye.