back to indexSegment Anything - Model explanation with code
Chapters
0:0 Introduction
1:20 Image Segmentation
3:28 Segment Anything
6:58 Task
8:20 Model (Overview)
9:51 Image Encoder
10:7 Vision Transformer
12:30 Masked Autoencoder Vision Transformer
15:32 Prompt Encoder
21:15 Positional Encodings
24:52 Mask Decoder
35:43 Intersection Over Union
37:8 Loss Functions
39:10 Data Engine and Dataset
41:35 Non Maximal Suppression
00:00:00.000 |
Hello guys, welcome to my new video about the new model from Meta called Segment Anything. 00:00:06.240 |
As you have heard from the internet, Segment Anything is a model that allows you to segment 00:00:11.660 |
an image into masks and without caring about what kind of image are we talking about. 00:00:18.500 |
So before, for example, we had segmentation models for medical applications or for pedestrian 00:00:27.420 |
But Segment Anything can work with any kind of image. 00:00:30.840 |
And the second novelty is that it allows you to work with prompts. 00:00:36.420 |
So given a prompt, like a list of points or a bounding box or a text, it can segment the 00:00:44.920 |
And this makes it a powerful foundation model just like BERT, just like GPT for NLP applications. 00:00:51.240 |
So it means that it can be later fine-tuned by working on the prompt and not only on the 00:00:59.920 |
In this video, we will watch what is the model, how does it work, what is its composition, 00:01:05.720 |
the data set it was trained upon, and also I will create a parallel with the code. 00:01:09.900 |
So I will also show you the code of this model along with an explanation of how it works 00:01:14.840 |
so that you can see things from a higher level to the bottom level. 00:01:25.420 |
Image segmentation is the process of partitioning a digital image into multiple regions such 00:01:30.600 |
that pixels that belong to the same region share some characteristics. 00:01:35.360 |
For example, if we are given this image, I think it was a painting from Van Gogh, and 00:01:41.160 |
if we partition it using segmentation, it will be partitioned into these masks, in which 00:01:46.120 |
for example, we have one mask for the grass, one for this house, one for this tree, one 00:01:53.900 |
And before we had segment anything, we had many models, one specifically tuned for each 00:01:58.840 |
For example, for medical imaging, we may want to locate, given an image of cells, which 00:02:04.000 |
one are the tumor cells and which one are not tumor cells. 00:02:08.520 |
Or in object detection, we may want to know where are the pedestrians in our image in 00:02:13.760 |
self-driving cars, or for example, in satellite images, we want to segment, for example, rivers 00:02:25.480 |
But it also had many challenges, this task, because, first of all, to create a dataset 00:02:32.440 |
I mean, imagine you need an operator who, pixel by pixel, has to define what this pixel 00:02:39.600 |
belongs to, or what this pixel belongs to, or what this pixel belongs to. 00:02:43.400 |
So it takes a lot of time to annotate images for image segmentation. 00:02:48.920 |
Just as I said before, the models usually were application-specific, and also the previous 00:02:54.840 |
That is, we could not tell the model, ah, please just select all the masks for cats 00:03:05.120 |
If the model was trained to detect that kind of mask, it detected it. 00:03:13.680 |
But now we can actually ask the model which kind of object we want to build the mask for. 00:03:18.200 |
And we can do that by using points, by using bounding box, by using text also. 00:03:25.300 |
And let's have a look on the website from Meta. 00:03:28.640 |
So if we go to segmentanything.com, we have this page called demo, in which, okay, we 00:03:36.960 |
And we can select any image, let's say one of these bears. 00:03:41.640 |
And the model can work with clicks, as I said before. 00:03:45.240 |
So if we click here, it tells the model that we want something from here. 00:03:53.700 |
Maybe the model selected we wanted only the face of the bear. 00:03:58.000 |
So we can remove some area by clicking on something that we want to remove. 00:04:02.920 |
So if we click the belly of this bear, it will remove the bottom part of the body. 00:04:08.920 |
The second thing we can do is by using a box. 00:04:11.200 |
For example, we may say, okay, select all the animals in this case. 00:04:17.360 |
Then we can guide the model by adding some points. 00:04:21.280 |
For example, this point was not included even if we wanted all the animals in this box. 00:04:28.740 |
But suppose that the model included the ears, or we also wanted to exclude something from 00:04:38.600 |
For example, we want to exclude, let's say, this paw here. 00:04:43.140 |
So we can add another point with remove area and put it here. 00:04:49.580 |
So of course, the model is not perfect, because the prompt is kind of, can be ambiguous in 00:05:02.800 |
And the second thing to notice is that the model is running in my browser. 00:05:18.320 |
So segment anything introduces three innovations. 00:05:23.240 |
It's called a promptable segmentation task, which can work with points, with boxes, with 00:05:27.960 |
text, or a combination of the above, for example, a box and a few points. 00:05:33.920 |
Which introduces a model that is a fast model. 00:05:36.560 |
It's an encoder decoder model that takes around 50 milliseconds in the web browser to generate 00:05:46.160 |
For example, given a point, that point may correspond to multiple objects. 00:05:50.540 |
For example, if we click here, for example, in this area, it may indicate this vegetable, 00:05:57.480 |
or all this vegetable, or only the white part of this vegetable. 00:06:02.520 |
And this means that the model cannot know, of course, what is our intent. 00:06:06.040 |
So the model will return the three most likely masks, indicating the part, the sub part and 00:06:13.960 |
Then of course, the model was trained on data. 00:06:16.920 |
And this data was a lot of was a big data set composed of 1.1 billion masks. 00:06:22.880 |
And the interesting thing is that these masks were actually generated by the model itself. 00:06:27.700 |
So they started the model with a very small data set. 00:06:31.360 |
Then they use this data model, which was created on a small data set to create an even bigger 00:06:36.480 |
data set with the help of operators, of course, manual operators. 00:06:40.200 |
And then after a while, they ask the model to generate all the masks automatically without 00:06:46.520 |
any human help, and then train the model on this automatically generated masks. 00:06:52.760 |
And the result is the one you just saw on the browser. 00:06:55.240 |
It's a model that can segment anything with a very high precision. 00:07:00.840 |
As you remember, in NLP, we have the next token prediction task, which is used by most 00:07:06.760 |
And so basically, we give a prompt to a model and the model has to come up with to complete 00:07:16.500 |
This is what happens with BERT and all the other language models. 00:07:20.240 |
And this is exactly what they wanted to do here. 00:07:23.580 |
They wanted to use a prompt to build a foundation model. 00:07:27.120 |
So a foundation model is a model that is trained on a lot of data and that can be fine-tuned, 00:07:33.000 |
let's say, for a specific task by working on the prompt. 00:07:37.160 |
And this prompt was made so that it can handle ambiguity. 00:07:42.500 |
For example, if the single click that we make is referring to multiple objects, the model 00:07:49.240 |
must return at least one mask that is reasonable for that click. 00:07:54.120 |
So this is the requirements that the author set for their model. 00:07:58.260 |
And we saw one ambiguous mask before, but for example, here we can see another case. 00:08:02.720 |
For example, if we click on the Z here, this point could refer to the Z itself or to the 00:08:14.840 |
So the model has to return at least one of these three. 00:08:18.360 |
And in the best case, of course, all of them. 00:08:26.640 |
The model, as we saw before, is an encoder-decoder model. 00:08:32.120 |
There is an image encoder that creates an embedding, given an image creates an embedding. 00:08:38.080 |
Then we have a prompt encoder that can encode the prompts given by the user, which can be 00:08:47.960 |
Then we will see later what is this mask here, but basically it means that if we run the 00:08:52.440 |
model with an initial prompt, for example, a single point, the model will build an initial 00:08:59.360 |
Then if we want to modify our prompt by adding another point, we can, instead of letting 00:09:04.680 |
the model guess what we want, we can reuse the previous output to guide the model into 00:09:10.160 |
telling the model that, okay, the previous mask was good, but not perfect. 00:09:14.400 |
So I am giving you, again, the previous mask as a hint, as a starting point, plus some 00:09:19.380 |
few points to guide you into telling you what I want you to remove or add to the mask. 00:09:25.920 |
This is the whole idea of this mask we can see here. 00:09:29.040 |
So it's a mask, is the result of a previous prompting of this model. 00:09:34.920 |
And then the model has a decoder that is given the prompt, the previous mask, and the embedding 00:09:40.160 |
of the image, it has to predict the mask along with the scores of confidence score. 00:09:46.880 |
Now let's go and watch what is the image encoder and how does it work? 00:09:52.360 |
In the paper, they say that they use the MAE pre-trained vision transformer. 00:09:57.640 |
So MAE means that it's a masked autoencoder and a vision transformer. 00:10:02.720 |
So let's review these two terms, what they mean and how they work. 00:10:08.440 |
The vision transformer was introduced in a very famous paper, I think a few years ago. 00:10:13.920 |
The paper name is "An image is worth 16 by 16 words" and it's from Google Research, Google 00:10:21.360 |
Basically what they did is, they take a picture, in this case, this one, they divide it into 00:10:26.160 |
patches of 16 by 16, and then they flatten these patches, so create a sequence of these 00:10:33.680 |
For example, this is the first patch, this is the second one, this is the third one, 00:10:38.800 |
And then they created embedding by using a linear projection. 00:10:42.480 |
So the embedding that captures somehow the information from each of this patch. 00:10:49.120 |
They feed all this sequence of patches and actually the embedding of these patches along 00:10:55.480 |
with the position encoding to the transformer. 00:10:59.860 |
But they not only feed the list of patches, but also another token here that is prepended 00:11:10.260 |
And this token is called the class embedding, the class token. 00:11:19.160 |
Basically if you remember the transformer, when we have a sequence in the transformer 00:11:23.160 |
encoder, the transformer basically allows, with its self-attention mechanism, to relate 00:11:30.040 |
So in the output we will have again a sequence of tokens, but the embedding of each token 00:11:37.240 |
will somehow capture the interaction of that token with all the other tokens. 00:11:42.620 |
And this is the idea behind adding this class token we have here. 00:11:46.920 |
So we send it to the transformer, and at the output of the transformer, as soon as we will 00:11:52.680 |
get another sequence in the output of the transformer encoder, we just take the first 00:11:57.740 |
token here, and we ask this token to map this token to a multilayer perceptron that has 00:12:10.080 |
Because this token, because of the self-attention mechanism, has interacted with all of the 00:12:15.920 |
So this token somehow captures the information of the other patches, and then we force the 00:12:23.040 |
model to convey all this information to this single token. 00:12:27.120 |
So this is the idea of the vision transformer and of the class token. 00:12:34.180 |
They took the vision transformer and they transformed it into a masked autoencoder 00:12:39.240 |
And this happened in another paper called Masked Autoencoders are Scalable Vision Learners 00:12:46.920 |
Now in this case they still have an input image, in this case this one, but what they 00:12:53.360 |
did is they split it into patches, but then they masked it out, they deleted some patches 00:12:59.120 |
and replaced it with zeros, so they hide some patches here, and if I remember correctly 00:13:10.920 |
Then they only take the visible patches, create a sequence, a linear sequence here, they give 00:13:15.880 |
it to the encoder of a transformer, which still produces as output a sequence as we 00:13:24.640 |
Then what they do, they take this sequence that is the output of the encoder of the transformer, 00:13:31.160 |
they again recreate the original image sequence, so if we knew that the first patch was empty, 00:13:40.160 |
then they put an empty space here, then an empty, then the third one was visible, so 00:13:46.800 |
Then the fourth, the fifth and the sixth were deleted, so four, five and six were deleted, 00:13:52.440 |
then they take the next one visible and they put it here. 00:13:55.000 |
So basically to the decoder they give the visible patches and the non-visible patches 00:14:00.920 |
along with the geometric information of the visibility. 00:14:04.540 |
So they are added in the same sequence in which they were cancelled out in the original 00:14:07.880 |
image, and then they ask the decoder to predict the original image, only being able to visualize 00:14:16.960 |
So basically the decoder has to come up with a full image, being only able to access 25% 00:14:24.320 |
And what they saw is that the decoder was actually able to rebuild the original image, 00:14:29.120 |
maybe not with the perfect quality, but a reasonable quality. 00:14:33.800 |
And what the authors of the segment editing paper did, they took this part of the masked 00:14:43.200 |
autoencoder of the vision transformer, because they are interested in this, the embedding 00:14:48.680 |
learned by the encoder, so the output of the encoder. 00:14:52.160 |
Because if the model is able to predict the original image, given only this embedding 00:14:57.440 |
of the visible patches, it means that these embeddings capture most of the information 00:15:02.320 |
of the image, which can be then reused to rebuild the original image. 00:15:06.320 |
So this is what they want, but this is what we want from an encoder. 00:15:09.960 |
We want the encoder to create a representation of something that captures most of its salient 00:15:14.840 |
information without caring about the extra information that is not necessary. 00:15:20.160 |
So this allows you to reduce the dimensionality of the original image while preserving the 00:15:26.280 |
And this is why they use the encoder of the masked autoencoder vision transformer. 00:15:33.880 |
Now that we have seen what is the image encoder, which is basically creating an embedding, 00:15:39.600 |
this one here, now we go to the next part, which is the prompt encoder. 00:15:43.240 |
So the job of the prompt encoder is also to encode the prompt, which is the list of points 00:15:48.000 |
chosen by the user, the boxes selected by the user and the text. 00:15:52.640 |
We will not visualize what is the text encoder, which is basically just the encoder of the 00:15:58.160 |
So if you are not familiar with the clip model, I suggest you watch my previous video about 00:16:03.640 |
the clip model, and it's quite interesting, actually, it's per se, so it deserves its 00:16:11.560 |
But basically the idea is the same as with the image encoder. 00:16:14.520 |
So we have a text and we want some representation that captures most of the information about 00:16:21.000 |
And this is done by the encoder of the clip text encoder. 00:16:29.220 |
Now in the prompt in their paper, segment anything, they say that they consider two 00:16:33.680 |
type of prompts, the sparse prompts, which is the points, boxes and text, and the dense 00:16:41.040 |
For the text encoder, they just use the text encoder from clip, we can see here, while 00:16:46.160 |
the other two prompts, so the points and the boxes are basically they take the points, 00:16:54.800 |
So an embedding that tells the model what is this point referring to inside of the image 00:17:06.420 |
Here we can see that basically, they take the sparse prompts, and they are mapped to 00:17:27.020 |
The points are a sequence of X and Ys, while the labels indicate if the point is additive, 00:17:34.700 |
so we want the model to add something to our mask, or subtractive, we want the model to 00:17:41.660 |
Here they are called foreground or background. 00:17:44.480 |
Foreground means that we want the model to add something, background we want the model 00:17:47.660 |
to remove something, just like we did with the example on the website before. 00:17:51.540 |
So the first thing they do is they create the positional encoding of these points. 00:17:56.500 |
So they convert the X and the Y into positional encodings, exactly like we do in the transformer 00:18:02.860 |
As you remember in the transformer model we have positional encodings. 00:18:06.180 |
They are special vectors that tell the model, that are combined with the embedding of each 00:18:11.620 |
token to tell the model what is the position of the token inside of the sentence. 00:18:16.620 |
And here the idea is the same, even if the positional encodings are different. 00:18:22.220 |
I mean the idea is the same, so that it's a vector with the dimension 256, but they 00:18:27.640 |
are built in a different way, and we will see why. 00:18:31.580 |
But we have to think that they transform the X and the Y into vectors, each of them representing 00:18:37.400 |
the position of the point inside of the image using the positional encoding. 00:18:42.860 |
Then they need to tell the model what is this point. 00:18:46.700 |
The model cannot know if that point is foreground or background. 00:18:52.420 |
Basically all the foreground points are summed to another embedding here. 00:18:58.500 |
That indicates it's an embedding that indicates that that is a foreground point. 00:19:03.460 |
And all the background points are summed to another embedding that indicates to the model 00:19:10.700 |
And if the point is a padding, because we don't have enough points, then they use another 00:19:18.060 |
And this is how they build the embedding for the points. 00:19:21.660 |
While for the boxes, the boxes are defined using the top left corner, so X and Y of the 00:19:27.540 |
top left corner, and the bottom right corner. 00:19:33.760 |
So basically they transform these two points, so the top left and the bottom right, using 00:19:40.100 |
the positional encodings to tell the model what is that X and Y corresponding to inside 00:19:46.780 |
And then they sum one embedding to indicate that it's a top left point and another embedding 00:19:55.660 |
And this is how they build the encoding for the prompt. 00:19:59.540 |
Why do we want to create 256 dimensional vector embeddings? 00:20:04.380 |
Because the 266 dimension is also the one used for the image embedding. 00:20:11.420 |
Because then we can combine them using a transformer. 00:20:16.120 |
So the mask we saw before, what is the role of the mask? 00:20:18.740 |
Now let's go into the detail of how it's combined with the image. 00:20:23.200 |
So basically, the masks are called a dense prompt in the segment editing model. 00:20:28.900 |
And what they do is, if the mask is specified, so here, okay, if the mask is specified, basically 00:20:37.820 |
they run it through a sequence of layers of convolutions to downscale this mask. 00:20:48.540 |
And then if no mask is specified, they create a special embedding called no mask. 00:20:57.220 |
So as you can see, it's just an embedding with a given dimension. 00:21:01.260 |
How they combine this mask with the image, they just use a pointwise sum, as you can 00:21:08.460 |
So they take the image and they just add the dense prompt embeddings, which is the mask 00:21:16.580 |
Now, as we saw before, we have to use the positional encoding to tell the model what 00:21:21.320 |
are the points that we are feeding to the model itself. 00:21:25.180 |
So the model cannot know X and Y, the model has need to know some other information. 00:21:31.700 |
We cannot just feed a list of X and Y to the model. 00:21:34.320 |
We need to tell the model something more, something that can be learned by the model. 00:21:39.520 |
And since the transformer models were very good at detecting the position using the sinusoidal 00:21:46.740 |
positional encoding using the vanilla transformer. 00:21:49.060 |
So if you remember in the vanilla transformer, so the transformer that was introduced in 00:21:52.780 |
the paper, attention is all you need, if you remember, the positional encodings were built 00:22:03.980 |
And these vectors told the model what is the position of the token inside of the sentence. 00:22:09.060 |
Now, this was fine as long as we worked with text, because text only move along one dimension, 00:22:14.780 |
that is, we have the token number zero, the token number one, the token number two, etc. 00:22:19.860 |
But pixels don't move in one direction, they move into two directions. 00:22:23.700 |
So one person, of course, one could think, why not use the one positional encoding for 00:22:30.300 |
the X coordinate to convert the X coordinate into a vector and another to map the Y coordinate 00:22:39.980 |
But the problem is, if we do in this way, suppose we convert the center position of 00:22:46.980 |
the image into two vectors, one encoded using the X coordinate and one encoded using the 00:22:54.820 |
What we do if we check the similarity with the other position in the image is we get 00:22:58.780 |
some hitmap like this, in which the zero, this position here is very similar to this 00:23:05.980 |
But it's not similar to this position here, which is not good, because in the in an image, 00:23:14.420 |
So pixel at the same Euclidean distance should have similarity with another point at the 00:23:22.020 |
So basically, this point and this point should have a similarity that is the same as this 00:23:30.860 |
Because we have a spatial representation, so pixels that are close to each other should 00:23:35.060 |
be very similar, pixels that are far from each other should not be very similar. 00:23:43.340 |
What we want is something like this, that is, if we have a point here, all the points 00:23:47.540 |
in the radius of, let's say, 10 pixels are very similar. 00:23:51.500 |
All the points in the radius of 20, so distance 20, are less similar, but still, depending 00:24:01.740 |
on the radius, they are similar in the same way as the other points with the same radius. 00:24:07.380 |
And the more we go far from the center, the more we become distant, the more we become 00:24:14.220 |
And this is what we want from positional encodings for an image. 00:24:22.180 |
You can see learnable Fourier features from multidimensional spatial positional encoding. 00:24:28.020 |
And however, this is not the paper used by Segment Anything. 00:24:31.020 |
For Segment Anything, they use this paper here, but you understood why we needed a new 00:24:38.100 |
Basically because we need to map two-dimensional mapping of X and Y to... we need to give an 00:24:52.300 |
And now let's look at the most important part of the model, which is the decoder. 00:24:56.900 |
Now before we look at the decoder, I want to remind you that in the model that we saw 00:25:01.860 |
before on the web browser, we could add the points on the real time in the browser. 00:25:09.940 |
That is, we loaded the image, and then I clicked on the model, and basically after a few milliseconds, 00:25:16.860 |
let's say 100 milliseconds or half a second, I saw the output of the model. 00:25:21.700 |
This could happen because the decoder is very fast, and the prompt encoder is very fast. 00:25:27.660 |
But the image encoder doesn't have to be very fast, because we only encode the image once 00:25:34.040 |
Then we can do multiple prompts, so we can save the image embeddings, and then we just 00:25:40.580 |
change the prompt embeddings and run them through the decoder to get the new masks. 00:25:45.580 |
Which means basically that the image encoder can be powerful, even if it's slow, but the 00:25:55.740 |
And the same goes for the prompt encoder, and this is actually the case, because that's 00:25:59.340 |
why we could use it on my browser in a reasonable time. 00:26:09.460 |
It's made of two layers, so we have to think that there is this block here, is repeated 00:26:16.980 |
again with another block that is after this one, where the output of this big block is 00:26:22.860 |
fed to the other block, and the output of that block is actually sent to the model here. 00:26:41.020 |
So the prompts sent by the user, so the clicks, the boxes, and then we have the image embedding 00:26:47.260 |
plus the mask we can see in the picture here. 00:26:50.540 |
So the image has already been combined with the mask through this layer here, through 00:27:00.060 |
The first thing the model, the decoder does is the self-attention. 00:27:04.100 |
So the self-attention between the prompt tokens and with the prompt tokens. 00:27:10.360 |
But here we can also see that there are these output tokens. 00:27:14.000 |
So before we proceed to see these steps, let's watch what are the output tokens. 00:27:19.880 |
The output tokens take the idea also from BERT. 00:27:23.180 |
So as you remember before, we saw the vision transformer, right? 00:27:26.760 |
In the vision transformer, when they fed the patches to the transformer encoder, they also 00:27:35.380 |
And the same idea is reused by segment anything, in which they append some tokens before the 00:27:41.860 |
promptable tokens, so the boxes, the clicks made by the user. 00:27:47.940 |
And then at the output of this decoder, they check again only these tokens and force the 00:27:54.220 |
model to put all the information into these tokens. 00:27:57.740 |
So in this case, we have one token that tells the IOU, so the intersection over union scores 00:28:08.040 |
And we will see later what is the IOU, if you're not familiar with it. 00:28:11.880 |
And then there are three mask tokens, so one token for each mask. 00:28:17.700 |
And basically, we feed these four tokens, so one IOU and three masks to the model. 00:28:26.920 |
And then we use the first token, so the IOU token, before we map it to a multi-layer perceptron, 00:28:34.580 |
to force the model to learn the IOU score into these tokens. 00:28:38.580 |
And then the other three are used to predict the three masks as output. 00:28:44.580 |
And we can see that here, they use the idea just like in the BERT paper. 00:28:51.900 |
And this is the reference to the BERT paper in which they introduced the CLS token, in 00:28:56.020 |
which in BERT was used for classification tasks. 00:28:59.060 |
So basically, also in BERT, they prepended this token called the CLS. 00:29:05.140 |
And then, at the output of the transformer, they just took the token corresponding to 00:29:09.980 |
this CLS, which was the first one, and they forced the model to learn all the information 00:29:20.580 |
Because the CLS token could interact with all the other tokens through the self-attention 00:29:28.140 |
So we feed the model with output tokens combined with the PROM tokens. 00:29:33.260 |
And you can see here, they just concatenate the two. 00:29:36.060 |
They take the IOU token and the mask token, they concatenate together. 00:29:39.420 |
And then they concatenate these output tokens with the PROM tokens you can see here. 00:29:50.360 |
So first, they run the self-attention with the tokens, which are the output tokens plus 00:29:58.540 |
You can see that the query, the key, and the values are the same. 00:30:05.140 |
The comments have been added by me, they are not present in the original code to make your 00:30:10.180 |
Even if I found it really hard to follow this nomenclature, because sometimes they use the 00:30:14.660 |
name is called Q, but then they pass it as K or P, etc. 00:30:21.420 |
What we want to get from this code is actually not the single instructions, but the overall 00:30:30.700 |
So the output of this self-attention is then fed to a cross-attention. 00:30:36.820 |
Basically, a cross-attention is an attention in which the query comes from one side, and 00:30:42.220 |
the key and the value come from another side. 00:30:44.820 |
If you remember my video about the transformer model, in a translation task usually, imagine 00:30:51.760 |
we are translating from English to Italian, or English to French, or English to Chinese. 00:30:57.340 |
Basically, we first run in the encoder a self-attention between all the input sentence, so all the 00:31:03.380 |
tokens of the input sentence related to all the other tokens of the input sentence. 00:31:08.180 |
And then in the decoder, we have this cross-attention in which we take the queries coming from one 00:31:14.560 |
language and the key and the values coming from another language. 00:31:18.780 |
And this is usually done to combine two different sequences, to relate two different sequences 00:31:28.120 |
In this case, what we want is to relate the tokens, so our prompt, with the image embeddings. 00:31:35.300 |
So the first cross-attention is the tokens used as queries, while the image embeddings 00:31:47.220 |
Then there is a multilayer perceptron, so it's just linear layers. 00:31:52.880 |
And finally, we have another cross-attention, but this time the opposite. 00:31:56.740 |
So in this case, the queries are the image embeddings and the keys and the values are 00:32:05.100 |
Because we want two outputs from this transformer. 00:32:08.020 |
One will be the sequence of tokens of the prompt, from which we will extract the output 00:32:12.700 |
tokens, one indicating the IOU score and three indicating the mask. 00:32:17.600 |
And one is the image embedding that we will then combine with the output tokens of the 00:32:27.980 |
Another thing here highlighted in the paper is that to ensure that the coder has critical 00:32:32.760 |
geometric information, the positional encodings are added to the image embedding whenever 00:32:39.940 |
And as you can see, this is done not only for the image, but also for the prompt. 00:32:43.820 |
So every time to the prompt, they add the positional encoding, and every time to the 00:32:48.020 |
image embedding, they also add the positional encoding. 00:32:52.960 |
Because we don't want to lose this information after all these layers. 00:32:58.080 |
And this is usually done with a skip connection. 00:33:07.160 |
As we saw before, we have a special layer of tokens added to the input of the transformer. 00:33:15.440 |
One is for the IOU prediction and three are for the mask prediction. 00:33:22.040 |
Because our transformer model has two cross-attentions, we have two outputs, two sequences outputs. 00:33:28.720 |
One is the output sequence of the tokens, and one is the output sequence of the embeddings 00:33:35.980 |
And they extract the IOU token, which is the first token added to the sequence, and then 00:33:41.680 |
they extract the mask tokens, which are the first three, skipping the first one, the next 00:33:50.380 |
They give the IOU tokens, they just give it to a prediction head to predict the IOU scores, 00:34:00.040 |
And then they take the output tokens for the masks, we can see them here, and they combine 00:34:09.680 |
them with the upscaled embedding of the image. 00:34:14.280 |
So they take the output of the transformer for the image, so SRC in this case, the variable 00:34:21.060 |
name is SRC, they upscale it here, and then they run each of the mask output tokens through 00:34:34.520 |
So here you can see we have multiple MLP blocks here. 00:34:42.000 |
They run each of them through its own, they get the output, and then they combine this 00:34:47.360 |
output of this MLP, one for each token, one for each mask, with the upscale embedding 00:34:53.160 |
of the image to produce the output mask here. 00:34:58.000 |
Another interesting part of the paper is this section here, making the model ambiguity aware. 00:35:03.200 |
So as I was saying before, we not only predict one mask, but we predict three masks. 00:35:08.200 |
And this happens when we do not have more than one prompt. 00:35:13.760 |
So if we only click one, for example, once, the model will produce three masks. 00:35:18.960 |
But if we have more than one prompt, because the ambiguity becomes less, at least theoretically, 00:35:26.000 |
the authors decided to add a fourth token that predicts another mask that is used only 00:35:37.380 |
And this mask is never returned for the single prompt. 00:35:44.360 |
Now let's have a look at what is intersection over union. 00:35:47.560 |
Intersection over union allows us to understand how good is our prediction given the ground 00:35:54.040 |
truth, especially in segmentation models or object detection models. 00:36:00.440 |
So for example, imagine we are using object detection and our ground truth box is this 00:36:06.200 |
green box here, but our object, our model produced this red prediction as output. 00:36:12.680 |
So because you can see that even if there is some overlapping, but it doesn't cover 00:36:16.680 |
the entire image, the prediction is quite poor. 00:36:21.580 |
But this improves when the box becomes bigger. 00:36:27.700 |
So there is more intersection, but also more union. 00:36:32.080 |
And finally, it becomes excellent when the intersection is covering all the box and it's 00:36:40.120 |
So the union of the two, basically the same box and they cover as most as possible. 00:36:45.520 |
This area here, so the area that is predicted, but was not asked in the ground truth is called 00:36:53.520 |
This one here is called false positive, while the area that should have been predicted, 00:36:59.240 |
but was not predicted is called false negative. 00:37:02.960 |
This is a commonly used term also in this kind of scenarios. 00:37:11.680 |
The loss of the model is a combination of two loss. 00:37:14.120 |
One is called the focal loss and one is the dice loss, and they are used in a ratio of 00:37:22.440 |
The focal loss takes his idea from the cross entropy, but with a modification that is the 00:37:32.440 |
So why do we have a class imbalance in this case? 00:37:38.280 |
We are trying to predict the map, the mask for a particular object in our image. 00:37:45.160 |
But of course, usually the mask is not covering the entire image, but it's only very few pixels 00:37:51.080 |
compared to the total image are actually participating in this mask. 00:37:55.960 |
And the instances of big mask are actually not so many. 00:37:59.360 |
So we have a class imbalance here because most of our pixel will be non mask and only 00:38:03.600 |
a small percentage of our pixel will be mask. 00:38:06.600 |
So we cannot use cross entropy in this case because the cross entropy doesn't pay attention 00:38:11.560 |
So this is why they use focal loss to pay attention to this class imbalance. 00:38:15.760 |
But the focal loss derives from the cross entropy and it was introduced in this paper 00:38:20.640 |
by Facebook research, you can see here focal loss for dense object detection. 00:38:29.100 |
The dice loss comes from the Soren dice coefficient and it's also called the F1 score. 00:38:34.880 |
And it's calculated as the total to twice the intersection. 00:38:38.600 |
So twice the area of overlap divided by the total area. 00:38:43.360 |
And this is the actually a measure of similarity of between two sets of data. 00:38:49.640 |
To get the loss, we just do one minus the dice score. 00:38:53.180 |
If you want more information about this dice score, which is very commonly used, I suggest 00:39:05.160 |
And the dice loss was introduced in this paper VNet, I think it's from 2015. 00:39:11.740 |
Another interesting thing is that a segment anything built its own data set. 00:39:15.840 |
And this is remarkable because we saw before that the segment anything has been trained 00:39:19.720 |
on 1.1 billion masks using millions of images. 00:39:25.360 |
And the data engine that was used to build this data set of 1.1 billion mask is composed 00:39:33.880 |
The first one was a manual stage, then a semi automatic stage and then a fully automatic 00:39:41.920 |
In the assisted manual stage, so the manual stage, basically they hired a team of professional 00:39:46.280 |
annotators that manually labeled the images using only the brush and the eraser tool. 00:39:52.700 |
So basically you have to think that there are many people who are using only pixel by 00:40:00.920 |
So this is what we would do to create a data set from zero. 00:40:05.120 |
Then they train this model on this manually created masks. 00:40:09.680 |
And then they went to the semi automatic stage, that is, some of the masks were already generated 00:40:15.240 |
by our model, which was trained on the manually generated mask. 00:40:19.720 |
And then the operators only had to adjust this mask to annotate any additional annotated 00:40:30.120 |
And finally, this create even more samples and they train the model on this sample. 00:40:34.480 |
And finally, they created the fully automatic stage. 00:40:40.720 |
In this fully automatic stage, the model, there is no operator. 00:40:44.760 |
The model is building the data set by itself. 00:40:49.000 |
They take an image, they create a grid of 32 by 32 points. 00:40:53.760 |
And then for each of these points, they ask the model to predict the masks. 00:40:59.480 |
Of course, this will produce a lot, a large number of masks. 00:41:02.600 |
So they only take the one with the highest confidence score and also the only one that 00:41:08.640 |
And by stable, they mean that if they threshold the probability map at 0.5 minus delta and 00:41:14.680 |
0.5 plus delta, they result in similar masks. 00:41:19.360 |
Next, because we have a lot of masks and some of them may be overlapping with each other, 00:41:24.520 |
some of them may be duplicate, actually, we need to remove some of them. 00:41:29.440 |
So we use an algorithm called the non-maximal suppression. 00:41:32.360 |
This is very famous also in object detection. 00:41:36.800 |
So non-maximal suppression usually works like this. 00:41:43.760 |
Usually when we detect a bounding box for an object, we get a lot of bounding boxes. 00:41:51.120 |
Well, basically, we take the one with the highest confidence score, and then we delete 00:41:56.480 |
all the other bounding boxes that have an IOU threshold with the one that we selected 00:42:04.680 |
higher than one threshold that is given as parameter. 00:42:08.740 |
This allow us to eliminate all the bounding boxes that are similar to the one we have 00:42:17.240 |
And then we do this for all the remaining boxes. 00:42:21.220 |
And the algorithm is very simple, and it's also very effective. 00:42:26.720 |
Thank you guys for watching my video about the segment anything. 00:42:30.640 |
I hope that most of the information was clear. 00:42:36.160 |
I will try to complement my errors or some misunderstanding or something that I should 00:42:45.000 |
I please subscribe to my channel because I will be uploading more videos in the future.