Back to Index

GUI-based Few Shot Classification Model Trainer | Demo


Chapters

0:0 Intro
1:14 Classification
2:49 Better Classifier Training
6:33 Classification as Vector Search
8:47 How Fine-tuning Works
10:50 Identifying Important Samples
12:39 CODE IMPLEMENTATION
13:13 Indexing
18:59 Fine-tuning the Classifier
27:37 Classifier Predictions
30:43 Closing Notes

Transcript

Today we're going to talk about a more effective way of training classification models. Nowadays pre-trained models dominate the field of machine learning. There are very few ML projects that start with us actually training a model from scratch. Instead we usually start by looking for an off-the-shelf pre-trained model. Whether that pre-trained model is from an online platform like PyTorch Hub or HuggingFace Hub or from our own internal already trained in-house models.

The ecosystem of these pre-trained models whether external or internal has allowed us to push the limits of what is possible in machine learning. This doesn't mean however that everything is super easy and everything works all the time. There are always going to be some challenges. Fortunately we're able to tackle a lot of these problems that are actually shared across a huge number of pre-trained models because they tend to have similar points of failure.

One of those is the excessive compute and data needed to actually fine-tune one of these models. Now focusing on classification a very typical scenario that we have is that we have some model some big model like BERT or T5 and what we want to do is fine-tune this model for classification.

Now one way we can do that is we add a simple linear layer onto the end of it and then we fine-tune that linear layer. Now what I want us to focus on here is the model that comes before doesn't really matter. We only really care about this linear layer.

We can actually fine-tune that for a lot of different use cases without even touching the model weights of the big pre-trained model that comes before it. It's a classification layer that is actually producing the final prediction and because of this that classification layer can become the single point of failure in producing our predictions.

So we focus on fine-tuning that classification layer and a common approach to doing this might look a little bit like this. First we have to collect a data set that focuses on enabling this model to adapt to a new domain or just dealing with data drift. Then we have to slog through the data set and if it's going to work well it's usually a large data set labeling the records as per their classification and then once all records have been labeled we have to fine-tune the classifier.

This approach works but it really is not efficient. There's actually a much better way of doing this. What we need to do is focus our fine-tuning efforts on the essential records that actually matter. Otherwise we're wasting time, our own time, and compute on annotating and fine-tuning across the entire data set when the vast majority of the data in the data set probably doesn't matter.

So now the question is how do we decide which samples are actually essential and which are not? Well that's where we can use vector search. We can use vector search to search through our data set before we even annotate everything and identify the records that are going to make the biggest impact on our model performance.

Meaning we save our time and a lot of compute by just skipping the non-essential records. Some of you may be thinking what does vector search have to do with training a classification model? Well it's actually super important. Many state-of-the-art models are available as pre-trained models. Those are models like BERT, T5, EpsilonNet, OpenAI's CLIP.

These models use an insane number of parameters and perform a lot of complex operations. Yet when applied to classification we're actually relying on the final layers that are added onto the end of these huge models. So we might have some simple feed forward layers or just a linear classification layer.

Now the reason for this is that these models they're not being trained to produce class predictions. We can think of them as actually being trained to make vector embeddings. So we pre-train these big models and the idea is that after pre-training these models will produce these very information rich vector embeddings.

And then what we do for different tasks is that we add an extra task specific head onto the end of that. And that task specific head is taking that vector embedding or vector embeddings from the model and running them through a smaller network. Like I said it can just be a linear layer and outputting something else.

Outputting those predictions. So the power of these models is not that they can do classification, question answering, all these different things. The power of these models is that they produce these very information rich vectors that then smaller simpler models can use to do these tasks of question answering, classification and so on.

These vectors that these models are producing are simply full of useful information that have been encoded into a vector space. Okay so you can imagine in this vector space, imagine a 2D space, we have vector A here, vector B here. Those two are very close to each other and therefore they share some sort of similar meaning.

