back to index

Mistral / Mixtral Explained: Sliding Window Attention, Sparse Mixture of Experts, Rolling Buffer


Chapters

0:0 Introduction
2:9 Transformer vs Mistral
5:35 Mistral 7B vs Mistral 8x7B
8:25 Sliding Window Attention
33:44 KV-Cache with Rolling Buffer Cache
49:27 Pre-Fill and Chunking
57:0 Sparse Mixture of Experts (SMoE)
64:22 Model Sharding
66:14 Pipeline Parallelism
71:11 xformers (block attention)
84:7 Conclusion

Whisper Transcript | Transcript Only Page

00:00:00.000 | Hello guys, welcome back to my channel today, we are gonna talk about Mistral
00:00:03.780 | So as you know Mistral is a new language model that came out a few months ago from Mistral AI
00:00:09.280 | Which is a one of the hottest to start up right now in Europe for language models
00:00:13.720 | It also became a unicorn recently and we will exploring both the models
00:00:18.580 | They released the one is the 7 billion and one is the 8 by 7 billion model
00:00:21.840 | So let's review the topics of today
00:00:23.960 | the first thing I will introduce you is the architectural differences between the vanilla transformer and the
00:00:29.240 | architecture of Mistral later
00:00:30.920 | We will see what is the sliding window attention and how it is related to the concept of receptive field a concept that usually we
00:00:37.520 | find in convolutional neural networks
00:00:39.440 | I will briefly review the KB cache because I want to introduce the concept of rolling buffer cache and also how it is done
00:00:46.480 | with the pre-filling and chunking
00:00:48.480 | we will see what is a sparse mixture of experts model sharding with a little
00:00:54.160 | with a very brief introduction with the
00:00:57.840 | pipeline parallelism and
00:00:59.840 | Last but not least we will also go through the code of Mistral because there is a lot of innovations in the code
00:01:06.160 | Especially when they use the Xformers library with the block attention
00:01:09.440 | So I want to guide you into understanding the code because it can be really hard for beginners to understand and find themselves around
00:01:16.880 | There are some topics that are related to Mistral
00:01:19.880 | But will not be covered in this current video because I already covered them in my previous video about Lama and in particular
00:01:25.960 | I will not be talking about the RMS normalization, the rotary positional encoding and the grouped query attention because I already
00:01:31.200 | Teach them in depth in my previous video on Lama
00:01:34.240 | So if you want to know about them, please watch my previous video on Lama
00:01:37.600 | In order the only prerequisite that I hope you have before watching this video because the topics we are going to touch are quite advanced
00:01:45.360 | Is that you are familiar with the transformer model
00:01:48.080 | So if you are not familiar with the transformer model and the attention mechanism in particular and in particular the self attention mechanism
00:01:54.160 | Please go watch my video on the transformer in which I teach all this concept very thoroughly very in detail
00:02:00.360 | These are really a prerequisite for watching this video because the topics here are quite advanced
00:02:05.720 | Okay, let's proceed further
00:02:08.360 | So let's watch the differences between the vanilla transformer and Mistral at the architecture level
00:02:13.480 | as you can see from the
00:02:16.200 | Image here, which I built by myself using the code because they didn't release any architecture picture in the paper
00:02:24.000 | And the architecture of Mistral first of all, let's talk about some terminology
00:02:29.040 | When you have a model like this made up of many encoder layers plus linear and the softmax
00:02:35.200 | We are talking about a decoder only model because this part this model here looks like the decoder of the vanilla
00:02:42.840 | Transformer you can see here except for the cross attention because as you can see here, there is no cross attention
00:02:49.360 | When we have a model without the linear and the softmax we call it an encoder only model
00:02:54.920 | For example BERT is an encoder only model because BERT has some heads at the end
00:03:00.080 | Which is one or more linear layers depending on the application
00:03:04.080 | But itself BERT doesn't need a head because it can be used for multiple downstream tasks
00:03:08.560 | so it's called an encoder only model because it resembles the
00:03:11.800 | Encoder side of the transformer because as you can see in the encoder side, there is no linear and softmax
00:03:17.240 | So Mistral is a decoder only model and it's very similar if not equal to Lama
00:03:23.320 | The differences between Lama and Mistral are highlighted here in red
00:03:28.000 | the first difference between Lama and Mistral is that
00:03:31.480 | In the self attention we use the sliding window attention and we still use the grouped query attention
00:03:39.400 | But and also the KV cache for inferencing
00:03:42.320 | But this is a rolling buffer KV cache and it's actually related to the fact that we are using sliding window attention
00:03:47.680 | So later, we will see all these concepts and the another difference is that the feedforward layer here instead of using the relu function that we use
00:03:55.880 | In the vanilla transformer or the ZWIGLU function that we use in Lama here in Mistral we use the SILU function
00:04:02.480 | And the feedforward is one in case of Mistral 7b
00:04:08.960 | So the first model they released and it can be eight feedforward
00:04:13.840 | Networks in parallel with each other which are the experts of this mixture of experts in the case of Mistral 8x7b
00:04:21.760 | We will see later how it works
00:04:24.520 | So for now, you just need to understand that Mistral is made up of okay the input which are converted into embeddings
00:04:31.080 | Then we have this block which is repeated n times and we will see that in the case of Mistral is repeated
00:04:36.800 | 32 times one after another such that the
00:04:40.120 | Output of each layer is fed to the next layer as input and the output of the last layer is then sent to this
00:04:47.920 | RMS norm to the linear and to the softmax to produce the output of the model and
00:04:52.760 | This is exactly the same as what we do with any other transformer model
00:04:57.880 | Usually we have many of these blocks here. Now in the code of Mistral this part here is known as transformer block
00:05:05.560 | But it's also known as encoder block or decoder block depending on the contents of this
00:05:12.960 | Block here. I will refer to it as an encoder block because if you look at it
00:05:17.640 | It looks like exactly as the block of the encoder side
00:05:20.560 | So it has a multi-head attention, add-end norm, a feedforward and add-end norm
00:05:24.720 | The only difference is that the normalization here comes before the the block of the feedforward and the self-attention
00:05:32.000 | Okay, let's move forward
00:05:35.520 | Now let's compare the two models
00:05:38.080 | So one is Mistral 7B and one is Mistral 8x7B
00:05:41.160 | The parameter dim indicates the dimensions of the the dimensions of the embedding vector
00:05:47.660 | So how big is the embedding vector? So each token is represented by an embedding vector of size
00:05:52.720 | 4096 dimensions
00:05:55.360 | We have 32 of the encoder layers. So this block here is repeated 32 times
00:06:02.880 | The head dimension indicates as you remember in the multi-head attention we have
00:06:07.540 | Each head is watching in the entire sentence
00:06:11.960 | But only a part of the embedding of each token and this indicates how much
00:06:17.280 | How many dimensions each head will attend to in each
00:06:21.920 | For the in the multi-head attention and the hidden dimension here indicates the hidden dimension of the feedforward layer
00:06:28.760 | so if the in the case of the feedforward layer, we have two linear layers one that converts the
00:06:33.600 | Dimension of the embedding vector into the hidden size then another one that converts the hidden size back into the embedding vector dimensions
00:06:42.440 | So in the case of the Mistral they are using as a hidden size
00:06:45.980 | 14336 usually this is a multiple of the dimension and it looks like it's 3.5 the dimension here
00:06:55.520 | The number of heads of attention for the query is a 32 while the number of heads for the K and V
00:07:01.680 | So the key and values is 8 and they are not equal because of the grouped query attention
00:07:06.320 | So if you remember from my previous video on llama in which we talked about the grouped query attention
00:07:10.420 | In the very simple case of the grouped query attention
00:07:14.360 | We have the multi query attention which means that only the query have the multi head while the key and V don't have the multi head
00:07:20.760 | attention
00:07:21.680 | Which means that you may have eight heads for the query and only one head for the K and V in the case of the
00:07:27.600 | grouped query attention means that each group of
00:07:29.800 | Query will have one
00:07:32.600 | Attention head for the K and V
00:07:34.920 | So in this case
00:07:36.000 | Every four query have one attention head for the keys and values if this concept is not clear
00:07:42.280 | I describe it very thoroughly in my previous video on llama
00:07:46.240 | the window size is the size of the sliding window that we used in the
00:07:51.420 | Calculation of the attention and we will see later how it works. The context length is
00:07:56.920 | What is the context size upon which the model was trained upon?
00:08:01.480 | And it's much bigger for the 8 by 7 B
00:08:04.720 | The vocabulary size is the same for both and then the last two parameters
00:08:08.280 | You can see here are related to the sparse mixture of experts and we will see later
00:08:13.440 | How it works, but we just remember that we have eight experts and for each token we use two experts
00:08:19.800 | But later I will clarify how it works
00:08:21.800 | Let's proceed further. So let's talk about the sliding window attention. But before we talk about the sliding window attention
00:08:28.440 | I need to review a little bit of the self attention mechanism. So what is self attention?
00:08:33.720 | Self attention is a mechanism that allows the model to relate tokens to each other
00:08:39.160 | So tokens that are in the same sentence are related with each other through the self attention mechanism
00:08:44.160 | This is why it's called self attention because each token is watching other tokens of the of the same sentence
00:08:50.640 | And when when this is means basically that the query key and values are the same matrix
00:08:57.000 | So imagine we have the following sentence the cat is on a chair
00:09:03.840 | We have our query which is a matrix made up of six tokens each token represented by
00:09:09.760 | 4096 dimensions, which is the dim parameter that we saw before
00:09:13.480 | This is multiplied by the transpose of the keys, which is
00:09:17.800 | 4096 by 6 but it's just the query matrix transpose because the query key and values are the same matrix in the case of self attention
00:09:26.600 | This will produce a matrix that is 6 by 6
00:09:29.840 | Because the inner two dimensions kind of cancel out and the outer dimensions indicate the dimension of the output matrix here
00:09:37.320 | Now, what are the values in this matrix representing?
00:09:41.840 | The first value here indicates the dot product of the first token with the first
00:09:47.920 | The first row of the query with the first column of the keys
00:09:52.280 | So basically the dot product of the embedding of the first token with itself
00:09:56.840 | The second value here indicates the dot product of the first row of the query matrix
00:10:02.560 | With the second column of the key matrix here the transpose of the keys matrix here
00:10:07.400 | Which basically means that it's the dot product of the embedding of the first token
00:10:12.040 | So the with the embedding of the second token, which is cat and etc, etc for all the other values
00:10:18.480 | Don't concentrate too much on the values because all the values I put here are random
00:10:22.240 | And also the fact that these numbers are less than one
00:10:25.120 | It's not necessary because the dot product can be bigger than one. It's not a condition of the dot product
00:10:30.440 | Usually in the formula we also normalize here we divide by the dimension of the dk
00:10:38.160 | dk basically it's the size, the part of the embedding to which this particular attention head will attend to
00:10:46.160 | But let's pretend that we only have one
00:10:49.680 | One head so dk is equal to d model. So basically this head will watch the full embedding of each token
00:10:56.880 | Okay, usually we train autoregressive models. So language model is an autoregressive model
00:11:03.160 | It means that the output depends on the previous
00:11:05.840 | The next token depends only on the previous tokens
00:11:09.240 | And this is why we apply a causal mask. Causal mask means basically that in the attention mechanism
00:11:15.280 | We don't want to release a word with future words
00:11:19.000 | So words that come after it but only with the words that come before it
00:11:23.280 | so for example, we don't want the word "the" to be related to the word "cat" because
00:11:28.680 | The word "cat" comes after the word "the" but on the other hand
00:11:33.840 | We want the word "cat" to be related to the word "the" because it comes before it
00:11:38.480 | And for this reason we apply this causal mask
00:11:40.960 | Because the attention mechanism uses the softmax function we can see here. The softmax function basically
00:11:49.440 | will transform all this minus infinity into zero because the formula of the softmax has at numerator an e to the power of X and
00:11:57.640 | When X goes to minus infinity e to the power of minus infinity will go to zero
00:12:02.360 | So this is why we apply a mask in which we put all the values that we don't want, all the interactions that we don't
00:12:08.320 | Want between tokens. We just mask them out by replacing them with minus infinity
00:12:12.920 | So that when we apply the softmax, the softmax will take care of
00:12:18.320 | transforming them into zeros
00:12:20.320 | Okay, also the softmax will do another thing because it will not only convert this minus infinities to zero
00:12:27.880 | But it will also modify the other value for each row such that they sum up to one
00:12:32.720 | So as you can see now these values here
00:12:35.280 | They don't sum up to one for each row, right? Because this is 0.2, 0.1 and 0.2
00:12:39.760 | They don't sum up to one but the softmax will convert the minus infinities into zero and the remaining values for each row such
00:12:46.560 | that they sum up to one
00:12:48.320 | Now let's talk about sliding window attention
00:12:50.960 | So we applied the causal mask to hide the interactions between
00:12:55.120 | The words, a word and all the future words, but with the sliding window attention
00:13:00.840 | We also don't want the word to watch
00:13:03.160 | Other words that are outside its local context
00:13:07.400 | What do I mean by this?
00:13:08.880 | In the previous case when we only applied the causal mask, the word chair for example was being related to all the previous
00:13:15.120 | Tokens as you can see, so the token chair here is related to itself
00:13:19.000 | But also to the a on is cat v, so it could watch basically all the sentence
00:13:25.120 | But in the case of sliding window attention, we don't want the word chair to watch words that are further than
00:13:32.520 | the sliding window size from itself, so
00:13:35.840 | The sliding window size in this case is three, so tokens that are distance more than three from the word
00:13:42.800 | we are considering, so the word the chair should not be related to the word is because the distance is four and the word
00:13:48.240 | a should not be related to the word cat because the distance is four and of course
00:13:52.800 | We still want the mask to be causal because we don't want the model to
00:13:57.140 | Each token to watch future words because we are training an autoregressive model
00:14:01.880 | So the sliding window attention basically reduces the number of dot products that we are performing
00:14:09.720 | And this will improve the performance during the training and the inference because as you can see when we only apply the causal mask
00:14:15.680 | We are performing all these dot products you see here, but with the sliding window attention
00:14:20.720 | We are performing less dot products because all the other will be masked out
00:14:24.240 | Sliding window attention however may lead to degradation of the performance of the model because as you can see here
00:14:31.640 | The word the chair and the word the are not related to each other anymore, right? So the information
00:14:38.680 | Will not be conveyed from the word the and the word chair the word chair will only be related to other tokens that are belonging
00:14:45.600 | to the local context of this particular token so only the
00:14:49.160 | Tokens that are in the same in the inside this is sliding window
00:14:53.260 | This may be if this window is too small it may reduce the performance of the model
00:14:58.480 | But it may also be beneficial because for example imagine you are reading a book you don't care about
00:15:04.180 | Relating the word in chapter 5 with the words in chapter 1 because most of the books
00:15:09.120 | They could be talking about totally different things and you don't even care about relating these two tokens
00:15:15.080 | But for sure you want to relate the tokens
00:15:17.700 | In the chapter 5 with other tokens in the chapter 5 because the local context matters
00:15:22.980 | But I want to introduce you the concept of receptive field because when we use sliding window attention
00:15:29.580 | Even if the word chair and the are not related to each other actually
00:15:34.520 | Because in Mistral and in all transformer models we use multiple layers of encoders
00:15:40.460 | We will see that the information so the the word the chair and the the will still be kind of related to
00:15:47.280 | To each other not directly, but indirectly
00:15:49.940 | In a concept that is very similar to the receptive field of the convolutional neural networks
00:15:55.600 | So let's talk about the receptive field as you remember in convolutional neural networks
00:16:01.520 | We have a mask a kernel that we run through an image. So imagine this is our original image
00:16:07.760 | This one here and we run a mask that is a kernel that is 3x3 this one here
00:16:13.600 | That when we run a kernel it will produce an output feature. So for example this feature here
00:16:19.120 | This is the output produced by applying the kernel to the first 3x3 grid here
00:16:25.440 | This value here the second value here in yellow
00:16:29.480 | It will be produced when we will move our kernel to the next group of 3x3 pixels. So let me draw
00:16:36.920 | Let's use the pen. So this value here
00:16:41.720 | will be produced when we will move our kernel in this grid here and
00:16:50.920 | This value here is also an output feature of a
00:16:56.140 | convolutional kernel that is a 3x3
00:16:59.160 | Applied to this layer 2. So this is a 3x3 kernel that is applied to this layer 2
00:17:05.840 | So apparently there is no connection between this one this pixel here and this one
00:17:11.820 | But because he this this output feature depends on a kernel applied in this grid and this grid
00:17:19.520 | Includes this feature here which depends on this pixel here. We can safely say that this feature here
00:17:26.720 | Depends indirectly also on this feature here
00:17:30.520 | Even if they are not directly related to each other and this is the concept of the receptive field. So basically
00:17:36.200 | One feature of the convolutional neural networks can watch a much bigger receptive field
00:17:42.900 | down upward in the layers because of this
00:17:49.400 | Sequential application of kernels in the convolutional kernels
00:17:53.240 | Let's see how this concept is related to the sliding window attention now
00:18:00.280 | After we apply the softmax to the mask that we have seen before as I told you before all the minus infinities are
00:18:06.840 | Converted into zero and all the other values are changed in such a way that they sum up to one
00:18:12.100 | So let's go back as you remember here. We have the minus infinities here here here here here and here
00:18:17.920 | So now we apply the softmax and it will become zeros zeros here. Also, let me
00:18:24.680 | Okay, all the zeros here all the zeros here and all the other values are changed in such a way that they sum up to
00:18:30.560 | One what is the next operation that we do in the self-attention?
00:18:33.640 | We then take the output of the softmax and multiply it by the V matrix. So let's do it
00:18:38.660 | The V matrix is basically the same as the initial sequence because I told you this is self-attention
00:18:46.360 | So the query key and values are the same matrix
00:18:49.160 | So this means let's analyze what happens by hand when we do this multiplication
00:18:54.640 | So let me change to the pen. Okay, the V matrix here is
00:18:59.200 | is a sequence of tokens where each token is a vector represented by
00:19:04.960 | 4096 dimensions so we can say that it's the output of the self-attention if you watch the
00:19:12.760 | Dimensions of these two matrices. So it's a 6 by 6 and the 6 by
00:19:16.040 | 4096 the output will be another matrix that is 6 by
00:19:19.320 | 4096 so it will have the same dimension as the V matrix and also as the Q and the
00:19:24.640 | Query matrix because they have the same dimensions. So it will be six tokens as output
00:19:31.120 | Okay, let's analyze. What is the first dimension of the output to this one here?
00:19:37.720 | So this first value of the output so the value on the row 1, column 1 of the output matrix will be the dot
00:19:45.200 | Product of the first row of this matrix here. So this row here
00:19:50.640 | With the first column of this matrix, so the first column we can see here and
00:19:58.040 | As you can see most of the values here are 0 which means that all the rows from the 1 to 5
00:20:06.280 | Sorry from 2 to 6 will not be used
00:20:09.340 | but only the first row here fully the values of the first row will be used because if you remember the dot product is the
00:20:15.660 | first dimension with the first dimension of this column and
00:20:19.560 | The second dimension of this row with the second dimension of this column
00:20:23.960 | The third dimension of this row with the third dimension of this column and then we sum up all these values
00:20:29.360 | So this first value of the output will only depend on the first token of the V matrix
00:20:36.440 | You can see here. Let's check the second one the second
00:20:40.040 | dimension of the
00:20:42.400 | the first dimension of the second row of the output matrix will be the dot product of the first row of
00:20:49.080 | this matrix here
00:20:51.600 | With the first column of the V matrix
00:20:55.760 | but most of the values are 0 which means that this
00:21:00.140 | dimension here and all the dimensions in this row will depend only on the first two tokens of the V matrix and
00:21:08.120 | We can say the same for the third. Let's analyze the sixth one here
00:21:12.240 | So the first dimension of the sixth row of the output matrix
00:21:17.040 | So this value here comes from the dot product of this row
00:21:23.480 | And the first column of the V matrix
00:21:26.920 | but most of the values at the beginning are 0 which means that it will only depend on the
00:21:32.860 | 4, 5 and 6th token of the V matrix and so will be all the dimensions here because in each column
00:21:41.880 | Whatever the column we use from the V matrix the first values will always be multiplied by 0, 0, 0
00:21:48.600 | So it will only use the values in these three rows here
00:21:51.880 | So we can safely say that the 6th token of the output matrix
00:21:57.720 | of this self-attention mechanism will be a vector that will only depend on the last three tokens of the V matrix and
00:22:05.680 | Because we are talking about self-attention the V matrix is equal to query matrix
00:22:10.120 | So we can say that the output of the self-attention is a matrix that has the same shape as the input sequence
00:22:17.480 | But where each token now captures some more information about other tokens
00:22:22.480 | Which tokens depending on the mask we have applied
00:22:26.080 | So our mask says that the first token can only watch itself
00:22:30.520 | So the first output token will be an embedding that will only depend on itself
00:22:35.200 | The second token will only depend on the first two tokens
00:22:39.560 | The third output token will only depend on the first three tokens
00:22:44.840 | The fourth will depend on the token number two because the first token is not used
00:22:49.960 | The token number two, the token number three and the token number four, etc, etc
00:22:53.160 | Until the last here
00:22:55.160 | The last token will depend only on the last three tokens because the first three tokens are masked out
00:23:01.320 | And this is the importance of the mask that we apply in the self-attention mechanism
00:23:05.960 | This concept that I show you now is very important to understand the rest of the video
00:23:10.240 | So please if you didn't understand it, you can take a little pause
00:23:13.000 | You can try to do it by yourself because it's really important that you understand
00:23:17.200 | How the self-attention mechanism works with the mask
00:23:19.560 | Okay, now that we have seen this concept
00:23:23.000 | I want to introduce you to the next one
00:23:25.040 | So as we saw before the output of the self-attention mechanism is another
00:23:29.360 | Matrix with the same shape as the query matrix in which each token is represented by an embedding of size
00:23:36.240 | 4096 but each embedding now captures information also about other tokens
00:23:42.320 | According to the mask and if we check the this mask here, so the output here
00:23:48.360 | We can safely say that the input of our sliding window attention was the initial
00:23:55.600 | Sequence dcat is on a chair
00:23:58.560 | But after applying the self-attention the first token is now related to itself
00:24:03.640 | The second token is related to itself and the token before it
00:24:07.800 | The third is related to the token before it and the one also before it
00:24:11.920 | The last one only depends on the previous two tokens, etc. According to the mask, right?
00:24:17.040 | Now what happens if we feed this one because as you know in the transformer world and also in Mistral and also in Lama
00:24:24.440 | we have many layers of encoders one after another which are also called the transformer block in the code and
00:24:30.480 | The output of each layer is fed to the next one. So this is the first layer of the transformer
00:24:37.680 | So we take the input sequence and we feed it to the first layer which will produce a list of tokens where each token now captures
00:24:44.000 | Information about other tokens, but this will become the input of the next layer where we it will produce an output
00:24:51.680 | This output I will prove you that will capture information about even more tokens
00:24:56.640 | Even if the sliding window attention says that they should only be able to watch the previous two tokens
00:25:02.440 | Because the sliding window size we chose three as a sliding window size
00:25:07.440 | I want to prove it. So
00:25:09.440 | Imagine this is the output of the first layer
00:25:13.400 | So it's a list of tokens that capture information about other tokens and it's the the matrix that we built in the previous slide
00:25:21.120 | Let's use it as an input for another layer of the encoder. So we multiply the query
00:25:27.160 | We multiply the query and the transposed of the keys which will produce a matrix like this one in which each token is not only
00:25:34.920 | One token, but it's capturing already information about multiple tokens, right according to the mask
00:25:39.560 | So I'm taking this one and this one will become query key and values
00:25:44.480 | So if we multiply the query by the key, it will return a matrix like this
00:25:48.480 | So the first token only depends on itself. The second one depends on himself and the previous one
00:25:54.080 | So the embedding of this token captures information about two tokens and the embedding of this token capture information about three tokens, etc
00:26:01.080 | Let's try to do the multiplication again
00:26:06.560 | We have that our V matrix is again a list of tokens and the output
00:26:14.160 | Will also be a list of tokens, but each one will capture information about other tokens
00:26:21.320 | Okay, let's analyze the dot product here
00:26:24.840 | So the first value of the first row
00:26:27.520 | So the first dimension of the first row of the output matrix will be the dot product of the first row of this
00:26:33.960 | Matrix here. So this row here with the first column of this matrix here. So this column here
00:26:41.320 | But because of this causal mask with the sliding window attention mask that we can see here
00:26:47.480 | It will the output will only depend on the first row of the V matrix
00:26:52.400 | But because the V matrix is a matrix that is made of these tokens here
00:26:57.240 | it will only depend on the word V so as we can see here the output of the second layer only depends on the
00:27:03.800 | word V and
00:27:05.800 | So will be this second one. So let's check the fourth token. For example here this one
00:27:12.400 | Let's check this fourth token here
00:27:15.800 | So this value here will be the product of the fourth row of this matrix dot product of
00:27:23.160 | This row with the first column of the V matrix. So this column here
00:27:28.480 | But the first token will not be used because it's we are multiplying it with zero whatever value we have here
00:27:34.760 | We will not be using it
00:27:35.960 | we are using the second token the third token and the fourth token and
00:27:41.840 | Each token actually they are aggregating this
00:27:45.400 | This values here. This token here is already
00:27:49.240 | Aggregating the value of two tokens, which is D and cat. So this embedding here is already about talking about D and
00:27:57.360 | cat and
00:27:59.960 | this token here is talking about is aggregating the information about the D cat and is so
00:28:07.720 | D cat and is and the fourth token is
00:28:12.280 | Aggregating the information of the cat is on so cat
00:28:16.020 | is and
00:28:18.480 | On because as we saw before the fourth token here cat is on which is the result of the previous self-attention that we done
00:28:25.940 | So this output value here will depend on three tokens that already include information about other tokens
00:28:35.080 | So this value here will aggregate
00:28:37.800 | Information about the union of all these tokens
00:28:41.160 | So it will for sure depend on the word D because it's included in the second token. We are multiplying it with
00:28:46.680 | It for sure will include information about the word the cat because it's included in this token as well for sure
00:28:54.000 | It will include information about is because it's included in the second value
00:28:57.740 | We are multiplying it with and for sure
00:28:59.640 | it will include about the token on because it's present in the the fourth token of the V matrix for with which we are
00:29:06.620 | Multiplying it because this value is not zero
00:29:09.600 | So as you can see after applying another layer of the encoder
00:29:14.960 | The fourth token now includes another token in its information before it was only
00:29:20.720 | including these three tokens
00:29:23.000 | but now it also depends on a new token which is the word V and
00:29:26.840 | We can prove the same for the fifth token and the sixth token
00:29:30.600 | so at every application of the encoder layer one after another we keep increasing the number of tokens that get
00:29:37.820 | Accumulated in these dot products and I made a notebook in Python to visualize this
00:29:44.720 | So if you look at my github repository
00:29:47.440 | You will see this
00:29:50.880 | Notebook called the sliding window attention in which I help you visualize this process and I also share the code on how I do this
00:29:58.280 | Self-attention basically I represent each token as a set so each
00:30:03.560 | Each token instead of being represented as an embedding as a set of all the words upon which that token depends
00:30:10.920 | depends then I apply the cell sliding window attention, which basically means that I take the two tokens that from the sequence and I
00:30:20.520 | Accumulate I make the union of the two sets they contain
00:30:23.520 | Because I am multiplying two vectors that already include information about multiple tokens
00:30:29.440 | So what is the output is the union of the two sets when I multiply by V
00:30:34.320 | I do the same thing and I can visualize it
00:30:36.640 | So after we apply the first layer, we will see that the input of the first layer is just our normal sequence
00:30:42.160 | So the cat is on a chair the output of the first layer will be another sequence
00:30:46.360 | There in which each position includes information about multiple tokens
00:30:50.120 | Depending on the mask that we have applied and I also show the mask that we apply
00:30:54.600 | After we apply the second layer, we can see that the information increases
00:30:59.560 | So this last token now is not watching only the previous three tokens, but the previous four tokens
00:31:05.340 | Sorry, not only the previous two tokens, but the previous four tokens
00:31:09.820 | so every step we do with the sliding window size of three we include two tokens at every layer and
00:31:16.660 | Here I
00:31:19.180 | Show it for five layers, but it's not necessary because after a while the sequence will reach the maximum length
00:31:25.580 | If you want you can increase the sequence's length here by including more tokens
00:31:30.620 | So this is the concept of the
00:31:35.580 | Receptive field applied to the self window attention. So basically with the sliding window attention, we are not
00:31:41.420 | Directly connecting two tokens with each other
00:31:46.180 | But if we apply multiple layers after one after another this information will get
00:31:51.860 | Will get captured by the embedding in successive applications of the layers such that the last layer
00:31:58.340 | basically will be able to watch all the sentence even if it's very long and
00:32:03.500 | this is actually shown by the
00:32:05.500 | Mistral paper in this picture you can see here. So basically this is our input sequence. So let me write
00:32:12.820 | So this is our input which is a
00:32:17.500 | The original sentence so the cat is on a chair
00:32:21.340 | The fourth token of the first layer. So this is the output of the first layer. So layer one
00:32:30.380 | We have seen that the fourth token here depend with a sliding window size of four this will depend on the itself
00:32:37.460 | On the previous token on the one before and also this token here and it will produce
00:32:43.180 | This this embedding here in the fourth position which includes information about the previous token as well
00:32:49.600 | But then this will become the input of the next layer, which is the layer number two
00:32:57.300 | this will produce an embedding at this position for example that will depend for sure on the previous four tokens because the sliding window size is
00:33:05.100 | Four but because for example this token here is already the aggregation of the previous four tokens
00:33:10.860 | It will actually multiply the visibility of its sliding window
00:33:14.620 | So this token here is not related directly to the first one
00:33:19.380 | We can see here, but indirectly through the this intermediate token. We can see here. I hope this
00:33:27.140 | I hope this concept is clear. If it's not clear, I try I
00:33:33.340 | Recommend using my notebook so that you can experiment by playing with the multiple
00:33:38.220 | Sequences and you can see how the information flow will go through all the layers
00:33:42.700 | All right, let's talk about our next topic
00:33:47.420 | Which is the KV cache because I want to introduce the KV cache which I already explained in my previous video on llama
00:33:52.620 | But I want to introduce it again and review it because I want to introduce later the rolling buffer cache
00:33:57.140 | So let's start by talking about first of all how we train language models because this is needed to understand the KV cache
00:34:04.540 | So the language models are trained using what is known as the next token prediction task
00:34:09.660 | so given a prompt the goal of the language model is to predict what is the next token that makes sense with the prompt that
00:34:15.900 | We have given and imagine we want to train a language model on Dante Alighieri's poem
00:34:21.820 | Divine comedy and in particular we will training it on a line that you can see here in
00:34:27.020 | This one in English. So love that can quickly seize the gentle heart. How does it work?
00:34:33.540 | We prepare an input for our language model
00:34:36.580 | Which is the line that we want to teach it with a token
00:34:40.300 | Prepended called the start of sentence and then we build the target which is the same line
00:34:45.240 | But with a token at the end called end of sentence. We run the input through this transformer model
00:34:51.260 | It will produce an output sequence
00:34:53.260 | so as we saw before the in the output of the self attention is another sequence with the same length as the input sequence, but
00:35:01.460 | Embedding is modified in such a way that each token capture information about other tokens. And this is what we do
00:35:09.140 | To actually train a model
00:35:11.120 | so if we feed the model with with the nine tokens the model will produce nine tokens as output and how
00:35:18.060 | does it work basically the
00:35:20.660 | model will learn a mapping between
00:35:22.660 | Input and output such that if we give to the model as input the token start of sentence only it will produce
00:35:30.660 | The first token as output which is the word love if we give to the model as input the first two tokens
00:35:37.300 | So start of sentence love the model will produce the two tokens as output
00:35:41.960 | So love that it will feed the model as input three tokens
00:35:46.020 | So start of sentence love that the model will produce love that can so when we train the model we train it like this
00:35:52.760 | We prepare the input like this the target like this. We calculated the output
00:35:56.500 | We calculated the loss using the cross entropy loss and then we run back propagation and this is done in all one step
00:36:01.880 | When we do the inference we do it in multiple step
00:36:04.540 | So when we do the inference at time step one, we feed the model only the first token
00:36:09.060 | So the start of sentence and the model will produce the output love
00:36:12.460 | Then we take the output the last token of the output and we prepend it to the input which becomes the input as time step
00:36:19.620 | Two so it becomes start of sentence love. So the model will produce love that
00:36:24.300 | We take the last token of the output and we prepare append it to the input for the time step three
00:36:29.880 | And this will become the new input which will produce love that can then we take the last
00:36:35.060 | Token of the output and we append it to the input for the time step four
00:36:39.820 | So it will become the new output will become love that can quickly then we take this word quickly
00:36:45.060 | We append it to the input for the next time step and which will produce the next token as output, etc
00:36:49.980 | Etc until the last token until we see the end of sentence token as output then we know that the model has stopped
00:36:56.340 | Has stopped producing new tokens and we can stop the inference
00:37:01.740 | Now at every step the inference we are only interested in the last token output by the model because we already have the
00:37:09.740 | previous one, but of course, we need to feed all the previous tokens to
00:37:14.700 | to the model which is
00:37:17.260 | Belonging to the prompt because the model needs to access the prompt to understand which token to produce next
00:37:23.500 | So for example, we cannot produce the word gentle only by giving the word the we need to give all this sentence to produce
00:37:31.020 | this output gentle here
00:37:33.020 | But at the same time we are only interested in the last word gentle
00:37:38.620 | And this is the reason we introduce the KVCache because the KVCache allow us to reduce the computations that we are doing
00:37:45.700 | by only producing one output at a time the one that we need but
00:37:51.020 | Without doing all the intermediate computations for all the other tokens that we never use
00:37:55.620 | So basically when we want the word heart
00:37:58.540 | We don't want to produce the output for the word love that can quickly seize the gentle because we already have them in the prompt
00:38:05.340 | We don't need to produce all these tokens
00:38:07.060 | We just want to produce the output for the token heart. So we want to reduce the computation that we are doing
00:38:11.820 | Let's see how it works
00:38:13.940 | Now in the self-attention mechanism
00:38:17.020 | You know that we multiply the query which can be thought of as a list of tokens where each token is an embedding of size
00:38:22.820 | 4096 and the transposed of the query becomes the is multiplied the transpose of the keys
00:38:29.580 | Are multiplied by the queries to produce this matrix here
00:38:33.140 | And then we multiply it by the V matrix to produce the output of the self-attention. You can see here
00:38:38.540 | Let's do this one token at a time
00:38:40.900 | So when we inference a language model, we start with our first token, which is the start of sentence. This is one token
00:38:47.520 | represented by an embedding of size
00:38:50.060 | 4096 we multiplied by the transposed of the keys which is again one token because it's a self-attention. So
00:38:57.000 | The query the key and the value are the same matrix. So this is just the transposed of the query
00:39:03.220 | Basically, and so it's a column vector and it will produce a one by one matrix
00:39:08.260 | We multiply it by V and it will produce an output token. We take this output token
00:39:12.460 | We send it to the linear layer and then to the softmax to understand which token this corresponds to in our vocabulary
00:39:19.900 | We take this token from our vocabulary and we append it to the query for the next
00:39:25.660 | Inference step to the keys and the values and then we compute again the product of the query multiplied by the keys
00:39:32.920 | We multiply then the result by V and it will produce an output made up of two tokens because we have two tokens as input
00:39:39.020 | It will produce two tokens as output, but we are all interested in the last token
00:39:43.100 | So we take this output token too, we send it to the linear layer then to the softmax
00:39:47.220 | This will result in what token is corresponding to in our vocabulary. We take this token from our vocabulary
00:39:53.340 | We append it for the next step to the query key and values
00:39:57.060 | We do again this process and then we take the last token as output
00:40:03.380 | You can see here. We may send it to the linear layer then the softmax
00:40:06.600 | We understand which token it corresponds to, we append it to our query key and values and then we compute again the self attention
00:40:14.140 | but we already start to notice something because
00:40:17.340 | First of all, in this matrix here, which is the result of the query multiplied by the transpose of the keys
00:40:24.540 | We have a lot of dot products at each step that were already computed at the previous step
00:40:29.680 | Let me show you. At the time step 4, we are computing all these dot products
00:40:34.340 | As you can see at the time step 3, we already computed these dot products and at the time step 4
00:40:40.220 | We are computing them again as you can see these dot products here
00:40:45.300 | The second thing is that usually when we
00:40:49.380 | Deal with the language model, we have a causal mask
00:40:52.860 | So we do not even care about computing the dot products that we see here in the dark violet
00:40:58.520 | because they will be anyway masked out by the
00:41:01.220 | causal mask that we apply
00:41:03.940 | Because we don't want the first token to watch the token number 2, the token number 3, the token number 4
00:41:09.460 | We only want the token number 4 to watch the previous one
00:41:12.580 | So the token number 4 should be related to itself, the previous one, the token number 2 and the token number 1
00:41:17.100 | But not the opposite
00:41:19.100 | And also we don't want to produce all these output tokens because we are only interested in the last one
00:41:24.460 | We are only interested in knowing what is the last
00:41:26.740 | token produced by the attention so that we can send it to the linear layer and then to the softmax to understand what is the
00:41:33.860 | Word corresponding in our vocabulary so that we can use it for the prompt to inference the next token again
00:41:39.900 | So now let's introduce the KVCache and how the KVCache solve this problem
00:41:44.400 | What we do with the KVCache, again, we start from our first step of the inferences
00:41:50.020 | So we start from our start of sentence token, which is multiplied
00:41:53.380 | So the query is only the start of sentence token. We multiply it by the transpose of the keys
00:41:57.800 | This will produce a 1 by 1 matrix here
00:41:59.940 | Then we multiply it by divi and it will produce our first token as output
00:42:04.260 | We send it to the linear layer then to the softmax then we know which token it corresponds to
00:42:09.040 | Now in the KVCache instead of appending this new token that we have produced as output to the query key and value
00:42:16.100 | we only append it to the key and the value and
00:42:19.140 | Replace entirely the previous query with this new token. So
00:42:23.500 | Before without the KVCache we were appending the every output token
00:42:28.900 | So the last token of the output to the query key and values, but in with the KVCache
00:42:34.420 | We don't append it to query key and value, but only to the key and values
00:42:38.960 | And we only use the last output token as query for the next step
00:42:43.980 | So if this is the output of the first step, so the output corresponding to the token start of sentence
00:42:51.060 | We take it we use it as query for the next step, but we append it to the key and the values
00:42:57.800 | So this is why it's called the KVCache because at each step
00:43:01.600 | We are keeping a cache of the previous K and V
00:43:05.200 | But not for the query because we are entirely replacing all the queries with the last token
00:43:09.960 | Anyway, this will produce a product
00:43:13.640 | So this matrix multiplied by this matrix will produce a matrix that is 1 by 2
00:43:17.560 | We multiply it by V and we will see that this produces only one token as output
00:43:22.440 | Then this we take this token
00:43:24.400 | we send it to the linear layer to the softmax then we know which token it corresponds to then we use it as
00:43:29.500 | Query for the next iteration, but we append it to the only the K and the V matrix
00:43:35.360 | This will produce a 1 by 3 matrix
00:43:37.720 | Which is then multiplied by the V which will produce the this output token
00:43:44.140 | This is the one we are interested in basically, then we use it as query for the next iteration
00:43:49.680 | But we append it to the K and the V etc. So as you can see at the fourth step of the inference
00:43:56.160 | We are producing only the last row that we were interested in
00:44:00.120 | When we didn't have the KVCache. So let me show you this is the
00:44:04.200 | Fourth time step with the KVCache. Let's look at the fourth time step without the KVCache
00:44:09.920 | As you can see we are only producing this row here
00:44:15.000 | This is the only one we are interested in to produce this last token
00:44:18.920 | So with the KVCache basically we reduce the number of computations that we are doing at every step
00:44:23.640 | Because the sum of the dot products we have already done in the previous steps and we only produce one token as output
00:44:29.820 | Which is exactly the one that we need for predicting the next token
00:44:33.920 | Ok, now let's talk about the rolling buffer cache
00:44:38.200 | So since we are using the sliding window attention with a size of W
00:44:42.520 | And in the examples I showed you before I was using a sliding window size with the size of 3
00:44:48.040 | We don't need to keep all the possible K and V in the cache
00:44:53.120 | But we can limit the K and the V only to W tokens because anyway, we will not be computing
00:45:00.600 | Attention outside of this W window. So we do not need imagine our window is 10 tokens
00:45:06.000 | We do not keep the previous 1000 tokens because anyway, our attention will only be calculated on the previous 10 tokens
00:45:12.200 | So this is the idea behind the rolling buffer cache. Let's see how it works
00:45:15.600 | Imagine we arrive at the token 8 of inference using the KVCache
00:45:21.360 | If we have a KVCache and we are using the sliding window size of 4 for example
00:45:27.760 | We will see that as query we will use the output of the previous step and as key and values
00:45:34.760 | We will use the entire cache which is made up of 8 tokens
00:45:37.920 | But because of the mask that we are using with the sliding window attention
00:45:43.760 | We are not interested in the computation of these dot products because anyway
00:45:48.240 | They will be masked out because the distance between this token and this token is outside of the sliding window attention
00:45:54.320 | so we are not interested in this calculating these dot products because
00:45:58.600 | They will be masked out and secondly
00:46:02.400 | We are not interested in keeping this one because anyway because these values will be masked
00:46:08.000 | By our mask for the sliding window attention, which basically will result in zeros here
00:46:13.920 | We do not care about producing these first four rows in the value matrix because anyway
00:46:20.960 | They will be multiplied by zeros. So they will not contribute to the output token. So here you have to imagine that
00:46:26.680 | Let me draw
00:46:29.360 | Here you have to imagine that the mask will take care of making this one zero, this one zero, this one zero, this one zero
00:46:37.000 | And this one will be a dot product. This one will be a dot product. This one will be a dot product
00:46:41.560 | This one will be a dot product. So whatever value there is here
00:46:44.560 | Whatever value there is here, here or here will not contribute to the output of this token because anyway
00:46:50.640 | They will be multiplied by zeros here
00:46:52.320 | So we do not need to keep this value also in the V matrix or in the K matrix because anyway
00:46:58.720 | They will not be used by the sliding window attention. So that's why we can limit the size of our K and V
00:47:05.380 | Cache only to W tokens where W is the size of the sliding window attention that we are using
00:47:12.520 | Now let's see how this rolling buffer cache was implemented. So basically rolling buffer cache is a way of
00:47:17.880 | Limiting the size of a cache to a limited size in this case W
00:47:24.040 | So imagine our W is only four
00:47:26.040 | Imagine we have a sentence "the cat is on a chair" and we want to use it for our KV cache
00:47:32.720 | At the first inference using the KV cache, we will add the first
00:47:37.840 | The first token to the KV cache, then we will add the second one the third one and the fourth one
00:47:43.600 | But now the KV cache is full. How do we proceed further?
00:47:46.720 | Basically, we keep track of where we added the last item
00:47:50.480 | Using a pointer that we keep track of and when we will arrive at the next token, which is the token A
00:47:56.800 | We basically replace the oldest value here starting from the beginning and we update the value of the right pointer
00:48:03.880 | but now how do we
00:48:05.740 | Go back because now the order of the tokens is not matching the sentence because as you can see now the
00:48:11.200 | Cache contains "a cat is on" but this is not the order in the original sentence in the original sentence
00:48:17.320 | the order should be "cat is on a" so what we do is we do the unrolling or unrotation and
00:48:23.960 | How do we do it? Basically because we kept track of this right pointer
00:48:28.960 | We just need to take all the values after the right pointer and then we put the values from 0 to the right pointer itself
00:48:36.120 | So all the values after the right pointer and then all the values before the right pointer and this is how we unrotate
00:48:42.180 | And this operation is done in the code in the function called unrotate. You can see here
00:48:47.320 | Which basically will have this condition. So if the cache is not full we can just ignore the unfilled item
00:48:54.480 | So if the cache is in this situation, then we take all the values from the 0 up to the right pointer
00:49:00.340 | if the cache is a full then we take the value from 0 up to the
00:49:04.640 | The value of the right pointer and
00:49:08.560 | if the value of the right pointer is already overwriting some value then we need to unrotate and this is done in the
00:49:16.140 | Third condition here
00:49:18.160 | So we take all the values after the pointer and then the value up to the pointer and this is how we
00:49:23.440 | unrotate this buffer cache
00:49:26.000 | Okay, let's talk about another concept that is very important, which is chunking and pre-feeding
00:49:31.860 | Basically when we generate a text using a language model
00:49:35.120 | We use a prompt and then we use this prompt to generate future tokens when dealing with a KV cache
00:49:40.800 | We need to build up this KV cache
00:49:42.640 | So we need to add the tokens of our prompt to the KV cache that so that we can then exploit this KV cache
00:49:48.560 | To build new tokens future tokens
00:49:51.280 | Now the prompt is known in advance, right? Because it's the input of our user
00:49:57.040 | It's what you ask to chatgpd for example, right? Tell me a poem. Tell me write me a poem or tell me a joke
00:50:02.320 | This is our prompt. So it's known in advance. So we don't know we don't need to generate it
00:50:06.960 | Okay, so what we can do is we can pre-fill the KV cache using the tokens of the prompt
00:50:12.320 | But there are many ways to do it like we were doing before when I was teaching you about the KV cache
00:50:17.600 | We work with one token at a time. So one way to
00:50:20.160 | To add the tokens to the KV cache is to add one token at a time
00:50:25.200 | But this can be very time consuming because imagine you have a very large prompt which happens
00:50:30.080 | With retrieval augmented generation, which we have very big prompts like 5,000 6,000 tokens or even bigger
00:50:36.080 | So this if we add one token at a time
00:50:38.400 | It will mean that we have to take 5,000 or 6,000 forward steps in our network
00:50:43.600 | Which is can be very time consuming and also doesn't exploit our gpu very much
00:50:48.400 | The other way is to take all these tokens and feed them all at once to the model
00:50:53.760 | But that may be limited by the size of our gpu because imagine we have 10,000 tokens as our prompt
00:51:00.080 | Then maybe our gpu cannot even hold 10,000 tokens. Maybe it can only hold 4,000 tokens or 2,000 tokens
00:51:06.240 | Depending also on the w size of the attention sliding window attention that we have chosen
00:51:11.120 | The solution in this case is to use chunking. Basically, we divide our prompt into chunks of a fixed size
00:51:18.160 | And this size is equal to w which is the sliding window attention size
00:51:24.240 | So imagine we have a very big prompt and we choose a sliding window size of 4 for the calculation of the attention
00:51:29.920 | And imagine that the prompt is this one. So can you tell me?
00:51:33.280 | Can you tell me who is the richest man in history? The way we work is this
00:51:38.880 | Basically, we take our first chunk of the prompt
00:51:43.200 | So because we chose a sliding window size of 4, we also will choose the chunk size to be 4
00:51:48.320 | So we take our first token of the prompt
00:51:51.280 | So can you tell me and we compute the self attention in the attention
00:51:56.720 | Self attention in the first layer of the model. How do we build the attention mask?
00:52:02.320 | Basically as queries we take all the incoming tokens in this chunk
00:52:08.400 | So as this is you can think of this column as the queries and this column as the keys
00:52:14.240 | And this is the result of the query multiplied by the transposed of the keys plus the mask
00:52:20.320 | So our query we take the first incoming chunk and as keys I will show you later
00:52:25.200 | We take the current content of the kvcache, but initially it is empty plus the incoming tokens of the current chunk
00:52:33.200 | And this is made for a very specific reason that I will show you in the next step
00:52:40.080 | So in the next step, basically we take the current chunk, which is the tokens who is the richest
00:52:48.800 | And we aggregate it with the content of the kvcache using the tokens of the previous chunk. So let me go back
00:52:55.520 | At the first step of this prefilling we take the first chunk of the prompt
00:53:00.880 | So can you tell me we calculated the attention mask using as query
00:53:04.480 | The first four tokens and as keys and values the content of the kvcache
00:53:11.120 | Which is empty plus the tokens of the first chunk and then we update the content of the kvcache
00:53:17.680 | Using this the tokens of this chunk after we have computed the attention
00:53:21.840 | So at the next step the kvcache now contains the previous the tokens of the previous chunk
00:53:27.760 | So can you tell me but now the current chunk has become who is the richest?
00:53:32.320 | so as query again, we take the tokens of the current chunk, but as keys and values we take the
00:53:40.000 | The content of the kvcache plus the tokens of the current chunk. Why?
00:53:46.800 | because
00:53:48.080 | As you can see when we were doing token generation when I was teaching you the kvcache
00:53:51.620 | We first add the last output token
00:53:54.960 | We add it to append it to the k and the v and we use it as the query for the next iteration
00:54:00.560 | This is not what we do here. Here. We first
00:54:03.120 | Calculated the attention and then we update the kvcache
00:54:06.660 | And when we use the when we build the query the query we use only the tokens of the current chunk
00:54:13.280 | And as key and values we take the content of the kvcache
00:54:16.580 | So the content of the previous chunk plus the tokens of the current chunk. Why?
00:54:23.120 | because imagine if we didn't do
00:54:26.240 | We didn't use the content of the previous chunk. What would happen is this we would have a
00:54:32.160 | Attention mask that is only comprised of the tokens of the current chunk. So it would be only limited to this matrix here
00:54:39.760 | Let me draw it. So
00:54:42.640 | Only this matrix here
00:54:44.800 | But if we only use this matrix here the word who
00:54:50.240 | Would not be able to to would not be related to the word me
00:54:54.480 | Tell and you even if with the sliding window size, they should be able to watch each other
00:55:00.640 | So because we want to relate the current chunk to the previous chunk
00:55:05.680 | We basically take as a key and value the content of the kvcache plus the tokens of the current chunk
00:55:12.960 | So that we can build this attention between chunks
00:55:16.320 | Otherwise this attention would not be built and as query we always use the tokens of the current chunk
00:55:23.120 | Let's review how this mechanism is built in the code
00:55:28.800 | So basically the pre-filling is done by chunks
00:55:32.560 | There is the first chunk and then there are subsequent chunks
00:55:35.600 | And finally there is token generation after we have a pre-filled our kvcache with the prompt
00:55:40.320 | During the first pre-fill, which means that we are doing it for the first chunk of our
00:55:46.080 | prompt
00:55:48.080 | As attention mask, we only consider the size of the incoming tokens in the current chunk
00:55:54.960 | But for any subsequent chunks, so after the first chunk as to build the attention mask
00:56:01.440 | For the query we just use the size of the incoming chunk
00:56:05.920 | But for the k and v we use the size of the kvcache, which is this one
00:56:10.560 | So cached as you can see here plus the size of the current chunk, which is this s variable you can see here
00:56:17.120 | And for token generation, we do the same system that we did before when I was teaching with the kvcache
00:56:23.360 | So one token at a time, we take it, we append it to the key
00:56:27.600 | We append it to the value and we replace the query with the output token from the previous step
00:56:33.760 | So the last chunk in our case will be the tokens man in history
00:56:38.900 | And what we do is basically we take the current chunk, so man in history, which becomes the query
00:56:45.520 | While the key becomes basically the previous chunk plus the tokens of the current chunk
00:56:51.280 | So the who is the richest plus the tokens of the current chunk, so man in history
00:56:56.000 | And the reason we do it because otherwise the word in the current chunk will not be able to
00:57:00.320 | Be related to the word of the previous chunk, which is necessary. Okay guys, let's talk about sparse
00:57:06.240 | mixture of experts
00:57:08.960 | So mixture of experts is an assembled technique in which we have multiple expert model
00:57:14.400 | Which each of this model is trained on a subset of the data such that each model will specialize on a subset of this data
00:57:21.760 | And then when we produce the output of this mixture of experts, we take the output for each of these experts
00:57:28.880 | We combine it usually by using a weighted sum or by averaging to produce one single output
00:57:34.960 | In the case of Mistral, we do not talk about only mixture of experts
00:57:38.960 | But we talk about a sparse mixture of experts because we have many expert models, but we only use some of them
00:57:45.280 | Let me show you. In the case of Mistral, we have eight experts which are present as the feed-forward layer
00:57:52.080 | So after we calculate the self-attention, as you remember, we have this feed-forward network. In the case of Mistral
00:57:58.000 | 8x7b, we have eight feed-forward layers. We have to think of them in parallel
00:58:05.040 | And the gate is a function that basically will decide for each token which expert, so which feed-forward
00:58:11.920 | network, should be working with that token and it will choose two
00:58:16.400 | feed-forward networks for each token. It will run the token through these feed-forward networks,
00:58:22.100 | will take their output and will weight it according to the logits this gate produces
00:58:28.020 | to produce a weighted sum, which will become the output of the self-attention for that particular token
00:58:35.520 | Let me show you with an example
00:58:37.280 | So this is the architecture of Mistral. As you can see, we have the input of this
00:58:42.160 | encoder layer. We first run the self-attention using the sliding window attention and the KV cache, etc, etc
00:58:48.800 | Then we run the normalization and finally we have this gate function here, which is basically just a linear layer that will produce
00:58:56.720 | logits
00:58:58.400 | eight logits, which
00:59:01.040 | will be values. Let's call them score values for our expert. The two best performing
00:59:07.220 | experts, so the two highest score, will indicate which experts that token should work with
00:59:13.040 | Then we run each token in their own two best performing experts
00:59:18.560 | Then we take the output of these two experts. We combine it with the weight
00:59:22.720 | What is the weight? Basically the
00:59:26.400 | logits produced by the gate are, suppose, eight values here
00:59:30.080 | Yeah, I draw only four because I don't have space, but you imagine you have eight values
00:59:34.080 | Then we take the top two, so 1.5 and 3.4 in this case
00:59:39.600 | These are the two experts through which we will run the token
00:59:42.240 | We take the softmax of the two best performing values. This will be the weight that we'll be using for the weighted sum
00:59:49.920 | And basically, why do we do it?
00:59:56.880 | Why do we do it?
00:59:59.360 | Because by using a sparse mixture of experts, we can have many expert models
01:00:04.640 | But during inferencing only two out of eight will be activated
01:00:08.900 | So as you remember the feed-for-wall network is basically two
01:00:13.040 | linear layers. So the linear layer can be thought of as a matrix multiplication of a weight matrix with the input
01:00:20.560 | So if we didn't use a sparse mixture of experts, we would run the token through all the eight experts
01:00:25.840 | Which means that we need to compute eight matrix multiplications
01:00:29.060 | But by using sparse mixture of experts for each token, we are only doing two matrix multiplications
01:00:35.460 | Which makes the inference faster
01:00:37.920 | But at the same time allows us to increase the power of the model and the parameter of the model
01:00:42.240 | Because we are only using some parameters for
01:00:46.960 | A subset of the token. So some tokens will use the expert number one. Some tokens will be using the token
01:00:51.520 | The expert number two and three. Some tokens will be using the expert number eight and three
01:00:56.880 | Or some other for example, the six and the four etc, etc
01:01:00.560 | So we are not using all the experts for each token, but only two of them
01:01:04.880 | This allows us to have each expert
01:01:08.720 | Specialized on a subset of tokens. For example, imagine the model has been trained on multiple languages
01:01:15.280 | What could happen is that basically some experts, so some feed-forward networks are specialized on Japanese tokens
01:01:22.320 | Some feed-forward networks are specialized on English tokens or some
01:01:26.320 | It could also happen that some are specialized in verbs, some are specialized in
01:01:30.720 | Nouns, some are specialized in adjectives, etc, etc
01:01:34.960 | So this is why we use a mixture of experts because we want to increase the size of the parameters of our model
01:01:42.000 | So the model becomes more powerful at capturing information
01:01:45.600 | But at the same time we don't sacrifice on performance because we only use a subset of the experts for each token
01:01:53.360 | And this is the implementation as done in the code
01:01:57.120 | So as you can see in the case of Mistral 7b, we have as feed-forward just a feed-forward neural network
01:02:03.520 | Which is two linear layers
01:02:05.520 | In the case of Mistral 8x7b, it's not only one feed-forward network
01:02:10.160 | But it's eight feed-forward networks. So this as you can see, it's the
01:02:13.520 | It's an array of eight feed-forward networks with a gating function, which is just a linear layer
01:02:19.840 | Which converts from the embedding size to eight, which is the number of experts
01:02:24.820 | So it produces for each embedding. So for each token, it produces logits which indicates for which
01:02:33.520 | Expert this token should run through and it will run through them to the top two experts
01:02:39.920 | So the two experts with the top logits score
01:02:43.520 | Okay, why we apply the softmax after selecting the top k expert so as I show you
01:02:51.600 | Here we have the gating function that produces some logits. We select the top two
01:02:57.920 | logits to understand which expert we should run through our
01:03:02.800 | Token and then we take the score of the best two performing
01:03:07.060 | experts and we
01:03:09.820 | Take the softmax of them to create the weights
01:03:12.960 | That we will use to create the weighted sum
01:03:15.600 | But why we take the softmax of the two best performing instead of taking the softmax of everyone? Well, the first problem is that
01:03:23.280 | If we take the softmax of all of the logits, then the two best performing may not sum up to one which is
01:03:32.640 | Which is a condition that we need in case we want to train multiple models and compare them because i'm pretty sure that the guys
01:03:38.640 | At Mistral did not only train one model. Maybe they trained multiple models with multiple hyper parameters
01:03:43.760 | Maybe they tried with four mixture of four experts, but also with three experts or two experts
01:03:48.960 | Then they choose the best one
01:03:50.640 | So if you want to compare models, you want the weighted sum to always perform
01:03:56.560 | The sum of the weights to be only one. Otherwise the output range may change from model to model and usually it's not a good idea
01:04:03.200 | To have the range of the output to change from one model to the next
01:04:07.520 | So to keep the range of the output stable, they apply the softmax after they have selected how many
01:04:14.000 | Experts they want to work with and choosing the logits of the best two performing
01:04:18.640 | experts
01:04:23.680 | The next thing we are talking about is model sharding which is also implemented in the code of the Mistral model
01:04:30.640 | So let's talk about it
01:04:32.720 | When we have a model that is too big to fit in a single gpu
01:04:36.560 | We can divide the model into groups of layers and place each group of layers in a single gpu
01:04:42.800 | For example in the case of Mistral, we have 32 layers of encoders
01:04:47.280 | You can see here one after another I didn't do all 32 of them
01:04:51.520 | You just think that this is layer from 1 to 8. This is from 9 to 16 from 17 to 24
01:04:56.800 | From 25 to 32 and we put each group of layers in a different gpu. So we have four gpus
01:05:05.280 | The way we inference a model like this is as follows
01:05:08.320 | So we have our input we convert it into embeddings and we run it through the first eight layers in the first gpu
01:05:14.320 | The first gpu will produce an output which will be the output of the eighth layer
01:05:19.600 | We transfer this output to the second gpu and we use it as input for the ninth layer
01:05:24.880 | Then we run all the this input through all the layers one after another until it
01:05:29.760 | It arrives to the layer number 16, which will produce an output. We take this output. We move it to the next gpu
01:05:36.000 | So it will become the input of the layer number 17
01:05:39.540 | And then we run iteratively to all the layers until the layer number 24, which will produce an output
01:05:45.280 | We move it to the next gpu. We run it through iteratively until the layer number 32
01:05:49.840 | Then we take the last linear layer and then the softmax to produce the output of the model
01:05:55.040 | however
01:05:57.040 | You can notice that this method is not very efficient because at any time only one gpu is working
01:06:02.880 | A better approach which is not implemented in the code of Mistral, but they reference it in the paper
01:06:08.720 | So I will talking about it is the pipeline parallelism. Let's see how it works
01:06:15.040 | This pipeline parallelism. I will talking about the algorithm that was introduced in this paper. So gpipe
01:06:21.220 | Basically, it works as follows first. Let me introduce you the problem
01:06:25.600 | This actually it's used usually when we are training a model not when we are inferencing
01:06:30.160 | But it can also be applied to the inference
01:06:32.340 | Imagine we want to train a model on a sharded model. So a model that is split into multiple
01:06:39.040 | Group of layers each group of layer is present on a different gpu
01:06:44.160 | Imagine we have four gpus each one with its own group of layers
01:06:48.000 | Imagine we want to train this model. So we run our input to the first gpu
01:06:52.880 | So we run the forward step to the first gpu. We take this output and we feed it to the next gpu
01:06:57.680 | So then we run forward from there
01:07:00.080 | We take the output and we run it through the next gpu the gpu number three
01:07:03.600 | We take the output we run it to the next gpu the gpu number four
01:07:06.800 | Now we have the output of the model. We compute the loss and then we can run back propagation the run propagation
01:07:13.200 | That's basically just the opposite. We go from the last gpu to the first gpu. So we run back propagation on the fourth gpu
01:07:19.760 | Then we have calculated the gradients at the fourth gpu and we use them to calculate the previous gradients at the third gpu
01:07:27.200 | And then we take these gradients and we use them to calculate the previous gradients and then we
01:07:31.760 | Take these gradients and we use to compute the previous gradients
01:07:35.520 | So the forward step goes from the input to the loss the backward step goes from the loss to the input
01:07:43.360 | And all the parameters which are also known as the leave nodes in the computational graph
01:07:47.200 | However, also as in this case, you can see that at each step
01:07:52.080 | We are only utilizing one single gpu and all the other gpus are quite
01:07:58.800 | Not working. They are idle
01:08:01.520 | A better way is to use pipeline parallelism
01:08:04.660 | So imagine that the previous step of training was done using a very big batch
01:08:09.840 | Suppose this batch is made up of eight items
01:08:12.640 | What we do with pipeline parallelism is we take this batch and we split it into micro batch
01:08:18.480 | So instead of eight items, we create micro batch. So four micro batch of two items each
01:08:24.720 | What we do is
01:08:27.120 | We run the first micro batch in the first gpu
01:08:30.880 | This will produce the output for the first micro batch and we can feed it to the next gpu
01:08:36.160 | But now at the time step one we realize that the gpu one now is free
01:08:39.840 | So she can already start working on the second micro batch
01:08:43.120 | Meanwhile, the second gpu is working on the first micro batch and when she will finish she can send it to the next gpu and
01:08:50.800 | Meanwhile, we realize that now the second gpu is free
01:08:54.240 | So we can if the gpu one has finished we can take the output of the gpu one and transfer it to the gpu two
01:09:00.880 | And the gpu one will be free so it can work on the third micro batch you can see here
01:09:05.920 | Then after the third gpu has finished it will take the output of the third gpu
01:09:11.600 | We send it to the fourth gpu, but we realize that the third gpu is now free
01:09:15.520 | So if the previous gpu have finished we can transfer the second micro batch to the third gpu
01:09:20.160 | The third micro batch to the second gpu and the first gpu which will be free can start working on a new micro batch
01:09:26.720 | Which is the fourth micro batch and basically we do this
01:09:30.000 | job of time shifting the micro batches
01:09:32.900 | And this will result in a better utilization of the gpus because now at every time step we
01:09:38.800 | At this time step for example all the four the gpus are working and also at this time step here at the backward step
01:09:44.720 | And for each micro batch we calculate the gradient, but we do not update the parameters
01:09:51.280 | We do what is called gradient accumulation
01:09:53.360 | Which basically means that we calculate the gradient for each micro batch and we keep summing it to the existing gradients
01:09:59.280 | But we do not update the parameters of the model after all the micro batch have finished processing the forward and the backward
01:10:05.440 | We update the parameters of the model
01:10:07.520 | The gradient accumulation is a technique that I have introduced in my previous video on
01:10:12.000 | Distributed training so if you want to understand how it works
01:10:15.440 | I refer you to my previous video on distributed training in which I explain also the math behind gradient accumulation and how it works
01:10:22.080 | But basically this is the solution with pipeline parallelism
01:10:25.760 | So we can actually divide our batch into micro batches and this can also work with inferencing because when we inference
01:10:32.880 | We just don't have this backward
01:10:35.120 | Step here, right? So we just delete this second half of the table
01:10:38.880 | But we can still take our big batch at the beginning. We split it into micro batches and we time shift them
01:10:45.920 | according to the availability of the gpu
01:10:48.880 | And this pipeline parallelism basically introduces still some
01:10:55.920 | Time steps in which not all gpus are working and these are called bubbles
01:10:59.920 | To avoid bubbles these big bubbles here. What we can do is we can
01:11:04.400 | Use a bigger initial batch size. So we have multiple micro batches
01:11:09.940 | Okay guys now let's go to the last part of this video
01:11:14.720 | I know that
01:11:16.800 | The mistral code is much more complicated to understand compared to the lama code and I will show you why but I will also help
01:11:23.280 | You understand the most complex topic in the code, which is the xformats library
01:11:27.760 | Which is a trick they use to improve the inference performance and it's actually a very advanced technique
01:11:34.000 | And I want to give you a glimpse into how it works
01:11:37.360 | So basically imagine you are running an ai company and you are providing llm inference service
01:11:45.440 | So you have a customer that has you for example provide an api and you have customers that send their prompts to your api
01:11:52.960 | And then want to run inference through your large language models
01:11:56.080 | Each prompt of course may have different length because each customer may be using the
01:12:01.600 | large language model for different purposes
01:12:04.160 | For suppose
01:12:05.600 | Simplicity suppose that each word is a token
01:12:08.000 | So suppose you have three customer the first customer says write a poem
01:12:11.920 | The second customer says write a historical novel and the third customer says tell me a funny joke
01:12:18.880 | of course you could process all these prompts one by one, but that would not be very efficient because
01:12:24.080 | The two other two customer would be waiting for the first customer to finish and when you have a lot of customers
01:12:29.700 | That's not good. And secondly, you may not be fully utilizing the memory of your gpu
01:12:34.720 | So the best thing that you can do is to do batching you create all these prompts you create one big batch
01:12:40.800 | But the problem is that the prompt have different lengths
01:12:44.880 | So the first prompt is made up of three tokens the second prompt of four tokens and the third prompt of five tokens
01:12:51.600 | One solution is to add padding to these tokens. So basically we create a batch in which we append
01:12:59.040 | Padding tokens to the input sequence until they all reach the same size
01:13:04.880 | Then we can run these sequences this batch through our large language model, which could be for example Lama or Mistral
01:13:14.800 | As we saw before when we have a input sequence of n tokens
01:13:19.120 | The attention mechanism produces an output sequence of n tokens
01:13:22.800 | And we usually take the embedding of the last token
01:13:26.640 | Send it to the linear layer then the softmax to understand what is the next token from our vocabulary
01:13:34.400 | In the first prompt we see that we have added two padding tokens
01:13:38.000 | So we cannot use the embedding corresponding to the last tokens because they correspond to the padding token
01:13:44.160 | What we should do is we should take the embedding corresponding to the last non-padding token
01:13:49.200 | And then send it to the linear layer and then to the softmax to understand what is the next token
01:13:54.960 | And in the case of the second prompt we should be using the fourth token not the last one
01:13:59.760 | Only in the last prompt we can use the last token because it's the last it's a non not padding token
01:14:05.600 | Now we have done this
01:14:09.600 | And how do we actually create a attention mask to run it?
01:14:13.200 | We basically just create an attention mask that is causal that will make each token only
01:14:19.440 | Visualize the previous tokens. So each token will be able to relate to previous tokens, but not to future tokens
01:14:26.160 | And this mask here will work fine for all the three scenarios. You can see here and I will show you later how
01:14:34.880 | We cannot use a different mask for each prompt because all the prompts are of the same length
01:14:40.160 | So all the masks must be 5x5 because we cannot use a 3x3 mask for this prompt a 4x4
01:14:46.960 | Mask for this prompt and the 5x5 for this prompt because the input sequence is 5
01:14:53.200 | So we must use a 5x5 mask and we have to use a 5x5 mask that is causal
01:15:00.400 | And also has the we can also mask out for example
01:15:04.080 | Imagine the sliding window is size is 4 then we can mask out this value here also because we don't want the
01:15:10.000 | This token here to watch tokens that are a distance of more than 4 for example
01:15:15.760 | So the problem here is that we are calculating a lot of dot products, especially for the first and the second prompt
01:15:23.360 | That will not be used. Let me show you why
01:15:27.760 | When we apply this mask, so the 5x5 mask you can see here to this input sequence here
01:15:34.640 | Which are I want to remind you is a batch
01:15:36.640 | Which will produce the following attention mask
01:15:40.720 | In which all this value will be masked out because they are minus infinity minus infinity and it's because of the causality of the mask
01:15:48.720 | We cannot mask
01:15:52.320 | This value here because they are needed for the last prompt for example
01:15:56.160 | And we also cannot mask this value here, which is needed for the last prompt
01:16:00.560 | But for the first and the second prompt we are doing a lot of dot products
01:16:04.240 | for example
01:16:04.800 | These ones between padding tokens and other tokens that we will not be using
01:16:08.400 | Because I want to remind you that in the first prompt as the output of the model
01:16:13.200 | So we will be using the output as the third token
01:16:16.880 | For the second prompt the output at the fourth token and only in the last token. We will be checking the last output
01:16:23.760 | Of the output of the self-attention, but for the first two prompts
01:16:27.520 | We will not be even checking the last token output from the self-attention because they correspond to the padding token
01:16:33.360 | So is there a way to avoid these padding tokens?
01:16:37.680 | Being introduced in our calculation and calculating all these dot products which will result in output tokens that we will not even use
01:16:46.320 | Well, there is a better solution and the solution is this
01:16:49.280 | The solution is to combine all the tokens of all the prompts into one big sequence
01:16:55.920 | Consecutively and we also keep track of what is the actual size of each prompt
01:17:02.880 | So we know that the prompt are coming from our API because we are running an AI company and we have this API
01:17:09.200 | So we know that the first customer has a token size prompt of size three tokens. The second one has
01:17:15.680 | Four tokens and the third one has five tokens
01:17:18.240 | So we can keep track of these sizes in an array, for example
01:17:22.240 | And then we build this sequence which is a concatenation of all the prompts that we receive
01:17:27.280 | We take this mega sequence. We run it through our LLM model. So it could be Mistral or it could be Lama
01:17:35.120 | This as I told you before
01:17:38.000 | An input sequence in a transformer will result in n output tokens in the output
01:17:45.200 | So we have here we have 3 plus 4 so 7, 7 plus 5, 12
01:17:50.800 | Tokens as input it will produce 12 tokens as output
01:17:54.320 | To understand what is the next token for each prompt. We need to check the
01:18:00.240 | We need to check the embedding
01:18:03.200 | corresponding to the token number 3 for the first
01:18:05.840 | Prompt to the token number 7 for the second prompt and the last token for the third prompt
01:18:12.560 | So we take all these embeddings we run them through the linear layer
01:18:16.480 | Then we apply the softmax and then we understand what is the next
01:18:20.480 | Token from our vocabulary, but you may be wondering
01:18:24.480 | How do we even produce an attention mask that can work with multiple prompts that are combined into one sequence?
01:18:31.920 | Such that the token of one sequence should not of one prompt should not be attend
01:18:37.840 | To the tokens of the another prompt but only of the tokens of the same prompt, right?
01:18:42.640 | Well, the Xformers library allow us allow us to do that using a method called a block diagonal causal mask
01:18:51.120 | Which is also used in the source code of Mistral. So I want to show you how it works
01:18:55.680 | basically Xformers
01:18:58.400 | This method called the block diagonal causal mask will produce a mask like this. It will be
01:19:05.200 | Group, basically all the prompts into groups
01:19:08.880 | Such that each token only can attend to the tokens in the same group here. We have three prompts
01:19:15.920 | So the token poem for example can only attend the token of the same
01:19:20.880 | Prompt the token novel for example cannot be related to the token poem. So it will put minus infinity here
01:19:28.480 | but all the token of the same prompt will be able to be attended by the
01:19:34.000 | the token novel
01:19:36.000 | While the token in the last prompt will be only be able to attend
01:19:40.560 | The other tokens in the same prompt and this is a special mask built by using the Xformers library
01:19:47.200 | Let me show you how it works in the code
01:19:50.720 | Okay, I want to show you actually how it works. So in the Mistral source code, they are using this
01:19:57.760 | Library called Xformers
01:20:02.240 | Xformers library
01:20:03.920 | Allows us to compute very complex attention mask and also to calculate the attention in a very efficient way
01:20:10.560 | Using the memory efficient attention calculation, which I will not show in this video. Maybe I will make a future video about it
01:20:16.480 | But basically what they do in the Mistral source code
01:20:19.920 | If you have multiple prompts, they will create one big sequence and then keep track of the number of tokens of each prompt
01:20:27.280 | And then they use these methods
01:20:30.400 | made available by the Xformers library to build these complex attention maps that
01:20:34.880 | Keep track of the different size of the KVCache because each prompt may have a KVCache that is different from another prompt
01:20:43.200 | Because imagine you have a prompt with 5,000 tokens and one prompt with only 10 tokens
01:20:47.680 | Of course, you will have a KVCache that is 5,000 tokens in one case and 10 in another case
01:20:53.280 | So the mask, attention mask that we build should take care of this
01:20:58.400 | And the second thing is that each group of tokens should only be able to relate
01:21:03.040 | To the tokens of the same group not to other groups. So not of tokens from another prompt
01:21:09.760 | And this is done with the block diagonal causal mask
01:21:14.720 | So basically we tell him okay
01:21:16.880 | The first prompt is made up of seven tokens
01:21:19.360 | The second prompt is made up of five tokens and the third prompt is made up of six tokens
01:21:24.640 | And we are also using a sliding window attention with a sliding window size of three and basically this will create the complex mask
01:21:31.280 | That we can see here
01:21:33.040 | This is the first group of tokens from 0 to 6 is the first prompt from 7 to 11 is the second prompt
01:21:40.240 | and from 12 to
01:21:42.560 | Let me check 17 is the third prompt and as you can see it also takes into consideration the sliding window size
01:21:49.680 | So each token can only watch at most two previous tokens
01:21:54.400 | So the tokens in this in the contained in the sliding window size of size three
01:21:58.880 | The second one they use is the block diagonal mask and okay. This one is used for the first
01:22:05.680 | chunk during the pre-filling
01:22:08.320 | This one is used for subsequent chunks in the pre-filling
01:22:11.940 | And basically it also takes in because during the first pre-filling we don't have the KV cache because it's initially empty
01:22:18.400 | But during the subsequent steps, it's not empty anymore
01:22:21.520 | So we need to take into consideration also the different size of the KV cache
01:22:24.880 | So for example, the first token may have a KV cache of size 10 because the prompt is very short
01:22:31.120 | But the second prompt may be very big suppose 5,000 tokens. So it may have a KV cache of size 5,000
01:22:36.740 | So it takes into consideration also the size of the KV cache
01:22:40.640 | And it will produce a mask that takes into consideration also the size of the KV cache
01:22:47.040 | The last method they use is this one block diagonal causal with offset padded keys mask
01:22:52.320 | Because each prompt may have a different size for the KV cache
01:22:57.280 | But only some tokens in this KV so the KV cache size is fixed. It's a tensor that is of fixed side w
01:23:05.840 | But only some tokens may be actual being filled in this KV cache
01:23:11.040 | So only maybe the KV cache size is let's say 10
01:23:14.240 | But because the first prompt is very short only three tokens are actually part in the KV cache
01:23:19.760 | But when we pass the KV cache to the calculation of the attention
01:23:25.120 | We pass all the tensor which is all the 10 items
01:23:28.960 | So we need a way to tell to the mask
01:23:31.360 | That it should only use the first three items from the KV cache and not all the KV cache
01:23:37.440 | Not all the tensor and this is done with block diagonal with offset padding keys mask
01:23:42.560 | So this method here, it's very long name very complicated, but this is why they use it
01:23:46.400 | And it will produce a mask like this
01:23:48.880 | So it takes into consideration the actual size of the KV cache
01:23:52.560 | Even if the KV all the KV cache have the same size because it's a fixed size tensor
01:23:57.760 | But it tells you how many items there are actually it should use from each cache
01:24:03.120 | Okay guys it has been a very demanding video I have to say
01:24:12.400 | I had to record it more than once
01:24:15.520 | I actually had to cut some parts because I even I got confused sometimes
01:24:20.160 | It's very complicated topics. It's a lot of things that you have to grasp
01:24:25.760 | But I hope that it will make your life easier when you want to understand the Mistral code
01:24:31.760 | I actually am also putting online
01:24:34.400 | My notes the one that you have seen so the two notebooks that I have shown you
01:24:39.120 | Plus also the code annotated by me on the Mistral source code
01:24:42.880 | Now the Mistral source code I actually never run it. So because my computer is not very powerful
01:24:48.960 | So I never run the actual model on my computer
01:24:51.920 | What I did to study the model was to run some random tensors to a model and I created basically a model with
01:24:58.960 | randomly initialized
01:25:01.200 | Weights but with the less number of layers so it could fit in my GPU and then I just run some random tensors
01:25:07.280 | To study all the shapes of the tensor and all the information passing
01:25:11.040 | So, I don't know if the code works, but I hope it will work. I mean, I didn't touch the logic. I just add some comments
01:25:17.200 | anyway, you can use the commented code by me to as
01:25:21.360 | As a learning tool to complement with the official code of Mistral so that you can understand
01:25:27.380 | More about the inner workings of this grid model. I actually really enjoyed studying it. I really enjoyed
01:25:33.680 | Studying the code and I learned a lot of stuff, you know
01:25:37.760 | I think it's very very good when you are doing something that is very complicated
01:25:41.860 | Because it teaches you a lot because if something is simple, then you don't learn much by the end of the day
01:25:46.800 | Anyway, guys, thanks you for watching my video. I hope you also enjoyed this journey with me
01:25:51.840 | Even if it was very complicated
01:25:53.600 | I hope that you liked this video and you will subscribe to my channel if you didn't please do it
01:25:58.640 | And the best way to support me guys is to share this video with all the people, you know
01:26:03.200 | So share it on social media share it on linkedin on twitter, etc
01:26:07.040 | because this is the best way to you can help me is to grow my channel and
01:26:12.160 | Please let me know if there is something that you don't understand. I am always available to help
01:26:17.280 | And connect with me on linkedin. Bye. Bye