back to indexMistral / Mixtral Explained: Sliding Window Attention, Sparse Mixture of Experts, Rolling Buffer
Chapters
0:0 Introduction
2:9 Transformer vs Mistral
5:35 Mistral 7B vs Mistral 8x7B
8:25 Sliding Window Attention
33:44 KV-Cache with Rolling Buffer Cache
49:27 Pre-Fill and Chunking
57:0 Sparse Mixture of Experts (SMoE)
64:22 Model Sharding
66:14 Pipeline Parallelism
71:11 xformers (block attention)
84:7 Conclusion
00:00:00.000 |
Hello guys, welcome back to my channel today, we are gonna talk about Mistral 00:00:03.780 |
So as you know Mistral is a new language model that came out a few months ago from Mistral AI 00:00:09.280 |
Which is a one of the hottest to start up right now in Europe for language models 00:00:13.720 |
It also became a unicorn recently and we will exploring both the models 00:00:18.580 |
They released the one is the 7 billion and one is the 8 by 7 billion model 00:00:23.960 |
the first thing I will introduce you is the architectural differences between the vanilla transformer and the 00:00:30.920 |
We will see what is the sliding window attention and how it is related to the concept of receptive field a concept that usually we 00:00:39.440 |
I will briefly review the KB cache because I want to introduce the concept of rolling buffer cache and also how it is done 00:00:48.480 |
we will see what is a sparse mixture of experts model sharding with a little 00:00:59.840 |
Last but not least we will also go through the code of Mistral because there is a lot of innovations in the code 00:01:06.160 |
Especially when they use the Xformers library with the block attention 00:01:09.440 |
So I want to guide you into understanding the code because it can be really hard for beginners to understand and find themselves around 00:01:16.880 |
There are some topics that are related to Mistral 00:01:19.880 |
But will not be covered in this current video because I already covered them in my previous video about Lama and in particular 00:01:25.960 |
I will not be talking about the RMS normalization, the rotary positional encoding and the grouped query attention because I already 00:01:31.200 |
Teach them in depth in my previous video on Lama 00:01:34.240 |
So if you want to know about them, please watch my previous video on Lama 00:01:37.600 |
In order the only prerequisite that I hope you have before watching this video because the topics we are going to touch are quite advanced 00:01:45.360 |
Is that you are familiar with the transformer model 00:01:48.080 |
So if you are not familiar with the transformer model and the attention mechanism in particular and in particular the self attention mechanism 00:01:54.160 |
Please go watch my video on the transformer in which I teach all this concept very thoroughly very in detail 00:02:00.360 |
These are really a prerequisite for watching this video because the topics here are quite advanced 00:02:08.360 |
So let's watch the differences between the vanilla transformer and Mistral at the architecture level 00:02:16.200 |
Image here, which I built by myself using the code because they didn't release any architecture picture in the paper 00:02:24.000 |
And the architecture of Mistral first of all, let's talk about some terminology 00:02:29.040 |
When you have a model like this made up of many encoder layers plus linear and the softmax 00:02:35.200 |
We are talking about a decoder only model because this part this model here looks like the decoder of the vanilla 00:02:42.840 |
Transformer you can see here except for the cross attention because as you can see here, there is no cross attention 00:02:49.360 |
When we have a model without the linear and the softmax we call it an encoder only model 00:02:54.920 |
For example BERT is an encoder only model because BERT has some heads at the end 00:03:00.080 |
Which is one or more linear layers depending on the application 00:03:04.080 |
But itself BERT doesn't need a head because it can be used for multiple downstream tasks 00:03:08.560 |
so it's called an encoder only model because it resembles the 00:03:11.800 |
Encoder side of the transformer because as you can see in the encoder side, there is no linear and softmax 00:03:17.240 |
So Mistral is a decoder only model and it's very similar if not equal to Lama 00:03:23.320 |
The differences between Lama and Mistral are highlighted here in red 00:03:28.000 |
the first difference between Lama and Mistral is that 00:03:31.480 |
In the self attention we use the sliding window attention and we still use the grouped query attention 00:03:42.320 |
But this is a rolling buffer KV cache and it's actually related to the fact that we are using sliding window attention 00:03:47.680 |
So later, we will see all these concepts and the another difference is that the feedforward layer here instead of using the relu function that we use 00:03:55.880 |
In the vanilla transformer or the ZWIGLU function that we use in Lama here in Mistral we use the SILU function 00:04:02.480 |
And the feedforward is one in case of Mistral 7b 00:04:08.960 |
So the first model they released and it can be eight feedforward 00:04:13.840 |
Networks in parallel with each other which are the experts of this mixture of experts in the case of Mistral 8x7b 00:04:24.520 |
So for now, you just need to understand that Mistral is made up of okay the input which are converted into embeddings 00:04:31.080 |
Then we have this block which is repeated n times and we will see that in the case of Mistral is repeated 00:04:40.120 |
Output of each layer is fed to the next layer as input and the output of the last layer is then sent to this 00:04:47.920 |
RMS norm to the linear and to the softmax to produce the output of the model and 00:04:52.760 |
This is exactly the same as what we do with any other transformer model 00:04:57.880 |
Usually we have many of these blocks here. Now in the code of Mistral this part here is known as transformer block 00:05:05.560 |
But it's also known as encoder block or decoder block depending on the contents of this 00:05:12.960 |
Block here. I will refer to it as an encoder block because if you look at it 00:05:17.640 |
It looks like exactly as the block of the encoder side 00:05:20.560 |
So it has a multi-head attention, add-end norm, a feedforward and add-end norm 00:05:24.720 |
The only difference is that the normalization here comes before the the block of the feedforward and the self-attention 00:05:41.160 |
The parameter dim indicates the dimensions of the the dimensions of the embedding vector 00:05:47.660 |
So how big is the embedding vector? So each token is represented by an embedding vector of size 00:05:55.360 |
We have 32 of the encoder layers. So this block here is repeated 32 times 00:06:02.880 |
The head dimension indicates as you remember in the multi-head attention we have 00:06:11.960 |
But only a part of the embedding of each token and this indicates how much 00:06:17.280 |
How many dimensions each head will attend to in each 00:06:21.920 |
For the in the multi-head attention and the hidden dimension here indicates the hidden dimension of the feedforward layer 00:06:28.760 |
so if the in the case of the feedforward layer, we have two linear layers one that converts the 00:06:33.600 |
Dimension of the embedding vector into the hidden size then another one that converts the hidden size back into the embedding vector dimensions 00:06:42.440 |
So in the case of the Mistral they are using as a hidden size 00:06:45.980 |
14336 usually this is a multiple of the dimension and it looks like it's 3.5 the dimension here 00:06:55.520 |
The number of heads of attention for the query is a 32 while the number of heads for the K and V 00:07:01.680 |
So the key and values is 8 and they are not equal because of the grouped query attention 00:07:06.320 |
So if you remember from my previous video on llama in which we talked about the grouped query attention 00:07:10.420 |
In the very simple case of the grouped query attention 00:07:14.360 |
We have the multi query attention which means that only the query have the multi head while the key and V don't have the multi head 00:07:21.680 |
Which means that you may have eight heads for the query and only one head for the K and V in the case of the 00:07:27.600 |
grouped query attention means that each group of 00:07:36.000 |
Every four query have one attention head for the keys and values if this concept is not clear 00:07:42.280 |
I describe it very thoroughly in my previous video on llama 00:07:46.240 |
the window size is the size of the sliding window that we used in the 00:07:51.420 |
Calculation of the attention and we will see later how it works. The context length is 00:07:56.920 |
What is the context size upon which the model was trained upon? 00:08:04.720 |
The vocabulary size is the same for both and then the last two parameters 00:08:08.280 |
You can see here are related to the sparse mixture of experts and we will see later 00:08:13.440 |
How it works, but we just remember that we have eight experts and for each token we use two experts 00:08:21.800 |
Let's proceed further. So let's talk about the sliding window attention. But before we talk about the sliding window attention 00:08:28.440 |
I need to review a little bit of the self attention mechanism. So what is self attention? 00:08:33.720 |
Self attention is a mechanism that allows the model to relate tokens to each other 00:08:39.160 |
So tokens that are in the same sentence are related with each other through the self attention mechanism 00:08:44.160 |
This is why it's called self attention because each token is watching other tokens of the of the same sentence 00:08:50.640 |
And when when this is means basically that the query key and values are the same matrix 00:08:57.000 |
So imagine we have the following sentence the cat is on a chair 00:09:03.840 |
We have our query which is a matrix made up of six tokens each token represented by 00:09:09.760 |
4096 dimensions, which is the dim parameter that we saw before 00:09:13.480 |
This is multiplied by the transpose of the keys, which is 00:09:17.800 |
4096 by 6 but it's just the query matrix transpose because the query key and values are the same matrix in the case of self attention 00:09:29.840 |
Because the inner two dimensions kind of cancel out and the outer dimensions indicate the dimension of the output matrix here 00:09:37.320 |
Now, what are the values in this matrix representing? 00:09:41.840 |
The first value here indicates the dot product of the first token with the first 00:09:47.920 |
The first row of the query with the first column of the keys 00:09:52.280 |
So basically the dot product of the embedding of the first token with itself 00:09:56.840 |
The second value here indicates the dot product of the first row of the query matrix 00:10:02.560 |
With the second column of the key matrix here the transpose of the keys matrix here 00:10:07.400 |
Which basically means that it's the dot product of the embedding of the first token 00:10:12.040 |
So the with the embedding of the second token, which is cat and etc, etc for all the other values 00:10:18.480 |
Don't concentrate too much on the values because all the values I put here are random 00:10:22.240 |
And also the fact that these numbers are less than one 00:10:25.120 |
It's not necessary because the dot product can be bigger than one. It's not a condition of the dot product 00:10:30.440 |
Usually in the formula we also normalize here we divide by the dimension of the dk 00:10:38.160 |
dk basically it's the size, the part of the embedding to which this particular attention head will attend to 00:10:49.680 |
One head so dk is equal to d model. So basically this head will watch the full embedding of each token 00:10:56.880 |
Okay, usually we train autoregressive models. So language model is an autoregressive model 00:11:03.160 |
It means that the output depends on the previous 00:11:05.840 |
The next token depends only on the previous tokens 00:11:09.240 |
And this is why we apply a causal mask. Causal mask means basically that in the attention mechanism 00:11:15.280 |
We don't want to release a word with future words 00:11:19.000 |
So words that come after it but only with the words that come before it 00:11:23.280 |
so for example, we don't want the word "the" to be related to the word "cat" because 00:11:28.680 |
The word "cat" comes after the word "the" but on the other hand 00:11:33.840 |
We want the word "cat" to be related to the word "the" because it comes before it 00:11:38.480 |
And for this reason we apply this causal mask 00:11:40.960 |
Because the attention mechanism uses the softmax function we can see here. The softmax function basically 00:11:49.440 |
will transform all this minus infinity into zero because the formula of the softmax has at numerator an e to the power of X and 00:11:57.640 |
When X goes to minus infinity e to the power of minus infinity will go to zero 00:12:02.360 |
So this is why we apply a mask in which we put all the values that we don't want, all the interactions that we don't 00:12:08.320 |
Want between tokens. We just mask them out by replacing them with minus infinity 00:12:12.920 |
So that when we apply the softmax, the softmax will take care of 00:12:20.320 |
Okay, also the softmax will do another thing because it will not only convert this minus infinities to zero 00:12:27.880 |
But it will also modify the other value for each row such that they sum up to one 00:12:35.280 |
They don't sum up to one for each row, right? Because this is 0.2, 0.1 and 0.2 00:12:39.760 |
They don't sum up to one but the softmax will convert the minus infinities into zero and the remaining values for each row such 00:12:48.320 |
Now let's talk about sliding window attention 00:12:50.960 |
So we applied the causal mask to hide the interactions between 00:12:55.120 |
The words, a word and all the future words, but with the sliding window attention 00:13:03.160 |
Other words that are outside its local context 00:13:08.880 |
In the previous case when we only applied the causal mask, the word chair for example was being related to all the previous 00:13:15.120 |
Tokens as you can see, so the token chair here is related to itself 00:13:19.000 |
But also to the a on is cat v, so it could watch basically all the sentence 00:13:25.120 |
But in the case of sliding window attention, we don't want the word chair to watch words that are further than 00:13:35.840 |
The sliding window size in this case is three, so tokens that are distance more than three from the word 00:13:42.800 |
we are considering, so the word the chair should not be related to the word is because the distance is four and the word 00:13:48.240 |
a should not be related to the word cat because the distance is four and of course 00:13:52.800 |
We still want the mask to be causal because we don't want the model to 00:13:57.140 |
Each token to watch future words because we are training an autoregressive model 00:14:01.880 |
So the sliding window attention basically reduces the number of dot products that we are performing 00:14:09.720 |
And this will improve the performance during the training and the inference because as you can see when we only apply the causal mask 00:14:15.680 |
We are performing all these dot products you see here, but with the sliding window attention 00:14:20.720 |
We are performing less dot products because all the other will be masked out 00:14:24.240 |
Sliding window attention however may lead to degradation of the performance of the model because as you can see here 00:14:31.640 |
The word the chair and the word the are not related to each other anymore, right? So the information 00:14:38.680 |
Will not be conveyed from the word the and the word chair the word chair will only be related to other tokens that are belonging 00:14:45.600 |
to the local context of this particular token so only the 00:14:49.160 |
Tokens that are in the same in the inside this is sliding window 00:14:53.260 |
This may be if this window is too small it may reduce the performance of the model 00:14:58.480 |
But it may also be beneficial because for example imagine you are reading a book you don't care about 00:15:04.180 |
Relating the word in chapter 5 with the words in chapter 1 because most of the books 00:15:09.120 |
They could be talking about totally different things and you don't even care about relating these two tokens 00:15:17.700 |
In the chapter 5 with other tokens in the chapter 5 because the local context matters 00:15:22.980 |
But I want to introduce you the concept of receptive field because when we use sliding window attention 00:15:29.580 |
Even if the word chair and the are not related to each other actually 00:15:34.520 |
Because in Mistral and in all transformer models we use multiple layers of encoders 00:15:40.460 |
We will see that the information so the the word the chair and the the will still be kind of related to 00:15:49.940 |
In a concept that is very similar to the receptive field of the convolutional neural networks 00:15:55.600 |
So let's talk about the receptive field as you remember in convolutional neural networks 00:16:01.520 |
We have a mask a kernel that we run through an image. So imagine this is our original image 00:16:07.760 |
This one here and we run a mask that is a kernel that is 3x3 this one here 00:16:13.600 |
That when we run a kernel it will produce an output feature. So for example this feature here 00:16:19.120 |
This is the output produced by applying the kernel to the first 3x3 grid here 00:16:25.440 |
This value here the second value here in yellow 00:16:29.480 |
It will be produced when we will move our kernel to the next group of 3x3 pixels. So let me draw 00:16:41.720 |
will be produced when we will move our kernel in this grid here and 00:16:50.920 |
This value here is also an output feature of a 00:16:59.160 |
Applied to this layer 2. So this is a 3x3 kernel that is applied to this layer 2 00:17:05.840 |
So apparently there is no connection between this one this pixel here and this one 00:17:11.820 |
But because he this this output feature depends on a kernel applied in this grid and this grid 00:17:19.520 |
Includes this feature here which depends on this pixel here. We can safely say that this feature here 00:17:30.520 |
Even if they are not directly related to each other and this is the concept of the receptive field. So basically 00:17:36.200 |
One feature of the convolutional neural networks can watch a much bigger receptive field 00:17:49.400 |
Sequential application of kernels in the convolutional kernels 00:17:53.240 |
Let's see how this concept is related to the sliding window attention now 00:18:00.280 |
After we apply the softmax to the mask that we have seen before as I told you before all the minus infinities are 00:18:06.840 |
Converted into zero and all the other values are changed in such a way that they sum up to one 00:18:12.100 |
So let's go back as you remember here. We have the minus infinities here here here here here and here 00:18:17.920 |
So now we apply the softmax and it will become zeros zeros here. Also, let me 00:18:24.680 |
Okay, all the zeros here all the zeros here and all the other values are changed in such a way that they sum up to 00:18:30.560 |
One what is the next operation that we do in the self-attention? 00:18:33.640 |
We then take the output of the softmax and multiply it by the V matrix. So let's do it 00:18:38.660 |
The V matrix is basically the same as the initial sequence because I told you this is self-attention 00:18:46.360 |
So the query key and values are the same matrix 00:18:49.160 |
So this means let's analyze what happens by hand when we do this multiplication 00:18:54.640 |
So let me change to the pen. Okay, the V matrix here is 00:18:59.200 |
is a sequence of tokens where each token is a vector represented by 00:19:04.960 |
4096 dimensions so we can say that it's the output of the self-attention if you watch the 00:19:12.760 |
Dimensions of these two matrices. So it's a 6 by 6 and the 6 by 00:19:16.040 |
4096 the output will be another matrix that is 6 by 00:19:19.320 |
4096 so it will have the same dimension as the V matrix and also as the Q and the 00:19:24.640 |
Query matrix because they have the same dimensions. So it will be six tokens as output 00:19:31.120 |
Okay, let's analyze. What is the first dimension of the output to this one here? 00:19:37.720 |
So this first value of the output so the value on the row 1, column 1 of the output matrix will be the dot 00:19:45.200 |
Product of the first row of this matrix here. So this row here 00:19:50.640 |
With the first column of this matrix, so the first column we can see here and 00:19:58.040 |
As you can see most of the values here are 0 which means that all the rows from the 1 to 5 00:20:09.340 |
but only the first row here fully the values of the first row will be used because if you remember the dot product is the 00:20:15.660 |
first dimension with the first dimension of this column and 00:20:19.560 |
The second dimension of this row with the second dimension of this column 00:20:23.960 |
The third dimension of this row with the third dimension of this column and then we sum up all these values 00:20:29.360 |
So this first value of the output will only depend on the first token of the V matrix 00:20:36.440 |
You can see here. Let's check the second one the second 00:20:42.400 |
the first dimension of the second row of the output matrix will be the dot product of the first row of 00:20:55.760 |
but most of the values are 0 which means that this 00:21:00.140 |
dimension here and all the dimensions in this row will depend only on the first two tokens of the V matrix and 00:21:08.120 |
We can say the same for the third. Let's analyze the sixth one here 00:21:12.240 |
So the first dimension of the sixth row of the output matrix 00:21:17.040 |
So this value here comes from the dot product of this row 00:21:26.920 |
but most of the values at the beginning are 0 which means that it will only depend on the 00:21:32.860 |
4, 5 and 6th token of the V matrix and so will be all the dimensions here because in each column 00:21:41.880 |
Whatever the column we use from the V matrix the first values will always be multiplied by 0, 0, 0 00:21:48.600 |
So it will only use the values in these three rows here 00:21:51.880 |
So we can safely say that the 6th token of the output matrix 00:21:57.720 |
of this self-attention mechanism will be a vector that will only depend on the last three tokens of the V matrix and 00:22:05.680 |
Because we are talking about self-attention the V matrix is equal to query matrix 00:22:10.120 |
So we can say that the output of the self-attention is a matrix that has the same shape as the input sequence 00:22:17.480 |
But where each token now captures some more information about other tokens 00:22:22.480 |
Which tokens depending on the mask we have applied 00:22:26.080 |
So our mask says that the first token can only watch itself 00:22:30.520 |
So the first output token will be an embedding that will only depend on itself 00:22:35.200 |
The second token will only depend on the first two tokens 00:22:39.560 |
The third output token will only depend on the first three tokens 00:22:44.840 |
The fourth will depend on the token number two because the first token is not used 00:22:49.960 |
The token number two, the token number three and the token number four, etc, etc 00:22:55.160 |
The last token will depend only on the last three tokens because the first three tokens are masked out 00:23:01.320 |
And this is the importance of the mask that we apply in the self-attention mechanism 00:23:05.960 |
This concept that I show you now is very important to understand the rest of the video 00:23:10.240 |
So please if you didn't understand it, you can take a little pause 00:23:13.000 |
You can try to do it by yourself because it's really important that you understand 00:23:17.200 |
How the self-attention mechanism works with the mask 00:23:25.040 |
So as we saw before the output of the self-attention mechanism is another 00:23:29.360 |
Matrix with the same shape as the query matrix in which each token is represented by an embedding of size 00:23:36.240 |
4096 but each embedding now captures information also about other tokens 00:23:42.320 |
According to the mask and if we check the this mask here, so the output here 00:23:48.360 |
We can safely say that the input of our sliding window attention was the initial 00:23:58.560 |
But after applying the self-attention the first token is now related to itself 00:24:03.640 |
The second token is related to itself and the token before it 00:24:07.800 |
The third is related to the token before it and the one also before it 00:24:11.920 |
The last one only depends on the previous two tokens, etc. According to the mask, right? 00:24:17.040 |
Now what happens if we feed this one because as you know in the transformer world and also in Mistral and also in Lama 00:24:24.440 |
we have many layers of encoders one after another which are also called the transformer block in the code and 00:24:30.480 |
The output of each layer is fed to the next one. So this is the first layer of the transformer 00:24:37.680 |
So we take the input sequence and we feed it to the first layer which will produce a list of tokens where each token now captures 00:24:44.000 |
Information about other tokens, but this will become the input of the next layer where we it will produce an output 00:24:51.680 |
This output I will prove you that will capture information about even more tokens 00:24:56.640 |
Even if the sliding window attention says that they should only be able to watch the previous two tokens 00:25:02.440 |
Because the sliding window size we chose three as a sliding window size 00:25:09.440 |
Imagine this is the output of the first layer 00:25:13.400 |
So it's a list of tokens that capture information about other tokens and it's the the matrix that we built in the previous slide 00:25:21.120 |
Let's use it as an input for another layer of the encoder. So we multiply the query 00:25:27.160 |
We multiply the query and the transposed of the keys which will produce a matrix like this one in which each token is not only 00:25:34.920 |
One token, but it's capturing already information about multiple tokens, right according to the mask 00:25:39.560 |
So I'm taking this one and this one will become query key and values 00:25:44.480 |
So if we multiply the query by the key, it will return a matrix like this 00:25:48.480 |
So the first token only depends on itself. The second one depends on himself and the previous one 00:25:54.080 |
So the embedding of this token captures information about two tokens and the embedding of this token capture information about three tokens, etc 00:26:06.560 |
We have that our V matrix is again a list of tokens and the output 00:26:14.160 |
Will also be a list of tokens, but each one will capture information about other tokens 00:26:27.520 |
So the first dimension of the first row of the output matrix will be the dot product of the first row of this 00:26:33.960 |
Matrix here. So this row here with the first column of this matrix here. So this column here 00:26:41.320 |
But because of this causal mask with the sliding window attention mask that we can see here 00:26:47.480 |
It will the output will only depend on the first row of the V matrix 00:26:52.400 |
But because the V matrix is a matrix that is made of these tokens here 00:26:57.240 |
it will only depend on the word V so as we can see here the output of the second layer only depends on the 00:27:05.800 |
So will be this second one. So let's check the fourth token. For example here this one 00:27:15.800 |
So this value here will be the product of the fourth row of this matrix dot product of 00:27:23.160 |
This row with the first column of the V matrix. So this column here 00:27:28.480 |
But the first token will not be used because it's we are multiplying it with zero whatever value we have here 00:27:35.960 |
we are using the second token the third token and the fourth token and 00:27:41.840 |
Each token actually they are aggregating this 00:27:49.240 |
Aggregating the value of two tokens, which is D and cat. So this embedding here is already about talking about D and 00:27:59.960 |
this token here is talking about is aggregating the information about the D cat and is so 00:28:12.280 |
Aggregating the information of the cat is on so cat 00:28:18.480 |
On because as we saw before the fourth token here cat is on which is the result of the previous self-attention that we done 00:28:25.940 |
So this output value here will depend on three tokens that already include information about other tokens 00:28:37.800 |
Information about the union of all these tokens 00:28:41.160 |
So it will for sure depend on the word D because it's included in the second token. We are multiplying it with 00:28:46.680 |
It for sure will include information about the word the cat because it's included in this token as well for sure 00:28:54.000 |
It will include information about is because it's included in the second value 00:28:59.640 |
it will include about the token on because it's present in the the fourth token of the V matrix for with which we are 00:29:06.620 |
Multiplying it because this value is not zero 00:29:09.600 |
So as you can see after applying another layer of the encoder 00:29:14.960 |
The fourth token now includes another token in its information before it was only 00:29:23.000 |
but now it also depends on a new token which is the word V and 00:29:26.840 |
We can prove the same for the fifth token and the sixth token 00:29:30.600 |
so at every application of the encoder layer one after another we keep increasing the number of tokens that get 00:29:37.820 |
Accumulated in these dot products and I made a notebook in Python to visualize this 00:29:50.880 |
Notebook called the sliding window attention in which I help you visualize this process and I also share the code on how I do this 00:29:58.280 |
Self-attention basically I represent each token as a set so each 00:30:03.560 |
Each token instead of being represented as an embedding as a set of all the words upon which that token depends 00:30:10.920 |
depends then I apply the cell sliding window attention, which basically means that I take the two tokens that from the sequence and I 00:30:20.520 |
Accumulate I make the union of the two sets they contain 00:30:23.520 |
Because I am multiplying two vectors that already include information about multiple tokens 00:30:29.440 |
So what is the output is the union of the two sets when I multiply by V 00:30:36.640 |
So after we apply the first layer, we will see that the input of the first layer is just our normal sequence 00:30:42.160 |
So the cat is on a chair the output of the first layer will be another sequence 00:30:46.360 |
There in which each position includes information about multiple tokens 00:30:50.120 |
Depending on the mask that we have applied and I also show the mask that we apply 00:30:54.600 |
After we apply the second layer, we can see that the information increases 00:30:59.560 |
So this last token now is not watching only the previous three tokens, but the previous four tokens 00:31:05.340 |
Sorry, not only the previous two tokens, but the previous four tokens 00:31:09.820 |
so every step we do with the sliding window size of three we include two tokens at every layer and 00:31:19.180 |
Show it for five layers, but it's not necessary because after a while the sequence will reach the maximum length 00:31:25.580 |
If you want you can increase the sequence's length here by including more tokens 00:31:35.580 |
Receptive field applied to the self window attention. So basically with the sliding window attention, we are not 00:31:41.420 |
Directly connecting two tokens with each other 00:31:46.180 |
But if we apply multiple layers after one after another this information will get 00:31:51.860 |
Will get captured by the embedding in successive applications of the layers such that the last layer 00:31:58.340 |
basically will be able to watch all the sentence even if it's very long and 00:32:05.500 |
Mistral paper in this picture you can see here. So basically this is our input sequence. So let me write 00:32:17.500 |
The original sentence so the cat is on a chair 00:32:21.340 |
The fourth token of the first layer. So this is the output of the first layer. So layer one 00:32:30.380 |
We have seen that the fourth token here depend with a sliding window size of four this will depend on the itself 00:32:37.460 |
On the previous token on the one before and also this token here and it will produce 00:32:43.180 |
This this embedding here in the fourth position which includes information about the previous token as well 00:32:49.600 |
But then this will become the input of the next layer, which is the layer number two 00:32:57.300 |
this will produce an embedding at this position for example that will depend for sure on the previous four tokens because the sliding window size is 00:33:05.100 |
Four but because for example this token here is already the aggregation of the previous four tokens 00:33:10.860 |
It will actually multiply the visibility of its sliding window 00:33:14.620 |
So this token here is not related directly to the first one 00:33:19.380 |
We can see here, but indirectly through the this intermediate token. We can see here. I hope this 00:33:27.140 |
I hope this concept is clear. If it's not clear, I try I 00:33:33.340 |
Recommend using my notebook so that you can experiment by playing with the multiple 00:33:38.220 |
Sequences and you can see how the information flow will go through all the layers 00:33:47.420 |
Which is the KV cache because I want to introduce the KV cache which I already explained in my previous video on llama 00:33:52.620 |
But I want to introduce it again and review it because I want to introduce later the rolling buffer cache 00:33:57.140 |
So let's start by talking about first of all how we train language models because this is needed to understand the KV cache 00:34:04.540 |
So the language models are trained using what is known as the next token prediction task 00:34:09.660 |
so given a prompt the goal of the language model is to predict what is the next token that makes sense with the prompt that 00:34:15.900 |
We have given and imagine we want to train a language model on Dante Alighieri's poem 00:34:21.820 |
Divine comedy and in particular we will training it on a line that you can see here in 00:34:27.020 |
This one in English. So love that can quickly seize the gentle heart. How does it work? 00:34:36.580 |
Which is the line that we want to teach it with a token 00:34:40.300 |
Prepended called the start of sentence and then we build the target which is the same line 00:34:45.240 |
But with a token at the end called end of sentence. We run the input through this transformer model 00:34:53.260 |
so as we saw before the in the output of the self attention is another sequence with the same length as the input sequence, but 00:35:01.460 |
Embedding is modified in such a way that each token capture information about other tokens. And this is what we do 00:35:11.120 |
so if we feed the model with with the nine tokens the model will produce nine tokens as output and how 00:35:22.660 |
Input and output such that if we give to the model as input the token start of sentence only it will produce 00:35:30.660 |
The first token as output which is the word love if we give to the model as input the first two tokens 00:35:37.300 |
So start of sentence love the model will produce the two tokens as output 00:35:41.960 |
So love that it will feed the model as input three tokens 00:35:46.020 |
So start of sentence love that the model will produce love that can so when we train the model we train it like this 00:35:52.760 |
We prepare the input like this the target like this. We calculated the output 00:35:56.500 |
We calculated the loss using the cross entropy loss and then we run back propagation and this is done in all one step 00:36:01.880 |
When we do the inference we do it in multiple step 00:36:04.540 |
So when we do the inference at time step one, we feed the model only the first token 00:36:09.060 |
So the start of sentence and the model will produce the output love 00:36:12.460 |
Then we take the output the last token of the output and we prepend it to the input which becomes the input as time step 00:36:19.620 |
Two so it becomes start of sentence love. So the model will produce love that 00:36:24.300 |
We take the last token of the output and we prepare append it to the input for the time step three 00:36:29.880 |
And this will become the new input which will produce love that can then we take the last 00:36:35.060 |
Token of the output and we append it to the input for the time step four 00:36:39.820 |
So it will become the new output will become love that can quickly then we take this word quickly 00:36:45.060 |
We append it to the input for the next time step and which will produce the next token as output, etc 00:36:49.980 |
Etc until the last token until we see the end of sentence token as output then we know that the model has stopped 00:36:56.340 |
Has stopped producing new tokens and we can stop the inference 00:37:01.740 |
Now at every step the inference we are only interested in the last token output by the model because we already have the 00:37:09.740 |
previous one, but of course, we need to feed all the previous tokens to 00:37:17.260 |
Belonging to the prompt because the model needs to access the prompt to understand which token to produce next 00:37:23.500 |
So for example, we cannot produce the word gentle only by giving the word the we need to give all this sentence to produce 00:37:33.020 |
But at the same time we are only interested in the last word gentle 00:37:38.620 |
And this is the reason we introduce the KVCache because the KVCache allow us to reduce the computations that we are doing 00:37:45.700 |
by only producing one output at a time the one that we need but 00:37:51.020 |
Without doing all the intermediate computations for all the other tokens that we never use 00:37:58.540 |
We don't want to produce the output for the word love that can quickly seize the gentle because we already have them in the prompt 00:38:07.060 |
We just want to produce the output for the token heart. So we want to reduce the computation that we are doing 00:38:17.020 |
You know that we multiply the query which can be thought of as a list of tokens where each token is an embedding of size 00:38:22.820 |
4096 and the transposed of the query becomes the is multiplied the transpose of the keys 00:38:29.580 |
Are multiplied by the queries to produce this matrix here 00:38:33.140 |
And then we multiply it by the V matrix to produce the output of the self-attention. You can see here 00:38:40.900 |
So when we inference a language model, we start with our first token, which is the start of sentence. This is one token 00:38:50.060 |
4096 we multiplied by the transposed of the keys which is again one token because it's a self-attention. So 00:38:57.000 |
The query the key and the value are the same matrix. So this is just the transposed of the query 00:39:03.220 |
Basically, and so it's a column vector and it will produce a one by one matrix 00:39:08.260 |
We multiply it by V and it will produce an output token. We take this output token 00:39:12.460 |
We send it to the linear layer and then to the softmax to understand which token this corresponds to in our vocabulary 00:39:19.900 |
We take this token from our vocabulary and we append it to the query for the next 00:39:25.660 |
Inference step to the keys and the values and then we compute again the product of the query multiplied by the keys 00:39:32.920 |
We multiply then the result by V and it will produce an output made up of two tokens because we have two tokens as input 00:39:39.020 |
It will produce two tokens as output, but we are all interested in the last token 00:39:43.100 |
So we take this output token too, we send it to the linear layer then to the softmax 00:39:47.220 |
This will result in what token is corresponding to in our vocabulary. We take this token from our vocabulary 00:39:53.340 |
We append it for the next step to the query key and values 00:39:57.060 |
We do again this process and then we take the last token as output 00:40:03.380 |
You can see here. We may send it to the linear layer then the softmax 00:40:06.600 |
We understand which token it corresponds to, we append it to our query key and values and then we compute again the self attention 00:40:14.140 |
but we already start to notice something because 00:40:17.340 |
First of all, in this matrix here, which is the result of the query multiplied by the transpose of the keys 00:40:24.540 |
We have a lot of dot products at each step that were already computed at the previous step 00:40:29.680 |
Let me show you. At the time step 4, we are computing all these dot products 00:40:34.340 |
As you can see at the time step 3, we already computed these dot products and at the time step 4 00:40:40.220 |
We are computing them again as you can see these dot products here 00:40:49.380 |
Deal with the language model, we have a causal mask 00:40:52.860 |
So we do not even care about computing the dot products that we see here in the dark violet 00:40:58.520 |
because they will be anyway masked out by the 00:41:03.940 |
Because we don't want the first token to watch the token number 2, the token number 3, the token number 4 00:41:09.460 |
We only want the token number 4 to watch the previous one 00:41:12.580 |
So the token number 4 should be related to itself, the previous one, the token number 2 and the token number 1 00:41:19.100 |
And also we don't want to produce all these output tokens because we are only interested in the last one 00:41:24.460 |
We are only interested in knowing what is the last 00:41:26.740 |
token produced by the attention so that we can send it to the linear layer and then to the softmax to understand what is the 00:41:33.860 |
Word corresponding in our vocabulary so that we can use it for the prompt to inference the next token again 00:41:39.900 |
So now let's introduce the KVCache and how the KVCache solve this problem 00:41:44.400 |
What we do with the KVCache, again, we start from our first step of the inferences 00:41:50.020 |
So we start from our start of sentence token, which is multiplied 00:41:53.380 |
So the query is only the start of sentence token. We multiply it by the transpose of the keys 00:41:59.940 |
Then we multiply it by divi and it will produce our first token as output 00:42:04.260 |
We send it to the linear layer then to the softmax then we know which token it corresponds to 00:42:09.040 |
Now in the KVCache instead of appending this new token that we have produced as output to the query key and value 00:42:16.100 |
we only append it to the key and the value and 00:42:19.140 |
Replace entirely the previous query with this new token. So 00:42:23.500 |
Before without the KVCache we were appending the every output token 00:42:28.900 |
So the last token of the output to the query key and values, but in with the KVCache 00:42:34.420 |
We don't append it to query key and value, but only to the key and values 00:42:38.960 |
And we only use the last output token as query for the next step 00:42:43.980 |
So if this is the output of the first step, so the output corresponding to the token start of sentence 00:42:51.060 |
We take it we use it as query for the next step, but we append it to the key and the values 00:42:57.800 |
So this is why it's called the KVCache because at each step 00:43:01.600 |
We are keeping a cache of the previous K and V 00:43:05.200 |
But not for the query because we are entirely replacing all the queries with the last token 00:43:13.640 |
So this matrix multiplied by this matrix will produce a matrix that is 1 by 2 00:43:17.560 |
We multiply it by V and we will see that this produces only one token as output 00:43:24.400 |
we send it to the linear layer to the softmax then we know which token it corresponds to then we use it as 00:43:29.500 |
Query for the next iteration, but we append it to the only the K and the V matrix 00:43:37.720 |
Which is then multiplied by the V which will produce the this output token 00:43:44.140 |
This is the one we are interested in basically, then we use it as query for the next iteration 00:43:49.680 |
But we append it to the K and the V etc. So as you can see at the fourth step of the inference 00:43:56.160 |
We are producing only the last row that we were interested in 00:44:00.120 |
When we didn't have the KVCache. So let me show you this is the 00:44:04.200 |
Fourth time step with the KVCache. Let's look at the fourth time step without the KVCache 00:44:09.920 |
As you can see we are only producing this row here 00:44:15.000 |
This is the only one we are interested in to produce this last token 00:44:18.920 |
So with the KVCache basically we reduce the number of computations that we are doing at every step 00:44:23.640 |
Because the sum of the dot products we have already done in the previous steps and we only produce one token as output 00:44:29.820 |
Which is exactly the one that we need for predicting the next token 00:44:33.920 |
Ok, now let's talk about the rolling buffer cache 00:44:38.200 |
So since we are using the sliding window attention with a size of W 00:44:42.520 |
And in the examples I showed you before I was using a sliding window size with the size of 3 00:44:48.040 |
We don't need to keep all the possible K and V in the cache 00:44:53.120 |
But we can limit the K and the V only to W tokens because anyway, we will not be computing 00:45:00.600 |
Attention outside of this W window. So we do not need imagine our window is 10 tokens 00:45:06.000 |
We do not keep the previous 1000 tokens because anyway, our attention will only be calculated on the previous 10 tokens 00:45:12.200 |
So this is the idea behind the rolling buffer cache. Let's see how it works 00:45:15.600 |
Imagine we arrive at the token 8 of inference using the KVCache 00:45:21.360 |
If we have a KVCache and we are using the sliding window size of 4 for example 00:45:27.760 |
We will see that as query we will use the output of the previous step and as key and values 00:45:34.760 |
We will use the entire cache which is made up of 8 tokens 00:45:37.920 |
But because of the mask that we are using with the sliding window attention 00:45:43.760 |
We are not interested in the computation of these dot products because anyway 00:45:48.240 |
They will be masked out because the distance between this token and this token is outside of the sliding window attention 00:45:54.320 |
so we are not interested in this calculating these dot products because 00:46:02.400 |
We are not interested in keeping this one because anyway because these values will be masked 00:46:08.000 |
By our mask for the sliding window attention, which basically will result in zeros here 00:46:13.920 |
We do not care about producing these first four rows in the value matrix because anyway 00:46:20.960 |
They will be multiplied by zeros. So they will not contribute to the output token. So here you have to imagine that 00:46:29.360 |
Here you have to imagine that the mask will take care of making this one zero, this one zero, this one zero, this one zero 00:46:37.000 |
And this one will be a dot product. This one will be a dot product. This one will be a dot product 00:46:41.560 |
This one will be a dot product. So whatever value there is here 00:46:44.560 |
Whatever value there is here, here or here will not contribute to the output of this token because anyway 00:46:52.320 |
So we do not need to keep this value also in the V matrix or in the K matrix because anyway 00:46:58.720 |
They will not be used by the sliding window attention. So that's why we can limit the size of our K and V 00:47:05.380 |
Cache only to W tokens where W is the size of the sliding window attention that we are using 00:47:12.520 |
Now let's see how this rolling buffer cache was implemented. So basically rolling buffer cache is a way of 00:47:17.880 |
Limiting the size of a cache to a limited size in this case W 00:47:26.040 |
Imagine we have a sentence "the cat is on a chair" and we want to use it for our KV cache 00:47:32.720 |
At the first inference using the KV cache, we will add the first 00:47:37.840 |
The first token to the KV cache, then we will add the second one the third one and the fourth one 00:47:43.600 |
But now the KV cache is full. How do we proceed further? 00:47:46.720 |
Basically, we keep track of where we added the last item 00:47:50.480 |
Using a pointer that we keep track of and when we will arrive at the next token, which is the token A 00:47:56.800 |
We basically replace the oldest value here starting from the beginning and we update the value of the right pointer 00:48:05.740 |
Go back because now the order of the tokens is not matching the sentence because as you can see now the 00:48:11.200 |
Cache contains "a cat is on" but this is not the order in the original sentence in the original sentence 00:48:17.320 |
the order should be "cat is on a" so what we do is we do the unrolling or unrotation and 00:48:23.960 |
How do we do it? Basically because we kept track of this right pointer 00:48:28.960 |
We just need to take all the values after the right pointer and then we put the values from 0 to the right pointer itself 00:48:36.120 |
So all the values after the right pointer and then all the values before the right pointer and this is how we unrotate 00:48:42.180 |
And this operation is done in the code in the function called unrotate. You can see here 00:48:47.320 |
Which basically will have this condition. So if the cache is not full we can just ignore the unfilled item 00:48:54.480 |
So if the cache is in this situation, then we take all the values from the 0 up to the right pointer 00:49:00.340 |
if the cache is a full then we take the value from 0 up to the 00:49:08.560 |
if the value of the right pointer is already overwriting some value then we need to unrotate and this is done in the 00:49:18.160 |
So we take all the values after the pointer and then the value up to the pointer and this is how we 00:49:26.000 |
Okay, let's talk about another concept that is very important, which is chunking and pre-feeding 00:49:31.860 |
Basically when we generate a text using a language model 00:49:35.120 |
We use a prompt and then we use this prompt to generate future tokens when dealing with a KV cache 00:49:42.640 |
So we need to add the tokens of our prompt to the KV cache that so that we can then exploit this KV cache 00:49:51.280 |
Now the prompt is known in advance, right? Because it's the input of our user 00:49:57.040 |
It's what you ask to chatgpd for example, right? Tell me a poem. Tell me write me a poem or tell me a joke 00:50:02.320 |
This is our prompt. So it's known in advance. So we don't know we don't need to generate it 00:50:06.960 |
Okay, so what we can do is we can pre-fill the KV cache using the tokens of the prompt 00:50:12.320 |
But there are many ways to do it like we were doing before when I was teaching you about the KV cache 00:50:17.600 |
We work with one token at a time. So one way to 00:50:20.160 |
To add the tokens to the KV cache is to add one token at a time 00:50:25.200 |
But this can be very time consuming because imagine you have a very large prompt which happens 00:50:30.080 |
With retrieval augmented generation, which we have very big prompts like 5,000 6,000 tokens or even bigger 00:50:38.400 |
It will mean that we have to take 5,000 or 6,000 forward steps in our network 00:50:43.600 |
Which is can be very time consuming and also doesn't exploit our gpu very much 00:50:48.400 |
The other way is to take all these tokens and feed them all at once to the model 00:50:53.760 |
But that may be limited by the size of our gpu because imagine we have 10,000 tokens as our prompt 00:51:00.080 |
Then maybe our gpu cannot even hold 10,000 tokens. Maybe it can only hold 4,000 tokens or 2,000 tokens 00:51:06.240 |
Depending also on the w size of the attention sliding window attention that we have chosen 00:51:11.120 |
The solution in this case is to use chunking. Basically, we divide our prompt into chunks of a fixed size 00:51:18.160 |
And this size is equal to w which is the sliding window attention size 00:51:24.240 |
So imagine we have a very big prompt and we choose a sliding window size of 4 for the calculation of the attention 00:51:29.920 |
And imagine that the prompt is this one. So can you tell me? 00:51:33.280 |
Can you tell me who is the richest man in history? The way we work is this 00:51:38.880 |
Basically, we take our first chunk of the prompt 00:51:43.200 |
So because we chose a sliding window size of 4, we also will choose the chunk size to be 4 00:51:51.280 |
So can you tell me and we compute the self attention in the attention 00:51:56.720 |
Self attention in the first layer of the model. How do we build the attention mask? 00:52:02.320 |
Basically as queries we take all the incoming tokens in this chunk 00:52:08.400 |
So as this is you can think of this column as the queries and this column as the keys 00:52:14.240 |
And this is the result of the query multiplied by the transposed of the keys plus the mask 00:52:20.320 |
So our query we take the first incoming chunk and as keys I will show you later 00:52:25.200 |
We take the current content of the kvcache, but initially it is empty plus the incoming tokens of the current chunk 00:52:33.200 |
And this is made for a very specific reason that I will show you in the next step 00:52:40.080 |
So in the next step, basically we take the current chunk, which is the tokens who is the richest 00:52:48.800 |
And we aggregate it with the content of the kvcache using the tokens of the previous chunk. So let me go back 00:52:55.520 |
At the first step of this prefilling we take the first chunk of the prompt 00:53:00.880 |
So can you tell me we calculated the attention mask using as query 00:53:04.480 |
The first four tokens and as keys and values the content of the kvcache 00:53:11.120 |
Which is empty plus the tokens of the first chunk and then we update the content of the kvcache 00:53:17.680 |
Using this the tokens of this chunk after we have computed the attention 00:53:21.840 |
So at the next step the kvcache now contains the previous the tokens of the previous chunk 00:53:27.760 |
So can you tell me but now the current chunk has become who is the richest? 00:53:32.320 |
so as query again, we take the tokens of the current chunk, but as keys and values we take the 00:53:40.000 |
The content of the kvcache plus the tokens of the current chunk. Why? 00:53:48.080 |
As you can see when we were doing token generation when I was teaching you the kvcache 00:53:54.960 |
We add it to append it to the k and the v and we use it as the query for the next iteration 00:54:03.120 |
Calculated the attention and then we update the kvcache 00:54:06.660 |
And when we use the when we build the query the query we use only the tokens of the current chunk 00:54:13.280 |
And as key and values we take the content of the kvcache 00:54:16.580 |
So the content of the previous chunk plus the tokens of the current chunk. Why? 00:54:26.240 |
We didn't use the content of the previous chunk. What would happen is this we would have a 00:54:32.160 |
Attention mask that is only comprised of the tokens of the current chunk. So it would be only limited to this matrix here 00:54:44.800 |
But if we only use this matrix here the word who 00:54:50.240 |
Would not be able to to would not be related to the word me 00:54:54.480 |
Tell and you even if with the sliding window size, they should be able to watch each other 00:55:00.640 |
So because we want to relate the current chunk to the previous chunk 00:55:05.680 |
We basically take as a key and value the content of the kvcache plus the tokens of the current chunk 00:55:12.960 |
So that we can build this attention between chunks 00:55:16.320 |
Otherwise this attention would not be built and as query we always use the tokens of the current chunk 00:55:23.120 |
Let's review how this mechanism is built in the code 00:55:28.800 |
So basically the pre-filling is done by chunks 00:55:32.560 |
There is the first chunk and then there are subsequent chunks 00:55:35.600 |
And finally there is token generation after we have a pre-filled our kvcache with the prompt 00:55:40.320 |
During the first pre-fill, which means that we are doing it for the first chunk of our 00:55:48.080 |
As attention mask, we only consider the size of the incoming tokens in the current chunk 00:55:54.960 |
But for any subsequent chunks, so after the first chunk as to build the attention mask 00:56:01.440 |
For the query we just use the size of the incoming chunk 00:56:05.920 |
But for the k and v we use the size of the kvcache, which is this one 00:56:10.560 |
So cached as you can see here plus the size of the current chunk, which is this s variable you can see here 00:56:17.120 |
And for token generation, we do the same system that we did before when I was teaching with the kvcache 00:56:23.360 |
So one token at a time, we take it, we append it to the key 00:56:27.600 |
We append it to the value and we replace the query with the output token from the previous step 00:56:33.760 |
So the last chunk in our case will be the tokens man in history 00:56:38.900 |
And what we do is basically we take the current chunk, so man in history, which becomes the query 00:56:45.520 |
While the key becomes basically the previous chunk plus the tokens of the current chunk 00:56:51.280 |
So the who is the richest plus the tokens of the current chunk, so man in history 00:56:56.000 |
And the reason we do it because otherwise the word in the current chunk will not be able to 00:57:00.320 |
Be related to the word of the previous chunk, which is necessary. Okay guys, let's talk about sparse 00:57:08.960 |
So mixture of experts is an assembled technique in which we have multiple expert model 00:57:14.400 |
Which each of this model is trained on a subset of the data such that each model will specialize on a subset of this data 00:57:21.760 |
And then when we produce the output of this mixture of experts, we take the output for each of these experts 00:57:28.880 |
We combine it usually by using a weighted sum or by averaging to produce one single output 00:57:34.960 |
In the case of Mistral, we do not talk about only mixture of experts 00:57:38.960 |
But we talk about a sparse mixture of experts because we have many expert models, but we only use some of them 00:57:45.280 |
Let me show you. In the case of Mistral, we have eight experts which are present as the feed-forward layer 00:57:52.080 |
So after we calculate the self-attention, as you remember, we have this feed-forward network. In the case of Mistral 00:57:58.000 |
8x7b, we have eight feed-forward layers. We have to think of them in parallel 00:58:05.040 |
And the gate is a function that basically will decide for each token which expert, so which feed-forward 00:58:11.920 |
network, should be working with that token and it will choose two 00:58:16.400 |
feed-forward networks for each token. It will run the token through these feed-forward networks, 00:58:22.100 |
will take their output and will weight it according to the logits this gate produces 00:58:28.020 |
to produce a weighted sum, which will become the output of the self-attention for that particular token 00:58:37.280 |
So this is the architecture of Mistral. As you can see, we have the input of this 00:58:42.160 |
encoder layer. We first run the self-attention using the sliding window attention and the KV cache, etc, etc 00:58:48.800 |
Then we run the normalization and finally we have this gate function here, which is basically just a linear layer that will produce 00:59:01.040 |
will be values. Let's call them score values for our expert. The two best performing 00:59:07.220 |
experts, so the two highest score, will indicate which experts that token should work with 00:59:13.040 |
Then we run each token in their own two best performing experts 00:59:18.560 |
Then we take the output of these two experts. We combine it with the weight 00:59:26.400 |
logits produced by the gate are, suppose, eight values here 00:59:30.080 |
Yeah, I draw only four because I don't have space, but you imagine you have eight values 00:59:34.080 |
Then we take the top two, so 1.5 and 3.4 in this case 00:59:39.600 |
These are the two experts through which we will run the token 00:59:42.240 |
We take the softmax of the two best performing values. This will be the weight that we'll be using for the weighted sum 00:59:59.360 |
Because by using a sparse mixture of experts, we can have many expert models 01:00:04.640 |
But during inferencing only two out of eight will be activated 01:00:08.900 |
So as you remember the feed-for-wall network is basically two 01:00:13.040 |
linear layers. So the linear layer can be thought of as a matrix multiplication of a weight matrix with the input 01:00:20.560 |
So if we didn't use a sparse mixture of experts, we would run the token through all the eight experts 01:00:25.840 |
Which means that we need to compute eight matrix multiplications 01:00:29.060 |
But by using sparse mixture of experts for each token, we are only doing two matrix multiplications 01:00:37.920 |
But at the same time allows us to increase the power of the model and the parameter of the model 01:00:42.240 |
Because we are only using some parameters for 01:00:46.960 |
A subset of the token. So some tokens will use the expert number one. Some tokens will be using the token 01:00:51.520 |
The expert number two and three. Some tokens will be using the expert number eight and three 01:00:56.880 |
Or some other for example, the six and the four etc, etc 01:01:00.560 |
So we are not using all the experts for each token, but only two of them 01:01:08.720 |
Specialized on a subset of tokens. For example, imagine the model has been trained on multiple languages 01:01:15.280 |
What could happen is that basically some experts, so some feed-forward networks are specialized on Japanese tokens 01:01:22.320 |
Some feed-forward networks are specialized on English tokens or some 01:01:26.320 |
It could also happen that some are specialized in verbs, some are specialized in 01:01:30.720 |
Nouns, some are specialized in adjectives, etc, etc 01:01:34.960 |
So this is why we use a mixture of experts because we want to increase the size of the parameters of our model 01:01:42.000 |
So the model becomes more powerful at capturing information 01:01:45.600 |
But at the same time we don't sacrifice on performance because we only use a subset of the experts for each token 01:01:53.360 |
And this is the implementation as done in the code 01:01:57.120 |
So as you can see in the case of Mistral 7b, we have as feed-forward just a feed-forward neural network 01:02:05.520 |
In the case of Mistral 8x7b, it's not only one feed-forward network 01:02:10.160 |
But it's eight feed-forward networks. So this as you can see, it's the 01:02:13.520 |
It's an array of eight feed-forward networks with a gating function, which is just a linear layer 01:02:19.840 |
Which converts from the embedding size to eight, which is the number of experts 01:02:24.820 |
So it produces for each embedding. So for each token, it produces logits which indicates for which 01:02:33.520 |
Expert this token should run through and it will run through them to the top two experts 01:02:43.520 |
Okay, why we apply the softmax after selecting the top k expert so as I show you 01:02:51.600 |
Here we have the gating function that produces some logits. We select the top two 01:02:57.920 |
logits to understand which expert we should run through our 01:03:02.800 |
Token and then we take the score of the best two performing 01:03:09.820 |
Take the softmax of them to create the weights 01:03:15.600 |
But why we take the softmax of the two best performing instead of taking the softmax of everyone? Well, the first problem is that 01:03:23.280 |
If we take the softmax of all of the logits, then the two best performing may not sum up to one which is 01:03:32.640 |
Which is a condition that we need in case we want to train multiple models and compare them because i'm pretty sure that the guys 01:03:38.640 |
At Mistral did not only train one model. Maybe they trained multiple models with multiple hyper parameters 01:03:43.760 |
Maybe they tried with four mixture of four experts, but also with three experts or two experts 01:03:50.640 |
So if you want to compare models, you want the weighted sum to always perform 01:03:56.560 |
The sum of the weights to be only one. Otherwise the output range may change from model to model and usually it's not a good idea 01:04:03.200 |
To have the range of the output to change from one model to the next 01:04:07.520 |
So to keep the range of the output stable, they apply the softmax after they have selected how many 01:04:14.000 |
Experts they want to work with and choosing the logits of the best two performing 01:04:23.680 |
The next thing we are talking about is model sharding which is also implemented in the code of the Mistral model 01:04:32.720 |
When we have a model that is too big to fit in a single gpu 01:04:36.560 |
We can divide the model into groups of layers and place each group of layers in a single gpu 01:04:42.800 |
For example in the case of Mistral, we have 32 layers of encoders 01:04:47.280 |
You can see here one after another I didn't do all 32 of them 01:04:51.520 |
You just think that this is layer from 1 to 8. This is from 9 to 16 from 17 to 24 01:04:56.800 |
From 25 to 32 and we put each group of layers in a different gpu. So we have four gpus 01:05:05.280 |
The way we inference a model like this is as follows 01:05:08.320 |
So we have our input we convert it into embeddings and we run it through the first eight layers in the first gpu 01:05:14.320 |
The first gpu will produce an output which will be the output of the eighth layer 01:05:19.600 |
We transfer this output to the second gpu and we use it as input for the ninth layer 01:05:24.880 |
Then we run all the this input through all the layers one after another until it 01:05:29.760 |
It arrives to the layer number 16, which will produce an output. We take this output. We move it to the next gpu 01:05:36.000 |
So it will become the input of the layer number 17 01:05:39.540 |
And then we run iteratively to all the layers until the layer number 24, which will produce an output 01:05:45.280 |
We move it to the next gpu. We run it through iteratively until the layer number 32 01:05:49.840 |
Then we take the last linear layer and then the softmax to produce the output of the model 01:05:57.040 |
You can notice that this method is not very efficient because at any time only one gpu is working 01:06:02.880 |
A better approach which is not implemented in the code of Mistral, but they reference it in the paper 01:06:08.720 |
So I will talking about it is the pipeline parallelism. Let's see how it works 01:06:15.040 |
This pipeline parallelism. I will talking about the algorithm that was introduced in this paper. So gpipe 01:06:21.220 |
Basically, it works as follows first. Let me introduce you the problem 01:06:25.600 |
This actually it's used usually when we are training a model not when we are inferencing 01:06:32.340 |
Imagine we want to train a model on a sharded model. So a model that is split into multiple 01:06:39.040 |
Group of layers each group of layer is present on a different gpu 01:06:44.160 |
Imagine we have four gpus each one with its own group of layers 01:06:48.000 |
Imagine we want to train this model. So we run our input to the first gpu 01:06:52.880 |
So we run the forward step to the first gpu. We take this output and we feed it to the next gpu 01:07:00.080 |
We take the output and we run it through the next gpu the gpu number three 01:07:03.600 |
We take the output we run it to the next gpu the gpu number four 01:07:06.800 |
Now we have the output of the model. We compute the loss and then we can run back propagation the run propagation 01:07:13.200 |
That's basically just the opposite. We go from the last gpu to the first gpu. So we run back propagation on the fourth gpu 01:07:19.760 |
Then we have calculated the gradients at the fourth gpu and we use them to calculate the previous gradients at the third gpu 01:07:27.200 |
And then we take these gradients and we use them to calculate the previous gradients and then we 01:07:31.760 |
Take these gradients and we use to compute the previous gradients 01:07:35.520 |
So the forward step goes from the input to the loss the backward step goes from the loss to the input 01:07:43.360 |
And all the parameters which are also known as the leave nodes in the computational graph 01:07:47.200 |
However, also as in this case, you can see that at each step 01:07:52.080 |
We are only utilizing one single gpu and all the other gpus are quite 01:08:04.660 |
So imagine that the previous step of training was done using a very big batch 01:08:12.640 |
What we do with pipeline parallelism is we take this batch and we split it into micro batch 01:08:18.480 |
So instead of eight items, we create micro batch. So four micro batch of two items each 01:08:27.120 |
We run the first micro batch in the first gpu 01:08:30.880 |
This will produce the output for the first micro batch and we can feed it to the next gpu 01:08:36.160 |
But now at the time step one we realize that the gpu one now is free 01:08:39.840 |
So she can already start working on the second micro batch 01:08:43.120 |
Meanwhile, the second gpu is working on the first micro batch and when she will finish she can send it to the next gpu and 01:08:50.800 |
Meanwhile, we realize that now the second gpu is free 01:08:54.240 |
So we can if the gpu one has finished we can take the output of the gpu one and transfer it to the gpu two 01:09:00.880 |
And the gpu one will be free so it can work on the third micro batch you can see here 01:09:05.920 |
Then after the third gpu has finished it will take the output of the third gpu 01:09:11.600 |
We send it to the fourth gpu, but we realize that the third gpu is now free 01:09:15.520 |
So if the previous gpu have finished we can transfer the second micro batch to the third gpu 01:09:20.160 |
The third micro batch to the second gpu and the first gpu which will be free can start working on a new micro batch 01:09:26.720 |
Which is the fourth micro batch and basically we do this 01:09:32.900 |
And this will result in a better utilization of the gpus because now at every time step we 01:09:38.800 |
At this time step for example all the four the gpus are working and also at this time step here at the backward step 01:09:44.720 |
And for each micro batch we calculate the gradient, but we do not update the parameters 01:09:53.360 |
Which basically means that we calculate the gradient for each micro batch and we keep summing it to the existing gradients 01:09:59.280 |
But we do not update the parameters of the model after all the micro batch have finished processing the forward and the backward 01:10:07.520 |
The gradient accumulation is a technique that I have introduced in my previous video on 01:10:12.000 |
Distributed training so if you want to understand how it works 01:10:15.440 |
I refer you to my previous video on distributed training in which I explain also the math behind gradient accumulation and how it works 01:10:22.080 |
But basically this is the solution with pipeline parallelism 01:10:25.760 |
So we can actually divide our batch into micro batches and this can also work with inferencing because when we inference 01:10:35.120 |
Step here, right? So we just delete this second half of the table 01:10:38.880 |
But we can still take our big batch at the beginning. We split it into micro batches and we time shift them 01:10:48.880 |
And this pipeline parallelism basically introduces still some 01:10:55.920 |
Time steps in which not all gpus are working and these are called bubbles 01:10:59.920 |
To avoid bubbles these big bubbles here. What we can do is we can 01:11:04.400 |
Use a bigger initial batch size. So we have multiple micro batches 01:11:09.940 |
Okay guys now let's go to the last part of this video 01:11:16.800 |
The mistral code is much more complicated to understand compared to the lama code and I will show you why but I will also help 01:11:23.280 |
You understand the most complex topic in the code, which is the xformats library 01:11:27.760 |
Which is a trick they use to improve the inference performance and it's actually a very advanced technique 01:11:34.000 |
And I want to give you a glimpse into how it works 01:11:37.360 |
So basically imagine you are running an ai company and you are providing llm inference service 01:11:45.440 |
So you have a customer that has you for example provide an api and you have customers that send their prompts to your api 01:11:52.960 |
And then want to run inference through your large language models 01:11:56.080 |
Each prompt of course may have different length because each customer may be using the 01:12:08.000 |
So suppose you have three customer the first customer says write a poem 01:12:11.920 |
The second customer says write a historical novel and the third customer says tell me a funny joke 01:12:18.880 |
of course you could process all these prompts one by one, but that would not be very efficient because 01:12:24.080 |
The two other two customer would be waiting for the first customer to finish and when you have a lot of customers 01:12:29.700 |
That's not good. And secondly, you may not be fully utilizing the memory of your gpu 01:12:34.720 |
So the best thing that you can do is to do batching you create all these prompts you create one big batch 01:12:40.800 |
But the problem is that the prompt have different lengths 01:12:44.880 |
So the first prompt is made up of three tokens the second prompt of four tokens and the third prompt of five tokens 01:12:51.600 |
One solution is to add padding to these tokens. So basically we create a batch in which we append 01:12:59.040 |
Padding tokens to the input sequence until they all reach the same size 01:13:04.880 |
Then we can run these sequences this batch through our large language model, which could be for example Lama or Mistral 01:13:14.800 |
As we saw before when we have a input sequence of n tokens 01:13:19.120 |
The attention mechanism produces an output sequence of n tokens 01:13:22.800 |
And we usually take the embedding of the last token 01:13:26.640 |
Send it to the linear layer then the softmax to understand what is the next token from our vocabulary 01:13:34.400 |
In the first prompt we see that we have added two padding tokens 01:13:38.000 |
So we cannot use the embedding corresponding to the last tokens because they correspond to the padding token 01:13:44.160 |
What we should do is we should take the embedding corresponding to the last non-padding token 01:13:49.200 |
And then send it to the linear layer and then to the softmax to understand what is the next token 01:13:54.960 |
And in the case of the second prompt we should be using the fourth token not the last one 01:13:59.760 |
Only in the last prompt we can use the last token because it's the last it's a non not padding token 01:14:09.600 |
And how do we actually create a attention mask to run it? 01:14:13.200 |
We basically just create an attention mask that is causal that will make each token only 01:14:19.440 |
Visualize the previous tokens. So each token will be able to relate to previous tokens, but not to future tokens 01:14:26.160 |
And this mask here will work fine for all the three scenarios. You can see here and I will show you later how 01:14:34.880 |
We cannot use a different mask for each prompt because all the prompts are of the same length 01:14:40.160 |
So all the masks must be 5x5 because we cannot use a 3x3 mask for this prompt a 4x4 01:14:46.960 |
Mask for this prompt and the 5x5 for this prompt because the input sequence is 5 01:14:53.200 |
So we must use a 5x5 mask and we have to use a 5x5 mask that is causal 01:15:00.400 |
And also has the we can also mask out for example 01:15:04.080 |
Imagine the sliding window is size is 4 then we can mask out this value here also because we don't want the 01:15:10.000 |
This token here to watch tokens that are a distance of more than 4 for example 01:15:15.760 |
So the problem here is that we are calculating a lot of dot products, especially for the first and the second prompt 01:15:27.760 |
When we apply this mask, so the 5x5 mask you can see here to this input sequence here 01:15:36.640 |
Which will produce the following attention mask 01:15:40.720 |
In which all this value will be masked out because they are minus infinity minus infinity and it's because of the causality of the mask 01:15:52.320 |
This value here because they are needed for the last prompt for example 01:15:56.160 |
And we also cannot mask this value here, which is needed for the last prompt 01:16:00.560 |
But for the first and the second prompt we are doing a lot of dot products 01:16:04.800 |
These ones between padding tokens and other tokens that we will not be using 01:16:08.400 |
Because I want to remind you that in the first prompt as the output of the model 01:16:13.200 |
So we will be using the output as the third token 01:16:16.880 |
For the second prompt the output at the fourth token and only in the last token. We will be checking the last output 01:16:23.760 |
Of the output of the self-attention, but for the first two prompts 01:16:27.520 |
We will not be even checking the last token output from the self-attention because they correspond to the padding token 01:16:33.360 |
So is there a way to avoid these padding tokens? 01:16:37.680 |
Being introduced in our calculation and calculating all these dot products which will result in output tokens that we will not even use 01:16:46.320 |
Well, there is a better solution and the solution is this 01:16:49.280 |
The solution is to combine all the tokens of all the prompts into one big sequence 01:16:55.920 |
Consecutively and we also keep track of what is the actual size of each prompt 01:17:02.880 |
So we know that the prompt are coming from our API because we are running an AI company and we have this API 01:17:09.200 |
So we know that the first customer has a token size prompt of size three tokens. The second one has 01:17:15.680 |
Four tokens and the third one has five tokens 01:17:18.240 |
So we can keep track of these sizes in an array, for example 01:17:22.240 |
And then we build this sequence which is a concatenation of all the prompts that we receive 01:17:27.280 |
We take this mega sequence. We run it through our LLM model. So it could be Mistral or it could be Lama 01:17:38.000 |
An input sequence in a transformer will result in n output tokens in the output 01:17:45.200 |
So we have here we have 3 plus 4 so 7, 7 plus 5, 12 01:17:50.800 |
Tokens as input it will produce 12 tokens as output 01:17:54.320 |
To understand what is the next token for each prompt. We need to check the 01:18:03.200 |
corresponding to the token number 3 for the first 01:18:05.840 |
Prompt to the token number 7 for the second prompt and the last token for the third prompt 01:18:12.560 |
So we take all these embeddings we run them through the linear layer 01:18:16.480 |
Then we apply the softmax and then we understand what is the next 01:18:20.480 |
Token from our vocabulary, but you may be wondering 01:18:24.480 |
How do we even produce an attention mask that can work with multiple prompts that are combined into one sequence? 01:18:31.920 |
Such that the token of one sequence should not of one prompt should not be attend 01:18:37.840 |
To the tokens of the another prompt but only of the tokens of the same prompt, right? 01:18:42.640 |
Well, the Xformers library allow us allow us to do that using a method called a block diagonal causal mask 01:18:51.120 |
Which is also used in the source code of Mistral. So I want to show you how it works 01:18:58.400 |
This method called the block diagonal causal mask will produce a mask like this. It will be 01:19:08.880 |
Such that each token only can attend to the tokens in the same group here. We have three prompts 01:19:15.920 |
So the token poem for example can only attend the token of the same 01:19:20.880 |
Prompt the token novel for example cannot be related to the token poem. So it will put minus infinity here 01:19:28.480 |
but all the token of the same prompt will be able to be attended by the 01:19:36.000 |
While the token in the last prompt will be only be able to attend 01:19:40.560 |
The other tokens in the same prompt and this is a special mask built by using the Xformers library 01:19:50.720 |
Okay, I want to show you actually how it works. So in the Mistral source code, they are using this 01:20:03.920 |
Allows us to compute very complex attention mask and also to calculate the attention in a very efficient way 01:20:10.560 |
Using the memory efficient attention calculation, which I will not show in this video. Maybe I will make a future video about it 01:20:16.480 |
But basically what they do in the Mistral source code 01:20:19.920 |
If you have multiple prompts, they will create one big sequence and then keep track of the number of tokens of each prompt 01:20:30.400 |
made available by the Xformers library to build these complex attention maps that 01:20:34.880 |
Keep track of the different size of the KVCache because each prompt may have a KVCache that is different from another prompt 01:20:43.200 |
Because imagine you have a prompt with 5,000 tokens and one prompt with only 10 tokens 01:20:47.680 |
Of course, you will have a KVCache that is 5,000 tokens in one case and 10 in another case 01:20:53.280 |
So the mask, attention mask that we build should take care of this 01:20:58.400 |
And the second thing is that each group of tokens should only be able to relate 01:21:03.040 |
To the tokens of the same group not to other groups. So not of tokens from another prompt 01:21:09.760 |
And this is done with the block diagonal causal mask 01:21:19.360 |
The second prompt is made up of five tokens and the third prompt is made up of six tokens 01:21:24.640 |
And we are also using a sliding window attention with a sliding window size of three and basically this will create the complex mask 01:21:33.040 |
This is the first group of tokens from 0 to 6 is the first prompt from 7 to 11 is the second prompt 01:21:42.560 |
Let me check 17 is the third prompt and as you can see it also takes into consideration the sliding window size 01:21:49.680 |
So each token can only watch at most two previous tokens 01:21:54.400 |
So the tokens in this in the contained in the sliding window size of size three 01:21:58.880 |
The second one they use is the block diagonal mask and okay. This one is used for the first 01:22:08.320 |
This one is used for subsequent chunks in the pre-filling 01:22:11.940 |
And basically it also takes in because during the first pre-filling we don't have the KV cache because it's initially empty 01:22:18.400 |
But during the subsequent steps, it's not empty anymore 01:22:21.520 |
So we need to take into consideration also the different size of the KV cache 01:22:24.880 |
So for example, the first token may have a KV cache of size 10 because the prompt is very short 01:22:31.120 |
But the second prompt may be very big suppose 5,000 tokens. So it may have a KV cache of size 5,000 01:22:36.740 |
So it takes into consideration also the size of the KV cache 01:22:40.640 |
And it will produce a mask that takes into consideration also the size of the KV cache 01:22:47.040 |
The last method they use is this one block diagonal causal with offset padded keys mask 01:22:52.320 |
Because each prompt may have a different size for the KV cache 01:22:57.280 |
But only some tokens in this KV so the KV cache size is fixed. It's a tensor that is of fixed side w 01:23:05.840 |
But only some tokens may be actual being filled in this KV cache 01:23:11.040 |
So only maybe the KV cache size is let's say 10 01:23:14.240 |
But because the first prompt is very short only three tokens are actually part in the KV cache 01:23:19.760 |
But when we pass the KV cache to the calculation of the attention 01:23:25.120 |
We pass all the tensor which is all the 10 items 01:23:31.360 |
That it should only use the first three items from the KV cache and not all the KV cache 01:23:37.440 |
Not all the tensor and this is done with block diagonal with offset padding keys mask 01:23:42.560 |
So this method here, it's very long name very complicated, but this is why they use it 01:23:48.880 |
So it takes into consideration the actual size of the KV cache 01:23:52.560 |
Even if the KV all the KV cache have the same size because it's a fixed size tensor 01:23:57.760 |
But it tells you how many items there are actually it should use from each cache 01:24:03.120 |
Okay guys it has been a very demanding video I have to say 01:24:15.520 |
I actually had to cut some parts because I even I got confused sometimes 01:24:20.160 |
It's very complicated topics. It's a lot of things that you have to grasp 01:24:25.760 |
But I hope that it will make your life easier when you want to understand the Mistral code 01:24:34.400 |
My notes the one that you have seen so the two notebooks that I have shown you 01:24:39.120 |
Plus also the code annotated by me on the Mistral source code 01:24:42.880 |
Now the Mistral source code I actually never run it. So because my computer is not very powerful 01:24:48.960 |
So I never run the actual model on my computer 01:24:51.920 |
What I did to study the model was to run some random tensors to a model and I created basically a model with 01:25:01.200 |
Weights but with the less number of layers so it could fit in my GPU and then I just run some random tensors 01:25:07.280 |
To study all the shapes of the tensor and all the information passing 01:25:11.040 |
So, I don't know if the code works, but I hope it will work. I mean, I didn't touch the logic. I just add some comments 01:25:17.200 |
anyway, you can use the commented code by me to as 01:25:21.360 |
As a learning tool to complement with the official code of Mistral so that you can understand 01:25:27.380 |
More about the inner workings of this grid model. I actually really enjoyed studying it. I really enjoyed 01:25:33.680 |
Studying the code and I learned a lot of stuff, you know 01:25:37.760 |
I think it's very very good when you are doing something that is very complicated 01:25:41.860 |
Because it teaches you a lot because if something is simple, then you don't learn much by the end of the day 01:25:46.800 |
Anyway, guys, thanks you for watching my video. I hope you also enjoyed this journey with me 01:25:53.600 |
I hope that you liked this video and you will subscribe to my channel if you didn't please do it 01:25:58.640 |
And the best way to support me guys is to share this video with all the people, you know 01:26:03.200 |
So share it on social media share it on linkedin on twitter, etc 01:26:07.040 |
because this is the best way to you can help me is to grow my channel and 01:26:12.160 |
Please let me know if there is something that you don't understand. I am always available to help