Whereas vector C over here is very far away from from A and B. Therefore it shares less meaning with A and B. Now the result of this is that these models are essentially creating a map of information. Using this map they're able to consume data like images or tech and output these useful information rich representations with vectors.

So our task in classification now is not to consume data and try and abstract different meaning from that and classify that abstraction of meaning. In reality the abstraction of meaning is already handled by the big models. Instead our task with classification is to teach a smaller model to identify the different regions within that map or the vector space.

Now a typical architecture that we will see for classification is a pre-trained model followed by a linear layer. Now we can think of the internal weights of this classifier as actually being a vector within the wider vector space. And Ido Liberty, the founder and CEO of Pinecone and past head of Amazon AI Labs explained to me that we can actually use this fact and couple it with vector search in order to massively optimize the learning process for our classifier.

So what we need to do is really imagine this problem as being within a vector space or a map. We have the internal model weights w and we have all these vectors that as of yet are unannotated and we haven't fine-tuned on them yet. We want to calculate the dot products between w and x.

If they share a positive direction they will have a positive value and they produce a negative score if the directions are opposite. Now there is just one problem with dot product here. It considers both direction and magnitude which means that if we have a vector x that has a larger magnitude than another vector x even if that other vector is actually the same vector as our model weights or very similar it can actually output a larger dot product score.

So what we need to do is normalize all these vectors that we're comparing. This simply removes the magnitude problem and makes it that we are comparing only the direction of the vectors. Now when we fine-tune the linear classifier with these vectors it's going to learn to align itself with vectors that we label as positives and move away from vectors we label as negatives.

Now this will work really well but there are still some improvements that we could add in here. First imagine we return only irrelevant samples in a single training batch. They will all be marked as negative one and the classifier knows to move away from these values but it doesn't know in which direction.

Okay and especially in a high dimensional space there are a lot of directions that the classifier can move in. So this is problematic because it means that the classifier is just going to be moving at random away from those negative vectors. Another problem is that many labels be more or less relevant.

So imagine we had the query dogs in the snow and then we had two pieces of text a dog and a dog in the snow. Both of those are relevant depending on what you're looking at but a dog in the snow is more relevant. These two pieces of text are not equally relevant but at the moment all we can do is label one as negative one as positive or both as positives and that's not really ideal because it doesn't really show the full picture of both of these are relevant just one is more than the other.

So what we need is almost like a gradient of relevance. We need a continuous range from negative e.g. minus one to positive e.g. plus one. Even if we just have a range from negative one to negative 0.8 there's still a direction that the model can figure out from that range of values.

So all of this together just allows our linear classifier to learn where to place itself within the vector space produced by the model layers preceding it. Now that describes a fine-tuning process but we can't do this across our entire data set. If we have like a big data set which we probably do it would take too much time annotating everything and it would be a waste of our time as well.

To do this efficiently what we must do is capitalize on the idea of identifying relevant versus irrelevant vectors within a proximity of the model's learned weights w. So we focus our efforts on the specific area that is actually going to be helpful. For an already trained classifier those are going to be the false positives and false negatives predicted by the classifier.

However we also usually don't have a list of false negatives and false positives but we do know that the solvable errors will be present near the classifier's decision boundary e.g. the line that separates the positive predictions from negative predictions. So we use vector search in order to actually pull in the high proximity samples that are most similar to the model weights w.

We then label those vectors and use them for training our model. The model optimizes those internal weights w. We extract them again and then we perform a vector search with them again and we just keep repeating this process over and over again until the linear classifier has been optimized and is producing the correct predictions that we need.

So by focusing annotation and training on these essential samples we avoid wasting time and compute on those vectors that don't make as much of a difference. Okay so all of that is the general idea behind this process. Now let's have a look at how we can put all that together and fine-tune a classifier with vector search.

Now we will see that there are two parts to the training process. First we need to index our data so that is where we embed everything using the preceding model layers e.g. BERT or CLIP or so on and then store those in a vector database and then step two is that we actually fine-tune the classifier.

