back to index

OpenAI's CLIP for Zero Shot Image Classification


Chapters

0:0
3:54 Zero Shot Classification
6:17 How Clip Makes Zero Shot Learning So Effective
7:53 Pre-Training
16:17 Image Embeddings

Whisper Transcript | Transcript Only Page

00:00:00.000 | Today we're going to talk about how to use CLIP for zero shot image classification. That is image
00:00:06.960 | classification without needing to fine-tune your model on a particular image data set. For me this
00:00:15.600 | is one of the most interesting use cases that really demonstrates the power of these multimodal
00:00:22.560 | models. Those are models that understand different domains of information like for CLIP for example
00:00:29.120 | that is image and text. But how about other state-of-the-art computer vision models? How do
00:00:34.800 | they perform in comparison? Well these other computer vision models they are characterized
00:00:40.880 | by the fact that they really focus on one thing in an image. So for example let's say we had a
00:00:48.320 | computer vision model that was trained to classify an image as to whether it contained a dog,
00:00:55.280 | a car, or a bird. Okay that computer vision model doesn't really need to know anything other than
00:01:01.760 | what a car looks like, what a dog looks like, and what a bird looks like. It can ignore everything
00:01:06.880 | else in those images and that's one of the limitations of these state-of-the-art computer
00:01:12.560 | vision models. They have been pre-trained on a huge amount of data but they've been pre-trained
00:01:18.240 | for classification of these particular elements and even if you have like a thousand or ten
00:01:24.160 | thousand classes in there that hardly represents the real world which has a huge variety of concepts
00:01:31.520 | and objects for us to as humans to to understand and see and categorize. So what we find with these
00:01:40.080 | computer vision models is that they perform very well on specific data sets but they're not so good
00:01:50.080 | at handling new classes for example. For that model to understand a new class of objects and
00:01:57.760 | images it will need to be fine-tuned further and it's going to need a lot of data and it's just
00:02:04.000 | it's not very easy to do. Ideally we want a computer vision model to just understand
00:02:08.400 | kind of not everything would be perfect obviously we're not quite there yet but if it can just
00:02:14.320 | understand the world and the visual world relatively well then we're sort of on the way
00:02:20.800 | to a more robust model. So for example this image of a dog we want the model to understand that dog
00:02:27.200 | is in the image and a you know usually use convolutional neural network that sort of model
00:02:32.800 | for image classification will know that there is a dog in the image but it doesn't it doesn't care
00:02:36.880 | about everything else in the image. Ideally we want it to understand that there are trees in
00:02:40.320 | the background the dog is running towards the camera it's on a grassy field it's sunrise or
00:02:45.280 | sunset and that there are some blurry trees in the background the blue street. Unfortunately
00:02:51.120 | classification training we don't get that. Instead the models are essentially just learning to push
00:02:57.200 | their internal representations of images with dogs in all towards the same sort of vector space.
00:03:04.160 | That's essentially how we can we can think about this and then for example earlier I said cars and
00:03:09.760 | birds as well we could imagine so they have like a cluster of dog images they have a cluster of
00:03:14.080 | bird images and they have a cluster of car images but they don't really have anything else in
00:03:19.360 | between that. So this is ideal if we just want a yes or no answer for a specific data set and we
00:03:24.800 | want good performance we can do that as long as we have the training data and the compute and the
00:03:29.200 | time to actually do that. However if we don't have all of that and we just want good performance
00:03:36.720 | maybe not state-of-the-art but good performance across a whole range of data sets that's where
00:03:41.600 | we use CLIP. CLIP has proved itself as incredibly flexible model that can work in both text and
00:03:51.040 | images and is amazing at what something we call zero shot classification. Zero shot basically
00:03:58.080 | saying you need zero training examples for this model to adapt to a new domain. So before we dive
00:04:05.120 | into CLIP let's just explain this zero shot thing in a little more detail. So zero shot comes from
00:04:11.440 | something called end shot learning. End shot you may have guessed is basically the number
00:04:16.320 | of training examples that you need for your model to perform on a particular new domain
00:04:21.280 | on your data set. Many state-of-the-art image classification models they tend to be pre-trained
00:04:27.120 | on like ImageNet and then they're fine-tuned for a specific task so that they have the pre-training
00:04:32.960 | and the pre-training basically sets up the internal model ways of that model to understand
00:04:38.640 | the visual world at least within the scope of the ImageNet classification set which is fairly big
00:04:46.800 | but it's obviously not as big as the actual world and then those models are usually fine-tuned
00:04:54.400 | on a particular data set and to fine-tune that pre-trained image classification model on a new
00:05:02.560 | domain you are going to need a lot of examples. Let's say as a rule of thumb maybe you need 10,000
00:05:11.680 | images for each class or each label within your data set. That may be excessive it may be too
00:05:18.880 | little I'm not sure but you do need something within that ballpark in order to get good
00:05:25.760 | performance. We could refer to these models so these are like ResNet and BERT as many shot
00:05:32.960 | learners they need many many training examples in order to learn a new domain. Ideally we want to
00:05:39.040 | maximize model performance whilst minimizing the n in n shot okay so minimizing the number of
00:05:46.800 | training examples needed for the model to perform well. Now so as I was noting that CLIP is not
00:05:52.560 | achieving state-of-the-art performance on any particular data sets or benchmarks other than one
00:05:59.440 | surprisingly without seeing any training data for this particular data set CLIP did actually get
00:06:06.480 | state-of-the-art performance on that one data set which is surprising without seeing any of the
00:06:12.320 | training data but here we go this is this is how useful this sort of thing is. Let's talk about
00:06:18.640 | how CLIP makes zero-shot learning so effective. So CLIP stands for Contrastive Language Image
00:06:24.480 | Pre-training it was released by OpenAI in 2021 and since then it has done pretty well we can find it
00:06:33.120 | in a lot of different use cases this is just one of them. So CLIP itself actually consists of two
00:06:37.760 | models I've discussed this in a previous video and article in a lot more detail so if you're
00:06:44.480 | interested go and have a look at that for now we're going to keep things pretty light on how
00:06:50.160 | CLIP works but in this version of CLIP those two models are going to consist of a typical text
00:06:57.040 | transformer model for dealing with the text encoding and a vision transformer model for
00:07:04.480 | dealing with the image encoding. Both of these models within CLIP are optimized during training
00:07:10.800 | in order to encode similar text and image pairs into the same vector space whilst also separating
00:07:19.360 | dissimilar text and image pairs so they are further away in vector space so essentially
00:07:24.160 | in that vector space similar items are together whether they are images or text. Now CLIP
00:07:30.800 | distinguishes itself from typical image classification models for a few reasons first
00:07:37.520 | it isn't trained for image classification and it was also trained on a very big data set of 400
00:07:44.080 | million image to text pairs with this contrastive learning approach. So from this we get a few
00:07:50.960 | a few benefits first for actually pre-training the models training the model CLIP only requires image
00:07:58.400 | to text pairs which in today's age of social media they're pretty easy to get any post on
00:08:06.160 | Instagram for example there's a image and there's usually a little caption of someone describing
00:08:10.640 | what is in the image we have stock photo websites social media you know just everything everywhere
00:08:16.480 | we have images and text usually tied together so there's a lot of data for us to to pull that.
00:08:22.880 | Because of the large data set sizes that we can use with CLIP, CLIP is able to get a really good
00:08:30.160 | general understanding of the concepts between language and images and just a general
00:08:36.560 | understanding of the world through these two modalities and as well within these pairs the
00:08:43.360 | text descriptions often describe the image not not just one part of the image like okay there's a dog
00:08:50.320 | in the image but something else like dog is running in a in a grassy field okay they describe
00:08:57.440 | something more and sometimes even describe very abstract things like the sort of feeling or mood
00:09:02.240 | of the photo so you get a lot more information from these image text pairs than you do with a
00:09:08.080 | typical classification data set and it's these three benefits of CLIP that have led to its
00:09:14.640 | pretty outstanding zero shot performance across a huge number of data sets. Now the authors of CLIP
00:09:21.200 | in the original CLIP paper they draw a really good example using CLIP and the ResNet101 model
00:09:30.560 | trained for ImageNet classification. Now CLIP was not trained specifically for ImageNet classification
00:09:37.840 | but they showed that zero shot performance with CLIP versus this state of the art model trained
00:09:44.400 | for ImageNet was comparable on the actual ImageNet data set and then we when we compare them on other
00:09:52.480 | data sets that are derived from ImageNet so you have ImageNet V2, ImageNet R, ObjectNet, ImageNet
00:09:59.200 | Sketch and ImageNet A, CLIP outperforms the model that was specifically trained for ImageNet on
00:10:06.720 | every single one of those data sets which is really impressive. Okay let's talk about how
00:10:13.600 | CLIP is actually doing zero shot classification and how we can use it for that as well. So CLIP
00:10:20.480 | well the two models within CLIP they both output a 512 dimensional vector. Now the text encoder
00:10:28.560 | it can consume any piece of text right and then it will output a vector representation of that text
00:10:36.160 | within sort of CLIP vector space. Then if you compare that text to an image also encoded with
00:10:43.920 | CLIP what you should find is that text and images that are more similar are closer together. So now
00:10:51.520 | imagine we do that but instead of so we have our images from an image classification data set
00:10:59.840 | and then for the text we actually feed in the class labels for that classification task.
00:11:06.720 | Then you process all that and then calculate similarity between the the outputs and whichever
00:11:13.680 | of your text embeddings has a high similarity to each image that is like your class okay your
00:11:21.520 | predicted class. Okay so let's move on to an actual applied example and implementation of
00:11:28.560 | zero shot learning with CLIP. Okay so to start we will need to pip install datasets torch and
00:11:35.280 | transformers and what we're going to do is download a dataset. So this is the frgfm image net dataset
00:11:45.040 | we've used this a couple of times before and it just contains 10 different classes not too much
00:11:52.080 | data here we're looking at a validation set. So we have just under 4000 items here and if we have
00:11:58.480 | a look at what we have in the labels feature so in the images image feature we obviously have the
00:12:03.520 | images themselves label feature we have these 10 labels okay but they're just numbers they're
00:12:11.600 | integer values. We obviously need text for this to work with CLIP so we need to modify these
00:12:18.480 | or we need to map these to the actual text labels. Now we do that by taking a look at the
00:12:26.000 | hugging face dataset info features and then label names. Okay so most hugging face datasets will have
00:12:33.840 | a format similar to this where you can find extra data set information like the label names. Okay
00:12:40.640 | and then from there we can see we have tench english springer cassette player you know a few
00:12:45.760 | different things all of these map directly to the values here. So for zero we'd have tench one we'd
00:12:52.720 | have english springer and so on. So as before we're going to convert these into sentences.
00:12:59.680 | So a photo of a tench photo of a english springer and so on and so on. Okay so from here before we
00:13:06.640 | can compare the labels and the images we actually need CLIP. So we can initialize CLIP through
00:13:14.080 | hugging face so we use this model id and then we use model processor which is going to pre-process
00:13:20.480 | our images and text and then we also click model here which is the actual model itself.
00:13:25.680 | And then we can also run it on CUDA if you have a CUDA enabled GPU. For me I'm just running this
00:13:32.960 | on Mac so CPU. NPS as far as I know it's not supported in full or CLIP is not supported in
00:13:41.520 | full with NPS yet. So that's like the the Mac M1 version of CUDA. For now CPU is fast enough it's
00:13:49.520 | not it's not slow so it's not a problem. Now one thing here is that text transformers don't read
00:13:56.320 | text directly like we do. They need like a translation from text into what are called
00:14:03.760 | input ids or token ids which are just integer representations of either words or sub words
00:14:11.680 | from the original text. So we do that with the processor here it's passing our text padding
00:14:16.960 | we set to true so that everything is the same size we need this when we're running things in
00:14:21.520 | parallel multiple inputs in parallel essentially when we're using batches. We're not passing our
00:14:27.440 | images here and we're going to return PyTorch tensors. Okay and we're going to move all that
00:14:33.600 | to our device for me it's just CPU so it doesn't actually matter but it's fine. And then here we
00:14:38.240 | can see those tokens so we have a starter sequence token here and then we could imagine this is
00:14:44.160 | something like a photo of a tench something along those lines and then end of sequence over there as
00:14:51.760 | well. So we can encode these tokens into sentence embeddings all we do is this so pass our sentence
00:14:58.240 | our sorry our tokens in here label tokens now in here we have input ids and also another tensor
00:15:05.760 | called attention mask and that's kind of wrapped within a dictionary which is why we're using
00:15:10.960 | these two asterisks here to iteratively pass both of those tensors as
00:15:15.600 | individual items to the get text features function.
00:15:20.320 | And then after we have our label embeddings over here we just want to detach them from PyTorch
00:15:27.200 | gradient computation of the model and convert that into NumPy and then we can see from that that we
00:15:32.960 | get 10 512 dimensional embeddings. Okay so they're now the text embeddings within that click vector
00:15:41.120 | space and one thing to note here is that they're not normalized okay we can see they're not normalized
00:15:46.960 | so we can either use cosine similarity to compare them or we can normalize them and then we can use
00:15:54.240 | dot product similarity. Now if we normalize first I find the code later on to be simpler so we will
00:16:01.440 | do that here so we're going to normalize here it's pretty simple and then we can see straight
00:16:06.960 | away they're normalized we can just use dot product similarity now. That's the text embedding or the
00:16:13.120 | label embedding part now what we want to do is have a look at how we do the image embeddings
00:16:18.880 | and then how we compare them. So we're going to start with this image first just a single image
00:16:24.240 | we'll go through the whole data set in full later so we just have a cassette recorder here
00:16:29.600 | now we go down here we're just going to process the image so using the same processor
00:16:33.520 | we set text to non because there's no text this time and what we want to do is just pass that
00:16:38.480 | image in here to images and we're going to return tensors it's pytorch tensors and extract the pixel
00:16:44.320 | values. Now the reason that we have to process the image is clip expects every tensor that it sees to
00:16:50.880 | be normalized which is the first thing it does and also a particular shape it expects this shape here
00:16:58.080 | so three color channels which is through here a 224 pixel wide image and 224 pixel height image
00:17:08.480 | okay so all we're doing there normalization and resizing and then from there we can pass it to get
00:17:16.880 | image features image again like here we didn't we didn't include the iterable because this is just a
00:17:24.400 | single tensor we don't need to pass the the two asterisks here and we get a single embedding here
00:17:31.600 | one vector which is 512 dimensions as with our label embeddings we're going to detach them move
00:17:37.360 | them to cpu and then convert to numpy i already have them cpu listed this part isn't necessary
00:17:43.600 | but if you're using cuda it will be and then we don't need to normalize them so when we're doing
00:17:49.040 | dot product similarity we just need one side of the calculation to be normalized not both
00:17:54.560 | okay so with that we do numpy dot we have our image embedding and then we transpose label
00:18:01.360 | embeddings and you see that we get scores and the shape of those scores is one dimension that's not
00:18:07.280 | important here we have those 10 similarity scores so one similarity value for each of our 10 labels
00:18:17.520 | okay so then we can take the index of the highest score which happens to be index two and then we
00:18:24.320 | find out okay which which label is that it is cassette player okay so it's correct that's
00:18:31.200 | pretty cool now let's have a look how we do that for the whole data set so all i'm going to do is
00:18:36.560 | we're going to go through in a loop through the whole data set we're going to do in batches of 32
00:18:42.320 | process everything this is all just the same stuff okay process get the embeddings get dot product
00:18:49.040 | okay we don't need to redo this for the labels because we we just have the 10 labels throughout
00:18:53.520 | the whole thing so it's not necessary we'll get the arg max and we're just going to append or
00:18:59.440 | extend a prediction list with all of those predictions okay and let's see what we get
00:19:08.800 | well let's see what the performance is there so here calculating how many of them align to the
00:19:14.000 | true values and we can see we get 0.987 okay so that means we get 98.7 accuracy which is pretty
00:19:26.400 | insane when you consider that we have done no training for clip here it has not seen any of
00:19:32.000 | these labels it has not seen any of these images this is like out of the box zero shot classification
00:19:39.840 | and it's scoring 98.7 accuracy which is i think really very very impressive so this is uh i think
00:19:50.080 | a good example why i think zero shot classification or image classification with clip is such an
00:19:57.040 | interesting use case and it's just so easy right you can do this for a whole ton of data sets and
00:20:03.680 | get good performance it's not going to be state-of-the-art performance but pretty good
00:20:07.840 | performance like like this super super easy so before clip i as far as i'm aware this sort of
00:20:16.000 | thing wasn't possible okay every every domain adaption to a new classification task needed
00:20:22.720 | training data it needed training and so on with this it's just a case of you need to write some
00:20:29.680 | labels maybe modify them into sentences and then you're you're good to go so that's why i think
00:20:37.440 | clip is i think it has created a pretty big leap forward in quite a few areas such as image
00:20:45.680 | classification so when i think of clip there is in these in the short time that it's been around
00:20:52.560 | we have multi-modal search now zero shot image classification object localization or image
00:20:59.040 | localization object detection also zero shot and we'll go into that in more detail pretty soon
00:21:05.200 | and even industry changing tools like openai's dali includes a clip model
00:21:10.400 | stable diffusion as far as i know also includes a clip model so there's this massive range of
00:21:16.800 | use cases that clip is being used for and i think that's super interesting so that's it for this
00:21:22.400 | video i hope you have found all this as interesting as i do so now thank you very much for watching
00:21:29.440 | and i will see you again in the next one bye