Back to Index

Vision Transformers (ViT) Explained + Fine-tuning in Python


Chapters

0:0 Intro
0:58 In this video
1:12 What are transformers and attention?
1:39 Attention explained simply
4:15 Attention used in CNNs
5:24 Transformers and attention
7:1 What vision transformer (ViT) does differently
7:28 Images to patch embeddings
8:22 1. Building image patches
10:23 2. Linear projection
10:57 3. Learnable class embedding
13:30 4. Adding positional embeddings
16:37 ViT implementation in python with Hugging Face
16:45 Packages, dataset, and Colab GPU
18:42 Initialize Hugging Face ViT Feature Extractor
22:48 Hugging Face Trainer setup
25:14 Training and CUDA device error
26:27 Evaluation and classification predictions with ViT
28:54 Final thoughts

Transcript

Vision and language are the two big domains in machine learning. Two distinct disciplines with their own problems, best practices and model architectures or at least that used to be the case. The vision transformer or VIT marks the first step towards a merger of both fields into a single unified discipline.

For the first time in the history of machine learning we have a single model architecture that is on track to become the dominant model in both language and vision. Before the vision transformer, transformers were known as those language models and nothing more. But since the introduction of the vision transformer there has been further work that has almost solidified their position as state-of-the-art in vision.

In this video we're going to dive into the vision transformer. We're going to explain what it is, how it works, why it works and we're going to look at how we can actually use it and implement it with Python. So let's get started with a very quick one-on-one introduction to transformers and the attention mechanism.

So transformers were introduced in 2017 in the pretty well-known paper called "Attention is all you need". Transformers quite literally changed the entire landscape of NLP and this was very much thanks to something called the attention mechanism. Now in NLP attention allows us to embed contextual meaning into the word level or sub-word level token embeddings within a model.

So what I mean by that is say you have a sentence and you have two words in that sentence that are very much related. Attention allows you to identify that relationship and then allow the model to understand those words with respect to one another within that greater sentence. Now this starts within the transformer model with tokens just being embedded into a vector space based purely on what that token is.

So the token for bank will be mapped to a particular vector that represents the word bank without any consideration of the words surrounding it. Now with these token embeddings what we can do is calculate dot products between their embeddings and we will return a high score when they are aligned and a low score when they are not aligned.

And as we do this within the attention mechanism we can essentially identify which words should be placed closer together within that vector space. So for example if we had three sentences a plane banks the grassy bank and the bank of England the initial embedding of that token bank for all of those sentences is equal.

But then through this encoder attention mechanism we essentially map the token embedding bank closer to the vector space of the other relevant words within each one of those sentences. So in the case of a plane banks what we would have is the word bank or banks being moved closer towards words like aeroplane, plane, airport, flight and so on.

For the phrase a grassy bank we will find that the token embedding for bank gets moved towards the embedding space for grass, nature, fields. And for the bank of England we'll find that the word bank gets moved towards finance, money and so on. So as we go through these many encoder blocks which contain the attention mechanism we are essentially embedding more contextual meaning into those initial token embeddings.

Now attention did find itself being used occasionally in convolutional neural networks which were the past state of the art in computer vision. Generally speaking this has produced some benefit but it is somewhat limited. Attention is a heavy operation when it comes to having a large number of items to compare because essentially with attention you're comparing every item against every other item within your input sequence.

So if your input sequence is a even relatively large image and you're comparing pixels to pixels with your attention mechanism the number of comparisons that you need to do becomes incredibly large very very quickly. So in the case of convolutional neural networks attention can only really be applied towards the later layers of the models where you basically have less activations being compared after a few convolutions.

Now that's better than nothing but it does limit the use of attention because you can't use it throughout the entire network. Now transform models in NLP have not had that limitation and can instead apply attention over many layers literally from the very starting point of the model. Now the setup used by BERT, which is a again a very well-known transformer model, involves several encoder layers.

Now within these encoder layers or encoder blocks we have a few different things going on. There is a normalization component, a multi-head attention component, which is essentially many attention operations happening in parallel, and a multi-layer perceptron layer. Through each of these blocks we're just encoding more and more information into our token embeddings and at the end of this process we get these super rich vector embeddings and these embeddings are the ultimate output of the core of a transform model including the vision transformer.

And from there what we tend to find with transform models is that we add another few layers onto the end which act as the head of the transformer which essentially encode or take these vector embeddings, information rich embeddings, and translate them into predictions for a particular task. So you might have a classification head or a NER head or a question answering head and they will all be slightly different in some way but at the core what they are doing is translating those super rich information embeddings into some sort of meaningful prediction.

