back to index

Vision Transformers (ViT) Explained + Fine-tuning in Python


Chapters

0:0 Intro
0:58 In this video
1:12 What are transformers and attention?
1:39 Attention explained simply
4:15 Attention used in CNNs
5:24 Transformers and attention
7:1 What vision transformer (ViT) does differently
7:28 Images to patch embeddings
8:22 1. Building image patches
10:23 2. Linear projection
10:57 3. Learnable class embedding
13:30 4. Adding positional embeddings
16:37 ViT implementation in python with Hugging Face
16:45 Packages, dataset, and Colab GPU
18:42 Initialize Hugging Face ViT Feature Extractor
22:48 Hugging Face Trainer setup
25:14 Training and CUDA device error
26:27 Evaluation and classification predictions with ViT
28:54 Final thoughts

Whisper Transcript | Transcript Only Page

00:00:00.000 | Vision and language are the two big domains in machine learning. Two distinct disciplines
00:00:06.320 | with their own problems, best practices and model architectures or at least that used to be the case.
00:00:14.240 | The vision transformer or VIT marks the first step towards a merger of both fields into a single
00:00:24.240 | unified discipline. For the first time in the history of machine learning we have a single
00:00:31.920 | model architecture that is on track to become the dominant model in both language and vision.
00:00:39.760 | Before the vision transformer, transformers were known as those language models and nothing more.
00:00:47.040 | But since the introduction of the vision transformer there has been further work
00:00:52.320 | that has almost solidified their position as state-of-the-art in vision. In this video we're
00:00:59.840 | going to dive into the vision transformer. We're going to explain what it is, how it works, why it
00:01:07.360 | works and we're going to look at how we can actually use it and implement it with Python.
00:01:12.400 | So let's get started with a very quick one-on-one introduction to transformers and the attention
00:01:20.160 | mechanism. So transformers were introduced in 2017 in the pretty well-known paper called
00:01:27.840 | "Attention is all you need". Transformers quite literally changed the entire landscape of NLP
00:01:34.640 | and this was very much thanks to something called the attention mechanism. Now in NLP attention
00:01:43.360 | allows us to embed contextual meaning into the word level or sub-word level token embeddings
00:01:53.280 | within a model. So what I mean by that is say you have a sentence and you have two words in
00:01:59.840 | that sentence that are very much related. Attention allows you to identify that relationship and then
00:02:07.840 | allow the model to understand those words with respect to one another within that greater
00:02:14.960 | sentence. Now this starts within the transformer model with tokens just being embedded into a
00:02:22.640 | vector space based purely on what that token is. So the token for bank will be mapped to a particular
00:02:30.480 | vector that represents the word bank without any consideration of the words surrounding it. Now
00:02:38.000 | with these token embeddings what we can do is calculate dot products between their embeddings
00:02:44.480 | and we will return a high score when they are aligned and a low score when they are not aligned.
00:02:51.360 | And as we do this within the attention mechanism we can essentially identify which words should
00:02:58.080 | be placed closer together within that vector space. So for example if we had three sentences
00:03:04.880 | a plane banks the grassy bank and the bank of England the initial embedding of that token bank
00:03:12.960 | for all of those sentences is equal. But then through this encoder attention mechanism we
00:03:21.040 | essentially map the token embedding bank closer to the vector space of the other relevant words
00:03:28.720 | within each one of those sentences. So in the case of a plane banks what we would have is the word
00:03:36.480 | bank or banks being moved closer towards words like aeroplane, plane, airport, flight and so on.
00:03:44.800 | For the phrase a grassy bank we will find that the token embedding for bank gets moved towards
00:03:51.200 | the embedding space for grass, nature, fields. And for the bank of England we'll find that the
00:03:58.080 | word bank gets moved towards finance, money and so on. So as we go through these many encoder
00:04:05.680 | blocks which contain the attention mechanism we are essentially embedding more contextual meaning
00:04:12.000 | into those initial token embeddings. Now attention did find itself being used occasionally
00:04:19.840 | in convolutional neural networks which were the past state of the art in computer vision.
00:04:25.600 | Generally speaking this has produced some benefit but it is somewhat limited. Attention is
00:04:32.560 | a heavy operation when it comes to having a large number of items to compare because essentially
00:04:39.920 | with attention you're comparing every item against every other item within your input sequence. So if
00:04:45.920 | your input sequence is a even relatively large image and you're comparing pixels to pixels with
00:04:53.680 | your attention mechanism the number of comparisons that you need to do becomes incredibly large very
00:05:01.040 | very quickly. So in the case of convolutional neural networks attention can only really be
00:05:06.880 | applied towards the later layers of the models where you basically have less activations being
00:05:14.640 | compared after a few convolutions. Now that's better than nothing but it does limit the use
00:05:20.880 | of attention because you can't use it throughout the entire network. Now transform models in NLP
00:05:26.960 | have not had that limitation and can instead apply attention over many layers literally from the very
00:05:34.240 | starting point of the model. Now the setup used by BERT, which is a again a very well-known
00:05:40.960 | transformer model, involves several encoder layers. Now within these encoder layers or encoder blocks
00:05:47.280 | we have a few different things going on. There is a normalization component, a multi-head attention
00:05:54.400 | component, which is essentially many attention operations happening in parallel, and a multi-layer
00:06:02.000 | perceptron layer. Through each of these blocks we're just encoding more and more information
00:06:07.120 | into our token embeddings and at the end of this process we get these super rich vector embeddings
00:06:14.800 | and these embeddings are the ultimate output of the core of a transform model including the vision
00:06:22.000 | transformer. And from there what we tend to find with transform models is that we add another few
00:06:29.120 | layers onto the end which act as the head of the transformer which essentially encode or take these
00:06:36.400 | vector embeddings, information rich embeddings, and translate them into predictions for a particular
00:06:43.680 | task. So you might have a classification head or a NER head or a question answering head and they
00:06:51.120 | will all be slightly different in some way but at the core what they are doing is translating those
00:06:57.200 | super rich information embeddings into some sort of meaningful prediction. Now the vision
00:07:02.960 | transformer actually works in the exact same way, the only difference is how we pre-process things
00:07:09.600 | before they are fed into the vision transformer. So rather than with BERT and other language
00:07:15.920 | transformers that consume word or sub-word tokens, the vision transformer consumes image patches.
00:07:24.080 | Then the remainder transformer works in the exact same way. So let's take a look at how we go from
00:07:30.720 | images to image patches and then after that into patch embeddings. The high level process for doing
00:07:39.120 | this is relatively simple. First we split an image into image patches. Two, we process those patches
00:07:46.720 | through a linear projection layer to get our initial patch embeddings. Then we pre-append something
00:07:53.040 | called a class embedding to those patch embeddings and finally we sum the patch embeddings and
00:07:59.440 | something called positional embedding. Now there's a lot of parallels with this process and what we
00:08:04.560 | see in the language and will relate to those where relevant. So after all these steps we have our
00:08:12.240 | patch embeddings and we just process them in the exact same way that we would token embeddings with
00:08:17.280 | a language transformer. But let's dive into each one of these steps in a little more detail. Our
00:08:22.400 | first step is the transformation of our image into image patches. In NLP we actually do the same
00:08:30.960 | thing. We take a sentence and we translate it into a list of tokens. So in this respect images are
00:08:39.040 | sentences and image patches are word or sub-word tokens. Now if we didn't create these image
00:08:46.720 | patches we could alternatively feed in the full set of pixels from a image. But as I mentioned
00:08:53.760 | before that basically makes it so that we can't use attention because the calculation or the number
00:09:00.240 | of computations that we need to do to compare all images would be very restrictive on the size of
00:09:06.880 | images that we could input. We could only essentially input very very small images. So
00:09:11.360 | if we consider that attention requires the comparison of everything to everything else
00:09:18.560 | and we're using pixels here. If we have a 224 by 224 pixel image that means we would have to perform
00:09:27.520 | 224 to the power of 4 comparisons. Which is 2.5 billion comparisons. Which is pretty insane and
00:09:38.240 | that's for a single attention layer. In transformers we have multiple attention layers.
00:09:43.520 | So it's already just far too much. If instead we split our 224 by 224 pixel image into image patches
00:09:54.480 | where we have 14 by 14 pixel patches that would leave us with 256 of these patches. And with that
00:10:04.960 | a single attention layer requires a much more manageable 9.8 million comparisons. Which is
00:10:13.680 | a lot easier to do. With that we can have a huge number of attention layers and still not even get
00:10:19.920 | close to the single attention layer with our full image. Now after building these image patches we
00:10:26.560 | move on to the linear projection step. For this we use a linear projection layer which is simply
00:10:32.000 | going to map our image patch arrays into image patch vectors. By mapping these patches to the
00:10:42.720 | patch embeddings we are reformatting them into the correct dimensionality to be input into
00:10:51.680 | our vision transformer. But we're not putting these into the vision transformer just yet because
00:10:56.560 | there's two more steps. Our third step is the learnable embedding or the class token. Now this
00:11:04.640 | is an idea that comes from BERT. So BERT introduced the use of something called a CLS or classifier
00:11:13.600 | token. Now the CLS token was a special token pre-appended to every sentence that was input
00:11:21.040 | into BERT. This CLS token was as with every other token converted into an embedding and passed
00:11:28.160 | through several encoder layers. Now there are two things that make CLS special. First it does not
00:11:35.040 | represent a real word so it almost acts as like a blank slate being input into the model. And second
00:11:44.240 | the CLS token embedding after the many encoder blocks is that embedding that is input into the
00:11:54.880 | classification head which is used as a part of the pre-training process. So essentially what
00:12:02.080 | we end up doing there is we end up embedding like a general representation of the full sentence into
00:12:09.200 | this single token embedding. Because in order for the model to make a good prediction about what
00:12:16.000 | this sentence is it needs to have a general embedding of the whole sentence in that single
00:12:22.400 | token. Because it's only that single token embedding that is passed into the classification
00:12:30.000 | head. Now the vision transformer applies the same logic and it adds something called a learnable
00:12:36.320 | embedding or a class embedding to the embeddings as they are processed by the first layers of the
00:12:44.080 | model. And this learnable embedding is practically the same thing as the CLS token in BERT. Now it's
00:12:50.800 | also worth noting that it is potentially even more important for the vision transformer than it is
00:12:57.520 | for BERT. Because for BERT the main mode of pre-training is something called mass language
00:13:03.200 | modeling which doesn't rely on the classification token. Whereas with the vision transformer the
00:13:10.880 | ideal mode of pre-training is actually a classification task. So in that sense we can
00:13:18.640 | think of this CLS token or CLS embedding as actually being very critical for the overall
00:13:26.800 | performance and overall training of the vision transformer. Now the final set that we need to
00:13:32.960 | apply to our patch embeddings before they are actually fed into the model is we need to add
00:13:40.800 | something called the positional embeddings. Now positional embeddings are a common thing to be
00:13:47.680 | added to transformers. And that's because transformers by default don't actually have
00:13:54.000 | any mechanism for tracking the position of inputs. So there's no order that is being considered. And
00:14:03.680 | that is difficult because when it comes to language and also vision, but let's think in
00:14:11.200 | the sense of language for now, the order of words in a sentence is incredibly important. If you mix
00:14:17.280 | up the order of words as a person it's hard to understand what this sentence is supposed to mean.
00:14:23.760 | And it can even mean something completely different. So obviously the order of words is
00:14:28.320 | super important and that applies as well to images. If we start mixing the image patches
00:14:35.120 | there's a good chance that we won't be able to understand what that image represents anymore.
00:14:40.080 | And in fact this is what we get with jigsaw puzzles. We get a ton of little image patches
00:14:47.040 | and we need to put them together in a certain order. And it takes people a long time to figure
00:14:50.640 | out what that order actually is. So the order of our image patches is obviously quite important,
00:14:58.240 | but by default transformers don't have a way of handling this. So that's where the positional
00:15:04.000 | embeddings come in. For the vision transformer, these positional embeddings are learned embeddings
00:15:10.080 | that are summed with the incoming patch embeddings. Now, as I mentioned, these positional
00:15:17.600 | embeddings are learned. So during pre-training these are adjusted and what we can actually see
00:15:23.680 | if we visualize this similarity or the cosine similarity between embeddings is that positional
00:15:30.480 | embeddings that are close to one another actually have a higher similarity. And in particular
00:15:36.320 | positional embeddings that exist within the same row and the same column as one another also have
00:15:42.720 | a higher similarity. So it seems like there's this logical thing going on here with these
00:15:50.480 | positional embeddings, whereas identifying patches that are within a similar area is pushing them
00:15:56.320 | into a similar vector space and patches that are in a dissimilar area is pushing them away from each
00:16:02.880 | other within that vector space. So there's a sense of locality being introduced within these
00:16:08.400 | positional embeddings. Now, after adding our positional embeddings and patch embeddings together,
00:16:14.800 | we have our final patch embeddings, which are then fed into our vision transformer and they're
00:16:20.880 | processed through that sort of encoder attention mechanism that we described before, which is just
00:16:26.160 | a typical transformer approach. Now, that is the logic behind vision transformer and the
00:16:34.320 | new innovations that it has brought. Now I want to describe or actually go through an example
00:16:40.720 | of an implementation of the vision transformer and how we can actually use it. Okay, so we start
00:16:47.200 | by just installing any prerequisites that we have. So here we've got pip install datasets and
00:16:53.760 | transformers and also PyTorch. So we run this and then what we want to do is download a dataset that
00:17:02.400 | we can actually test all of this on and also fine tune with. So we're going to be using the CFAR-10
00:17:08.880 | dataset. We're going to be getting that from HungFix datasets. So from datasets, import load
00:17:14.560 | dataset. Let this run and we just run this. One thing just to check here before we go through
00:17:20.560 | everything is to make sure that we're using GPU. Save and we will have to rerun everything.
00:17:28.320 | Okay, so after that's downloaded, we'll see that we have 50,000 images with classification labels
00:17:33.440 | within our training data. And we also download the test split as well. That has 10,000 of these.
00:17:41.120 | And then what we want to do is we want to just have a look at the labels quickly. So let's see
00:17:46.720 | what we have in there. So we have 10 labels. That's why it's called CFAR-10. And we want to
00:17:53.200 | have 10 labels. That's why it's called CFAR-10. And of those, we have these particular classes
00:18:00.400 | within the dataset. Airplane, automobile, so on and so on. So from there, we can have a look at
00:18:06.240 | what is within a single item within that dataset. So we have this pill. So Python pill object is
00:18:13.120 | essentially a image. And then also the label. Now that label corresponds to airplane here in this
00:18:20.480 | case, because it's number zero. And we can just check that. So run this. This is a Z. We can't
00:18:28.240 | really see it very well. It's very small, but that is an airplane. And we can actually map the label.
00:18:35.360 | So zero to labels.names in order to get the actual human readable class label. Okay, cool. So
00:18:43.200 | what we're going to do is we're going to load the Vision Transformer feature extractor. So
00:18:48.720 | we're going to be using this model here from the FaceHub. And we can actually see that over here.
00:18:55.120 | So we have Google VIT Base Patch 16.225 in or IN 21K. Now what that means is we have patches that
00:19:06.560 | are 16 by 16 pixels. They are being pulled or being built during pre-training at least by a
00:19:14.720 | 224 by 224 pixel image. And this IN 21K is just to say that this has been trained on or pre-trained
00:19:25.040 | on the ImageNet 21K dataset. So that is the model we'll be using. And we use this feature extractor,
00:19:32.720 | which is almost like a pre-processor for this particular model. So we can run that and this
00:19:38.560 | will just download that feature extractor for us. That's pretty quick. And we can see the
00:19:44.480 | configuration within that feature extractor here. So what is this feature extractor doing exactly?
00:19:50.080 | It is taking an image. Our image can be any size and in a lot of different formats. And what it's
00:19:58.080 | going to do is just normalize and resize that image into something that we can then process
00:20:03.200 | with our vision transformer. So we can see here that it will normalize the pixel values within
00:20:09.360 | the image and it will resize the image as well. It will resize the image to this here, 224 by 224
00:20:17.680 | pixels. In terms of normalization, to normalize I'm using these values here for each of the color
00:20:23.040 | channels. So we have red, green, and blue. And yeah, that's pretty much, that's what it's going
00:20:31.520 | to be doing. So if we take a look at the first image, we can use the feature extractor here
00:20:36.800 | on our first image, which is that plane. And we're going to just return tensors
00:20:42.240 | in using PyTorch because we'll be using PyTorch later on. So we run this and what we return is
00:20:50.800 | a dictionary containing a single tensor or a single key value pair, which is pixel values,
00:20:55.760 | which maps to this single tensor here. And we can go down and we can have a look at the shape of
00:21:02.160 | that. And we see that we have this 224 by 224 pixel image or pixel values tensor. Now that is
00:21:12.560 | different to the original image because the original image was train zero image or IMG.
00:21:22.560 | What's the shape of this? I think we can maybe do this. Maybe size.
00:21:35.120 | Okay. 32 by 32. So it's been resized up to 224 by 224, which is the format that the
00:21:43.200 | vision transformer needs. Now, when we are doing this, what we're going to want to do later on
00:21:49.280 | is we're going to be training everything on GPU, not CPU. Now, by default, this tensor here is on
00:21:56.480 | CPU. We don't want that. We need to be using a GPU where possible. So we say, okay, if a CUDA
00:22:03.600 | enabled GPU is available, please use GPU. Okay. So we can see here, there is one available. So
00:22:10.000 | we're on Colab. So that's great. It means everything will be much faster. And the reason
00:22:16.720 | why we need that is because here, we're going to need to move everything to that device.
00:22:23.040 | So what we'll do is here, as we use feature extractor here, we're going to say to device.
00:22:32.160 | That will just move everything to GPU for us. Okay. And then we use this with transform to
00:22:38.240 | apply that to both the training and the testing data set. Or in reality, we're going to be using
00:22:43.280 | test data set more as a validation data set. Now, after all that, we're ready to move on to
00:22:49.680 | the model fine tuning step. So with this, there are a few things we're going to need to define.
00:22:55.760 | So training and testing data set, we've already done that. It's not a problem. Feature extractor,
00:23:00.000 | we have already done that as well. Not a problem. The model, we will define that. It's pretty easy.
00:23:05.120 | Something called a collate function, evaluation metric, and some other training arguments. So
00:23:11.520 | let's start with the collate function. So here, this is essentially, when we're training with
00:23:18.000 | the Hug & Face trainer, we need a way to collate all of our data into batches in a way that makes
00:23:25.600 | sense, which requires this dictionary format. So each record is represented by dictionary,
00:23:33.120 | and each record contains inputs, which is the pixel values, and also the labels. So we run this.
00:23:39.520 | We then need to define our evaluation metric, which I'm using accuracy,
00:23:45.680 | which is, you can read that if you want, but it's pretty straightforward. So we define that.
00:23:55.360 | And then we have all these training arguments. So these are essentially just the training
00:24:00.800 | parameters that we're going to use to actually train our model. So we have the batch size that
00:24:06.960 | we want to use, where we're going to output the model, the number of training epochs that we want
00:24:12.880 | to use, how often do we want to evaluate the model. So run it on the validation/test data set
00:24:20.160 | that we have, what learning rate do you want to use, and so on and so on. Rerun that. That just
00:24:25.840 | sets up the configuration for our training. And then we move on to initializing our model. Again,
00:24:34.240 | this is just using the same thing that we had before. So when we had that feature extractor,
00:24:39.680 | we initialized it from pre-trained, and then we had the model name or path, model ID.
00:24:47.040 | So that is just the VIT patch 16224 that you saw before. One thing that we do need to add here is,
00:24:55.360 | because we're doing this VIT image classification, we need to specify the number of labels or
00:25:01.680 | classes that will be output from that classification head, which in this case is 10 of those labels.
00:25:08.400 | So we define that as well. We move the model to our GPU, and with that, we are ready to initialize
00:25:16.400 | our trainer with all of those things that we just defined. So we run that, and then to actually
00:25:23.680 | train the model, we do this. So trainer.train. After that, we can save the model, we can log
00:25:30.800 | our metrics, save our metrics, and then just save the current state of the trainer at that point.
00:25:37.360 | So I'm going to run that very briefly and then stop. Okay, so it seems we're getting this error,
00:25:43.280 | which I think might be because we're trying to move the input tensors to GPU twice. So I think
00:25:50.960 | the trainer is doing it by default, but earlier on, we added the two device, so we need to remove
00:25:56.640 | that and run it again. So up here within preprocess, we just remove this, run it again,
00:26:05.200 | and then just rerun everything. Then pass everything to the trainer, and then try and
00:26:11.600 | train again. Okay, it looks like we're having a little more luck with it this time. So we can see
00:26:17.200 | that the model is training. Actually, it doesn't take too long, but what I'm going to do is just
00:26:23.440 | skip forward. So I'm going to stop this, and what we can do is you can run this to get your
00:26:31.600 | evaluation metrics and view your evaluation metrics. Your model will be evaluating as it goes
00:26:37.520 | through your training set, thanks to the trainer. But if you would like to check again, you can just
00:26:45.040 | use this. But for now, let's just have a look at a specific example. So what we're going to do is
00:26:50.800 | load this image. I mean, I can't really tell what that image is. I think, so if we come down here,
00:26:58.400 | it should be a cat, yeah? So run this, we can see that it's actually supposed to be a cat.
00:27:03.120 | It's very blurry, I can't personally tell. But what we're going to do is load a fine-tuned model.
00:27:09.120 | So this is the model that has been fine-tuned using this same process. So we can download that
00:27:16.000 | from Hugging Face Hub. We can also download the feature extractor, which we don't need to do that
00:27:23.840 | because it is actually using the same feature extractor, but in a real use case scenario,
00:27:31.280 | you might actually just download everything from a particular model that is hosted within
00:27:36.480 | the Hugging Face Hub. So this is what you would do, because it's not really fine-tuned. So run that.
00:27:44.640 | That will just download the fine-tuned model.
00:27:49.120 | And you can see here, we have the exact same feature extractor configuration there. We process
00:27:59.360 | our image through the feature extractor, return PySource sensors, and then we say,
00:28:04.160 | with Torch Node Gradle, which is essentially to make sure that we're not updating the gradients
00:28:09.760 | of the model like we would during fine-tuning, because we're actually just making a prediction
00:28:14.560 | here. We don't want to train anything. We use the model, process the inputs, and we extract the
00:28:21.840 | logits, which is just the output activations. And what we want to do is take the argmax, so
00:28:28.480 | where the logits is the maximum value is basically highest probability that it is that class being
00:28:35.280 | predicted. So we extract that, we get the labels, and then we output labels. And if we run that,
00:28:41.680 | we will see that we get cat. Okay, so it looks like we have fine-tuned a position transformer
00:28:49.040 | using that same process, and the performance is pretty accurate. Now, before 2021, which really
00:28:57.760 | not that long ago, transformers were known as just being those language models that they were not
00:29:02.960 | used in anything else. But now, as we can see, we're actually able to use transformers and get
00:29:09.280 | really good results within the field of computer vision. And we're actually seeing this use in a
00:29:15.920 | lot of places. Vision transformer is a key component of the OpenAI's CLIP model, and OpenAI's
00:29:22.560 | CLIP is a key component of all of the diffusion models that we've seen pop up everywhere, and
00:29:28.400 | the world is going crazy over them right now. Transformers are also a key component in Tesla
00:29:35.360 | for self-driving. They are finding use in a huge number of places that would have just been
00:29:43.680 | incredibly unexpected a year or even two, three years ago. And I think as time progresses, we will
00:29:50.320 | undoubtedly see more use of transformers within computer vision, and of course, the continued use
00:29:56.240 | of transformers within the field of language. And they will undoubtedly become more and more unified
00:30:03.920 | over time. For now, that's it for this video. I hope all of this has been useful and interesting.
00:30:11.280 | So, thank you very much for watching, and I'll see you again in the next one. Bye.