So query with model weights w, return the most similar records, annotate them and then use them to fine-tune the classifier. So let's go ahead and start with indexing. Given a data set of images or other formats we first need to process everything through the big model preceding our linear classifier to create the vector embeddings.

For our example we're going to use a model called CLIP that's capable of understanding both text and images and it has been trained on text image pairs and has learned how to encode them into as similar vector space as possible. So what we're going to need to start with before indexing anything is initializing a data set that we can then encode with CLIP.

So we're going to use this data set from Hugging Face datasets hub. So we can pip install everything we're going to need for this here. We're taking the train split and that contains 9.5 000 images. Some of those are radios like you can see here, there's pictures of dogs, trucks and a few other things.

And we can see an array of one of those images right there. Now it's not so important for what we're doing here. What we do want to do is actually initialize both the model and the pre-processing steps before the data is being fed into the model. So we do that here.

So initialize the model CLIP using this model ID here. Okay so this is one version the CLIP model. And then the pre-processor will just take images and process them so that CLIP can read them. Okay as all we're doing here we're going to go through all of these steps.

This is the pre-processing and from that we get the image features. Those image features are a vector representation of the image. So in this case we've done the Sony radio image and that gives us a 512 dimensional vector embedding. The embeddings from CLIP are not normalized. Okay so we're going to be using dot product both within the model and during our vector search.

So we should really normalize these. So we do that here and then we see that these values are all between the values of negative one to two plus one. Now that's how we embed or create a vector embedding for a single item. But we're going to want to do for loads of items and we're also going to want to index them and store them inside a vector database.

So we're going to use Pinecone for this. You may need to sign up for a free API key if you haven't already. And what we do is initialize our connection to Pinecone here. You just put your API key here. It's all free. And then we create an index. Now it's important that we have a few things here.

So the index name that doesn't actually matter. Okay you can put whatever you want. But what you do need is the correct dimensionality. So that is the 512 that you saw up here. That is what we put in here. We do need to make sure that we're using dot product similarity.

And we're going to also include this metadata config. So basically when once we see an image and we label it we're going to tell Pinecone we don't want to return that image again. Okay so that we can go through and not over optimize on like 10 images. And then we connect to the index after we have created it there.

Now to add that single feature embedding that we just created, that image embedding we just created, we would do this. Okay so we have an ID and then we just convert the embedding into a list format and we just upsert. So with that we have one embedding within our vector index.

But of course we want to have our full data set in there so we can search for it and add data and so on. So to do that we're going to use this loop here. I'm not going to go through because it's literally what we've just done. Okay the only thing I think I've added here is this which is checking for grayscale versus RGB images.

But the rest of this is exactly the same. Okay we're just going we're doing it all at a larger scale and we're also adding in the metadata here. Okay so that's seen. We're setting it to zero for all the images to start with and then we'll set it to one once we've seen a set of images.

Mark them as you know positive or negative and train with them. Then we set that seen value to one so we don't return it again. Okay so we have this this radio. Let's have a quick look at how we might query. So we create our query vector xq here which is just we're doing the same thing again as what we did before.

Normalizing it and then we query with it. Okay and that returns these items here from Pinecone. Let's have a look at what they look like. So the first one is obviously that radio. That radio is the most similar of the vector. So naturally that would be the first thing that gets returned.

Okay next one we have a car radio. We have another Sony radio. I think it's even the same model. And another Sony radio which is also the same model. It seems so. And then just another radio. It's very similar. So clearly those embeddings are pretty good from Clip. But now what we want to do is fine-tune a linear classifier on top of that to classify these different images.

Okay so to do that I'm going to start from scratch. So this is a new notebook. You can find all the links to these notebooks by the way in the video description or if you're watching this on the article down at the bottom of the article in the resources section.

So here initialize the connection to the index again. You don't need to do this if you just ran through the last bit of code. You can just keep that as it is and maintain your connection to the index. Again we're going to load the data set and again you don't need to do that if you've already done it.