Now the vision transformer actually works in the exact same way, the only difference is how we pre-process things before they are fed into the vision transformer. So rather than with BERT and other language transformers that consume word or sub-word tokens, the vision transformer consumes image patches. Then the remainder transformer works in the exact same way.

So let's take a look at how we go from images to image patches and then after that into patch embeddings. The high level process for doing this is relatively simple. First we split an image into image patches. Two, we process those patches through a linear projection layer to get our initial patch embeddings.

Then we pre-append something called a class embedding to those patch embeddings and finally we sum the patch embeddings and something called positional embedding. Now there's a lot of parallels with this process and what we see in the language and will relate to those where relevant. So after all these steps we have our patch embeddings and we just process them in the exact same way that we would token embeddings with a language transformer.

But let's dive into each one of these steps in a little more detail. Our first step is the transformation of our image into image patches. In NLP we actually do the same thing. We take a sentence and we translate it into a list of tokens. So in this respect images are sentences and image patches are word or sub-word tokens.

Now if we didn't create these image patches we could alternatively feed in the full set of pixels from a image. But as I mentioned before that basically makes it so that we can't use attention because the calculation or the number of computations that we need to do to compare all images would be very restrictive on the size of images that we could input.

We could only essentially input very very small images. So if we consider that attention requires the comparison of everything to everything else and we're using pixels here. If we have a 224 by 224 pixel image that means we would have to perform 224 to the power of 4 comparisons.

Which is 2.5 billion comparisons. Which is pretty insane and that's for a single attention layer. In transformers we have multiple attention layers. So it's already just far too much. If instead we split our 224 by 224 pixel image into image patches where we have 14 by 14 pixel patches that would leave us with 256 of these patches.

And with that a single attention layer requires a much more manageable 9.8 million comparisons. Which is a lot easier to do. With that we can have a huge number of attention layers and still not even get close to the single attention layer with our full image. Now after building these image patches we move on to the linear projection step.

For this we use a linear projection layer which is simply going to map our image patch arrays into image patch vectors. By mapping these patches to the patch embeddings we are reformatting them into the correct dimensionality to be input into our vision transformer. But we're not putting these into the vision transformer just yet because there's two more steps.

Our third step is the learnable embedding or the class token. Now this is an idea that comes from BERT. So BERT introduced the use of something called a CLS or classifier token. Now the CLS token was a special token pre-appended to every sentence that was input into BERT. This CLS token was as with every other token converted into an embedding and passed through several encoder layers.

Now there are two things that make CLS special. First it does not represent a real word so it almost acts as like a blank slate being input into the model. And second the CLS token embedding after the many encoder blocks is that embedding that is input into the classification head which is used as a part of the pre-training process.

So essentially what we end up doing there is we end up embedding like a general representation of the full sentence into this single token embedding. Because in order for the model to make a good prediction about what this sentence is it needs to have a general embedding of the whole sentence in that single token.

Because it's only that single token embedding that is passed into the classification head. Now the vision transformer applies the same logic and it adds something called a learnable embedding or a class embedding to the embeddings as they are processed by the first layers of the model. And this learnable embedding is practically the same thing as the CLS token in BERT.

Now it's also worth noting that it is potentially even more important for the vision transformer than it is for BERT. Because for BERT the main mode of pre-training is something called mass language modeling which doesn't rely on the classification token. Whereas with the vision transformer the ideal mode of pre-training is actually a classification task.

So in that sense we can think of this CLS token or CLS embedding as actually being very critical for the overall performance and overall training of the vision transformer. Now the final set that we need to apply to our patch embeddings before they are actually fed into the model is we need to add something called the positional embeddings.

Now positional embeddings are a common thing to be added to transformers. And that's because transformers by default don't actually have any mechanism for tracking the position of inputs. So there's no order that is being considered. And that is difficult because when it comes to language and also vision, but let's think in the sense of language for now, the order of words in a sentence is incredibly important.

If you mix up the order of words as a person it's hard to understand what this sentence is supposed to mean. And it can even mean something completely different. So obviously the order of words is super important and that applies as well to images. If we start mixing the image patches there's a good chance that we won't be able to understand what that image represents anymore.

And in fact this is what we get with jigsaw puzzles. We get a ton of little image patches and we need to put them together in a certain order. And it takes people a long time to figure out what that order actually is. So the order of our image patches is obviously quite important, but by default transformers don't have a way of handling this.

