back to index

LLaMA explained: KV-Cache, Rotary Positional Embedding, RMS Norm, Grouped Query Attention, SwiGLU


Chapters

0:0 Introduction
2:20 Transformer vs LLaMA
5:20 LLaMA 1
6:22 LLaMA 2
6:59 Input Embeddings
8:52 Normalization & RMSNorm
24:31 Rotary Positional Embeddings
37:19 Review of Self-Attention
40:22 KV Cache
54:0 Grouped Multi-Query Attention
64:7 SwiGLU Activation function

Whisper Transcript | Transcript Only Page

00:00:00.000 | Hello guys! Welcome to my new video about Lama. In this video we will be seeing what
00:00:05.400 | is Lama, how it is made, how it is structurally different from the transformer and we will
00:00:12.560 | be building each block that makes up Lama. So I will not only explain you concept-wise
00:00:18.400 | what is each block doing but we will also explore it from the mathematical point of
00:00:22.480 | view and also from the coding point of view, so that we can unify theory with practice.
00:00:29.440 | I can guarantee that if you watch this video you will have a deep understanding of what
00:00:35.380 | makes Lama the model it is. So you will not only understand how the blocks interact with
00:00:42.960 | each other but how they function and why we needed these blocks in the first place. In
00:00:49.440 | this video we will be reviewing a lot of topics, so we will start from the architectural differences
00:00:55.360 | between the vanilla transformer and the Lama model. We will be watching what is the new
00:00:59.600 | normalization, the RMS normalization, rotary positional embedding, KV cache, multi-query
00:01:04.240 | attention, grouped multi-query attention, the ZWIGLU activation function for the feed-forward
00:01:08.400 | layer. But of course I take for granted that you have some background knowledge. First
00:01:14.280 | of all I highly recommend that you watch my previous video about the transformer because
00:01:18.440 | you need to know how the transformer works. And in my previous video I also explored the
00:01:23.040 | concept of training and inferencing a transformer model. It's about 45 minutes and I think it's
00:01:29.200 | worth a watch because it will really give you a deep understanding of the transformer.
00:01:33.560 | After you have that knowledge you can watch this video. Anyway, for those who have already
00:01:38.120 | watched the video but forgot some things, I will review most of the concepts as we proceed
00:01:44.520 | through the topics. I also take for granted that you have some basic linear algebra knowledge,
00:01:51.240 | so matrix multiplication, dot product, basic stuff anyway. And also, because we will be
00:01:56.600 | using the rotary positional embeddings, some knowledge about the complex numbers, even
00:02:00.440 | if it's not fundamental. So if you don't remember the complex numbers or how they work
00:02:06.040 | or the LRS formula, it doesn't matter. You will understand the concept, not the math.
00:02:10.680 | It's not really fundamental. Sometimes I will be reviewing topics that maybe you are already
00:02:16.040 | familiar with, so feel free to skip those parts. Let's start our journey by reviewing
00:02:21.800 | the architectural differences between the vanilla transformer and Lama. This picture
00:02:27.400 | was built by me on the right side because I couldn't find the architectural picture on
00:02:33.400 | the paper. So let's review the differences. As you remember, in the vanilla transformer
00:02:39.320 | we have an encoder part and a decoder part. And let me highlight it. So this is the
00:02:46.040 | encoder and the right side here is the decoder. While in Lama, we only have an encoder. First
00:02:53.720 | of all, because the Lama is a large language model, it has been trained on the next prediction
00:02:59.240 | token task. So basically, we only need the self-attention to predict the next token.
00:03:04.120 | And we will see all these concepts. So we will see what is the next prediction task,
00:03:07.400 | how it works, and how this new self-attention works. The second difference that we can see
00:03:13.080 | from these pictures is that we have here, at the beginning, we have the embedding and
00:03:19.080 | also we had the embedding here on the original transformer. But right after the embedding,
00:03:23.640 | we don't have the positional encoding, but we have this RMS norm. And actually, all the
00:03:28.600 | norms have been moved before the blocks. So before we had the multi-head attention,
00:03:34.280 | and then we had the add-end norm, which is this plus sign here. So it's a concatenation
00:03:40.120 | of a skip connection and the output of the multi-head attention, and the normalization.
00:03:44.840 | And we also have this normalization here, here, here. So after every block. But here in Lama,
00:03:50.360 | we have it before every block. And we will review what is the normalization and why
00:03:55.080 | it works like the way it is. Right after the normalization, we have this query,
00:04:01.400 | key, and values input for the self-attention. One thing we can notice is that the positional
00:04:07.960 | encodings are not anymore the positional encodings of the transformer, but they have become the
00:04:13.000 | rotary positional encodings, and they are only applied to the query and the keys, but not the
00:04:17.960 | values. And we will see also why. Another thing is the self-attention is now the self-attention
00:04:25.320 | with KV cache. We will see what is the KV cache and how it works. And also we have this grouped
00:04:31.560 | multi-query attention. Another thing that changed is this feed-forward layer. In the original
00:04:38.440 | feed-forward layer of the vanilla transformer, we had the relu activation function for the
00:04:45.640 | feed-forward block. But in Lama, we are using the zwiglu function, and we will see why.
00:04:51.160 | This nx means that this block here in the dashed lines is repeated n times one after another,
00:04:59.800 | such that the output of the last layer is then fed to this rms norm, then to the linear layer,
00:05:05.400 | and then to the softmax. And we will build each of these blocks from the bottom. So I will show
00:05:11.960 | you exactly what these blocks do, how they work, how they interact with each other, what is the
00:05:16.680 | math behind, what is the problem they were trying to solve. So we will have a deep knowledge of
00:05:21.560 | these models. Let's start our journey with reviewing the models introduced by Lama.
00:05:27.880 | So Lama1 came out in February 2023, and they had four dimensions for this model. One model was
00:05:37.160 | with 6.7 billion parameters, 13, 32, 65. And then we have these numbers. What do they mean?
00:05:43.880 | The dimension here indicates the size of the embedding vector. So as you can see here,
00:05:50.760 | we have these input embeddings that we will review later. This is basically, they convert
00:05:55.960 | each token into a vector of size indicated by this dimension. Then we have the number of heads.
00:06:01.560 | So how many heads the attention has, the number of layers. If you remember from the original
00:06:08.680 | transformer, the dimension was 512. The number of heads was eight. The number of layers, I think,
00:06:14.600 | was six. And then we have the number of tokens each model was trained upon. So 1 trillion and
00:06:22.280 | 1.4 trillion. With Lama2, most of the numbers have doubled. So the context length is basically
00:06:29.160 | the sequence length. So what is the longest sequence the model can be fed? And then the
00:06:37.400 | number of tokens upon which the model have been trained is also doubled. So from 1 to 2 trillion
00:06:42.120 | for each size of the model, while the parameters more or less remain the same. Then we have this
00:06:47.560 | column here, GQA, that indicates that these two sizes of the model, so the 34 billion and 70
00:06:54.040 | billion, they use the grouped query attention. And we will see how it works. Let's start by
00:07:00.200 | reviewing what is the embeddings layer here. And for that, I will use the slides from my
00:07:05.480 | previous video. If you remember my previous video, we introduced the embedding like this.
00:07:10.040 | So we have a sentence that is made of six words. What we do is we tokenize the sentence,
00:07:15.880 | so it converts into tokens. The tokenization usually is done not by space, but by the BPE
00:07:22.440 | tokenizer. So actually, each word will be split into subwords also. But for clarity, for simplicity,
00:07:28.760 | we just tokenize our sentence by using the whitespace as separator. So each token is
00:07:35.880 | separated by whitespace from other tokens. And each token is then mapped into its position
00:07:42.680 | into the vocabulary. So the vocabulary is the list of the words that our model recognizes.
00:07:51.480 | They don't have to be words, of course. They could be anything. They are just tokens.
00:07:56.920 | So each token occupies a position in this vocabulary, and the input IDs indicate the number
00:08:02.520 | occupied by each token in the vocabulary. Then we map each input ID into a vector of size 512
00:08:11.320 | in the original transformer. But in Lama, it becomes 4096. And these embeddings are vectors
00:08:19.880 | that are learnable. So they are parameters for the model. And while the model will be trained,
00:08:24.920 | this embedding will change in such a way that they will capture the meaning of the word they
00:08:29.720 | are mapping. So we hope that, for example, the word "cat" and "dog" will have similar embedding,
00:08:37.320 | because kind of they map similar-- at least they are in the same semantic group. And also,
00:08:43.880 | the word "house" and "building," they will be very close to each other if we check the two vectors.
00:08:49.880 | And this is the idea behind the embedding. Now let's check what is normalization. Because this
00:08:57.480 | is the layer right after the embeddings. And for that, let's introduce some review of the
00:09:03.880 | neural networks and how they work. So suppose we have a feed-forward neural network with an input,
00:09:10.760 | a hidden layer made of neurons, another hidden layer made of another five neurons,
00:09:18.680 | which then maps to an output. We usually have a target. And comparing the output with the target,
00:09:24.680 | we produce a loss. The loss is then propagated back to the two hidden layers by means of back
00:09:31.400 | propagation. So what we do is we calculate the gradient of the loss with respect to each weight
00:09:37.240 | of these two hidden layers. And we modify these weights of the hidden layer accordingly,
00:09:43.800 | also according to the learning rate that we have set. To check why we need to normalize and what is
00:09:50.760 | the need of normalization, I will make a simplification of the neural network. So let's
00:09:56.680 | suppose our neural network is actually a factory, a factory that makes phones. So to make a phone,
00:10:02.600 | we start with some raw material that are given to a hardware team that will take the raw material
00:10:08.760 | and produce some hardware. For example, they may select the Bluetooth device, they may select the
00:10:14.760 | display, they may select the microphone, the camera, etc, etc. And they make up the hardware
00:10:20.920 | of this phone. The hardware team then gives this prototype to the software team, which then creates
00:10:26.840 | the software for this hardware. And then the output of the software team is the complete phone
00:10:32.120 | with hardware and software and is given as the output. The output is then compared with what was
00:10:38.760 | the original design of the phone. And then we compute a loss. So what is the difference between
00:10:44.440 | the target we had for our phone and what we actually produced? So suppose the loss is our CEO.
00:10:50.680 | And the loss is quite big, suppose. So our CEO will talk with the hardware team and with the
00:10:57.880 | software team and will tell them to adjust their strategy so as to go closer to the target next
00:11:03.720 | time. So suppose that the hardware was too expensive. So the CEO will tell the hardware
00:11:08.600 | team to use maybe a smaller display, to use a cheaper camera, to change the Bluetooth to a
00:11:14.200 | lower range one, or to change the Wi-Fi to a low energy one, to change the battery, etc, etc.
00:11:19.320 | And we'll also talk with the software team to adjust their strategy and maybe tell the software
00:11:25.160 | team to concentrate less on refactoring, to concentrate less on training, to hire more
00:11:32.120 | interns and not care too much about the employees because the costs are too high, etc. And he will
00:11:39.880 | adjust the strategy of the software and the hardware team. So the next time we start with the
00:11:45.560 | raw material again. So let's go back. We start with the raw material again. And the hardware
00:11:53.800 | team, according to the new strategy set by the CEO, will produce a new hardware. Now the problem
00:12:00.520 | arises. The software team now will receive a hardware that the software team has never seen
00:12:05.800 | before because the display has been changed, the Bluetooth has been changed, the Wi-Fi has been
00:12:11.160 | changed, everything has been changed. So the software team needs to redo a lot of work and
00:12:17.480 | especially they need to adjust their strategy a lot because they are dealing with something they
00:12:23.400 | have never seen before. So the output of the software team will be much different compared
00:12:28.840 | to what they previously output. And maybe it will be even further from the target because
00:12:35.960 | the software team was not ready to make all these adjustments. So maybe they wasted a lot of time,
00:12:40.280 | so maybe they wasted a lot of resources, so they maybe could not even reach the target,
00:12:45.320 | even get closer to the target. So this time maybe the loss is even higher. So as you can see,
00:12:50.920 | the problem arises by the fact that the loss function modifies the weights of the hardware
00:12:57.000 | team and the software team. But then the software team at the next iteration receives an input that
00:13:04.520 | it has never seen before and this input makes it produce an output that is much divergent compared
00:13:12.520 | to the one it used to produce before. This will make the model oscillate kind of in the loss
00:13:18.760 | and will make the training very slower. Now let's look what happens at the math level to understand
00:13:25.160 | how the normalization works. So let's review some maths. Suppose that we have a linear layer
00:13:31.320 | defined as nn.linear with three input features and five output features with bias. This is the
00:13:39.000 | linear layer as defined in PyTorch. The linear layer will create two matrices, one called W,
00:13:45.960 | the weight, and one called B, the bias. Suppose we have an input of shape 10 rows by 3 columns,
00:13:53.240 | the output of this linear layer with this input x will be 10 rows by 5 columns. But how does this
00:14:00.920 | happen mathematically? Let's review it. So imagine we have our input which is 10 by 3,
00:14:06.440 | which means that we have 10 items and each item has 10 features. The W matrix created by the
00:14:13.400 | linear layer will be 5 by 3, so the output features by the 3 input features. And we can think of each
00:14:21.880 | of this row as one neuron, each of them having three weights, one weight for each of the input
00:14:29.560 | features of the x input. Then we have the bias vector and the bias vector is one weight for each
00:14:39.240 | neuron because the bias is one for every neuron. And this will produce an output which is 10 by 5,
00:14:46.920 | which means we have 10 items with 5 features. Let's try to understand what is the flow of
00:14:54.360 | information in these matrices. The flow of information is governed by this expression,
00:15:01.800 | so the output is equal to the x multiplied by the transpose of the W matrix plus B.
00:15:09.880 | So let's suppose we have this input x and we have one item and the item 1 has three features,
00:15:17.640 | A1, A2 and A3. The transpose of Wt is this matrix here, so in which we swap the row with the
00:15:25.240 | columns because according to the formula we need to make the transpose of that matrix.
00:15:29.240 | So we have neuron 1 with the three weights W1, W2, W3. We multiply the two and we obtain this
00:15:36.120 | matrix, so x multiplied by the transpose of W produces this matrix here, in which this row 1
00:15:43.640 | is the dot product of this row vector with this column vector. Then we add the B row vector.
00:15:55.160 | As you can see, to add two matrices they need to have the same dimension, but in PyTorch,
00:16:01.960 | because of broadcasting, this row will be added to this row here and then to independently to
00:16:08.360 | this row and to this row etc etc because of the broadcasting. And then we will have this output.
00:16:15.000 | And the first item here will be Z1. What is Z1? Well, Z1 is equal to R1 plus B1. But what is R1?
00:16:25.480 | R1 is the dot product of this column with this row or this row with this column. So it's this
00:16:31.320 | expression here. So the output of the neuron 1 for the item 1 only depends on the features of
00:16:38.040 | the item 1. Usually after this output we also apply a non-linearity like the ReLU function,
00:16:44.280 | which and the argument of the ReLU function is referred to as the activation of the neuron 1.
00:16:51.000 | Now, as we can see, the output of the neuron 1 only depends on the input features of each item.
00:16:59.480 | So the output of a neuron for a data item depends on the features of the input data
00:17:04.200 | item and the neuron's parameter. We can think of the input to a neuron as the output of a previous
00:17:09.960 | layer. So, for example, that input that we saw before, the X, it may as well be the output of
00:17:14.840 | the previous layer. If the previous layer, after its weight are updated because of the gradient
00:17:21.320 | descent, changes drastically the output, like we did before, for example, because the CEO realigned
00:17:27.640 | the strategy of the hardware team, so the previous layer, the hardware team, will produce an output
00:17:31.720 | that is drastically different compared to what it used to produce, the next layer will have its
00:17:38.040 | output changed also drastically. So, because it will be forced to readjust its weight drastically
00:17:45.320 | at the next step of the gradient descent. So what we don't like is the fact that the weight,
00:17:50.760 | the output of the previous layer changes too much, so that the next layer also has to change
00:17:56.440 | its output a lot, because it's to adhere to the strategy defined by the loss function.
00:18:03.720 | So this phenomenon, by which the distribution of the internal nodes of a neuron change, is referred
00:18:09.960 | to as internal covariate shift. And we want to avoid it, because it makes training the network
00:18:15.240 | slower, as the neurons are forced to readjust drastically their weights in one direction
00:18:20.280 | or another, because of drastic changes in the output of the previous layers.
00:18:25.000 | So what do we do? We do layer normalization, at least in the vanilla transformer. So let's
00:18:30.440 | review how the layer normalization works. Imagine we still have our input x defined with 10 rows by
00:18:37.640 | 3 columns, and for each of these items, independently, we calculate two statistics.
00:18:46.760 | One is the mu, so the mean, and one is the sigma, so the variance. And then we normalize the values
00:18:55.960 | in this matrix according to this formula. So we take basically x minus its mu, so each item minus
00:19:04.120 | the mu, divided by the square root of the variance plus epsilon, where epsilon is a very small number,
00:19:11.160 | so that we never divide by zero in this way, even if the variance is very small.
00:19:15.160 | And each of these numbers is then multiplied with the two parameters,
00:19:20.840 | one is gamma, and one is beta. They are both learnable by the model, and they are useful,
00:19:26.760 | because the model can adjust this gamma and beta to amplify the values that it needs.
00:19:32.040 | So before we had layer normalization, we used to normalize with batch normalization,
00:19:40.600 | and with batch normalization, the only difference is that instead of calculating the statistics by
00:19:45.400 | rows, we calculated them by columns. So the feature 1, feature 2, and feature 3.
00:19:51.000 | With layer normalization, we do it by row. So each row will have its own mu and sigma.
00:19:57.000 | So by using the layer normalization, basically, we transform the initial distribution of features,
00:20:03.320 | no matter what they are, into normalized numbers that are distributed with 0 mean and 1 variance.
00:20:10.840 | So this formula actually comes from probability statistics, and if you remember,
00:20:14.360 | let me use the pen, okay, if you remember, basically, if we have a variable x, which is
00:20:23.160 | distributed like a normal variable with a mean, let's say 5, and a variance of 36,
00:20:31.000 | if we do x minus its mean, so 5 divided by the square root of the variance, so 36,
00:20:41.800 | this one, this variable here, let's call it z, will be distributed like n, 0, 1.
00:20:50.600 | So it will become a standard Gaussian, and this is exactly what we are doing here. So we are
00:20:56.600 | transforming them into standard Gaussians, so that this value, most of the times will be close to 0,
00:21:03.000 | I mean, will be distributed around 0. Now let's talk about root-mean-square
00:21:09.080 | normalization, the one used by Lama. The root-mean-square normalization was introduced
00:21:17.240 | in this paper, root-mean-square layer normalization, from these two researchers,
00:21:23.320 | and let's read the paper together. A well-known explanation of the success of layer norm is its
00:21:29.800 | re-centering and re-scaling invariance property. So what do they mean? What is the re-centering
00:21:36.040 | and the re-scaling invariance? The fact that the features, no matter what they are, they will be
00:21:41.080 | re-centered around the zero mean, and re-scaled to have a variance of 1. The former enables the
00:21:47.800 | model to be insensitive to shift noises on both input and weights, and the latter keeps the output
00:21:54.120 | representations intact when both input and weight are randomly scaled. In this paper, we hypothesize
00:22:00.520 | that the re-scaling invariance is the reason for success of layer norm, rather than the re-centering
00:22:06.680 | invariance. So what they claim in this paper is that, basically, the success of layer norm is not
00:22:14.360 | because of the re-centering and the re-scaling, but mostly because of the re-scaling, so this
00:22:21.240 | division by the variance, basically, so to have a variance of 1. And what they do is, basically,
00:22:27.960 | they said, okay, can we find another statistic that doesn't depend on the mean because we believe
00:22:33.720 | that it's not necessary? Well, yes. They use this root-mean-square statistic, so this statistic
00:22:41.480 | defined here, the statistic defined here, and as you can see from the expression of this statistic,
00:22:51.960 | we don't use the mean to calculate it anymore, because the previous statistics here, so the
00:22:56.840 | variance, to be calculated you need the mean, because if you remember, the variance to be
00:23:03.320 | calculated needs the mean, so the variance is equal to the summation of x minus mu to the power of 2
00:23:11.240 | divided by n. So we need the mean to calculate the variance. So what the authors wanted to do
00:23:18.040 | in this paper, they said, okay, because we don't need to re-center, because we believe, we
00:23:23.400 | hypothesize that the re-centering is not needed to obtain the effect of the layer normalization,
00:23:28.680 | we want to find a statistic that doesn't depend on the mean, and the RMS statistic doesn't depend
00:23:33.400 | on the mean. So they do exactly the same thing that they did in the layer normalization, so they
00:23:39.800 | calculate the RMS statistic by rows, so one for each row, and then they normalize according to
00:23:46.840 | this formula here, so they just divide by the statistic, RMS statistic, and then multiply by
00:23:52.120 | this gamma parameter, which is learnable. Now, why root-mean-square normalization? Well,
00:24:00.200 | it requires less computation compared to layer normalization, because we are not computing
00:24:05.240 | two statistics, so we are not computing the mean and the sigma, we are only computing one,
00:24:10.360 | so it gives you a computational advantage. And it works well in practice, so actually what the
00:24:17.480 | authors of the paper hypothesized is actually true, we only need the invariance to obtain the
00:24:23.720 | effect made by the layer normalization, we don't need the re-centering. At least, this is what
00:24:29.080 | happens with Lama. The next topic we will be talking about is the positional encodings,
00:24:34.200 | but before we introduce the rotary positional encodings, let's review the positional encodings
00:24:39.320 | in the vanilla transformer. As you remember, after we transform our tokens into embeddings,
00:24:45.320 | so vectors of size 512, in the vanilla transformer, then we sum another vector to these
00:24:51.320 | embeddings, that indicate the position of each token inside the sentence,
00:24:59.000 | and these positional embeddings are fixed, so they are not learned by the model, they are computed
00:25:04.600 | once and then they are reused for every sentence during training and inference, and each word gets
00:25:12.280 | his own vector of size 512. We have a new kind of positional encoding called rotary positional
00:25:18.760 | encoding, so absolute positional encodings are fixed vectors that are added to the embedding
00:25:24.680 | of a token to represent its absolute position in the sentence, so the token number 1 gets its own
00:25:30.600 | vector, the token number 2 gets its own vector, the token number 3 gets its own vector, so the
00:25:36.600 | absolute positional encoding deals with one token at a time. You can think of it as the pair latitude
00:25:42.680 | and longitude on a map, each point on the earth will have its own unique latitude and longitude,
00:25:49.000 | so that's an absolute indication of the position of each point on the earth, and this is the same
00:25:54.520 | what happens with absolute positional encoding in the vanilla transformer. We have one vector
00:25:59.000 | that represents exactly that position, which is added to that particular token in that position.
00:26:04.280 | With relative positional encodings, on the other hand, it deals with two tokens at a time,
00:26:10.760 | and it is involved when we calculate the attention. Since the attention mechanism
00:26:15.320 | captures the intensity of how much two words are related to each other, relative positional
00:26:20.600 | encodings tell the attention mechanism the distance between the two words involved in
00:26:25.880 | this attention mechanism. So, given two tokens, we create a vector that represents their distance.
00:26:33.000 | This is why it's called relative, because it's relative to the distance between two tokens.
00:26:38.440 | Relative positional encodings were first introduced in the following paper
00:26:42.520 | from Google, and you can notice that Vasvani, I think, is the same author of the transformer model.
00:26:50.200 | So, now, with absolute positional encoding, so from the attention is all you need,
00:26:56.520 | when we calculate the dot product in the attention mechanism, so if you remember the
00:27:04.200 | attention mechanism, the formula, let me write it, the attention is equal to the query multiplied by
00:27:18.040 | the transpose of the key divided by the square root of d model, d model, all of this, then we
00:27:28.040 | do the softmax, and then we multiply it by v, etc., etc., but we only concentrate on the q
00:27:33.160 | multiplied by the k transposed in this case, and this is what we see here. So, when we calculate
00:27:41.080 | this dot product, the attention mechanism is calculating the dot product between two tokens,
00:27:47.800 | that already have the absolute position encoded into them, because we already added the absolute
00:27:54.440 | positional encoding to each token. So, in this attention mechanism from the vanilla transformer,
00:27:59.480 | we have two tokens and the attention mechanism, while in relative positional encodings,
00:28:03.880 | we have three vectors. We have the token one, the token two, and then we have this vector here,
00:28:15.960 | we have this vector here, that represents the distance between these two tokens,
00:28:21.480 | and so we have three vectors involved in this attention mechanism, and we want the attention
00:28:28.200 | mechanism to actually match this token differently based on this vector here. So, this vector will
00:28:35.160 | indicate to the attention mechanism, so to the dot product, how to relate these two words that
00:28:40.840 | are at this particular distance. With rotary positional embeddings, we do a similar job,
00:28:48.120 | and they were introduced with this paper, so Reformer, and they are from a Chinese company.
00:28:55.160 | So, the dot product used in the attention mechanism is a type of inner product. So,
00:29:01.720 | if you remember from linear algebra, the dot product is a kind of operation that has some
00:29:07.480 | properties, and these properties are the kind of properties that every inner product must have.
00:29:13.400 | So, the inner product can be thought of as a generalization of the dot product.
00:29:17.400 | What the authors of the paper wanted to do is, can we find an inner product over the two-vector
00:29:25.640 | query and key used in the attention mechanism that only depends on the two vectors themselves
00:29:32.280 | and the relative distance of the token they represent. That is, given two vectors, query
00:29:39.480 | and key, that only contain the embedding of the word that they represent, and their position
00:29:46.120 | inside of the sentence, so this m is actually an absolute number, so it's a scalar, it represents
00:29:52.440 | the position of the word inside of the sentence, and this n represents the position of the second
00:29:57.240 | word inside of the sentence. What they wanted to say is, can we find an inner product, so this
00:30:03.080 | particular parenthesis we see here is an inner product between these two vectors,
00:30:09.000 | that behaves like this function g, that only depends on the embedding of xn, so the first
00:30:17.080 | token, of xn, the second token, and the relative distance between them, and no other information.
00:30:24.760 | So this function will be given only the embedding of the first token, the embedding of the second
00:30:29.640 | token, and a number that represents the relative position of these two tokens, relative distance
00:30:35.960 | of these two tokens. Yes, we can find such a function, and the function is the one defined
00:30:43.320 | here. So we can define a function g, like the following, that only needs, only depends on the
00:30:49.640 | two embedding vectors q and k, and the relative distance. And this function is defined in the
00:30:56.840 | complex number space, and it can be converted by using the Euler formula into this form.
00:31:03.160 | And another thing to notice is that this function here, the one we are watching, is defined for
00:31:09.720 | vectors of dimension 2. Of course later we will see what happens when the dimension is bigger.
00:31:17.560 | And when we convert this expression here, which is in the complex number space,
00:31:22.680 | into it's matrix form, through the Euler's formula, we can recognize this matrix here
00:31:29.160 | as the rotation matrix. So this matrix here basically represents the rotation of a vector.
00:31:34.760 | For example, this one here, so this product here, will be a vector, and this rotation matrix will
00:31:42.920 | rotate this vector into the space by the amount described by m theta, so the angle m theta.
00:31:50.840 | Let's see an example. So imagine we have a vector v0, and we want to rotate it by theta,
00:31:58.840 | by an angle theta here, to arrive to the vector v prime. So what we do is, we multiply the vector
00:32:05.320 | v0 with this matrix, exactly this one, in which the values are calculated like this,
00:32:11.000 | cosine of theta, minus sine of theta, sine of theta, and cosine of theta.
00:32:15.080 | And the resulting vector will be the same vector, so the same length, but rotated by this angle.
00:32:22.040 | And this is why they are called rotary positional embeddings,
00:32:26.440 | because this vector represents a rotation.
00:32:29.160 | Now, when the vector is not two-dimensional, but we have n dimensions, for example in the
00:32:36.680 | original transformer model our embedding size is 512, and in Lama it's 4096, we need to use this
00:32:44.680 | form. Now, I want you to notice not what are the numbers in this matrix, but the fact that this
00:32:51.480 | matrix is sparse, so it is not convenient to use it to compute the positional embeddings,
00:32:57.080 | because if we multiply by this embedding, our tensorflow, our gpu, our computer will do a lot
00:33:02.680 | of operations that are useless, because we already know that most of the products will be zero.
00:33:07.160 | So, is there a better way, a more computationally efficient way to do this computation?
00:33:12.200 | Well, there is, this form here. So, given a token with the embedding vector x,
00:33:19.160 | and the position m of the token inside the sentence, this is how we compute the position
00:33:24.760 | embedding for the token. We take the dimensions of the token, we multiply by this matrix here,
00:33:31.640 | computed like the following, where the theta are fixed, m is the position of the token, x1, x2,
00:33:38.680 | x3 are the dimensions of the embedding, so the first dimension of the embedding, the second
00:33:42.360 | dimension of the embedding, etc., plus, minus the second embedding, this vector computed like
00:33:50.120 | with the following positions, so minus x2, which is the negative value of the second
00:33:57.400 | dimension of the embedding of the vector x, multiplied by this matrix here.
00:34:03.240 | So, there is nothing we have to learn in this matrix, everything is fixed, because if we watch
00:34:07.240 | the previous slide, we can see that this theta actually is computed like this, one for each
00:34:13.240 | dimension, and so there is nothing to learn. So, basically, they are just like the absolute
00:34:20.040 | positional encoding, so we compute them once, and then we can reuse them for all the sentences that
00:34:27.080 | we will train the model upon. Another interesting property of the
00:34:30.760 | rotary positional embeddings is the long-term decay. So, what the authors did, they calculated
00:34:37.000 | an upper bound for the inner product that we saw before, so the g function, by varying the distance
00:34:42.760 | between the two tokens, and then they proved that no matter what are the two tokens, there is an
00:34:48.840 | upper bound that decreases as the distance between the two tokens grow. And if you remember that the
00:34:57.560 | inner product or the dot product that we are computing is for the calculation of the attention,
00:35:02.760 | this dot product represents the intensity of relationship between the two tokens for which
00:35:07.640 | we are computing the attention. And what these rotary positional embeddings do, they will
00:35:12.920 | basically decay this relationship, the strength of this relationship between the two tokens,
00:35:19.560 | if the two tokens that we are matching are distant from each other. And this is actually
00:35:28.120 | what we want. So, we want two words that are very far from each other to have a less strong
00:35:33.240 | relationship, and two words that are close to each other to have a stronger relationship.
00:35:37.480 | And this is a desired property that we want from these rotary positional embeddings.
00:35:42.200 | Now, the rotary positional embeddings are only applied to the query and the keys,
00:35:48.200 | but not to the values. Let's see why. Well, the first consideration is that they basically,
00:35:54.440 | they come into play when we are calculating the attention. So, when we calculate the attention,
00:35:59.080 | it's the attention mechanism that will change the score. So, as you remember,
00:36:05.400 | the attention mechanism is kind of a score that tells how much strong is the relationship between
00:36:10.920 | two tokens. So, this relationship will be stronger or less stronger or will change according to also
00:36:18.920 | the position of these two tokens inside of the sentence and the relative distance between these
00:36:24.760 | two tokens. Another thing is that the rotation, rotary positional embeddings are applied after
00:36:30.040 | the vector Q and K have been multiplied by the W matrix in the attention mechanism,
00:36:35.480 | while in the vanilla transformer, they are applied before. So, in the vanilla transformer,
00:36:40.200 | the position embeddings are applied right after we transform the tokens into embeddings.
00:36:47.240 | But in the rotary positional embeddings, so in Lama, we don't do this. We basically,
00:36:53.240 | right after we multiply by the W matrix in the attention mechanism. So, the W matrix,
00:36:59.080 | if you remember, is the matrix of parameters that each head has, each attention head has.
00:37:07.160 | And so, in the Lama, basically, we apply the rotary position encoding after we multiply the
00:37:16.200 | vectors Q and K by the W matrix. Now comes the interesting part, in which we will watch how the
00:37:21.880 | self-attention works in Lama. But before we can talk about the self-attention as used in Lama,
00:37:27.560 | we need to review, at least briefly, the self-attention in the vanilla transformer.
00:37:33.000 | So, if you remember the self-attention in the vanilla transformer, we start with the matrix Q,
00:37:38.760 | which is a matrix of sequence by the model, which means that we have on the rows, the tokens,
00:37:45.480 | and on the columns, the dimensions of the embedding vector. So, we can think of it like the following.
00:37:51.400 | Let me. Okay. So, we can think of it like having six rows, one, and each of these rows is a vector
00:38:01.240 | of dimension 512 that represents the embedding of that token. And now, let me delete.
00:38:08.040 | And then, we multiply according to this formula. So, Q multiplied by the transpose of the K. So,
00:38:16.280 | transpose of the K divided by the square root of 512, which is the dimension of the embedding vector,
00:38:22.040 | where the K is equal to Q and V is also equal to Q, because this is a self-attention. So,
00:38:29.000 | the three matrices are actually the same sequence. Then, we apply the softmax and we obtain this
00:38:36.120 | matrix. So, we had the matrix that was 6 by 512 multiplied by another one that is 512 by 6. We
00:38:43.400 | will obtain a matrix that is 6 by 6, where each item in this matrix represents the dot product
00:38:50.440 | of the first token with itself, then the first token with the second token, the first token with
00:38:56.760 | the third token, the first token with the fourth token, etc. So, this matrix captures the intensity
00:39:03.720 | of relationship between two tokens. Then, the output of this softmax is multiplied by the V
00:39:13.640 | matrix to obtain the attention sequence. So, the output of the self-attention is another matrix
00:39:21.080 | that has the same dimensions as the initial matrix. So, it will produce a sequence where
00:39:28.280 | the embeddings now not only capture the meaning of each token, not only they capture the position
00:39:34.600 | of each token, but they also capture kind of the relationship between that token and every other
00:39:40.760 | token. If you didn't understand this concept, please go back and watch my previous video about
00:39:45.800 | the transformer where I explain it very carefully and in much more detail. Now, let's have a look
00:39:51.320 | at the multi-head attention very briefly. So, the multi-head attention basically means that we have
00:39:57.720 | an input sequence, we take it, we copy it into Q, K, and V, so they are the same matrix, we multiply
00:40:05.400 | by parameter matrices, and then we split into multiple smaller matrices, one for each head,
00:40:12.200 | and we calculate the attention between these heads. So, head 1, head 2, head 3, head 4.
00:40:17.160 | Then, we concatenate the output of these heads, we multiply by the output matrix W_O,
00:40:22.920 | and finally we have the output of the multi-head attention. Let's look at what is the first KV
00:40:28.520 | cache. So, before we introduce the KV cache, we need to understand how Lama was trained,
00:40:35.320 | and we need to understand what is the next token prediction task. So, Lama, just like most of the
00:40:42.120 | large language models, have been trained on the next token prediction task, which means that given
00:40:47.880 | a sequence, it will try to predict what is the next token, the most likely next token, to continue
00:40:54.760 | the prompt. So, for example, if we tell him a poem, for example, without the last word,
00:41:01.880 | probably it will come up with the last word that is missing from that poem. In this case,
00:41:07.640 | I will be using one very famous passage from Dante Alighieri, and I will not use the Italian
00:41:13.640 | translation, but we will use the English translation here. So, I will only deal with
00:41:17.640 | the first line you can see here, "Love that can quickly seize the gentle heart".
00:41:21.480 | So, let's train Lama on this sentence. How does the training work? Well, we give the input to
00:41:29.480 | the model, the input is built in such a way that we first prepare the start of sentence token,
00:41:35.400 | and then the target is built such that we append an end of sentence token. Why? Because the
00:41:43.160 | model, this transformer model, is a sequence-to-sequence model, which maps each
00:41:50.040 | position in the input sequence into another position in the output sequence. So, basically,
00:41:58.280 | the first token of the input sequence will be mapped to the first token of the output sequence,
00:42:04.040 | and the second token of the input sequence will be mapped to the second token of the output
00:42:08.440 | sequence, etc., etc., etc. This also means that if we give our model the input "sos",
00:42:14.920 | it will produce the first token as output, so "love", then if we give the first two tokens,
00:42:21.320 | it will produce the second token as output, so "love that", and if we give the first three tokens,
00:42:29.400 | it will produce the output, the third token as output. Of course, the model will also produce
00:42:35.960 | the output for the previous two tokens, but let's see it with an example. So, if you remember from
00:42:42.440 | my previous video, also in which I do the inferencing, when we train the model, we only do
00:42:47.080 | it in one step, so we give the input, we give the target, we calculate the loss, and we don't have
00:42:53.320 | any for loop to train the model for one single sentence, but for the inference, we need to do it
00:43:01.240 | token by token. So, in this inferencing, we start with a time step, time step one, in which we only
00:43:10.200 | give the input "sos", so start of sentence, and the output is "love". Then, we take the output token
00:43:17.720 | here, "love", and we append it to the input, and we give it again to the model, and the model will
00:43:23.960 | produce the next token, "love that". Then, we take the last token output by the model, "that",
00:43:31.080 | we append it again to the input, and the model will produce the next token. And then, we again
00:43:36.920 | take the next token, so "can", we append it to the input, and we feed it again to the model,
00:43:43.160 | and the model will output the next token quickly. And we do it for all the steps that are necessary
00:43:49.880 | until we reach the end of sentence token. Then, that's when we know that the model has finished
00:43:56.120 | outputting its output. Now, this is not how Lama was trained, actually, but this is a good example
00:44:04.840 | to show you how the next token prediction task works. Now, there is a problem with this approach.
00:44:13.320 | Let's see why. At every step of the inference, we are only interested in the last token output by
00:44:20.920 | the model, because we already have the previous ones. However, the model needs to access all the
00:44:27.480 | previous tokens to decide on which token to output, since they constitute its context, or the prompt.
00:44:33.320 | So, what I mean by this is that to output, for example, the word "D", the model has to see all
00:44:40.280 | the input here. We cannot just give the "Cs". The model needs to see all the input to output this
00:44:45.960 | last token, "D". But, the point is, this is a sequence-to-sequence model, so it will produce
00:44:52.600 | this sequence as output, even if we only care about the last token. So, there is a lot of
00:44:58.040 | unnecessary computation we are doing to calculate these tokens, again, that we already actually have
00:45:03.480 | from the previous time steps. So, let's find a way to not do this useless computation.
00:45:08.760 | And this is what we do with the KVCache. So, the KVCache is a way to do less computation
00:45:16.120 | on the tokens that we have already seen during inferencing. So, it's only applied during
00:45:22.680 | inferencing in a transformer model, and it not only applies to the transformer like the one in
00:45:31.160 | Lama, but to all transformer models, because all transformer models work in this way. This is a
00:45:36.520 | description, it's a picture of how the self-attention works during the next token prediction
00:45:42.360 | task. So, as you saw also in my previous slides, we have a query matrix here with N tokens, then we
00:45:50.120 | have the transposed of the keys, so the query can be taught as rows of vectors, where the first
00:45:56.680 | vector represents the first token, the second token, etc. Then the transposed of the keys is
00:46:01.240 | the same tokens but transposed, so the rows become columns. This produces a matrix that is N by N,
00:46:08.360 | so if the initial input matrix is 9, the output maximum will be 9 by 9. Then we multiply it by
00:46:15.560 | the V matrix, and this will produce the attention. The attention is then fed to the linear layer of
00:46:23.080 | the transformer, then the linear layer will produce the logits, and the logits are fed to
00:46:29.000 | the softmax, and the softmax allow us to decide which is the token from our vocabulary. Again,
00:46:36.040 | if you are not familiar with this, please watch my previous video of the transformer about the
00:46:40.840 | inferencing of the transformer, and you will see this clearly. So, this is a description of what
00:46:48.120 | happens at a general level in the self-attention. Now, let's watch it step by step. So, imagine at
00:46:54.840 | inference step 1, we only have the first token. If you remember before, we were only using the
00:47:00.840 | start of sentence token. So, we take the start of sentence token, we multiply it by itself,
00:47:06.120 | so the transposed, it will produce a matrix that is 1 by 1, so this matrix is 1 by 4096,
00:47:12.680 | multiplied by another matrix that is 4096 by 1, it will produce a 1 by 1 matrix.
00:47:17.640 | Why 4096? Because the embedding vector in Lama is 4096. Then the output, so this 1 by 1,
00:47:25.560 | is multiplied by the V, and it will produce the output token here, and this will be our first
00:47:32.040 | token of the output. And then we take the output token, this one, and we append it to the input
00:47:40.040 | at the next step. So, now we have two tokens as input. They are multiplied by itself, but with
00:47:46.600 | the transposed version of itself, and it will produce a 2 by 2 matrix, which is then multiplied
00:47:52.280 | by the V matrix, and it will produce two output tokens. But we are only interested in the last
00:47:57.240 | token's output by the model, so this one, attention 2, which is then appended to the input
00:48:03.720 | matrix at the time step 3. So, now we have three tokens in the time step 3, which are multiplied
00:48:10.200 | by the transposed version of itself, and it will produce a 3 by 3 matrix, which is then multiplied
00:48:16.200 | by the V matrix, and we have these three tokens as output. But we are only interested in the last
00:48:23.880 | token output by the model, so we append it again as input to the Q matrix, which is now four tokens,
00:48:30.440 | which is multiplied by the transposed version of itself, and it will produce a 4 by 4 matrix
00:48:37.080 | as output, which is then multiplied by this matrix here, and it will produce this attention matrix
00:48:43.000 | here. But we are only interested in the last attention, which will be then added again to the
00:48:48.520 | input of the next step. But we notice already something. First of all, we already here in this
00:48:55.720 | matrix, where we compute the dot product between this token and this, this token and this, this
00:49:01.240 | token and this. So this matrix is all the dot products between these two matrices. We can see
00:49:07.000 | something. The first thing is that we already computed these dot products in the previous step.
00:49:13.160 | Can we cache them? So let's go back. As you can see, this matrix is growing. Two, three,
00:49:19.880 | four. See, there is a lot of attention, because every time we are inferencing the transformer,
00:49:26.360 | we are giving the transformer some input, so it's re-computing all these dot products,
00:49:32.280 | which is inconvenient, because we actually already computed them in the previous time step. So
00:49:37.080 | is there a way to not compute them again? Can we kind of cache them? Yes, we can. And then,
00:49:44.360 | since the model is causal, we don't care about the attention of a token with its predecessors,
00:49:50.520 | but only with a token before it. So as you remember, in the self-attention, we apply a mask,
00:49:56.600 | right? So the mask is basically, we don't want the dot product of one word with the word that
00:50:02.280 | comes after it, but only the one that comes before it. So basically, we don't want all the numbers
00:50:08.760 | above the principal diagonal of this matrix. And that's why we applied the mask in the
00:50:15.400 | self-attention. But okay, the point is, we don't need to compute all these dot products. The only
00:50:21.800 | dot products that we are interested in is this last row. So because we added the token 4 as input
00:50:29.880 | compared to the last time step, so we only have this new token, token 4, and we want this token 4
00:50:35.800 | how it is interacting with all the other tokens. So basically, we are only interested in this last
00:50:42.440 | row here. And also, as we only care about the attention of the last token, because we want
00:50:49.960 | to select the word from the vocabulary, so we only care about the last row, we don't care about
00:50:54.680 | producing these two, these three attention score here in the output sequence of the self-attention,
00:51:02.280 | we only care about the last one. So is there a way to remove all these redundant calculations?
00:51:07.880 | Yes, we can do it with the KV cache. Let's see how. So with the KV cache, basically, what we do
00:51:16.280 | is we cache the query, sorry, the keys and the values. And every time we have a new token,
00:51:25.560 | we append it to the key and the values, while the query is only the output of the previous step.
00:51:32.360 | So at the beginning, we don't have any output from the previous step, so we only use the first token.
00:51:37.880 | So the first, the time step one of the inference is the same as without the cache. So we have the
00:51:44.360 | token one with itself, will produce a matrix one by one, multiplied with one token, and it will
00:51:50.040 | produce one attention. However, at the time step two, we don't append it to the previous query,
00:51:59.240 | we just replace the previous token with the new token we have here. However, we keep the cache
00:52:05.240 | of the keys. So we keep the previous token in the keys, and we append the last output to the keys
00:52:11.880 | here, and also to the values. And if you do this multiplication, it will produce a matrix that is
00:52:18.680 | one by two, where the first item is the dot product of the token two with the token one and
00:52:24.840 | the token two with the token two. This is actually what we want. And if we then multiply with the V
00:52:30.760 | matrix, it will only produce one attention score, which is exactly the one we want. And we do again,
00:52:36.920 | so we take this attention two, and this will become the input of the next inference step.
00:52:43.720 | So this token three, we append it to the previously cached K matrix and also to the
00:52:49.560 | previously cached V matrix. This multiplication will produce an output matrix that we can see
00:52:55.960 | here. The multiplication of this output matrix with this V matrix will produce one token in the
00:53:03.000 | output, which is this one, and we know which token to select using this one. Then we use it
00:53:08.680 | as an input for the next inferencing step by appending it to the cached keys and appending
00:53:14.280 | to the cached V matrix. We do this multiplication, and we will get this matrix, which is four,
00:53:23.160 | one by four, which is the dot product of the token four with the token one, the token four
00:53:28.600 | with the token two, token four with the token three, and the token four with itself. We multiply
00:53:33.560 | by the V matrix, and this will only produce one attention, which is exactly what we want to select
00:53:39.080 | the output token. This is the reason why it's called the KV cache, because we are keeping a
00:53:44.920 | cache of the keys and the values. As you can see, the KV cache allow us to save a lot of computation
00:53:52.120 | because we are not doing a lot of dot products that we used to do before, and this makes the
00:53:59.720 | inferencing faster. The next layer that we will be talking about is the grouped multi-query
00:54:04.440 | attention, but before we talk about the grouped multi-query attention, we need to introduce its
00:54:09.560 | predecessor, the multi-query attention. Let's see. So let's start with the problem. The problem is
00:54:16.680 | that the GPUs are too fast. If you watch this datasheet, this is from the A1 GPU from NVIDIA,
00:54:26.280 | we can see that the GPU is very fast at computing, at performing calculations,
00:54:31.160 | but not so much, not so fast at transferring data from its memory. That means, for example,
00:54:39.880 | that the A100 can do 19.5 tera floating point operations per second by using a 32-bit precision,
00:54:50.840 | while it can only transfer 1.9 thousand gigabytes per second. It's nearly 10 times
00:55:00.920 | more slower at transferring data than it is at performing calculations, and this means that
00:55:10.840 | sometimes the bottleneck is not how many operations we perform, but how much data transfer our
00:55:16.840 | operations need, and that depends on the size and the quantity of the tensors involved in our
00:55:22.680 | calculations. For example, if we compute the same operations on the same tensor n times,
00:55:29.240 | it may be faster than computing the same operations on n different tokens, even if they
00:55:34.840 | have the same size. This is because the GPU may need to move these tensors around. So this means
00:55:41.560 | that our goal should not only be to optimize the number of operations we do with our algorithms,
00:55:48.040 | but also minimize the memory access and the memory transfers that our algorithms perform,
00:55:53.960 | because the memory access and memory transfer are more expensive in terms of time compared to the
00:56:02.440 | computations. And this also happens with software when we do I/O, for example. If we copy, for
00:56:08.840 | example, we do some multiplications in the CPU or we read some data from the hard disk, reading from
00:56:14.920 | the hard disk is much more slower than doing a lot of computations on the CPU. And this is a problem.
00:56:21.880 | Now, in this paper, we introduced the multi-query attention. This paper is from Noam Shazir,
00:56:27.960 | who is also one of the authors of the attention paper. So attention is all you need. And in this
00:56:34.360 | paper, he introduced the problem. He said, well, let's look at the multi-head attention. So the
00:56:41.720 | batched multi-head attention. This is the multi-head attention as presented in the original
00:56:46.840 | paper. Attention is all you need. Let's look at the algorithm and let's calculate the number of
00:56:51.880 | arithmetic operations performed and also the total memory involved in these operations.
00:56:58.440 | So he calculated that the number of arithmetic operations is performed in O(1), O(b) and
00:57:05.320 | d^2, where b is the batch size, n is the sequence length, and d is the size of the embedding vector.
00:57:11.960 | While the total memory involved in the operations, given by the sum of all the tensors involved in
00:57:19.160 | the calculations, including the derived ones, is equal to O(b) and d^2 + b*h*n^2, where h is the
00:57:29.640 | number of heads in this multi-head attention, plus d^2. Now, if we compute the ratio between
00:57:36.920 | the total memory and the number of arithmetic operations, we get this expression here,
00:57:43.240 | 1/k + 1/b. In this case, the ratio is much smaller than 1, which means that the number of memory
00:57:50.920 | accesses that we perform is much less than the number of arithmetic operations. So the memory
00:57:56.040 | access in this case is not the bottleneck. So what I mean to say is that the bottleneck of
00:58:04.120 | this algorithm is not the memory access, it is actually the number of computations. And as you
00:58:09.240 | saw before, when we introduced the KV cache, the problem we were trying to solve is the number of
00:58:13.880 | computations, but by introducing the KV cache, we created a new bottleneck, and it's not the
00:58:27.560 | computation anymore. So this algorithm here is the multi-head self-attention, but using the KV cache,
00:58:35.000 | and this reduces the number of operations performed. So if we look at the number of
00:58:40.120 | arithmetic operations performed, it's bnd^2. The total memory involved in the operation is bn^2d
00:58:48.280 | + ndd^2, and the ratio between the two is this, O(n/d + 1/b), so the ratio between the total memory
00:58:58.840 | and the number of arithmetic operations. This means that when n is very similar to d,
00:59:06.600 | this ratio will become 1, or when b is very similar to 1, or in the limit of 1, so the batch size is 1,
00:59:13.960 | this ratio will become 1. And this is a problem, because now when this condition is verified,
00:59:20.360 | it's true, then the memory access becomes the bottleneck of the algorithm. And this also means
00:59:28.040 | that either we keep the dimension of the embedding vector much bigger than the sequence length,
00:59:36.760 | but if we increase the sequence length without making the dimension of the embedding vector
00:59:42.120 | much bigger, the memory access will become the bottleneck. So what we can do is, we need to
00:59:49.800 | find a better way. To solve the problem of the previous algorithm, in which the memory became
00:59:54.920 | the bottleneck, we introduced the multi-query attention. So what the author did was to remove
01:00:01.080 | the h dimension from the k and the v, while keeping it for the q. So it's still a multi-head
01:00:09.880 | attention, but only with respect to q, that's why it's called multi-query attention. So we will
01:00:15.800 | have multiple heads only for the q, but the k and v will be shared by all the heads. And if we use
01:00:22.680 | this algorithm, the ratio becomes this, 1/d + n/dh + 1/b. So we compare it to the previous one,
01:00:33.960 | in which was n/d, now it's n/dh. So we reduced the n/d factor, the ratio n/d by a factor of h,
01:00:46.040 | because we removed the h number of heads for the k and v. So the gains, the performance gains are
01:00:53.400 | important actually, because now it happens less, it is less likely that this ratio will become 1.
01:01:01.320 | But of course, by removing the heads from the k and v, our model will also have less parameters,
01:01:10.840 | it will also have less degrees of freedom and complexity, which may degrade the quality
01:01:17.320 | of the model. And it actually does degrade the quality of the model, but only slightly,
01:01:22.120 | and we will see. So if we compare, for example, the blue score on a translation task from English
01:01:27.320 | to German, we can see that the multi-head attention, so the attention that was in the
01:01:32.040 | original attention paper, has a blue score of 26.7, while the multi-query has a blue score of 26.5.
01:01:41.560 | The author also compared it with the multi-head local and multi-query local, where local means
01:01:49.720 | that they restrict the attention calculation only to the previous 31 positions of each token.
01:01:58.600 | And we can see it here. But the performance gains by reducing the heads of the k and v is great,
01:02:07.160 | because you can see the inference time, for example, on the original multi-head attention
01:02:11.560 | and the multi-query attention. The inference time went from 1.7 microseconds plus 46 microseconds
01:02:20.200 | for the decoder to 1.5 microseconds plus 3.8 microseconds for the decoder. So in total here,
01:02:28.040 | more or less, we took 48 microseconds, while here we more or less take 6 microseconds for
01:02:36.360 | the multi-query. So it's a great benefit from a performance point of view during the inferencing.
01:02:45.160 | Let's talk about grouped multi-query attention, because now we just introduced the kvcache
01:02:51.560 | and the multi-query attention. But the next step of the multi-query attention
01:02:56.520 | is the grouped multi-query attention, which is the one that is used in llama.
01:03:00.360 | So let's have a look at it. With multi-query, we only have multiple heads for the queries,
01:03:06.440 | but only one head for the key and the values. With grouped multi-query attention, basically,
01:03:12.520 | we divide the queries into groups. So for example, this is the group 1, this is the group 2,
01:03:18.760 | group 3 and group 4. And for each group, we have one different head of k and v.
01:03:26.600 | This is a good compromise between the multi-head, in which there is a one-to-one correspondence,
01:03:32.280 | and the multi-query, where there is a n-to-one correspondence.
01:03:36.680 | So in this case, we have still multiple heads for the keys and values, but they are
01:03:42.600 | less numerically compared to the number of heads of the queries.
01:03:46.040 | And this is a good compromise between the quality of the model and the speed of the model,
01:03:52.840 | because anyway, here we benefit from the computational benefit of the reduction in
01:04:01.320 | the number of heads of key and values, but we don't sacrifice too much on the quality side.
01:04:07.560 | And now the last part of the model. As you can see here, the feedforward in the llama model
01:04:13.880 | has been converted into, has its activation function changed with the zwiglu function.
01:04:20.760 | Let's have a look at how it works. So the zwiglu function was analyzed
01:04:25.640 | in this famous paper from Noam Shazir, who is also one of the authors of the attention model,
01:04:30.920 | who is also one of the authors of the multi-query attention that we saw before.
01:04:36.200 | So let's have a look at this paper. So the author compared the performance of
01:04:42.120 | the transformer model by using different activation functions in the feedforward
01:04:46.680 | layer of the transformer architecture. And the one we are interested in is this
01:04:51.800 | zwiglu here, which is basically the swish function with beta equal to one calculated
01:04:58.600 | in the X multiplied by a W matrix, which is a parameter matrix, which is then multiplied with
01:05:05.080 | the X multiplied by V, V is also another parameter matrix, and W2, which is another parameter matrix.
01:05:11.480 | So compare this with the original feedforward network and here we have three parameter matrices,
01:05:19.000 | while in the original feedforward network, we only had two. So to make the comparison fair,
01:05:26.040 | the author reduced the number of the size of these matrices to have two such that the model,
01:05:33.720 | model's total number of parameters remains the same with the vanilla transformer.
01:05:38.120 | In the vanilla transformer, we had this feedforward network, which was the relu function.
01:05:42.840 | So this max zero, et cetera, is the relu function. And we only had the two parameter matrices.
01:05:48.760 | Actually, some successor version of the transformer didn't have the bias.
01:05:54.200 | So this is, I took this formula from the paper, but there are many implementations
01:05:57.640 | without the bias actually. And while in Lama, we use this
01:06:02.840 | computation for the feedforward network. And this is the code I took from the repository from Lama.
01:06:09.560 | And as you can see, it's just what the model says. It's the silu function. Why the silu function?
01:06:15.400 | Because it's the swish function with beta equal to one. And when the swish function
01:06:20.040 | that has this expression, we give beta equal to one, it's called the sigmoid linear unit
01:06:26.360 | that has this graph and it's called silu. So the silu function evaluated in the w1 of x,
01:06:35.080 | then multiplied by w3, which is then we apply it to w2. So we have three matrices.
01:06:41.960 | And these three matrices are basically linear layers. Now they use the parallelized version
01:06:47.720 | of this linear layer, but it's a linear layer. And if we look at the graph of this silu function,
01:06:54.360 | we can see that it's kind of like a relu, but in this here before the zero, we don't cancel out
01:07:04.760 | immediately the activation. We keep a little tail here so that even values that are very
01:07:10.920 | close to zero from the negative side are not automatically canceled out by the function.
01:07:16.840 | So let's see how does it perform. So this is wiglu function actually performs very well.
01:07:22.200 | Here they evaluate the log complexity, perplexity of the model when we use this
01:07:30.200 | particular function. And we can see that the perplexity here is the lowest. The perplexity
01:07:36.360 | basically means how unsure is the model about its choices. And the wiglu function is performing well.
01:07:45.880 | Then they also run the comparison on many benchmarks. And we see that this wiglu function
01:07:53.160 | is performing quite well on a lot of them. So why is this wiglu activation function working so well?
01:08:01.080 | If we look at the conclusion of this paper, we see that we offer no explanation as to why this
01:08:07.000 | architecture seems to work. We attribute their success as all else to divine benevolence.
01:08:13.640 | Actually, this is kind of funny, but it's also kind of true. Because in most of the
01:08:18.600 | deep learning research, we do not know why things work in the way they do. Because imagine you have
01:08:25.640 | a model of 70 billion parameters. How can you prove what is happening to each one of them
01:08:32.440 | after you modify one activation function? It's not easy to come up with a model that
01:08:38.840 | can explain why the model is reacting in a particular way. What usually we do, we can
01:08:45.880 | either simplify the model, so we can work with a very small model, and then make some assumptions
01:08:51.320 | on why things work the way they do. Or we can just do it on a practical level. So we take a model,
01:08:58.440 | we modify it a little bit, we do some ablation study, and we check which one is performing better.
01:09:03.560 | And this also happens in a lot of areas of machine learning. For example,
01:09:07.560 | we do a lot of grid search to find the right parameters for a model, because we cannot know
01:09:13.000 | beforehand which one will work well, or which one to increase, or which one to decrease. Because it
01:09:17.960 | depends on a lot of factors, not only on the algorithm used, but also on the data, also on
01:09:23.480 | the particular computations used, also on the normalization used. So there is a lot of factors,
01:09:28.120 | there is no formula for everything to explain everything. So this is why the research needs
01:09:35.240 | to do a lot of study on the variants of models, to come up with something that works maybe in
01:09:40.920 | one domain and doesn't work well in other domains. So in this case, we use the Zwiglu,
01:09:45.160 | mostly because in practice it works well with this kind of models.
01:09:48.920 | Thank you guys for watching this long video. I hope that you learned in a deeper level what
01:09:56.200 | happens in Lama, and why it is different from a standard transformer model. I know that the
01:10:01.240 | video has been quite long, and I know that it has been hard on some parts to follow,
01:10:05.880 | so I actually kind of suggest to re-watch it multiple times, especially the parts that you
01:10:10.920 | are less familiar with, and to integrate this video with my previous video about the transformer.
01:10:15.560 | So you can, I will put the chapters so you can easily find the part that you want, but this is
01:10:22.360 | what you need to do. You need to watch multiple times the same concept to actually master it.
01:10:27.960 | And I hope to make another video in which we code the Lama model from zero, so we can
01:10:32.120 | put all this theory into practice. But as you know, I am doing this on my free time,
01:10:39.320 | and my free time is not so much. So thank you guys for watching my video, and please subscribe
01:10:44.200 | to my channel, because this is the best motivation for me to keep posting amazing content on AI and
01:10:50.360 | machine learning. Thank you for watching, and have an amazing rest of the day.