Back to Index

OpenAI's CLIP for Zero Shot Image Classification


Chapters

0:0
3:54 Zero Shot Classification
6:17 How Clip Makes Zero Shot Learning So Effective
7:53 Pre-Training
16:17 Image Embeddings

Transcript

Today we're going to talk about how to use CLIP for zero shot image classification. That is image classification without needing to fine-tune your model on a particular image data set. For me this is one of the most interesting use cases that really demonstrates the power of these multimodal models.

Those are models that understand different domains of information like for CLIP for example that is image and text. But how about other state-of-the-art computer vision models? How do they perform in comparison? Well these other computer vision models they are characterized by the fact that they really focus on one thing in an image.

So for example let's say we had a computer vision model that was trained to classify an image as to whether it contained a dog, a car, or a bird. Okay that computer vision model doesn't really need to know anything other than what a car looks like, what a dog looks like, and what a bird looks like.

It can ignore everything else in those images and that's one of the limitations of these state-of-the-art computer vision models. They have been pre-trained on a huge amount of data but they've been pre-trained for classification of these particular elements and even if you have like a thousand or ten thousand classes in there that hardly represents the real world which has a huge variety of concepts and objects for us to as humans to to understand and see and categorize.

So what we find with these computer vision models is that they perform very well on specific data sets but they're not so good at handling new classes for example. For that model to understand a new class of objects and images it will need to be fine-tuned further and it's going to need a lot of data and it's just it's not very easy to do.

Ideally we want a computer vision model to just understand kind of not everything would be perfect obviously we're not quite there yet but if it can just understand the world and the visual world relatively well then we're sort of on the way to a more robust model. So for example this image of a dog we want the model to understand that dog is in the image and a you know usually use convolutional neural network that sort of model for image classification will know that there is a dog in the image but it doesn't it doesn't care about everything else in the image.

Ideally we want it to understand that there are trees in the background the dog is running towards the camera it's on a grassy field it's sunrise or sunset and that there are some blurry trees in the background the blue street. Unfortunately classification training we don't get that. Instead the models are essentially just learning to push their internal representations of images with dogs in all towards the same sort of vector space.

That's essentially how we can we can think about this and then for example earlier I said cars and birds as well we could imagine so they have like a cluster of dog images they have a cluster of bird images and they have a cluster of car images but they don't really have anything else in between that.

So this is ideal if we just want a yes or no answer for a specific data set and we want good performance we can do that as long as we have the training data and the compute and the time to actually do that. However if we don't have all of that and we just want good performance maybe not state-of-the-art but good performance across a whole range of data sets that's where we use CLIP.

CLIP has proved itself as incredibly flexible model that can work in both text and images and is amazing at what something we call zero shot classification. Zero shot basically saying you need zero training examples for this model to adapt to a new domain. So before we dive into CLIP let's just explain this zero shot thing in a little more detail.

So zero shot comes from something called end shot learning. End shot you may have guessed is basically the number of training examples that you need for your model to perform on a particular new domain on your data set. Many state-of-the-art image classification models they tend to be pre-trained on like ImageNet and then they're fine-tuned for a specific task so that they have the pre-training and the pre-training basically sets up the internal model ways of that model to understand the visual world at least within the scope of the ImageNet classification set which is fairly big but it's obviously not as big as the actual world and then those models are usually fine-tuned on a particular data set and to fine-tune that pre-trained image classification model on a new domain you are going to need a lot of examples.

Let's say as a rule of thumb maybe you need 10,000 images for each class or each label within your data set. That may be excessive it may be too little I'm not sure but you do need something within that ballpark in order to get good performance. We could refer to these models so these are like ResNet and BERT as many shot learners they need many many training examples in order to learn a new domain.

Ideally we want to maximize model performance whilst minimizing the n in n shot okay so minimizing the number of training examples needed for the model to perform well. Now so as I was noting that CLIP is not achieving state-of-the-art performance on any particular data sets or benchmarks other than one surprisingly without seeing any training data for this particular data set CLIP did actually get state-of-the-art performance on that one data set which is surprising without seeing any of the training data but here we go this is this is how useful this sort of thing is.

