back to index

Segment Anything - Model explanation with code


Chapters

0:0 Introduction
1:20 Image Segmentation
3:28 Segment Anything
6:58 Task
8:20 Model (Overview)
9:51 Image Encoder
10:7 Vision Transformer
12:30 Masked Autoencoder Vision Transformer
15:32 Prompt Encoder
21:15 Positional Encodings
24:52 Mask Decoder
35:43 Intersection Over Union
37:8 Loss Functions
39:10 Data Engine and Dataset
41:35 Non Maximal Suppression

Whisper Transcript | Transcript Only Page

00:00:00.000 | Hello guys, welcome to my new video about the new model from Meta called Segment Anything.
00:00:06.240 | As you have heard from the internet, Segment Anything is a model that allows you to segment
00:00:11.660 | an image into masks and without caring about what kind of image are we talking about.
00:00:18.500 | So before, for example, we had segmentation models for medical applications or for pedestrian
00:00:24.420 | detection or for some other objects.
00:00:27.420 | But Segment Anything can work with any kind of image.
00:00:30.840 | And the second novelty is that it allows you to work with prompts.
00:00:34.540 | So just like you work in NLP.
00:00:36.420 | So given a prompt, like a list of points or a bounding box or a text, it can segment the
00:00:42.300 | image given your input.
00:00:44.920 | And this makes it a powerful foundation model just like BERT, just like GPT for NLP applications.
00:00:51.240 | So it means that it can be later fine-tuned by working on the prompt and not only on the
00:00:55.880 | data to apply to specific tasks.
00:00:59.920 | In this video, we will watch what is the model, how does it work, what is its composition,
00:01:05.720 | the data set it was trained upon, and also I will create a parallel with the code.
00:01:09.900 | So I will also show you the code of this model along with an explanation of how it works
00:01:14.840 | so that you can see things from a higher level to the bottom level.
00:01:21.240 | Let's start.
00:01:22.880 | So first of all, what is image segmentation?
00:01:25.420 | Image segmentation is the process of partitioning a digital image into multiple regions such
00:01:30.600 | that pixels that belong to the same region share some characteristics.
00:01:35.360 | For example, if we are given this image, I think it was a painting from Van Gogh, and
00:01:41.160 | if we partition it using segmentation, it will be partitioned into these masks, in which
00:01:46.120 | for example, we have one mask for the grass, one for this house, one for this tree, one
00:01:50.920 | for this other tree, etc.
00:01:53.900 | And before we had segment anything, we had many models, one specifically tuned for each
00:01:57.840 | application.
00:01:58.840 | For example, for medical imaging, we may want to locate, given an image of cells, which
00:02:04.000 | one are the tumor cells and which one are not tumor cells.
00:02:08.520 | Or in object detection, we may want to know where are the pedestrians in our image in
00:02:13.760 | self-driving cars, or for example, in satellite images, we want to segment, for example, rivers
00:02:20.880 | and the mountains and urban areas, etc.
00:02:25.480 | But it also had many challenges, this task, because, first of all, to create a dataset
00:02:30.240 | for image segmentation was very expensive.
00:02:32.440 | I mean, imagine you need an operator who, pixel by pixel, has to define what this pixel
00:02:39.600 | belongs to, or what this pixel belongs to, or what this pixel belongs to.
00:02:43.400 | So it takes a lot of time to annotate images for image segmentation.
00:02:48.920 | Just as I said before, the models usually were application-specific, and also the previous
00:02:53.200 | models were not promptable.
00:02:54.840 | That is, we could not tell the model, ah, please just select all the masks for cats
00:03:00.200 | or for dogs or for trees or for houses.
00:03:03.880 | So we could not build a prompt.
00:03:05.120 | If the model was trained to detect that kind of mask, it detected it.
00:03:10.200 | Otherwise it didn't.
00:03:11.200 | So it was all or nothing.
00:03:13.680 | But now we can actually ask the model which kind of object we want to build the mask for.
00:03:18.200 | And we can do that by using points, by using bounding box, by using text also.
00:03:25.300 | And let's have a look on the website from Meta.
00:03:28.640 | So if we go to segmentanything.com, we have this page called demo, in which, okay, we
00:03:34.240 | accept the conditions.
00:03:36.960 | And we can select any image, let's say one of these bears.
00:03:41.640 | And the model can work with clicks, as I said before.
00:03:45.240 | So if we click here, it tells the model that we want something from here.
00:03:50.840 | But imagine the model selected too much.
00:03:53.700 | Maybe the model selected we wanted only the face of the bear.
00:03:58.000 | So we can remove some area by clicking on something that we want to remove.
00:04:02.920 | So if we click the belly of this bear, it will remove the bottom part of the body.
00:04:08.920 | The second thing we can do is by using a box.
00:04:11.200 | For example, we may say, okay, select all the animals in this case.
00:04:15.520 | But now it only selected the box.
00:04:17.360 | Then we can guide the model by adding some points.
00:04:21.280 | For example, this point was not included even if we wanted all the animals in this box.
00:04:25.800 | So we can tell him to add this animal.
00:04:28.740 | But suppose that the model included the ears, or we also wanted to exclude something from
00:04:37.600 | here.
00:04:38.600 | For example, we want to exclude, let's say, this paw here.
00:04:43.140 | So we can add another point with remove area and put it here.
00:04:47.280 | And hopefully it will remove the paw.
00:04:49.580 | So of course, the model is not perfect, because the prompt is kind of, can be ambiguous in
00:04:55.360 | some case, even for us humans.
00:04:58.040 | And so of course, the model is not perfect.
00:05:01.080 | But it's, I mean, it looks very good.
00:05:02.800 | And the second thing to notice is that the model is running in my browser.
00:05:06.280 | There is no back processing on a server.
00:05:09.760 | It's happening in real time in my browser.
00:05:11.800 | So it's quite fast.
00:05:13.720 | And let's go back to our slides.
00:05:18.320 | So segment anything introduces three innovations.
00:05:21.360 | The first is the task itself.
00:05:23.240 | It's called a promptable segmentation task, which can work with points, with boxes, with
00:05:27.960 | text, or a combination of the above, for example, a box and a few points.
00:05:33.920 | Which introduces a model that is a fast model.
00:05:36.560 | It's an encoder decoder model that takes around 50 milliseconds in the web browser to generate
00:05:41.920 | a mask given a prompt.
00:05:44.320 | And it's also ambiguity aware.
00:05:46.160 | For example, given a point, that point may correspond to multiple objects.
00:05:50.540 | For example, if we click here, for example, in this area, it may indicate this vegetable,
00:05:57.480 | or all this vegetable, or only the white part of this vegetable.
00:06:02.520 | And this means that the model cannot know, of course, what is our intent.
00:06:06.040 | So the model will return the three most likely masks, indicating the part, the sub part and
00:06:11.360 | the whole.
00:06:13.960 | Then of course, the model was trained on data.
00:06:16.920 | And this data was a lot of was a big data set composed of 1.1 billion masks.
00:06:22.880 | And the interesting thing is that these masks were actually generated by the model itself.
00:06:27.700 | So they started the model with a very small data set.
00:06:31.360 | Then they use this data model, which was created on a small data set to create an even bigger
00:06:36.480 | data set with the help of operators, of course, manual operators.
00:06:40.200 | And then after a while, they ask the model to generate all the masks automatically without
00:06:46.520 | any human help, and then train the model on this automatically generated masks.
00:06:52.760 | And the result is the one you just saw on the browser.
00:06:55.240 | It's a model that can segment anything with a very high precision.
00:06:58.280 | The authors took inspiration from NLP.
00:07:00.840 | As you remember, in NLP, we have the next token prediction task, which is used by most
00:07:05.000 | language models.
00:07:06.760 | And so basically, we give a prompt to a model and the model has to come up with to complete
00:07:13.280 | the sentence with something meaningful.
00:07:15.120 | This is what happens with GPT.
00:07:16.500 | This is what happens with BERT and all the other language models.
00:07:20.240 | And this is exactly what they wanted to do here.
00:07:23.580 | They wanted to use a prompt to build a foundation model.
00:07:27.120 | So a foundation model is a model that is trained on a lot of data and that can be fine-tuned,
00:07:33.000 | let's say, for a specific task by working on the prompt.
00:07:37.160 | And this prompt was made so that it can handle ambiguity.
00:07:42.500 | For example, if the single click that we make is referring to multiple objects, the model
00:07:49.240 | must return at least one mask that is reasonable for that click.
00:07:54.120 | So this is the requirements that the author set for their model.
00:07:58.260 | And we saw one ambiguous mask before, but for example, here we can see another case.
00:08:02.720 | For example, if we click on the Z here, this point could refer to the Z itself or to the
00:08:10.120 | entire text here or to the entire wall.
00:08:14.840 | So the model has to return at least one of these three.
00:08:18.360 | And in the best case, of course, all of them.
00:08:21.320 | Now let's overview the model.
00:08:24.840 | What is the model architecture?
00:08:26.640 | The model, as we saw before, is an encoder-decoder model.
00:08:30.360 | And it's composed of these parts.
00:08:32.120 | There is an image encoder that creates an embedding, given an image creates an embedding.
00:08:38.080 | Then we have a prompt encoder that can encode the prompts given by the user, which can be
00:08:44.160 | points, boxes, text, or a combination.
00:08:47.960 | Then we will see later what is this mask here, but basically it means that if we run the
00:08:52.440 | model with an initial prompt, for example, a single point, the model will build an initial
00:08:58.360 | mask.
00:08:59.360 | Then if we want to modify our prompt by adding another point, we can, instead of letting
00:09:04.680 | the model guess what we want, we can reuse the previous output to guide the model into
00:09:10.160 | telling the model that, okay, the previous mask was good, but not perfect.
00:09:14.400 | So I am giving you, again, the previous mask as a hint, as a starting point, plus some
00:09:19.380 | few points to guide you into telling you what I want you to remove or add to the mask.
00:09:25.920 | This is the whole idea of this mask we can see here.
00:09:29.040 | So it's a mask, is the result of a previous prompting of this model.
00:09:34.920 | And then the model has a decoder that is given the prompt, the previous mask, and the embedding
00:09:40.160 | of the image, it has to predict the mask along with the scores of confidence score.
00:09:46.880 | Now let's go and watch what is the image encoder and how does it work?
00:09:52.360 | In the paper, they say that they use the MAE pre-trained vision transformer.
00:09:57.640 | So MAE means that it's a masked autoencoder and a vision transformer.
00:10:02.720 | So let's review these two terms, what they mean and how they work.
00:10:05.560 | Let's first review the vision transformer.
00:10:08.440 | The vision transformer was introduced in a very famous paper, I think a few years ago.
00:10:13.920 | The paper name is "An image is worth 16 by 16 words" and it's from Google Research, Google
00:10:19.560 | Brain.
00:10:21.360 | Basically what they did is, they take a picture, in this case, this one, they divide it into
00:10:26.160 | patches of 16 by 16, and then they flatten these patches, so create a sequence of these
00:10:32.680 | patches.
00:10:33.680 | For example, this is the first patch, this is the second one, this is the third one,
00:10:36.720 | the fourth, etc.
00:10:38.800 | And then they created embedding by using a linear projection.
00:10:42.480 | So the embedding that captures somehow the information from each of this patch.
00:10:49.120 | They feed all this sequence of patches and actually the embedding of these patches along
00:10:55.480 | with the position encoding to the transformer.
00:10:59.860 | But they not only feed the list of patches, but also another token here that is prepended
00:11:08.400 | to this sequence of patches.
00:11:10.260 | And this token is called the class embedding, the class token.
00:11:16.040 | And the idea comes from the BERT paper.
00:11:19.160 | Basically if you remember the transformer, when we have a sequence in the transformer
00:11:23.160 | encoder, the transformer basically allows, with its self-attention mechanism, to relate
00:11:28.360 | the tokens to each other.
00:11:30.040 | So in the output we will have again a sequence of tokens, but the embedding of each token
00:11:37.240 | will somehow capture the interaction of that token with all the other tokens.
00:11:42.620 | And this is the idea behind adding this class token we have here.
00:11:46.920 | So we send it to the transformer, and at the output of the transformer, as soon as we will
00:11:52.680 | get another sequence in the output of the transformer encoder, we just take the first
00:11:57.740 | token here, and we ask this token to map this token to a multilayer perceptron that has
00:12:06.480 | to predict the class.
00:12:08.640 | Why do we use this token?
00:12:10.080 | Because this token, because of the self-attention mechanism, has interacted with all of the
00:12:14.480 | other patches.
00:12:15.920 | So this token somehow captures the information of the other patches, and then we force the
00:12:23.040 | model to convey all this information to this single token.
00:12:27.120 | So this is the idea of the vision transformer and of the class token.
00:12:34.180 | They took the vision transformer and they transformed it into a masked autoencoder
00:12:38.080 | vision transformer.
00:12:39.240 | And this happened in another paper called Masked Autoencoders are Scalable Vision Learners
00:12:44.160 | from Facebook, from Meta.
00:12:46.920 | Now in this case they still have an input image, in this case this one, but what they
00:12:53.360 | did is they split it into patches, but then they masked it out, they deleted some patches
00:12:59.120 | and replaced it with zeros, so they hide some patches here, and if I remember correctly
00:13:06.840 | it's 75% of the patches are masked out.
00:13:10.920 | Then they only take the visible patches, create a sequence, a linear sequence here, they give
00:13:15.880 | it to the encoder of a transformer, which still produces as output a sequence as we
00:13:21.960 | saw before.
00:13:24.640 | Then what they do, they take this sequence that is the output of the encoder of the transformer,
00:13:31.160 | they again recreate the original image sequence, so if we knew that the first patch was empty,
00:13:40.160 | then they put an empty space here, then an empty, then the third one was visible, so
00:13:44.680 | they used the first embedding.
00:13:46.800 | Then the fourth, the fifth and the sixth were deleted, so four, five and six were deleted,
00:13:52.440 | then they take the next one visible and they put it here.
00:13:55.000 | So basically to the decoder they give the visible patches and the non-visible patches
00:14:00.920 | along with the geometric information of the visibility.
00:14:04.540 | So they are added in the same sequence in which they were cancelled out in the original
00:14:07.880 | image, and then they ask the decoder to predict the original image, only being able to visualize
00:14:14.600 | the embedding of the visible patches.
00:14:16.960 | So basically the decoder has to come up with a full image, being only able to access 25%
00:14:23.120 | of the image.
00:14:24.320 | And what they saw is that the decoder was actually able to rebuild the original image,
00:14:29.120 | maybe not with the perfect quality, but a reasonable quality.
00:14:33.800 | And what the authors of the segment editing paper did, they took this part of the masked
00:14:43.200 | autoencoder of the vision transformer, because they are interested in this, the embedding
00:14:48.680 | learned by the encoder, so the output of the encoder.
00:14:52.160 | Because if the model is able to predict the original image, given only this embedding
00:14:57.440 | of the visible patches, it means that these embeddings capture most of the information
00:15:02.320 | of the image, which can be then reused to rebuild the original image.
00:15:06.320 | So this is what they want, but this is what we want from an encoder.
00:15:09.960 | We want the encoder to create a representation of something that captures most of its salient
00:15:14.840 | information without caring about the extra information that is not necessary.
00:15:20.160 | So this allows you to reduce the dimensionality of the original image while preserving the
00:15:25.280 | information.
00:15:26.280 | And this is why they use the encoder of the masked autoencoder vision transformer.
00:15:33.880 | Now that we have seen what is the image encoder, which is basically creating an embedding,
00:15:39.600 | this one here, now we go to the next part, which is the prompt encoder.
00:15:43.240 | So the job of the prompt encoder is also to encode the prompt, which is the list of points
00:15:48.000 | chosen by the user, the boxes selected by the user and the text.
00:15:52.640 | We will not visualize what is the text encoder, which is basically just the encoder of the
00:15:56.960 | clip model.
00:15:58.160 | So if you are not familiar with the clip model, I suggest you watch my previous video about
00:16:03.640 | the clip model, and it's quite interesting, actually, it's per se, so it deserves its
00:16:10.280 | own video.
00:16:11.560 | But basically the idea is the same as with the image encoder.
00:16:14.520 | So we have a text and we want some representation that captures most of the information about
00:16:19.600 | this text.
00:16:21.000 | And this is done by the encoder of the clip text encoder.
00:16:26.400 | Let's have a look at the prompt encoder now.
00:16:29.220 | Now in the prompt in their paper, segment anything, they say that they consider two
00:16:33.680 | type of prompts, the sparse prompts, which is the points, boxes and text, and the dense
00:16:38.600 | prompts, which is the mask we saw before.
00:16:41.040 | For the text encoder, they just use the text encoder from clip, we can see here, while
00:16:46.160 | the other two prompts, so the points and the boxes are basically they take the points,
00:16:52.440 | they create a representation of this point.
00:16:54.800 | So an embedding that tells the model what is this point referring to inside of the image
00:17:00.220 | using the positional encoding.
00:17:02.980 | Let's see how does it work on a code level.
00:17:06.420 | Here we can see that basically, they take the sparse prompts, and they are mapped to
00:17:10.660 | 256 dimensional vector embeddings.
00:17:14.520 | So 256 dimensional vector embeddings.
00:17:19.620 | Basically here is how we encode the points.
00:17:23.960 | We have the points, and then we have labels.
00:17:27.020 | The points are a sequence of X and Ys, while the labels indicate if the point is additive,
00:17:34.700 | so we want the model to add something to our mask, or subtractive, we want the model to
00:17:40.180 | remove something.
00:17:41.660 | Here they are called foreground or background.
00:17:44.480 | Foreground means that we want the model to add something, background we want the model
00:17:47.660 | to remove something, just like we did with the example on the website before.
00:17:51.540 | So the first thing they do is they create the positional encoding of these points.
00:17:56.500 | So they convert the X and the Y into positional encodings, exactly like we do in the transformer
00:18:01.860 | model.
00:18:02.860 | As you remember in the transformer model we have positional encodings.
00:18:06.180 | They are special vectors that tell the model, that are combined with the embedding of each
00:18:11.620 | token to tell the model what is the position of the token inside of the sentence.
00:18:16.620 | And here the idea is the same, even if the positional encodings are different.
00:18:22.220 | I mean the idea is the same, so that it's a vector with the dimension 256, but they
00:18:27.640 | are built in a different way, and we will see why.
00:18:31.580 | But we have to think that they transform the X and the Y into vectors, each of them representing
00:18:37.400 | the position of the point inside of the image using the positional encoding.
00:18:42.860 | Then they need to tell the model what is this point.
00:18:46.700 | The model cannot know if that point is foreground or background.
00:18:50.700 | So how do we do that?
00:18:52.420 | Basically all the foreground points are summed to another embedding here.
00:18:58.500 | That indicates it's an embedding that indicates that that is a foreground point.
00:19:03.460 | And all the background points are summed to another embedding that indicates to the model
00:19:08.580 | that that point is a background point.
00:19:10.700 | And if the point is a padding, because we don't have enough points, then they use another
00:19:15.420 | special embedding here.
00:19:18.060 | And this is how they build the embedding for the points.
00:19:21.660 | While for the boxes, the boxes are defined using the top left corner, so X and Y of the
00:19:27.540 | top left corner, and the bottom right corner.
00:19:31.360 | And they do the same with the boxes.
00:19:33.760 | So basically they transform these two points, so the top left and the bottom right, using
00:19:40.100 | the positional encodings to tell the model what is that X and Y corresponding to inside
00:19:44.700 | of the image.
00:19:46.780 | And then they sum one embedding to indicate that it's a top left point and another embedding
00:19:52.660 | to indicate that it's a bottom right point.
00:19:55.660 | And this is how they build the encoding for the prompt.
00:19:59.540 | Why do we want to create 256 dimensional vector embeddings?
00:20:04.380 | Because the 266 dimension is also the one used for the image embedding.
00:20:11.420 | Because then we can combine them using a transformer.
00:20:16.120 | So the mask we saw before, what is the role of the mask?
00:20:18.740 | Now let's go into the detail of how it's combined with the image.
00:20:23.200 | So basically, the masks are called a dense prompt in the segment editing model.
00:20:28.900 | And what they do is, if the mask is specified, so here, okay, if the mask is specified, basically
00:20:37.820 | they run it through a sequence of layers of convolutions to downscale this mask.
00:20:48.540 | And then if no mask is specified, they create a special embedding called no mask.
00:20:55.860 | And it's defined here.
00:20:57.220 | So as you can see, it's just an embedding with a given dimension.
00:21:01.260 | How they combine this mask with the image, they just use a pointwise sum, as you can
00:21:07.140 | see here.
00:21:08.460 | So they take the image and they just add the dense prompt embeddings, which is the mask
00:21:15.580 | embeddings.
00:21:16.580 | Now, as we saw before, we have to use the positional encoding to tell the model what
00:21:21.320 | are the points that we are feeding to the model itself.
00:21:25.180 | So the model cannot know X and Y, the model has need to know some other information.
00:21:31.700 | We cannot just feed a list of X and Y to the model.
00:21:34.320 | We need to tell the model something more, something that can be learned by the model.
00:21:39.520 | And since the transformer models were very good at detecting the position using the sinusoidal
00:21:46.740 | positional encoding using the vanilla transformer.
00:21:49.060 | So if you remember in the vanilla transformer, so the transformer that was introduced in
00:21:52.780 | the paper, attention is all you need, if you remember, the positional encodings were built
00:21:57.980 | using sinusoidal functions.
00:22:00.380 | So sines and cosines combined together.
00:22:03.980 | And these vectors told the model what is the position of the token inside of the sentence.
00:22:09.060 | Now, this was fine as long as we worked with text, because text only move along one dimension,
00:22:14.780 | that is, we have the token number zero, the token number one, the token number two, etc.
00:22:19.860 | But pixels don't move in one direction, they move into two directions.
00:22:23.700 | So one person, of course, one could think, why not use the one positional encoding for
00:22:30.300 | the X coordinate to convert the X coordinate into a vector and another to map the Y coordinate
00:22:37.140 | into a vector?
00:22:38.140 | Yeah, we could do this.
00:22:39.980 | But the problem is, if we do in this way, suppose we convert the center position of
00:22:46.980 | the image into two vectors, one encoded using the X coordinate and one encoded using the
00:22:53.140 | Y coordinate.
00:22:54.820 | What we do if we check the similarity with the other position in the image is we get
00:22:58.780 | some hitmap like this, in which the zero, this position here is very similar to this
00:23:04.300 | position here.
00:23:05.980 | But it's not similar to this position here, which is not good, because in the in an image,
00:23:12.580 | we have the Euclidean distance.
00:23:14.420 | So pixel at the same Euclidean distance should have similarity with another point at the
00:23:21.020 | same distance.
00:23:22.020 | So basically, this point and this point should have a similarity that is the same as this
00:23:27.220 | point and this point.
00:23:30.860 | Because we have a spatial representation, so pixels that are close to each other should
00:23:35.060 | be very similar, pixels that are far from each other should not be very similar.
00:23:40.100 | But this is not what happens in this hitmap.
00:23:43.340 | What we want is something like this, that is, if we have a point here, all the points
00:23:47.540 | in the radius of, let's say, 10 pixels are very similar.
00:23:51.500 | All the points in the radius of 20, so distance 20, are less similar, but still, depending
00:24:01.740 | on the radius, they are similar in the same way as the other points with the same radius.
00:24:07.380 | And the more we go far from the center, the more we become distant, the more we become
00:24:13.220 | different.
00:24:14.220 | And this is what we want from positional encodings for an image.
00:24:18.620 | And this idea was introduced in this paper.
00:24:22.180 | You can see learnable Fourier features from multidimensional spatial positional encoding.
00:24:28.020 | And however, this is not the paper used by Segment Anything.
00:24:31.020 | For Segment Anything, they use this paper here, but you understood why we needed a new
00:24:35.780 | kind of positional encoding.
00:24:38.100 | Basically because we need to map two-dimensional mapping of X and Y to... we need to give an
00:24:47.340 | X and Y mapping to the model.
00:24:49.820 | We cannot just give them independent.
00:24:52.300 | And now let's look at the most important part of the model, which is the decoder.
00:24:56.900 | Now before we look at the decoder, I want to remind you that in the model that we saw
00:25:01.860 | before on the web browser, we could add the points on the real time in the browser.
00:25:09.940 | That is, we loaded the image, and then I clicked on the model, and basically after a few milliseconds,
00:25:16.860 | let's say 100 milliseconds or half a second, I saw the output of the model.
00:25:21.700 | This could happen because the decoder is very fast, and the prompt encoder is very fast.
00:25:27.660 | But the image encoder doesn't have to be very fast, because we only encode the image once
00:25:32.800 | when we load the image.
00:25:34.040 | Then we can do multiple prompts, so we can save the image embeddings, and then we just
00:25:40.580 | change the prompt embeddings and run them through the decoder to get the new masks.
00:25:45.580 | Which means basically that the image encoder can be powerful, even if it's slow, but the
00:25:51.380 | mask decoder has to be lightweight and fast.
00:25:55.740 | And the same goes for the prompt encoder, and this is actually the case, because that's
00:25:59.340 | why we could use it on my browser in a reasonable time.
00:26:04.740 | So the mask decoder is made in this way.
00:26:09.460 | It's made of two layers, so we have to think that there is this block here, is repeated
00:26:16.980 | again with another block that is after this one, where the output of this big block is
00:26:22.860 | fed to the other block, and the output of that block is actually sent to the model here.
00:26:28.460 | Let me delete that.
00:26:33.300 | Okay.
00:26:34.980 | Now let's look at the input of this decoder.
00:26:38.540 | First of all, we have the prompts here.
00:26:41.020 | So the prompts sent by the user, so the clicks, the boxes, and then we have the image embedding
00:26:47.260 | plus the mask we can see in the picture here.
00:26:50.540 | So the image has already been combined with the mask through this layer here, through
00:26:55.420 | this addition, element-wise addition.
00:27:00.060 | The first thing the model, the decoder does is the self-attention.
00:27:04.100 | So the self-attention between the prompt tokens and with the prompt tokens.
00:27:10.360 | But here we can also see that there are these output tokens.
00:27:14.000 | So before we proceed to see these steps, let's watch what are the output tokens.
00:27:19.880 | The output tokens take the idea also from BERT.
00:27:23.180 | So as you remember before, we saw the vision transformer, right?
00:27:26.760 | In the vision transformer, when they fed the patches to the transformer encoder, they also
00:27:32.060 | prepended another token called the class.
00:27:35.380 | And the same idea is reused by segment anything, in which they append some tokens before the
00:27:41.860 | promptable tokens, so the boxes, the clicks made by the user.
00:27:47.940 | And then at the output of this decoder, they check again only these tokens and force the
00:27:54.220 | model to put all the information into these tokens.
00:27:57.740 | So in this case, we have one token that tells the IOU, so the intersection over union scores
00:28:05.080 | of the predicted masks.
00:28:08.040 | And we will see later what is the IOU, if you're not familiar with it.
00:28:11.880 | And then there are three mask tokens, so one token for each mask.
00:28:17.700 | And basically, we feed these four tokens, so one IOU and three masks to the model.
00:28:24.340 | We take them here at the output.
00:28:26.920 | And then we use the first token, so the IOU token, before we map it to a multi-layer perceptron,
00:28:34.580 | to force the model to learn the IOU score into these tokens.
00:28:38.580 | And then the other three are used to predict the three masks as output.
00:28:44.580 | And we can see that here, they use the idea just like in the BERT paper.
00:28:51.900 | And this is the reference to the BERT paper in which they introduced the CLS token, in
00:28:56.020 | which in BERT was used for classification tasks.
00:28:59.060 | So basically, also in BERT, they prepended this token called the CLS.
00:29:05.140 | And then, at the output of the transformer, they just took the token corresponding to
00:29:09.980 | this CLS, which was the first one, and they forced the model to learn all the information
00:29:16.340 | it needed to classify into this CLS token.
00:29:19.580 | Why this works?
00:29:20.580 | Because the CLS token could interact with all the other tokens through the self-attention
00:29:24.860 | mechanism.
00:29:26.040 | And the same idea is reused here.
00:29:28.140 | So we feed the model with output tokens combined with the PROM tokens.
00:29:33.260 | And you can see here, they just concatenate the two.
00:29:36.060 | They take the IOU token and the mask token, they concatenate together.
00:29:39.420 | And then they concatenate these output tokens with the PROM tokens you can see here.
00:29:44.820 | Now the second part, they run attention.
00:29:50.360 | So first, they run the self-attention with the tokens, which are the output tokens plus
00:29:55.400 | the PROM tokens.
00:29:56.620 | And this part is here.
00:29:58.540 | You can see that the query, the key, and the values are the same.
00:30:01.820 | And they are the PROM tokens.
00:30:05.140 | The comments have been added by me, they are not present in the original code to make your
00:30:08.800 | life easier.
00:30:10.180 | Even if I found it really hard to follow this nomenclature, because sometimes they use the
00:30:14.660 | name is called Q, but then they pass it as K or P, etc.
00:30:18.940 | But hopefully it's clear enough.
00:30:21.420 | What we want to get from this code is actually not the single instructions, but the overall
00:30:29.500 | concepts.
00:30:30.700 | So the output of this self-attention is then fed to a cross-attention.
00:30:35.820 | What is a cross-attention?
00:30:36.820 | Basically, a cross-attention is an attention in which the query comes from one side, and
00:30:42.220 | the key and the value come from another side.
00:30:44.820 | If you remember my video about the transformer model, in a translation task usually, imagine
00:30:51.760 | we are translating from English to Italian, or English to French, or English to Chinese.
00:30:57.340 | Basically, we first run in the encoder a self-attention between all the input sentence, so all the
00:31:03.380 | tokens of the input sentence related to all the other tokens of the input sentence.
00:31:08.180 | And then in the decoder, we have this cross-attention in which we take the queries coming from one
00:31:14.560 | language and the key and the values coming from another language.
00:31:18.780 | And this is usually done to combine two different sequences, to relate two different sequences
00:31:26.800 | with each other.
00:31:28.120 | In this case, what we want is to relate the tokens, so our prompt, with the image embeddings.
00:31:33.180 | This is why we do a cross-attention.
00:31:35.300 | So the first cross-attention is the tokens used as queries, while the image embeddings
00:31:41.140 | are used as keys and values.
00:31:44.220 | And this is the first cross-attention.
00:31:47.220 | Then there is a multilayer perceptron, so it's just linear layers.
00:31:52.880 | And finally, we have another cross-attention, but this time the opposite.
00:31:56.740 | So in this case, the queries are the image embeddings and the keys and the values are
00:32:00.900 | the prompt tokens.
00:32:03.160 | Why do we want two cross-attentions?
00:32:05.100 | Because we want two outputs from this transformer.
00:32:08.020 | One will be the sequence of tokens of the prompt, from which we will extract the output
00:32:12.700 | tokens, one indicating the IOU score and three indicating the mask.
00:32:17.600 | And one is the image embedding that we will then combine with the output tokens of the
00:32:22.860 | mask to build the mask.
00:32:24.880 | But we will see this later.
00:32:27.980 | Another thing here highlighted in the paper is that to ensure that the coder has critical
00:32:32.760 | geometric information, the positional encodings are added to the image embedding whenever
00:32:37.540 | they participate in an attention layer.
00:32:39.940 | And as you can see, this is done not only for the image, but also for the prompt.
00:32:43.820 | So every time to the prompt, they add the positional encoding, and every time to the
00:32:48.020 | image embedding, they also add the positional encoding.
00:32:51.440 | Why do we keep adding them?
00:32:52.960 | Because we don't want to lose this information after all these layers.
00:32:58.080 | And this is usually done with a skip connection.
00:33:00.760 | And in this case, they just add them back.
00:33:04.560 | And now let's have a look at the output.
00:33:07.160 | As we saw before, we have a special layer of tokens added to the input of the transformer.
00:33:15.440 | One is for the IOU prediction and three are for the mask prediction.
00:33:19.040 | And you can see them here.
00:33:22.040 | Because our transformer model has two cross-attentions, we have two outputs, two sequences outputs.
00:33:28.720 | One is the output sequence of the tokens, and one is the output sequence of the embeddings
00:33:33.240 | of the image.
00:33:35.980 | And they extract the IOU token, which is the first token added to the sequence, and then
00:33:41.680 | they extract the mask tokens, which are the first three, skipping the first one, the next
00:33:47.040 | three tokens.
00:33:49.380 | Then what they do?
00:33:50.380 | They give the IOU tokens, they just give it to a prediction head to predict the IOU scores,
00:33:55.760 | and we can see that here.
00:34:00.040 | And then they take the output tokens for the masks, we can see them here, and they combine
00:34:09.680 | them with the upscaled embedding of the image.
00:34:14.280 | So they take the output of the transformer for the image, so SRC in this case, the variable
00:34:21.060 | name is SRC, they upscale it here, and then they run each of the mask output tokens through
00:34:32.400 | its own MLP layer.
00:34:34.520 | So here you can see we have multiple MLP blocks here.
00:34:38.880 | Each of the tokens have its own MLP block.
00:34:42.000 | They run each of them through its own, they get the output, and then they combine this
00:34:47.360 | output of this MLP, one for each token, one for each mask, with the upscale embedding
00:34:53.160 | of the image to produce the output mask here.
00:34:58.000 | Another interesting part of the paper is this section here, making the model ambiguity aware.
00:35:03.200 | So as I was saying before, we not only predict one mask, but we predict three masks.
00:35:08.200 | And this happens when we do not have more than one prompt.
00:35:13.760 | So if we only click one, for example, once, the model will produce three masks.
00:35:18.960 | But if we have more than one prompt, because the ambiguity becomes less, at least theoretically,
00:35:26.000 | the authors decided to add a fourth token that predicts another mask that is used only
00:35:33.960 | when we have more than one prompt.
00:35:37.380 | And this mask is never returned for the single prompt.
00:35:44.360 | Now let's have a look at what is intersection over union.
00:35:47.560 | Intersection over union allows us to understand how good is our prediction given the ground
00:35:54.040 | truth, especially in segmentation models or object detection models.
00:36:00.440 | So for example, imagine we are using object detection and our ground truth box is this
00:36:06.200 | green box here, but our object, our model produced this red prediction as output.
00:36:12.680 | So because you can see that even if there is some overlapping, but it doesn't cover
00:36:16.680 | the entire image, the prediction is quite poor.
00:36:21.580 | But this improves when the box becomes bigger.
00:36:26.000 | So the red box becomes bigger.
00:36:27.700 | So there is more intersection, but also more union.
00:36:32.080 | And finally, it becomes excellent when the intersection is covering all the box and it's
00:36:38.840 | covering all the union.
00:36:40.120 | So the union of the two, basically the same box and they cover as most as possible.
00:36:45.520 | This area here, so the area that is predicted, but was not asked in the ground truth is called
00:36:52.520 | false.
00:36:53.520 | This one here is called false positive, while the area that should have been predicted,
00:36:59.240 | but was not predicted is called false negative.
00:37:02.960 | This is a commonly used term also in this kind of scenarios.
00:37:10.080 | Now let's have a look at the loss.
00:37:11.680 | The loss of the model is a combination of two loss.
00:37:14.120 | One is called the focal loss and one is the dice loss, and they are used in a ratio of
00:37:18.040 | 20 to one.
00:37:20.100 | Let's have a look at the focal loss.
00:37:22.440 | The focal loss takes his idea from the cross entropy, but with a modification that is the
00:37:28.440 | focal loss is adjusted for class imbalance.
00:37:32.440 | So why do we have a class imbalance in this case?
00:37:35.560 | Because imagine we are using a segmentation.
00:37:38.280 | We are trying to predict the map, the mask for a particular object in our image.
00:37:45.160 | But of course, usually the mask is not covering the entire image, but it's only very few pixels
00:37:51.080 | compared to the total image are actually participating in this mask.
00:37:55.960 | And the instances of big mask are actually not so many.
00:37:59.360 | So we have a class imbalance here because most of our pixel will be non mask and only
00:38:03.600 | a small percentage of our pixel will be mask.
00:38:06.600 | So we cannot use cross entropy in this case because the cross entropy doesn't pay attention
00:38:10.120 | to this class imbalance.
00:38:11.560 | So this is why they use focal loss to pay attention to this class imbalance.
00:38:15.760 | But the focal loss derives from the cross entropy and it was introduced in this paper
00:38:20.640 | by Facebook research, you can see here focal loss for dense object detection.
00:38:26.720 | The next loss is the dice loss.
00:38:29.100 | The dice loss comes from the Soren dice coefficient and it's also called the F1 score.
00:38:34.880 | And it's calculated as the total to twice the intersection.
00:38:38.600 | So twice the area of overlap divided by the total area.
00:38:43.360 | And this is the actually a measure of similarity of between two sets of data.
00:38:49.640 | To get the loss, we just do one minus the dice score.
00:38:53.180 | If you want more information about this dice score, which is very commonly used, I suggest
00:38:57.480 | you click on this link.
00:38:58.820 | It's on Medium.
00:39:00.740 | It's a nice article on how it works.
00:39:05.160 | And the dice loss was introduced in this paper VNet, I think it's from 2015.
00:39:11.740 | Another interesting thing is that a segment anything built its own data set.
00:39:15.840 | And this is remarkable because we saw before that the segment anything has been trained
00:39:19.720 | on 1.1 billion masks using millions of images.
00:39:25.360 | And the data engine that was used to build this data set of 1.1 billion mask is composed
00:39:32.200 | of three stages.
00:39:33.880 | The first one was a manual stage, then a semi automatic stage and then a fully automatic
00:39:38.160 | stage.
00:39:39.160 | Let's review them.
00:39:41.920 | In the assisted manual stage, so the manual stage, basically they hired a team of professional
00:39:46.280 | annotators that manually labeled the images using only the brush and the eraser tool.
00:39:52.700 | So basically you have to think that there are many people who are using only pixel by
00:39:57.480 | pixel mapping this pixel to masks.
00:40:00.920 | So this is what we would do to create a data set from zero.
00:40:05.120 | Then they train this model on this manually created masks.
00:40:09.680 | And then they went to the semi automatic stage, that is, some of the masks were already generated
00:40:15.240 | by our model, which was trained on the manually generated mask.
00:40:19.720 | And then the operators only had to adjust this mask to annotate any additional annotated
00:40:26.580 | objects that were missed from the model.
00:40:30.120 | And finally, this create even more samples and they train the model on this sample.
00:40:34.480 | And finally, they created the fully automatic stage.
00:40:40.720 | In this fully automatic stage, the model, there is no operator.
00:40:44.760 | The model is building the data set by itself.
00:40:48.000 | How does it do?
00:40:49.000 | They take an image, they create a grid of 32 by 32 points.
00:40:53.760 | And then for each of these points, they ask the model to predict the masks.
00:40:59.480 | Of course, this will produce a lot, a large number of masks.
00:41:02.600 | So they only take the one with the highest confidence score and also the only one that
00:41:07.460 | are stable.
00:41:08.640 | And by stable, they mean that if they threshold the probability map at 0.5 minus delta and
00:41:14.680 | 0.5 plus delta, they result in similar masks.
00:41:19.360 | Next, because we have a lot of masks and some of them may be overlapping with each other,
00:41:24.520 | some of them may be duplicate, actually, we need to remove some of them.
00:41:29.440 | So we use an algorithm called the non-maximal suppression.
00:41:32.360 | This is very famous also in object detection.
00:41:34.760 | Let's review how it works.
00:41:36.800 | So non-maximal suppression usually works like this.
00:41:40.480 | Imagine we have an object detection model.
00:41:43.760 | Usually when we detect a bounding box for an object, we get a lot of bounding boxes.
00:41:49.020 | And how do we only select one?
00:41:51.120 | Well, basically, we take the one with the highest confidence score, and then we delete
00:41:56.480 | all the other bounding boxes that have an IOU threshold with the one that we selected
00:42:04.680 | higher than one threshold that is given as parameter.
00:42:08.740 | This allow us to eliminate all the bounding boxes that are similar to the one we have
00:42:12.960 | selected.
00:42:13.960 | And which one did we select?
00:42:15.160 | The one with the highest score.
00:42:17.240 | And then we do this for all the remaining boxes.
00:42:21.220 | And the algorithm is very simple, and it's also very effective.
00:42:26.720 | Thank you guys for watching my video about the segment anything.
00:42:30.640 | I hope that most of the information was clear.
00:42:34.440 | If not, please let me know in the comments.
00:42:36.160 | I will try to complement my errors or some misunderstanding or something that I should
00:42:42.320 | have said better.
00:42:45.000 | I please subscribe to my channel because I will be uploading more videos in the future.
00:42:51.360 | And hopefully see you again.