So that's where the positional embeddings come in. For the vision transformer, these positional embeddings are learned embeddings that are summed with the incoming patch embeddings. Now, as I mentioned, these positional embeddings are learned. So during pre-training these are adjusted and what we can actually see if we visualize this similarity or the cosine similarity between embeddings is that positional embeddings that are close to one another actually have a higher similarity.

And in particular positional embeddings that exist within the same row and the same column as one another also have a higher similarity. So it seems like there's this logical thing going on here with these positional embeddings, whereas identifying patches that are within a similar area is pushing them into a similar vector space and patches that are in a dissimilar area is pushing them away from each other within that vector space.

So there's a sense of locality being introduced within these positional embeddings. Now, after adding our positional embeddings and patch embeddings together, we have our final patch embeddings, which are then fed into our vision transformer and they're processed through that sort of encoder attention mechanism that we described before, which is just a typical transformer approach.

Now, that is the logic behind vision transformer and the new innovations that it has brought. Now I want to describe or actually go through an example of an implementation of the vision transformer and how we can actually use it. Okay, so we start by just installing any prerequisites that we have.

So here we've got pip install datasets and transformers and also PyTorch. So we run this and then what we want to do is download a dataset that we can actually test all of this on and also fine tune with. So we're going to be using the CFAR-10 dataset. We're going to be getting that from HungFix datasets.

So from datasets, import load dataset. Let this run and we just run this. One thing just to check here before we go through everything is to make sure that we're using GPU. Save and we will have to rerun everything. Okay, so after that's downloaded, we'll see that we have 50,000 images with classification labels within our training data.

And we also download the test split as well. That has 10,000 of these. And then what we want to do is we want to just have a look at the labels quickly. So let's see what we have in there. So we have 10 labels. That's why it's called CFAR-10.

And we want to have 10 labels. That's why it's called CFAR-10. And of those, we have these particular classes within the dataset. Airplane, automobile, so on and so on. So from there, we can have a look at what is within a single item within that dataset. So we have this pill.

So Python pill object is essentially a image. And then also the label. Now that label corresponds to airplane here in this case, because it's number zero. And we can just check that. So run this. This is a Z. We can't really see it very well. It's very small, but that is an airplane.

And we can actually map the label. So zero to labels.names in order to get the actual human readable class label. Okay, cool. So what we're going to do is we're going to load the Vision Transformer feature extractor. So we're going to be using this model here from the FaceHub.

And we can actually see that over here. So we have Google VIT Base Patch 16.225 in or IN 21K. Now what that means is we have patches that are 16 by 16 pixels. They are being pulled or being built during pre-training at least by a 224 by 224 pixel image.

And this IN 21K is just to say that this has been trained on or pre-trained on the ImageNet 21K dataset. So that is the model we'll be using. And we use this feature extractor, which is almost like a pre-processor for this particular model. So we can run that and this will just download that feature extractor for us.

That's pretty quick. And we can see the configuration within that feature extractor here. So what is this feature extractor doing exactly? It is taking an image. Our image can be any size and in a lot of different formats. And what it's going to do is just normalize and resize that image into something that we can then process with our vision transformer.

So we can see here that it will normalize the pixel values within the image and it will resize the image as well. It will resize the image to this here, 224 by 224 pixels. In terms of normalization, to normalize I'm using these values here for each of the color channels.

So we have red, green, and blue. And yeah, that's pretty much, that's what it's going to be doing. So if we take a look at the first image, we can use the feature extractor here on our first image, which is that plane. And we're going to just return tensors in using PyTorch because we'll be using PyTorch later on.

So we run this and what we return is a dictionary containing a single tensor or a single key value pair, which is pixel values, which maps to this single tensor here. And we can go down and we can have a look at the shape of that. And we see that we have this 224 by 224 pixel image or pixel values tensor.

Now that is different to the original image because the original image was train zero image or IMG. What's the shape of this? I think we can maybe do this. Maybe size. Okay. 32 by 32. So it's been resized up to 224 by 224, which is the format that the vision transformer needs.

Now, when we are doing this, what we're going to want to do later on is we're going to be training everything on GPU, not CPU. Now, by default, this tensor here is on CPU. We don't want that. We need to be using a GPU where possible. So we say, okay, if a CUDA enabled GPU is available, please use GPU.

Okay. So we can see here, there is one available. So we're on Colab. So that's great. It means everything will be much faster. And the reason why we need that is because here, we're going to need to move everything to that device. So what we'll do is here, as we use feature extractor here, we're going to say to device.

That will just move everything to GPU for us. Okay. And then we use this with transform to apply that to both the training and the testing data set. Or in reality, we're going to be using test data set more as a validation data set. Now, after all that, we're ready to move on to the model fine tuning step.