Let's talk about how CLIP makes zero-shot learning so effective. So CLIP stands for Contrastive Language Image Pre-training it was released by OpenAI in 2021 and since then it has done pretty well we can find it in a lot of different use cases this is just one of them. So CLIP itself actually consists of two models I've discussed this in a previous video and article in a lot more detail so if you're interested go and have a look at that for now we're going to keep things pretty light on how CLIP works but in this version of CLIP those two models are going to consist of a typical text transformer model for dealing with the text encoding and a vision transformer model for dealing with the image encoding.

Both of these models within CLIP are optimized during training in order to encode similar text and image pairs into the same vector space whilst also separating dissimilar text and image pairs so they are further away in vector space so essentially in that vector space similar items are together whether they are images or text.

Now CLIP distinguishes itself from typical image classification models for a few reasons first it isn't trained for image classification and it was also trained on a very big data set of 400 million image to text pairs with this contrastive learning approach. So from this we get a few a few benefits first for actually pre-training the models training the model CLIP only requires image to text pairs which in today's age of social media they're pretty easy to get any post on Instagram for example there's a image and there's usually a little caption of someone describing what is in the image we have stock photo websites social media you know just everything everywhere we have images and text usually tied together so there's a lot of data for us to to pull that.

Because of the large data set sizes that we can use with CLIP, CLIP is able to get a really good general understanding of the concepts between language and images and just a general understanding of the world through these two modalities and as well within these pairs the text descriptions often describe the image not not just one part of the image like okay there's a dog in the image but something else like dog is running in a in a grassy field okay they describe something more and sometimes even describe very abstract things like the sort of feeling or mood of the photo so you get a lot more information from these image text pairs than you do with a typical classification data set and it's these three benefits of CLIP that have led to its pretty outstanding zero shot performance across a huge number of data sets.

Now the authors of CLIP in the original CLIP paper they draw a really good example using CLIP and the ResNet101 model trained for ImageNet classification. Now CLIP was not trained specifically for ImageNet classification but they showed that zero shot performance with CLIP versus this state of the art model trained for ImageNet was comparable on the actual ImageNet data set and then we when we compare them on other data sets that are derived from ImageNet so you have ImageNet V2, ImageNet R, ObjectNet, ImageNet Sketch and ImageNet A, CLIP outperforms the model that was specifically trained for ImageNet on every single one of those data sets which is really impressive.

Okay let's talk about how CLIP is actually doing zero shot classification and how we can use it for that as well. So CLIP well the two models within CLIP they both output a 512 dimensional vector. Now the text encoder it can consume any piece of text right and then it will output a vector representation of that text within sort of CLIP vector space.

Then if you compare that text to an image also encoded with CLIP what you should find is that text and images that are more similar are closer together. So now imagine we do that but instead of so we have our images from an image classification data set and then for the text we actually feed in the class labels for that classification task.

Then you process all that and then calculate similarity between the the outputs and whichever of your text embeddings has a high similarity to each image that is like your class okay your predicted class. Okay so let's move on to an actual applied example and implementation of zero shot learning with CLIP.

Okay so to start we will need to pip install datasets torch and transformers and what we're going to do is download a dataset. So this is the frgfm image net dataset we've used this a couple of times before and it just contains 10 different classes not too much data here we're looking at a validation set.

So we have just under 4000 items here and if we have a look at what we have in the labels feature so in the images image feature we obviously have the images themselves label feature we have these 10 labels okay but they're just numbers they're integer values. We obviously need text for this to work with CLIP so we need to modify these or we need to map these to the actual text labels.

Now we do that by taking a look at the hugging face dataset info features and then label names. Okay so most hugging face datasets will have a format similar to this where you can find extra data set information like the label names. Okay and then from there we can see we have tench english springer cassette player you know a few different things all of these map directly to the values here.

So for zero we'd have tench one we'd have english springer and so on. So as before we're going to convert these into sentences. So a photo of a tench photo of a english springer and so on and so on. Okay so from here before we can compare the labels and the images we actually need CLIP.

So we can initialize CLIP through hugging face so we use this model id and then we use model processor which is going to pre-process our images and text and then we also click model here which is the actual model itself. And then we can also run it on CUDA if you have a CUDA enabled GPU.

For me I'm just running this on Mac so CPU. NPS as far as I know it's not supported in full or CLIP is not supported in full with NPS yet. So that's like the the Mac M1 version of CUDA. For now CPU is fast enough it's not it's not slow so it's not a problem.

Now one thing here is that text transformers don't read text directly like we do. They need like a translation from text into what are called input ids or token ids which are just integer representations of either words or sub words from the original text. So we do that with the processor here it's passing our text padding we set to true so that everything is the same size we need this when we're running things in parallel multiple inputs in parallel essentially when we're using batches.

