back to indexOpenAI's CLIP for Zero Shot Image Classification
Chapters
0:0
3:54 Zero Shot Classification
6:17 How Clip Makes Zero Shot Learning So Effective
7:53 Pre-Training
16:17 Image Embeddings
00:00:00.000 |
Today we're going to talk about how to use CLIP for zero shot image classification. That is image 00:00:06.960 |
classification without needing to fine-tune your model on a particular image data set. For me this 00:00:15.600 |
is one of the most interesting use cases that really demonstrates the power of these multimodal 00:00:22.560 |
models. Those are models that understand different domains of information like for CLIP for example 00:00:29.120 |
that is image and text. But how about other state-of-the-art computer vision models? How do 00:00:34.800 |
they perform in comparison? Well these other computer vision models they are characterized 00:00:40.880 |
by the fact that they really focus on one thing in an image. So for example let's say we had a 00:00:48.320 |
computer vision model that was trained to classify an image as to whether it contained a dog, 00:00:55.280 |
a car, or a bird. Okay that computer vision model doesn't really need to know anything other than 00:01:01.760 |
what a car looks like, what a dog looks like, and what a bird looks like. It can ignore everything 00:01:06.880 |
else in those images and that's one of the limitations of these state-of-the-art computer 00:01:12.560 |
vision models. They have been pre-trained on a huge amount of data but they've been pre-trained 00:01:18.240 |
for classification of these particular elements and even if you have like a thousand or ten 00:01:24.160 |
thousand classes in there that hardly represents the real world which has a huge variety of concepts 00:01:31.520 |
and objects for us to as humans to to understand and see and categorize. So what we find with these 00:01:40.080 |
computer vision models is that they perform very well on specific data sets but they're not so good 00:01:50.080 |
at handling new classes for example. For that model to understand a new class of objects and 00:01:57.760 |
images it will need to be fine-tuned further and it's going to need a lot of data and it's just 00:02:04.000 |
it's not very easy to do. Ideally we want a computer vision model to just understand 00:02:08.400 |
kind of not everything would be perfect obviously we're not quite there yet but if it can just 00:02:14.320 |
understand the world and the visual world relatively well then we're sort of on the way 00:02:20.800 |
to a more robust model. So for example this image of a dog we want the model to understand that dog 00:02:27.200 |
is in the image and a you know usually use convolutional neural network that sort of model 00:02:32.800 |
for image classification will know that there is a dog in the image but it doesn't it doesn't care 00:02:36.880 |
about everything else in the image. Ideally we want it to understand that there are trees in 00:02:40.320 |
the background the dog is running towards the camera it's on a grassy field it's sunrise or 00:02:45.280 |
sunset and that there are some blurry trees in the background the blue street. Unfortunately 00:02:51.120 |
classification training we don't get that. Instead the models are essentially just learning to push 00:02:57.200 |
their internal representations of images with dogs in all towards the same sort of vector space. 00:03:04.160 |
That's essentially how we can we can think about this and then for example earlier I said cars and 00:03:09.760 |
birds as well we could imagine so they have like a cluster of dog images they have a cluster of 00:03:14.080 |
bird images and they have a cluster of car images but they don't really have anything else in 00:03:19.360 |
between that. So this is ideal if we just want a yes or no answer for a specific data set and we 00:03:24.800 |
want good performance we can do that as long as we have the training data and the compute and the 00:03:29.200 |
time to actually do that. However if we don't have all of that and we just want good performance 00:03:36.720 |
maybe not state-of-the-art but good performance across a whole range of data sets that's where 00:03:41.600 |
we use CLIP. CLIP has proved itself as incredibly flexible model that can work in both text and 00:03:51.040 |
images and is amazing at what something we call zero shot classification. Zero shot basically 00:03:58.080 |
saying you need zero training examples for this model to adapt to a new domain. So before we dive 00:04:05.120 |
into CLIP let's just explain this zero shot thing in a little more detail. So zero shot comes from 00:04:11.440 |
something called end shot learning. End shot you may have guessed is basically the number 00:04:16.320 |
of training examples that you need for your model to perform on a particular new domain 00:04:21.280 |
on your data set. Many state-of-the-art image classification models they tend to be pre-trained 00:04:27.120 |
on like ImageNet and then they're fine-tuned for a specific task so that they have the pre-training 00:04:32.960 |
and the pre-training basically sets up the internal model ways of that model to understand 00:04:38.640 |
the visual world at least within the scope of the ImageNet classification set which is fairly big 00:04:46.800 |
but it's obviously not as big as the actual world and then those models are usually fine-tuned 00:04:54.400 |
on a particular data set and to fine-tune that pre-trained image classification model on a new 00:05:02.560 |
domain you are going to need a lot of examples. Let's say as a rule of thumb maybe you need 10,000 00:05:11.680 |
images for each class or each label within your data set. That may be excessive it may be too 00:05:18.880 |
little I'm not sure but you do need something within that ballpark in order to get good 00:05:25.760 |
performance. We could refer to these models so these are like ResNet and BERT as many shot 00:05:32.960 |
learners they need many many training examples in order to learn a new domain. Ideally we want to 00:05:39.040 |
maximize model performance whilst minimizing the n in n shot okay so minimizing the number of 00:05:46.800 |
training examples needed for the model to perform well. Now so as I was noting that CLIP is not 00:05:52.560 |
achieving state-of-the-art performance on any particular data sets or benchmarks other than one 00:05:59.440 |
surprisingly without seeing any training data for this particular data set CLIP did actually get 00:06:06.480 |
state-of-the-art performance on that one data set which is surprising without seeing any of the 00:06:12.320 |
training data but here we go this is this is how useful this sort of thing is. Let's talk about 00:06:18.640 |
how CLIP makes zero-shot learning so effective. So CLIP stands for Contrastive Language Image 00:06:24.480 |
Pre-training it was released by OpenAI in 2021 and since then it has done pretty well we can find it 00:06:33.120 |
in a lot of different use cases this is just one of them. So CLIP itself actually consists of two 00:06:37.760 |
models I've discussed this in a previous video and article in a lot more detail so if you're 00:06:44.480 |
interested go and have a look at that for now we're going to keep things pretty light on how 00:06:50.160 |
CLIP works but in this version of CLIP those two models are going to consist of a typical text 00:06:57.040 |
transformer model for dealing with the text encoding and a vision transformer model for 00:07:04.480 |
dealing with the image encoding. Both of these models within CLIP are optimized during training 00:07:10.800 |
in order to encode similar text and image pairs into the same vector space whilst also separating 00:07:19.360 |
dissimilar text and image pairs so they are further away in vector space so essentially 00:07:24.160 |
in that vector space similar items are together whether they are images or text. Now CLIP 00:07:30.800 |
distinguishes itself from typical image classification models for a few reasons first 00:07:37.520 |
it isn't trained for image classification and it was also trained on a very big data set of 400 00:07:44.080 |
million image to text pairs with this contrastive learning approach. So from this we get a few 00:07:50.960 |
a few benefits first for actually pre-training the models training the model CLIP only requires image 00:07:58.400 |
to text pairs which in today's age of social media they're pretty easy to get any post on 00:08:06.160 |
Instagram for example there's a image and there's usually a little caption of someone describing 00:08:10.640 |
what is in the image we have stock photo websites social media you know just everything everywhere 00:08:16.480 |
we have images and text usually tied together so there's a lot of data for us to to pull that. 00:08:22.880 |
Because of the large data set sizes that we can use with CLIP, CLIP is able to get a really good 00:08:30.160 |
general understanding of the concepts between language and images and just a general 00:08:36.560 |
understanding of the world through these two modalities and as well within these pairs the 00:08:43.360 |
text descriptions often describe the image not not just one part of the image like okay there's a dog 00:08:50.320 |
in the image but something else like dog is running in a in a grassy field okay they describe 00:08:57.440 |
something more and sometimes even describe very abstract things like the sort of feeling or mood 00:09:02.240 |
of the photo so you get a lot more information from these image text pairs than you do with a 00:09:08.080 |
typical classification data set and it's these three benefits of CLIP that have led to its 00:09:14.640 |
pretty outstanding zero shot performance across a huge number of data sets. Now the authors of CLIP 00:09:21.200 |
in the original CLIP paper they draw a really good example using CLIP and the ResNet101 model 00:09:30.560 |
trained for ImageNet classification. Now CLIP was not trained specifically for ImageNet classification 00:09:37.840 |
but they showed that zero shot performance with CLIP versus this state of the art model trained 00:09:44.400 |
for ImageNet was comparable on the actual ImageNet data set and then we when we compare them on other 00:09:52.480 |
data sets that are derived from ImageNet so you have ImageNet V2, ImageNet R, ObjectNet, ImageNet 00:09:59.200 |
Sketch and ImageNet A, CLIP outperforms the model that was specifically trained for ImageNet on 00:10:06.720 |
every single one of those data sets which is really impressive. Okay let's talk about how 00:10:13.600 |
CLIP is actually doing zero shot classification and how we can use it for that as well. So CLIP 00:10:20.480 |
well the two models within CLIP they both output a 512 dimensional vector. Now the text encoder 00:10:28.560 |
it can consume any piece of text right and then it will output a vector representation of that text 00:10:36.160 |
within sort of CLIP vector space. Then if you compare that text to an image also encoded with 00:10:43.920 |
CLIP what you should find is that text and images that are more similar are closer together. So now 00:10:51.520 |
imagine we do that but instead of so we have our images from an image classification data set 00:10:59.840 |
and then for the text we actually feed in the class labels for that classification task. 00:11:06.720 |
Then you process all that and then calculate similarity between the the outputs and whichever 00:11:13.680 |
of your text embeddings has a high similarity to each image that is like your class okay your 00:11:21.520 |
predicted class. Okay so let's move on to an actual applied example and implementation of 00:11:28.560 |
zero shot learning with CLIP. Okay so to start we will need to pip install datasets torch and 00:11:35.280 |
transformers and what we're going to do is download a dataset. So this is the frgfm image net dataset 00:11:45.040 |
we've used this a couple of times before and it just contains 10 different classes not too much 00:11:52.080 |
data here we're looking at a validation set. So we have just under 4000 items here and if we have 00:11:58.480 |
a look at what we have in the labels feature so in the images image feature we obviously have the 00:12:03.520 |
images themselves label feature we have these 10 labels okay but they're just numbers they're 00:12:11.600 |
integer values. We obviously need text for this to work with CLIP so we need to modify these 00:12:18.480 |
or we need to map these to the actual text labels. Now we do that by taking a look at the 00:12:26.000 |
hugging face dataset info features and then label names. Okay so most hugging face datasets will have 00:12:33.840 |
a format similar to this where you can find extra data set information like the label names. Okay 00:12:40.640 |
and then from there we can see we have tench english springer cassette player you know a few 00:12:45.760 |
different things all of these map directly to the values here. So for zero we'd have tench one we'd 00:12:52.720 |
have english springer and so on. So as before we're going to convert these into sentences. 00:12:59.680 |
So a photo of a tench photo of a english springer and so on and so on. Okay so from here before we 00:13:06.640 |
can compare the labels and the images we actually need CLIP. So we can initialize CLIP through 00:13:14.080 |
hugging face so we use this model id and then we use model processor which is going to pre-process 00:13:20.480 |
our images and text and then we also click model here which is the actual model itself. 00:13:25.680 |
And then we can also run it on CUDA if you have a CUDA enabled GPU. For me I'm just running this 00:13:32.960 |
on Mac so CPU. NPS as far as I know it's not supported in full or CLIP is not supported in 00:13:41.520 |
full with NPS yet. So that's like the the Mac M1 version of CUDA. For now CPU is fast enough it's 00:13:49.520 |
not it's not slow so it's not a problem. Now one thing here is that text transformers don't read 00:13:56.320 |
text directly like we do. They need like a translation from text into what are called 00:14:03.760 |
input ids or token ids which are just integer representations of either words or sub words 00:14:11.680 |
from the original text. So we do that with the processor here it's passing our text padding 00:14:16.960 |
we set to true so that everything is the same size we need this when we're running things in 00:14:21.520 |
parallel multiple inputs in parallel essentially when we're using batches. We're not passing our 00:14:27.440 |
images here and we're going to return PyTorch tensors. Okay and we're going to move all that 00:14:33.600 |
to our device for me it's just CPU so it doesn't actually matter but it's fine. And then here we 00:14:38.240 |
can see those tokens so we have a starter sequence token here and then we could imagine this is 00:14:44.160 |
something like a photo of a tench something along those lines and then end of sequence over there as 00:14:51.760 |
well. So we can encode these tokens into sentence embeddings all we do is this so pass our sentence 00:14:58.240 |
our sorry our tokens in here label tokens now in here we have input ids and also another tensor 00:15:05.760 |
called attention mask and that's kind of wrapped within a dictionary which is why we're using 00:15:10.960 |
these two asterisks here to iteratively pass both of those tensors as 00:15:15.600 |
individual items to the get text features function. 00:15:20.320 |
And then after we have our label embeddings over here we just want to detach them from PyTorch 00:15:27.200 |
gradient computation of the model and convert that into NumPy and then we can see from that that we 00:15:32.960 |
get 10 512 dimensional embeddings. Okay so they're now the text embeddings within that click vector 00:15:41.120 |
space and one thing to note here is that they're not normalized okay we can see they're not normalized 00:15:46.960 |
so we can either use cosine similarity to compare them or we can normalize them and then we can use 00:15:54.240 |
dot product similarity. Now if we normalize first I find the code later on to be simpler so we will 00:16:01.440 |
do that here so we're going to normalize here it's pretty simple and then we can see straight 00:16:06.960 |
away they're normalized we can just use dot product similarity now. That's the text embedding or the 00:16:13.120 |
label embedding part now what we want to do is have a look at how we do the image embeddings 00:16:18.880 |
and then how we compare them. So we're going to start with this image first just a single image 00:16:24.240 |
we'll go through the whole data set in full later so we just have a cassette recorder here 00:16:29.600 |
now we go down here we're just going to process the image so using the same processor 00:16:33.520 |
we set text to non because there's no text this time and what we want to do is just pass that 00:16:38.480 |
image in here to images and we're going to return tensors it's pytorch tensors and extract the pixel 00:16:44.320 |
values. Now the reason that we have to process the image is clip expects every tensor that it sees to 00:16:50.880 |
be normalized which is the first thing it does and also a particular shape it expects this shape here 00:16:58.080 |
so three color channels which is through here a 224 pixel wide image and 224 pixel height image 00:17:08.480 |
okay so all we're doing there normalization and resizing and then from there we can pass it to get 00:17:16.880 |
image features image again like here we didn't we didn't include the iterable because this is just a 00:17:24.400 |
single tensor we don't need to pass the the two asterisks here and we get a single embedding here 00:17:31.600 |
one vector which is 512 dimensions as with our label embeddings we're going to detach them move 00:17:37.360 |
them to cpu and then convert to numpy i already have them cpu listed this part isn't necessary 00:17:43.600 |
but if you're using cuda it will be and then we don't need to normalize them so when we're doing 00:17:49.040 |
dot product similarity we just need one side of the calculation to be normalized not both 00:17:54.560 |
okay so with that we do numpy dot we have our image embedding and then we transpose label 00:18:01.360 |
embeddings and you see that we get scores and the shape of those scores is one dimension that's not 00:18:07.280 |
important here we have those 10 similarity scores so one similarity value for each of our 10 labels 00:18:17.520 |
okay so then we can take the index of the highest score which happens to be index two and then we 00:18:24.320 |
find out okay which which label is that it is cassette player okay so it's correct that's 00:18:31.200 |
pretty cool now let's have a look how we do that for the whole data set so all i'm going to do is 00:18:36.560 |
we're going to go through in a loop through the whole data set we're going to do in batches of 32 00:18:42.320 |
process everything this is all just the same stuff okay process get the embeddings get dot product 00:18:49.040 |
okay we don't need to redo this for the labels because we we just have the 10 labels throughout 00:18:53.520 |
the whole thing so it's not necessary we'll get the arg max and we're just going to append or 00:18:59.440 |
extend a prediction list with all of those predictions okay and let's see what we get 00:19:08.800 |
well let's see what the performance is there so here calculating how many of them align to the 00:19:14.000 |
true values and we can see we get 0.987 okay so that means we get 98.7 accuracy which is pretty 00:19:26.400 |
insane when you consider that we have done no training for clip here it has not seen any of 00:19:32.000 |
these labels it has not seen any of these images this is like out of the box zero shot classification 00:19:39.840 |
and it's scoring 98.7 accuracy which is i think really very very impressive so this is uh i think 00:19:50.080 |
a good example why i think zero shot classification or image classification with clip is such an 00:19:57.040 |
interesting use case and it's just so easy right you can do this for a whole ton of data sets and 00:20:03.680 |
get good performance it's not going to be state-of-the-art performance but pretty good 00:20:07.840 |
performance like like this super super easy so before clip i as far as i'm aware this sort of 00:20:16.000 |
thing wasn't possible okay every every domain adaption to a new classification task needed 00:20:22.720 |
training data it needed training and so on with this it's just a case of you need to write some 00:20:29.680 |
labels maybe modify them into sentences and then you're you're good to go so that's why i think 00:20:37.440 |
clip is i think it has created a pretty big leap forward in quite a few areas such as image 00:20:45.680 |
classification so when i think of clip there is in these in the short time that it's been around 00:20:52.560 |
we have multi-modal search now zero shot image classification object localization or image 00:20:59.040 |
localization object detection also zero shot and we'll go into that in more detail pretty soon 00:21:05.200 |
and even industry changing tools like openai's dali includes a clip model 00:21:10.400 |
stable diffusion as far as i know also includes a clip model so there's this massive range of 00:21:16.800 |
use cases that clip is being used for and i think that's super interesting so that's it for this 00:21:22.400 |
video i hope you have found all this as interesting as i do so now thank you very much for watching