So with this, there are a few things we're going to need to define. So training and testing data set, we've already done that. It's not a problem. Feature extractor, we have already done that as well. Not a problem. The model, we will define that. It's pretty easy. Something called a collate function, evaluation metric, and some other training arguments.

So let's start with the collate function. So here, this is essentially, when we're training with the Hug & Face trainer, we need a way to collate all of our data into batches in a way that makes sense, which requires this dictionary format. So each record is represented by dictionary, and each record contains inputs, which is the pixel values, and also the labels.

So we run this. We then need to define our evaluation metric, which I'm using accuracy, which is, you can read that if you want, but it's pretty straightforward. So we define that. And then we have all these training arguments. So these are essentially just the training parameters that we're going to use to actually train our model.

So we have the batch size that we want to use, where we're going to output the model, the number of training epochs that we want to use, how often do we want to evaluate the model. So run it on the validation/test data set that we have, what learning rate do you want to use, and so on and so on.

Rerun that. That just sets up the configuration for our training. And then we move on to initializing our model. Again, this is just using the same thing that we had before. So when we had that feature extractor, we initialized it from pre-trained, and then we had the model name or path, model ID.

So that is just the VIT patch 16224 that you saw before. One thing that we do need to add here is, because we're doing this VIT image classification, we need to specify the number of labels or classes that will be output from that classification head, which in this case is 10 of those labels.

So we define that as well. We move the model to our GPU, and with that, we are ready to initialize our trainer with all of those things that we just defined. So we run that, and then to actually train the model, we do this. So trainer.train. After that, we can save the model, we can log our metrics, save our metrics, and then just save the current state of the trainer at that point.

So I'm going to run that very briefly and then stop. Okay, so it seems we're getting this error, which I think might be because we're trying to move the input tensors to GPU twice. So I think the trainer is doing it by default, but earlier on, we added the two device, so we need to remove that and run it again.

So up here within preprocess, we just remove this, run it again, and then just rerun everything. Then pass everything to the trainer, and then try and train again. Okay, it looks like we're having a little more luck with it this time. So we can see that the model is training.

Actually, it doesn't take too long, but what I'm going to do is just skip forward. So I'm going to stop this, and what we can do is you can run this to get your evaluation metrics and view your evaluation metrics. Your model will be evaluating as it goes through your training set, thanks to the trainer.

But if you would like to check again, you can just use this. But for now, let's just have a look at a specific example. So what we're going to do is load this image. I mean, I can't really tell what that image is. I think, so if we come down here, it should be a cat, yeah?

So run this, we can see that it's actually supposed to be a cat. It's very blurry, I can't personally tell. But what we're going to do is load a fine-tuned model. So this is the model that has been fine-tuned using this same process. So we can download that from Hugging Face Hub.

We can also download the feature extractor, which we don't need to do that because it is actually using the same feature extractor, but in a real use case scenario, you might actually just download everything from a particular model that is hosted within the Hugging Face Hub. So this is what you would do, because it's not really fine-tuned.

So run that. That will just download the fine-tuned model. And you can see here, we have the exact same feature extractor configuration there. We process our image through the feature extractor, return PySource sensors, and then we say, with Torch Node Gradle, which is essentially to make sure that we're not updating the gradients of the model like we would during fine-tuning, because we're actually just making a prediction here.

We don't want to train anything. We use the model, process the inputs, and we extract the logits, which is just the output activations. And what we want to do is take the argmax, so where the logits is the maximum value is basically highest probability that it is that class being predicted.

So we extract that, we get the labels, and then we output labels. And if we run that, we will see that we get cat. Okay, so it looks like we have fine-tuned a position transformer using that same process, and the performance is pretty accurate. Now, before 2021, which really not that long ago, transformers were known as just being those language models that they were not used in anything else.

But now, as we can see, we're actually able to use transformers and get really good results within the field of computer vision. And we're actually seeing this use in a lot of places. Vision transformer is a key component of the OpenAI's CLIP model, and OpenAI's CLIP is a key component of all of the diffusion models that we've seen pop up everywhere, and the world is going crazy over them right now.

Transformers are also a key component in Tesla for self-driving. They are finding use in a huge number of places that would have just been incredibly unexpected a year or even two, three years ago. And I think as time progresses, we will undoubtedly see more use of transformers within computer vision, and of course, the continued use of transformers within the field of language.

And they will undoubtedly become more and more unified over time. For now, that's it for this video. I hope all of this has been useful and interesting. So, thank you very much for watching, and I'll see you again in the next one. Bye.