We're not passing our images here and we're going to return PyTorch tensors. Okay and we're going to move all that to our device for me it's just CPU so it doesn't actually matter but it's fine. And then here we can see those tokens so we have a starter sequence token here and then we could imagine this is something like a photo of a tench something along those lines and then end of sequence over there as well.

So we can encode these tokens into sentence embeddings all we do is this so pass our sentence our sorry our tokens in here label tokens now in here we have input ids and also another tensor called attention mask and that's kind of wrapped within a dictionary which is why we're using these two asterisks here to iteratively pass both of those tensors as individual items to the get text features function.

And then after we have our label embeddings over here we just want to detach them from PyTorch gradient computation of the model and convert that into NumPy and then we can see from that that we get 10 512 dimensional embeddings. Okay so they're now the text embeddings within that click vector space and one thing to note here is that they're not normalized okay we can see they're not normalized so we can either use cosine similarity to compare them or we can normalize them and then we can use dot product similarity.

Now if we normalize first I find the code later on to be simpler so we will do that here so we're going to normalize here it's pretty simple and then we can see straight away they're normalized we can just use dot product similarity now. That's the text embedding or the label embedding part now what we want to do is have a look at how we do the image embeddings and then how we compare them.

So we're going to start with this image first just a single image we'll go through the whole data set in full later so we just have a cassette recorder here now we go down here we're just going to process the image so using the same processor we set text to non because there's no text this time and what we want to do is just pass that image in here to images and we're going to return tensors it's pytorch tensors and extract the pixel values.

Now the reason that we have to process the image is clip expects every tensor that it sees to be normalized which is the first thing it does and also a particular shape it expects this shape here so three color channels which is through here a 224 pixel wide image and 224 pixel height image okay so all we're doing there normalization and resizing and then from there we can pass it to get image features image again like here we didn't we didn't include the iterable because this is just a single tensor we don't need to pass the the two asterisks here and we get a single embedding here one vector which is 512 dimensions as with our label embeddings we're going to detach them move them to cpu and then convert to numpy i already have them cpu listed this part isn't necessary but if you're using cuda it will be and then we don't need to normalize them so when we're doing dot product similarity we just need one side of the calculation to be normalized not both okay so with that we do numpy dot we have our image embedding and then we transpose label embeddings and you see that we get scores and the shape of those scores is one dimension that's not important here we have those 10 similarity scores so one similarity value for each of our 10 labels okay so then we can take the index of the highest score which happens to be index two and then we find out okay which which label is that it is cassette player okay so it's correct that's pretty cool now let's have a look how we do that for the whole data set so all i'm going to do is we're going to go through in a loop through the whole data set we're going to do in batches of 32 process everything this is all just the same stuff okay process get the embeddings get dot product okay we don't need to redo this for the labels because we we just have the 10 labels throughout the whole thing so it's not necessary we'll get the arg max and we're just going to append or extend a prediction list with all of those predictions okay and let's see what we get well let's see what the performance is there so here calculating how many of them align to the true values and we can see we get 0.987 okay so that means we get 98.7 accuracy which is pretty insane when you consider that we have done no training for clip here it has not seen any of these labels it has not seen any of these images this is like out of the box zero shot classification and it's scoring 98.7 accuracy which is i think really very very impressive so this is uh i think a good example why i think zero shot classification or image classification with clip is such an interesting use case and it's just so easy right you can do this for a whole ton of data sets and get good performance it's not going to be state-of-the-art performance but pretty good performance like like this super super easy so before clip i as far as i'm aware this sort of thing wasn't possible okay every every domain adaption to a new classification task needed training data it needed training and so on with this it's just a case of you need to write some labels maybe modify them into sentences and then you're you're good to go so that's why i think clip is i think it has created a pretty big leap forward in quite a few areas such as image classification so when i think of clip there is in these in the short time that it's been around we have multi-modal search now zero shot image classification object localization or image localization object detection also zero shot and we'll go into that in more detail pretty soon and even industry changing tools like openai's dali includes a clip model stable diffusion as far as i know also includes a clip model so there's this massive range of use cases that clip is being used for and i think that's super interesting so that's it for this video i hope you have found all this as interesting as i do so now thank you very much for watching and i will see you again in the next one bye