Initialize the model. So Clip and the processor. So there's one thing different here and you can actually tokenize using the other preprocessor. But for the sake of covering everything I'm just showing you how to do with the Clip tokenizer fast here as well. So here we're initializing just the tokenizer side of the Clip preprocessor.

And we're setting up this prompt. So dogs in the snow. We tokenize them to get a set of token IDs and then we use the model get text features method in order to get a vector embedding of that text, of that dogs in the snow prompt. Okay and we come down here.

We create the query vector from that and we're just going to retrieve top 10 most similar records and store them in in XC. So it's just like the contents. So there's a few things in XC here. We actually don't need all of this. So what we want is the IDs and then the values as well.

So first we get the IDs then we get the values. Okay and we can see why it's returned. So dogs in the snow. Right this one is not a dog in the snow but you can kind of see where it's a bit confused. The sand in the background does look kind of white and snowy.

But then the rest of these yeah they're dogs in the snow other than this one. So it's returning the right thing here but let's say we don't want dogs in the snow. Okay let's say we want to adjust this to something slightly different. Like for example dogs at dog shows and we'll go through this.

So this code here not really that important. All this is is a little interface that I built within Jupiter so that we can sort of quickly go through and label the images. So I would run this. Okay I'm not going to run it again. So I'll just run this here and basically what it's going to do is it's going to show an image.

So example this one here it's going to show the image and say okay what you rate this from negative one to one. And you just go through you say you know what you what you would rate it. And then that will give you or that will basically produce a dictionary that maps these ideas to the score that you gave it.

So you can see all the scores I gave last time I ran this. And you can just double check that the ideas and scores are aligned here. Yes they are so you don't need to worry so much about that. And all we do is we need to get the values which are going to be the inputs of training data for the linear classifier.

And then we get the labels okay so the scores. So we go through and what we're going to do here is just initialize a PyTorch linear classifier layer. And what I do first is so in most cases I imagine that we're going to have a linear classifier already trained.

So I'm just emulating that here. So I'm getting the query vector reshaping that and I'm inserting it as the first set of model weights w. And what we're going to do is we're going to initialize the loss. We're going to use bc with logics loss. And we're going to use stochastic gradient descent.

Now this learning rate you'll probably find that's quite high. And it is high. We're just kind of putting it high so that we can see a lot of like quick movement through the data set. If you're actually implementing something like this you might want to use a lower learning rate.

So with that we just create this function fit here which is basically just a training loop. And we can set the number of iterations per training loop. Again you might want to lower this if you don't want to move so quickly through the vector space and keep things a bit more stable.

And yeah we'll just call fit. From that the model weight will actually be optimized and it will change. And that will represent the next query that we're going to pass into our vector database. So we convert into a flat list so that Pinecone can we can query in Pinecone with it.

And so that we're not returning the same records that we just went through. We update the metadata attached to each one of the vectors that we've just seen to be set to equal scene equals one. And then the reason we do that is because we add a filter now to the next query where we set scene equal to zero.

Okay and then we return the next set of queries and we can see here we have some other images. And basically what I'm doing here is trying to optimize for dogs and fields. And then from dogs and fields we're going to try and move to dogs at dog shows.

Okay and we'll just go through this bit quickly now. So this is just tuning. So I'm putting all the what we just did into a single function just to make things a bit simpler. And yeah we'll go through. Okay so you can see how things are kind of changing more towards dogs and fields here.

And then here it goes a bit crazy because basically I'm putting a lot of dogs as negative. So now it's thinking or maybe I don't actually want to see any dogs. And that makes it push away from that. But obviously I don't want that to happen. So I just set everything negative here other than I think this image that has a field or maybe this image that has a field and also this image of a dog.

And then we go towards dogs again. Focus on that. Push towards dogs. And then here you can see the first in the middle right here. There's the first image of dogs at a dog show. Actually I think this is also a dog show here. So that would technically be the first one.

But this is what I'm looking for. More like this sort of image. So we focus on that and we push for that a little more. Next one we see oh okay we have a few more dog shows here. So here and here. And we keep pushing for that. And you can see as we go through each step there's more of these dogs in dog shows.

Because that's what I'm labeling as being more relevant. Okay and now we're really getting into that sort of space. Keep going and now we're at the point where pretty much everything when we're returning is a dog show. So this is the final bit. So now that we've done that we want to set all of the scene labels in our vector database back to not scene.

Okay because we want to search again. We can either search without the filter just to check that it has trained the classifier. Or we just reset all of those scene labels. If you wanted to go through data again and focus more on those that's where you might want to reset all the labels back to zero.

So to do that all I'm going to do is go through a while loop. And we keep going through and we search for everything where the filter is equal to scene. We get those ideas and then we mark them as not seen. Once we don't return any more items that means we've set everything to not seen because we're not returning anything else.

We've seen equal to true. So at that point we break. So after that if we search again we get a completely unfiltered view of the search results. And here we go. Okay so we can see loads of dogs at dog shows. Now there's one here that isn't a dog at a dog show.

I think the rest of them are. So with that we've actually fine-tuned our classifier. So now that we've finished optimizing those model weights we can save them to file. Okay so we do this. And with that let's have a look at how the model performs on actually classifying images.

So again move to another notebook. This is number 02 classifier tests. And here we're just going to test the classifier on a set of images that it has not seen before. So again we initialize everything. Again if you've already loaded everything and you're in the same notebook you don't need to do this.

So we need to load the validation split from ImageNet. So you can see here this before was train. Now it's validation. So you will need to rerun this bit. And we have about 4,000 images there. Now let's start by checking the predictions for some specific images. Okay so this one is a dog at a dog show.

So we pre-process that. We get the image features from clip. And then we make a prediction. So the classifier and then we put in those the vector output by clip. And we can see there's a pretty positive value there. So positive remember is a true value. Negative is a not true prediction.

Okay cool. So that's correct. It's predicted that that is a dog show. Now let's have a look at this. Okay this is not a dog show. So we should see that it will predict a negative value. So let's go through and yeah we get a pretty negative value there.

So we can label the full data set and we'll find a cutoff point between what is viewed as relevant and what is irrelevant. So basically anything that's positive. So we do that here. I'm not going to go through it but it's essentially the same thing as what we just said.

I'm just making a list of these predictions. Okay I'm going to add a column to the ImageNet data set called predictions. So we now have these three. And let's have a look. So filter out any results where the prediction is not positive. So we get 23 results. And let's have a look at what those are.

So those 23 positive results. All of them, I think almost all of them, are dog shows. And we keep going through. So each one of these as we go through has been scored less highly but all these are still scored very highly. Okay I'm going through. And then we go through and then yeah we get this, I don't know, emoji chainsaw thing which is right at the bottom of these positively labeled things.

It's kind of random. I don't know why it's in there. Yeah so other than literally these two images right at the end, everything else is a true positive. So it's predicted everything correctly other than these two. This one, no idea why. This one, I kind of understand, you know, dogs in a field.

So generally speaking these are I think very good results. And we got these from fine-tuning our classifier on not really that many images. I think there was maybe 50 images there. So really good results on a very small amount of data. And that's because we're using vector search to focus our annotation and training on what is the most important part of the data set.

Now doing this for an image classifier is just one example. We can do this with text. We can do this in recommendation engines or even anomaly detection. There's like a huge number of use cases with this. Basically whenever you need to classify something and you want to do that efficiently, you can use this as long as you're using something like a linear classifier.

So for me, I think that was a really cool method for efficiently training classification models. Thank you a lot to Edo for actually sharing with this and explaining and walking me through everything. I think, yeah, this is a really useful method and I hope you will find it useful as well.

So thank you very much for watching and I will see you again in the next one. Bye.