Hello guys! Welcome to my new video about Lama. In this video we will be seeing what is Lama, how it is made, how it is structurally different from the transformer and we will be building each block that makes up Lama. So I will not only explain you concept-wise what is each block doing but we will also explore it from the mathematical point of view and also from the coding point of view, so that we can unify theory with practice.
I can guarantee that if you watch this video you will have a deep understanding of what makes Lama the model it is. So you will not only understand how the blocks interact with each other but how they function and why we needed these blocks in the first place. In this video we will be reviewing a lot of topics, so we will start from the architectural differences between the vanilla transformer and the Lama model.
We will be watching what is the new normalization, the RMS normalization, rotary positional embedding, KV cache, multi-query attention, grouped multi-query attention, the ZWIGLU activation function for the feed-forward layer. But of course I take for granted that you have some background knowledge. First of all I highly recommend that you watch my previous video about the transformer because you need to know how the transformer works.
And in my previous video I also explored the concept of training and inferencing a transformer model. It's about 45 minutes and I think it's worth a watch because it will really give you a deep understanding of the transformer. After you have that knowledge you can watch this video. Anyway, for those who have already watched the video but forgot some things, I will review most of the concepts as we proceed through the topics.
I also take for granted that you have some basic linear algebra knowledge, so matrix multiplication, dot product, basic stuff anyway. And also, because we will be using the rotary positional embeddings, some knowledge about the complex numbers, even if it's not fundamental. So if you don't remember the complex numbers or how they work or the LRS formula, it doesn't matter.
You will understand the concept, not the math. It's not really fundamental. Sometimes I will be reviewing topics that maybe you are already familiar with, so feel free to skip those parts. Let's start our journey by reviewing the architectural differences between the vanilla transformer and Lama. This picture was built by me on the right side because I couldn't find the architectural picture on the paper.
So let's review the differences. As you remember, in the vanilla transformer we have an encoder part and a decoder part. And let me highlight it. So this is the encoder and the right side here is the decoder. While in Lama, we only have an encoder. First of all, because the Lama is a large language model, it has been trained on the next prediction token task.
So basically, we only need the self-attention to predict the next token. And we will see all these concepts. So we will see what is the next prediction task, how it works, and how this new self-attention works. The second difference that we can see from these pictures is that we have here, at the beginning, we have the embedding and also we had the embedding here on the original transformer.
But right after the embedding, we don't have the positional encoding, but we have this RMS norm. And actually, all the norms have been moved before the blocks. So before we had the multi-head attention, and then we had the add-end norm, which is this plus sign here. So it's a concatenation of a skip connection and the output of the multi-head attention, and the normalization.
And we also have this normalization here, here, here. So after every block. But here in Lama, we have it before every block. And we will review what is the normalization and why it works like the way it is. Right after the normalization, we have this query, key, and values input for the self-attention.
One thing we can notice is that the positional encodings are not anymore the positional encodings of the transformer, but they have become the rotary positional encodings, and they are only applied to the query and the keys, but not the values. And we will see also why. Another thing is the self-attention is now the self-attention with KV cache.
We will see what is the KV cache and how it works. And also we have this grouped multi-query attention. Another thing that changed is this feed-forward layer. In the original feed-forward layer of the vanilla transformer, we had the relu activation function for the feed-forward block. But in Lama, we are using the zwiglu function, and we will see why.
This nx means that this block here in the dashed lines is repeated n times one after another, such that the output of the last layer is then fed to this rms norm, then to the linear layer, and then to the softmax. And we will build each of these blocks from the bottom.
So I will show you exactly what these blocks do, how they work, how they interact with each other, what is the math behind, what is the problem they were trying to solve. So we will have a deep knowledge of these models. Let's start our journey with reviewing the models introduced by Lama.
So Lama1 came out in February 2023, and they had four dimensions for this model. One model was with 6.7 billion parameters, 13, 32, 65. And then we have these numbers. What do they mean? The dimension here indicates the size of the embedding vector. So as you can see here, we have these input embeddings that we will review later.
This is basically, they convert each token into a vector of size indicated by this dimension. Then we have the number of heads. So how many heads the attention has, the number of layers. If you remember from the original transformer, the dimension was 512. The number of heads was eight.
The number of layers, I think, was six. And then we have the number of tokens each model was trained upon. So 1 trillion and 1.4 trillion. With Lama2, most of the numbers have doubled. So the context length is basically the sequence length. So what is the longest sequence the model can be fed?
And then the number of tokens upon which the model have been trained is also doubled. So from 1 to 2 trillion for each size of the model, while the parameters more or less remain the same. Then we have this column here, GQA, that indicates that these two sizes of the model, so the 34 billion and 70 billion, they use the grouped query attention.
And we will see how it works. Let's start by reviewing what is the embeddings layer here. And for that, I will use the slides from my previous video. If you remember my previous video, we introduced the embedding like this. So we have a sentence that is made of six words.
What we do is we tokenize the sentence, so it converts into tokens. The tokenization usually is done not by space, but by the BPE tokenizer. So actually, each word will be split into subwords also. But for clarity, for simplicity, we just tokenize our sentence by using the whitespace as separator.
So each token is separated by whitespace from other tokens. And each token is then mapped into its position into the vocabulary. So the vocabulary is the list of the words that our model recognizes. They don't have to be words, of course. They could be anything. They are just tokens.
So each token occupies a position in this vocabulary, and the input IDs indicate the number occupied by each token in the vocabulary. Then we map each input ID into a vector of size 512 in the original transformer. But in Lama, it becomes 4096. And these embeddings are vectors that are learnable.
So they are parameters for the model. And while the model will be trained, this embedding will change in such a way that they will capture the meaning of the word they are mapping. So we hope that, for example, the word "cat" and "dog" will have similar embedding, because kind of they map similar-- at least they are in the same semantic group.
And also, the word "house" and "building," they will be very close to each other if we check the two vectors. And this is the idea behind the embedding. Now let's check what is normalization. Because this is the layer right after the embeddings. And for that, let's introduce some review of the neural networks and how they work.
So suppose we have a feed-forward neural network with an input, a hidden layer made of neurons, another hidden layer made of another five neurons, which then maps to an output. We usually have a target. And comparing the output with the target, we produce a loss. The loss is then propagated back to the two hidden layers by means of back propagation.
So what we do is we calculate the gradient of the loss with respect to each weight of these two hidden layers. And we modify these weights of the hidden layer accordingly, also according to the learning rate that we have set. To check why we need to normalize and what is the need of normalization, I will make a simplification of the neural network.
So let's suppose our neural network is actually a factory, a factory that makes phones. So to make a phone, we start with some raw material that are given to a hardware team that will take the raw material and produce some hardware. For example, they may select the Bluetooth device, they may select the display, they may select the microphone, the camera, etc, etc.
And they make up the hardware of this phone. The hardware team then gives this prototype to the software team, which then creates the software for this hardware. And then the output of the software team is the complete phone with hardware and software and is given as the output. The output is then compared with what was the original design of the phone.
And then we compute a loss. So what is the difference between the target we had for our phone and what we actually produced? So suppose the loss is our CEO. And the loss is quite big, suppose. So our CEO will talk with the hardware team and with the software team and will tell them to adjust their strategy so as to go closer to the target next time.
So suppose that the hardware was too expensive. So the CEO will tell the hardware team to use maybe a smaller display, to use a cheaper camera, to change the Bluetooth to a lower range one, or to change the Wi-Fi to a low energy one, to change the battery, etc, etc.
And we'll also talk with the software team to adjust their strategy and maybe tell the software team to concentrate less on refactoring, to concentrate less on training, to hire more interns and not care too much about the employees because the costs are too high, etc. And he will adjust the strategy of the software and the hardware team.
So the next time we start with the raw material again. So let's go back. We start with the raw material again. And the hardware team, according to the new strategy set by the CEO, will produce a new hardware. Now the problem arises. The software team now will receive a hardware that the software team has never seen before because the display has been changed, the Bluetooth has been changed, the Wi-Fi has been changed, everything has been changed.
So the software team needs to redo a lot of work and especially they need to adjust their strategy a lot because they are dealing with something they have never seen before. So the output of the software team will be much different compared to what they previously output. And maybe it will be even further from the target because the software team was not ready to make all these adjustments.
So maybe they wasted a lot of time, so maybe they wasted a lot of resources, so they maybe could not even reach the target, even get closer to the target. So this time maybe the loss is even higher. So as you can see, the problem arises by the fact that the loss function modifies the weights of the hardware team and the software team.
But then the software team at the next iteration receives an input that it has never seen before and this input makes it produce an output that is much divergent compared to the one it used to produce before. This will make the model oscillate kind of in the loss and will make the training very slower.
Now let's look what happens at the math level to understand how the normalization works. So let's review some maths. Suppose that we have a linear layer defined as nn.linear with three input features and five output features with bias. This is the linear layer as defined in PyTorch. The linear layer will create two matrices, one called W, the weight, and one called B, the bias.
Suppose we have an input of shape 10 rows by 3 columns, the output of this linear layer with this input x will be 10 rows by 5 columns. But how does this happen mathematically? Let's review it. So imagine we have our input which is 10 by 3, which means that we have 10 items and each item has 10 features.
The W matrix created by the linear layer will be 5 by 3, so the output features by the 3 input features. And we can think of each of this row as one neuron, each of them having three weights, one weight for each of the input features of the x input.
Then we have the bias vector and the bias vector is one weight for each neuron because the bias is one for every neuron. And this will produce an output which is 10 by 5, which means we have 10 items with 5 features. Let's try to understand what is the flow of information in these matrices.
The flow of information is governed by this expression, so the output is equal to the x multiplied by the transpose of the W matrix plus B. So let's suppose we have this input x and we have one item and the item 1 has three features, A1, A2 and A3.
The transpose of Wt is this matrix here, so in which we swap the row with the columns because according to the formula we need to make the transpose of that matrix. So we have neuron 1 with the three weights W1, W2, W3. We multiply the two and we obtain this matrix, so x multiplied by the transpose of W produces this matrix here, in which this row 1 is the dot product of this row vector with this column vector.
Then we add the B row vector. As you can see, to add two matrices they need to have the same dimension, but in PyTorch, because of broadcasting, this row will be added to this row here and then to independently to this row and to this row etc etc because of the broadcasting.
And then we will have this output. And the first item here will be Z1. What is Z1? Well, Z1 is equal to R1 plus B1. But what is R1? R1 is the dot product of this column with this row or this row with this column. So it's this expression here.
So the output of the neuron 1 for the item 1 only depends on the features of the item 1. Usually after this output we also apply a non-linearity like the ReLU function, which and the argument of the ReLU function is referred to as the activation of the neuron 1.
Now, as we can see, the output of the neuron 1 only depends on the input features of each item. So the output of a neuron for a data item depends on the features of the input data item and the neuron's parameter. We can think of the input to a neuron as the output of a previous layer.
So, for example, that input that we saw before, the X, it may as well be the output of the previous layer. If the previous layer, after its weight are updated because of the gradient descent, changes drastically the output, like we did before, for example, because the CEO realigned the strategy of the hardware team, so the previous layer, the hardware team, will produce an output that is drastically different compared to what it used to produce, the next layer will have its output changed also drastically.
So, because it will be forced to readjust its weight drastically at the next step of the gradient descent. So what we don't like is the fact that the weight, the output of the previous layer changes too much, so that the next layer also has to change its output a lot, because it's to adhere to the strategy defined by the loss function.
So this phenomenon, by which the distribution of the internal nodes of a neuron change, is referred to as internal covariate shift. And we want to avoid it, because it makes training the network slower, as the neurons are forced to readjust drastically their weights in one direction or another, because of drastic changes in the output of the previous layers.
So what do we do? We do layer normalization, at least in the vanilla transformer. So let's review how the layer normalization works. Imagine we still have our input x defined with 10 rows by 3 columns, and for each of these items, independently, we calculate two statistics. One is the mu, so the mean, and one is the sigma, so the variance.
And then we normalize the values in this matrix according to this formula. So we take basically x minus its mu, so each item minus the mu, divided by the square root of the variance plus epsilon, where epsilon is a very small number, so that we never divide by zero in this way, even if the variance is very small.
And each of these numbers is then multiplied with the two parameters, one is gamma, and one is beta. They are both learnable by the model, and they are useful, because the model can adjust this gamma and beta to amplify the values that it needs. So before we had layer normalization, we used to normalize with batch normalization, and with batch normalization, the only difference is that instead of calculating the statistics by rows, we calculated them by columns.
So the feature 1, feature 2, and feature 3. With layer normalization, we do it by row. So each row will have its own mu and sigma. So by using the layer normalization, basically, we transform the initial distribution of features, no matter what they are, into normalized numbers that are distributed with 0 mean and 1 variance.
So this formula actually comes from probability statistics, and if you remember, let me use the pen, okay, if you remember, basically, if we have a variable x, which is distributed like a normal variable with a mean, let's say 5, and a variance of 36, if we do x minus its mean, so 5 divided by the square root of the variance, so 36, this one, this variable here, let's call it z, will be distributed like n, 0, 1.
So it will become a standard Gaussian, and this is exactly what we are doing here. So we are transforming them into standard Gaussians, so that this value, most of the times will be close to 0, I mean, will be distributed around 0. Now let's talk about root-mean-square normalization, the one used by Lama.
The root-mean-square normalization was introduced in this paper, root-mean-square layer normalization, from these two researchers, and let's read the paper together. A well-known explanation of the success of layer norm is its re-centering and re-scaling invariance property. So what do they mean? What is the re-centering and the re-scaling invariance? The fact that the features, no matter what they are, they will be re-centered around the zero mean, and re-scaled to have a variance of 1.
The former enables the model to be insensitive to shift noises on both input and weights, and the latter keeps the output representations intact when both input and weight are randomly scaled. In this paper, we hypothesize that the re-scaling invariance is the reason for success of layer norm, rather than the re-centering invariance.
So what they claim in this paper is that, basically, the success of layer norm is not because of the re-centering and the re-scaling, but mostly because of the re-scaling, so this division by the variance, basically, so to have a variance of 1. And what they do is, basically, they said, okay, can we find another statistic that doesn't depend on the mean because we believe that it's not necessary?
Well, yes. They use this root-mean-square statistic, so this statistic defined here, the statistic defined here, and as you can see from the expression of this statistic, we don't use the mean to calculate it anymore, because the previous statistics here, so the variance, to be calculated you need the mean, because if you remember, the variance to be calculated needs the mean, so the variance is equal to the summation of x minus mu to the power of 2 divided by n.
So we need the mean to calculate the variance. So what the authors wanted to do in this paper, they said, okay, because we don't need to re-center, because we believe, we hypothesize that the re-centering is not needed to obtain the effect of the layer normalization, we want to find a statistic that doesn't depend on the mean, and the RMS statistic doesn't depend on the mean.
So they do exactly the same thing that they did in the layer normalization, so they calculate the RMS statistic by rows, so one for each row, and then they normalize according to this formula here, so they just divide by the statistic, RMS statistic, and then multiply by this gamma parameter, which is learnable.
Now, why root-mean-square normalization? Well, it requires less computation compared to layer normalization, because we are not computing two statistics, so we are not computing the mean and the sigma, we are only computing one, so it gives you a computational advantage. And it works well in practice, so actually what the authors of the paper hypothesized is actually true, we only need the invariance to obtain the effect made by the layer normalization, we don't need the re-centering.
At least, this is what happens with Lama. The next topic we will be talking about is the positional encodings, but before we introduce the rotary positional encodings, let's review the positional encodings in the vanilla transformer. As you remember, after we transform our tokens into embeddings, so vectors of size 512, in the vanilla transformer, then we sum another vector to these embeddings, that indicate the position of each token inside the sentence, and these positional embeddings are fixed, so they are not learned by the model, they are computed once and then they are reused for every sentence during training and inference, and each word gets his own vector of size 512.
We have a new kind of positional encoding called rotary positional encoding, so absolute positional encodings are fixed vectors that are added to the embedding of a token to represent its absolute position in the sentence, so the token number 1 gets its own vector, the token number 2 gets its own vector, the token number 3 gets its own vector, so the absolute positional encoding deals with one token at a time.
You can think of it as the pair latitude and longitude on a map, each point on the earth will have its own unique latitude and longitude, so that's an absolute indication of the position of each point on the earth, and this is the same what happens with absolute positional encoding in the vanilla transformer.
We have one vector that represents exactly that position, which is added to that particular token in that position. With relative positional encodings, on the other hand, it deals with two tokens at a time, and it is involved when we calculate the attention. Since the attention mechanism captures the intensity of how much two words are related to each other, relative positional encodings tell the attention mechanism the distance between the two words involved in this attention mechanism.
So, given two tokens, we create a vector that represents their distance. This is why it's called relative, because it's relative to the distance between two tokens. Relative positional encodings were first introduced in the following paper from Google, and you can notice that Vasvani, I think, is the same author of the transformer model.
So, now, with absolute positional encoding, so from the attention is all you need, when we calculate the dot product in the attention mechanism, so if you remember the attention mechanism, the formula, let me write it, the attention is equal to the query multiplied by the transpose of the key divided by the square root of d model, d model, all of this, then we do the softmax, and then we multiply it by v, etc., etc., but we only concentrate on the q multiplied by the k transposed in this case, and this is what we see here.
So, when we calculate this dot product, the attention mechanism is calculating the dot product between two tokens, that already have the absolute position encoded into them, because we already added the absolute positional encoding to each token. So, in this attention mechanism from the vanilla transformer, we have two tokens and the attention mechanism, while in relative positional encodings, we have three vectors.
We have the token one, the token two, and then we have this vector here, we have this vector here, that represents the distance between these two tokens, and so we have three vectors involved in this attention mechanism, and we want the attention mechanism to actually match this token differently based on this vector here.
So, this vector will indicate to the attention mechanism, so to the dot product, how to relate these two words that are at this particular distance. With rotary positional embeddings, we do a similar job, and they were introduced with this paper, so Reformer, and they are from a Chinese company.
So, the dot product used in the attention mechanism is a type of inner product. So, if you remember from linear algebra, the dot product is a kind of operation that has some properties, and these properties are the kind of properties that every inner product must have. So, the inner product can be thought of as a generalization of the dot product.
What the authors of the paper wanted to do is, can we find an inner product over the two-vector query and key used in the attention mechanism that only depends on the two vectors themselves and the relative distance of the token they represent. That is, given two vectors, query and key, that only contain the embedding of the word that they represent, and their position inside of the sentence, so this m is actually an absolute number, so it's a scalar, it represents the position of the word inside of the sentence, and this n represents the position of the second word inside of the sentence.
What they wanted to say is, can we find an inner product, so this particular parenthesis we see here is an inner product between these two vectors, that behaves like this function g, that only depends on the embedding of xn, so the first token, of xn, the second token, and the relative distance between them, and no other information.
So this function will be given only the embedding of the first token, the embedding of the second token, and a number that represents the relative position of these two tokens, relative distance of these two tokens. Yes, we can find such a function, and the function is the one defined here.
So we can define a function g, like the following, that only needs, only depends on the two embedding vectors q and k, and the relative distance. And this function is defined in the complex number space, and it can be converted by using the Euler formula into this form. And another thing to notice is that this function here, the one we are watching, is defined for vectors of dimension 2.
Of course later we will see what happens when the dimension is bigger. And when we convert this expression here, which is in the complex number space, into it's matrix form, through the Euler's formula, we can recognize this matrix here as the rotation matrix. So this matrix here basically represents the rotation of a vector.
For example, this one here, so this product here, will be a vector, and this rotation matrix will rotate this vector into the space by the amount described by m theta, so the angle m theta. Let's see an example. So imagine we have a vector v0, and we want to rotate it by theta, by an angle theta here, to arrive to the vector v prime.
So what we do is, we multiply the vector v0 with this matrix, exactly this one, in which the values are calculated like this, cosine of theta, minus sine of theta, sine of theta, and cosine of theta. And the resulting vector will be the same vector, so the same length, but rotated by this angle.
And this is why they are called rotary positional embeddings, because this vector represents a rotation. Now, when the vector is not two-dimensional, but we have n dimensions, for example in the original transformer model our embedding size is 512, and in Lama it's 4096, we need to use this form.
Now, I want you to notice not what are the numbers in this matrix, but the fact that this matrix is sparse, so it is not convenient to use it to compute the positional embeddings, because if we multiply by this embedding, our tensorflow, our gpu, our computer will do a lot of operations that are useless, because we already know that most of the products will be zero.
So, is there a better way, a more computationally efficient way to do this computation? Well, there is, this form here. So, given a token with the embedding vector x, and the position m of the token inside the sentence, this is how we compute the position embedding for the token.
We take the dimensions of the token, we multiply by this matrix here, computed like the following, where the theta are fixed, m is the position of the token, x1, x2, x3 are the dimensions of the embedding, so the first dimension of the embedding, the second dimension of the embedding, etc., plus, minus the second embedding, this vector computed like with the following positions, so minus x2, which is the negative value of the second dimension of the embedding of the vector x, multiplied by this matrix here.
So, there is nothing we have to learn in this matrix, everything is fixed, because if we watch the previous slide, we can see that this theta actually is computed like this, one for each dimension, and so there is nothing to learn. So, basically, they are just like the absolute positional encoding, so we compute them once, and then we can reuse them for all the sentences that we will train the model upon.
Another interesting property of the rotary positional embeddings is the long-term decay. So, what the authors did, they calculated an upper bound for the inner product that we saw before, so the g function, by varying the distance between the two tokens, and then they proved that no matter what are the two tokens, there is an upper bound that decreases as the distance between the two tokens grow.
And if you remember that the inner product or the dot product that we are computing is for the calculation of the attention, this dot product represents the intensity of relationship between the two tokens for which we are computing the attention. And what these rotary positional embeddings do, they will basically decay this relationship, the strength of this relationship between the two tokens, if the two tokens that we are matching are distant from each other.
And this is actually what we want. So, we want two words that are very far from each other to have a less strong relationship, and two words that are close to each other to have a stronger relationship. And this is a desired property that we want from these rotary positional embeddings.
Now, the rotary positional embeddings are only applied to the query and the keys, but not to the values. Let's see why. Well, the first consideration is that they basically, they come into play when we are calculating the attention. So, when we calculate the attention, it's the attention mechanism that will change the score.
So, as you remember, the attention mechanism is kind of a score that tells how much strong is the relationship between two tokens. So, this relationship will be stronger or less stronger or will change according to also the position of these two tokens inside of the sentence and the relative distance between these two tokens.
Another thing is that the rotation, rotary positional embeddings are applied after the vector Q and K have been multiplied by the W matrix in the attention mechanism, while in the vanilla transformer, they are applied before. So, in the vanilla transformer, the position embeddings are applied right after we transform the tokens into embeddings.
But in the rotary positional embeddings, so in Lama, we don't do this. We basically, right after we multiply by the W matrix in the attention mechanism. So, the W matrix, if you remember, is the matrix of parameters that each head has, each attention head has. And so, in the Lama, basically, we apply the rotary position encoding after we multiply the vectors Q and K by the W matrix.
Now comes the interesting part, in which we will watch how the self-attention works in Lama. But before we can talk about the self-attention as used in Lama, we need to review, at least briefly, the self-attention in the vanilla transformer. So, if you remember the self-attention in the vanilla transformer, we start with the matrix Q, which is a matrix of sequence by the model, which means that we have on the rows, the tokens, and on the columns, the dimensions of the embedding vector.
So, we can think of it like the following. Let me. Okay. So, we can think of it like having six rows, one, and each of these rows is a vector of dimension 512 that represents the embedding of that token. And now, let me delete. And then, we multiply according to this formula.
So, Q multiplied by the transpose of the K. So, transpose of the K divided by the square root of 512, which is the dimension of the embedding vector, where the K is equal to Q and V is also equal to Q, because this is a self-attention. So, the three matrices are actually the same sequence.
Then, we apply the softmax and we obtain this matrix. So, we had the matrix that was 6 by 512 multiplied by another one that is 512 by 6. We will obtain a matrix that is 6 by 6, where each item in this matrix represents the dot product of the first token with itself, then the first token with the second token, the first token with the third token, the first token with the fourth token, etc.
So, this matrix captures the intensity of relationship between two tokens. Then, the output of this softmax is multiplied by the V matrix to obtain the attention sequence. So, the output of the self-attention is another matrix that has the same dimensions as the initial matrix. So, it will produce a sequence where the embeddings now not only capture the meaning of each token, not only they capture the position of each token, but they also capture kind of the relationship between that token and every other token.
If you didn't understand this concept, please go back and watch my previous video about the transformer where I explain it very carefully and in much more detail. Now, let's have a look at the multi-head attention very briefly. So, the multi-head attention basically means that we have an input sequence, we take it, we copy it into Q, K, and V, so they are the same matrix, we multiply by parameter matrices, and then we split into multiple smaller matrices, one for each head, and we calculate the attention between these heads.
So, head 1, head 2, head 3, head 4. Then, we concatenate the output of these heads, we multiply by the output matrix W_O, and finally we have the output of the multi-head attention. Let's look at what is the first KV cache. So, before we introduce the KV cache, we need to understand how Lama was trained, and we need to understand what is the next token prediction task.
So, Lama, just like most of the large language models, have been trained on the next token prediction task, which means that given a sequence, it will try to predict what is the next token, the most likely next token, to continue the prompt. So, for example, if we tell him a poem, for example, without the last word, probably it will come up with the last word that is missing from that poem.
In this case, I will be using one very famous passage from Dante Alighieri, and I will not use the Italian translation, but we will use the English translation here. So, I will only deal with the first line you can see here, "Love that can quickly seize the gentle heart".
So, let's train Lama on this sentence. How does the training work? Well, we give the input to the model, the input is built in such a way that we first prepare the start of sentence token, and then the target is built such that we append an end of sentence token.
Why? Because the model, this transformer model, is a sequence-to-sequence model, which maps each position in the input sequence into another position in the output sequence. So, basically, the first token of the input sequence will be mapped to the first token of the output sequence, and the second token of the input sequence will be mapped to the second token of the output sequence, etc., etc., etc.
This also means that if we give our model the input "sos", it will produce the first token as output, so "love", then if we give the first two tokens, it will produce the second token as output, so "love that", and if we give the first three tokens, it will produce the output, the third token as output.
Of course, the model will also produce the output for the previous two tokens, but let's see it with an example. So, if you remember from my previous video, also in which I do the inferencing, when we train the model, we only do it in one step, so we give the input, we give the target, we calculate the loss, and we don't have any for loop to train the model for one single sentence, but for the inference, we need to do it token by token.
So, in this inferencing, we start with a time step, time step one, in which we only give the input "sos", so start of sentence, and the output is "love". Then, we take the output token here, "love", and we append it to the input, and we give it again to the model, and the model will produce the next token, "love that".
Then, we take the last token output by the model, "that", we append it again to the input, and the model will produce the next token. And then, we again take the next token, so "can", we append it to the input, and we feed it again to the model, and the model will output the next token quickly.
And we do it for all the steps that are necessary until we reach the end of sentence token. Then, that's when we know that the model has finished outputting its output. Now, this is not how Lama was trained, actually, but this is a good example to show you how the next token prediction task works.
Now, there is a problem with this approach. Let's see why. At every step of the inference, we are only interested in the last token output by the model, because we already have the previous ones. However, the model needs to access all the previous tokens to decide on which token to output, since they constitute its context, or the prompt.
So, what I mean by this is that to output, for example, the word "D", the model has to see all the input here. We cannot just give the "Cs". The model needs to see all the input to output this last token, "D". But, the point is, this is a sequence-to-sequence model, so it will produce this sequence as output, even if we only care about the last token.
So, there is a lot of unnecessary computation we are doing to calculate these tokens, again, that we already actually have from the previous time steps. So, let's find a way to not do this useless computation. And this is what we do with the KVCache. So, the KVCache is a way to do less computation on the tokens that we have already seen during inferencing.
So, it's only applied during inferencing in a transformer model, and it not only applies to the transformer like the one in Lama, but to all transformer models, because all transformer models work in this way. This is a description, it's a picture of how the self-attention works during the next token prediction task.
So, as you saw also in my previous slides, we have a query matrix here with N tokens, then we have the transposed of the keys, so the query can be taught as rows of vectors, where the first vector represents the first token, the second token, etc. Then the transposed of the keys is the same tokens but transposed, so the rows become columns.
This produces a matrix that is N by N, so if the initial input matrix is 9, the output maximum will be 9 by 9. Then we multiply it by the V matrix, and this will produce the attention. The attention is then fed to the linear layer of the transformer, then the linear layer will produce the logits, and the logits are fed to the softmax, and the softmax allow us to decide which is the token from our vocabulary.
Again, if you are not familiar with this, please watch my previous video of the transformer about the inferencing of the transformer, and you will see this clearly. So, this is a description of what happens at a general level in the self-attention. Now, let's watch it step by step. So, imagine at inference step 1, we only have the first token.
If you remember before, we were only using the start of sentence token. So, we take the start of sentence token, we multiply it by itself, so the transposed, it will produce a matrix that is 1 by 1, so this matrix is 1 by 4096, multiplied by another matrix that is 4096 by 1, it will produce a 1 by 1 matrix.
Why 4096? Because the embedding vector in Lama is 4096. Then the output, so this 1 by 1, is multiplied by the V, and it will produce the output token here, and this will be our first token of the output. And then we take the output token, this one, and we append it to the input at the next step.
So, now we have two tokens as input. They are multiplied by itself, but with the transposed version of itself, and it will produce a 2 by 2 matrix, which is then multiplied by the V matrix, and it will produce two output tokens. But we are only interested in the last token's output by the model, so this one, attention 2, which is then appended to the input matrix at the time step 3.
So, now we have three tokens in the time step 3, which are multiplied by the transposed version of itself, and it will produce a 3 by 3 matrix, which is then multiplied by the V matrix, and we have these three tokens as output. But we are only interested in the last token output by the model, so we append it again as input to the Q matrix, which is now four tokens, which is multiplied by the transposed version of itself, and it will produce a 4 by 4 matrix as output, which is then multiplied by this matrix here, and it will produce this attention matrix here.
But we are only interested in the last attention, which will be then added again to the input of the next step. But we notice already something. First of all, we already here in this matrix, where we compute the dot product between this token and this, this token and this, this token and this.
So this matrix is all the dot products between these two matrices. We can see something. The first thing is that we already computed these dot products in the previous step. Can we cache them? So let's go back. As you can see, this matrix is growing. Two, three, four. See, there is a lot of attention, because every time we are inferencing the transformer, we are giving the transformer some input, so it's re-computing all these dot products, which is inconvenient, because we actually already computed them in the previous time step.
So is there a way to not compute them again? Can we kind of cache them? Yes, we can. And then, since the model is causal, we don't care about the attention of a token with its predecessors, but only with a token before it. So as you remember, in the self-attention, we apply a mask, right?
So the mask is basically, we don't want the dot product of one word with the word that comes after it, but only the one that comes before it. So basically, we don't want all the numbers above the principal diagonal of this matrix. And that's why we applied the mask in the self-attention.
But okay, the point is, we don't need to compute all these dot products. The only dot products that we are interested in is this last row. So because we added the token 4 as input compared to the last time step, so we only have this new token, token 4, and we want this token 4 how it is interacting with all the other tokens.
So basically, we are only interested in this last row here. And also, as we only care about the attention of the last token, because we want to select the word from the vocabulary, so we only care about the last row, we don't care about producing these two, these three attention score here in the output sequence of the self-attention, we only care about the last one.
So is there a way to remove all these redundant calculations? Yes, we can do it with the KV cache. Let's see how. So with the KV cache, basically, what we do is we cache the query, sorry, the keys and the values. And every time we have a new token, we append it to the key and the values, while the query is only the output of the previous step.
So at the beginning, we don't have any output from the previous step, so we only use the first token. So the first, the time step one of the inference is the same as without the cache. So we have the token one with itself, will produce a matrix one by one, multiplied with one token, and it will produce one attention.
However, at the time step two, we don't append it to the previous query, we just replace the previous token with the new token we have here. However, we keep the cache of the keys. So we keep the previous token in the keys, and we append the last output to the keys here, and also to the values.
And if you do this multiplication, it will produce a matrix that is one by two, where the first item is the dot product of the token two with the token one and the token two with the token two. This is actually what we want. And if we then multiply with the V matrix, it will only produce one attention score, which is exactly the one we want.
And we do again, so we take this attention two, and this will become the input of the next inference step. So this token three, we append it to the previously cached K matrix and also to the previously cached V matrix. This multiplication will produce an output matrix that we can see here.
The multiplication of this output matrix with this V matrix will produce one token in the output, which is this one, and we know which token to select using this one. Then we use it as an input for the next inferencing step by appending it to the cached keys and appending to the cached V matrix.
We do this multiplication, and we will get this matrix, which is four, one by four, which is the dot product of the token four with the token one, the token four with the token two, token four with the token three, and the token four with itself. We multiply by the V matrix, and this will only produce one attention, which is exactly what we want to select the output token.
This is the reason why it's called the KV cache, because we are keeping a cache of the keys and the values. As you can see, the KV cache allow us to save a lot of computation because we are not doing a lot of dot products that we used to do before, and this makes the inferencing faster.
The next layer that we will be talking about is the grouped multi-query attention, but before we talk about the grouped multi-query attention, we need to introduce its predecessor, the multi-query attention. Let's see. So let's start with the problem. The problem is that the GPUs are too fast. If you watch this datasheet, this is from the A1 GPU from NVIDIA, we can see that the GPU is very fast at computing, at performing calculations, but not so much, not so fast at transferring data from its memory.
That means, for example, that the A100 can do 19.5 tera floating point operations per second by using a 32-bit precision, while it can only transfer 1.9 thousand gigabytes per second. It's nearly 10 times more slower at transferring data than it is at performing calculations, and this means that sometimes the bottleneck is not how many operations we perform, but how much data transfer our operations need, and that depends on the size and the quantity of the tensors involved in our calculations.
For example, if we compute the same operations on the same tensor n times, it may be faster than computing the same operations on n different tokens, even if they have the same size. This is because the GPU may need to move these tensors around. So this means that our goal should not only be to optimize the number of operations we do with our algorithms, but also minimize the memory access and the memory transfers that our algorithms perform, because the memory access and memory transfer are more expensive in terms of time compared to the computations.
And this also happens with software when we do I/O, for example. If we copy, for example, we do some multiplications in the CPU or we read some data from the hard disk, reading from the hard disk is much more slower than doing a lot of computations on the CPU.
And this is a problem. Now, in this paper, we introduced the multi-query attention. This paper is from Noam Shazir, who is also one of the authors of the attention paper. So attention is all you need. And in this paper, he introduced the problem. He said, well, let's look at the multi-head attention.
So the batched multi-head attention. This is the multi-head attention as presented in the original paper. Attention is all you need. Let's look at the algorithm and let's calculate the number of arithmetic operations performed and also the total memory involved in these operations. So he calculated that the number of arithmetic operations is performed in O(1), O(b) and d^2, where b is the batch size, n is the sequence length, and d is the size of the embedding vector.
While the total memory involved in the operations, given by the sum of all the tensors involved in the calculations, including the derived ones, is equal to O(b) and d^2 + b*h*n^2, where h is the number of heads in this multi-head attention, plus d^2. Now, if we compute the ratio between the total memory and the number of arithmetic operations, we get this expression here, 1/k + 1/b.
In this case, the ratio is much smaller than 1, which means that the number of memory accesses that we perform is much less than the number of arithmetic operations. So the memory access in this case is not the bottleneck. So what I mean to say is that the bottleneck of this algorithm is not the memory access, it is actually the number of computations.
And as you saw before, when we introduced the KV cache, the problem we were trying to solve is the number of computations, but by introducing the KV cache, we created a new bottleneck, and it's not the computation anymore. So this algorithm here is the multi-head self-attention, but using the KV cache, and this reduces the number of operations performed.
So if we look at the number of arithmetic operations performed, it's bnd^2. The total memory involved in the operation is bn^2d + ndd^2, and the ratio between the two is this, O(n/d + 1/b), so the ratio between the total memory and the number of arithmetic operations. This means that when n is very similar to d, this ratio will become 1, or when b is very similar to 1, or in the limit of 1, so the batch size is 1, this ratio will become 1.
And this is a problem, because now when this condition is verified, it's true, then the memory access becomes the bottleneck of the algorithm. And this also means that either we keep the dimension of the embedding vector much bigger than the sequence length, but if we increase the sequence length without making the dimension of the embedding vector much bigger, the memory access will become the bottleneck.
So what we can do is, we need to find a better way. To solve the problem of the previous algorithm, in which the memory became the bottleneck, we introduced the multi-query attention. So what the author did was to remove the h dimension from the k and the v, while keeping it for the q.
So it's still a multi-head attention, but only with respect to q, that's why it's called multi-query attention. So we will have multiple heads only for the q, but the k and v will be shared by all the heads. And if we use this algorithm, the ratio becomes this, 1/d + n/dh + 1/b.
So we compare it to the previous one, in which was n/d, now it's n/dh. So we reduced the n/d factor, the ratio n/d by a factor of h, because we removed the h number of heads for the k and v. So the gains, the performance gains are important actually, because now it happens less, it is less likely that this ratio will become 1.
But of course, by removing the heads from the k and v, our model will also have less parameters, it will also have less degrees of freedom and complexity, which may degrade the quality of the model. And it actually does degrade the quality of the model, but only slightly, and we will see.
So if we compare, for example, the blue score on a translation task from English to German, we can see that the multi-head attention, so the attention that was in the original attention paper, has a blue score of 26.7, while the multi-query has a blue score of 26.5. The author also compared it with the multi-head local and multi-query local, where local means that they restrict the attention calculation only to the previous 31 positions of each token.
And we can see it here. But the performance gains by reducing the heads of the k and v is great, because you can see the inference time, for example, on the original multi-head attention and the multi-query attention. The inference time went from 1.7 microseconds plus 46 microseconds for the decoder to 1.5 microseconds plus 3.8 microseconds for the decoder.
So in total here, more or less, we took 48 microseconds, while here we more or less take 6 microseconds for the multi-query. So it's a great benefit from a performance point of view during the inferencing. Let's talk about grouped multi-query attention, because now we just introduced the kvcache and the multi-query attention.
But the next step of the multi-query attention is the grouped multi-query attention, which is the one that is used in llama. So let's have a look at it. With multi-query, we only have multiple heads for the queries, but only one head for the key and the values. With grouped multi-query attention, basically, we divide the queries into groups.
So for example, this is the group 1, this is the group 2, group 3 and group 4. And for each group, we have one different head of k and v. This is a good compromise between the multi-head, in which there is a one-to-one correspondence, and the multi-query, where there is a n-to-one correspondence.
So in this case, we have still multiple heads for the keys and values, but they are less numerically compared to the number of heads of the queries. And this is a good compromise between the quality of the model and the speed of the model, because anyway, here we benefit from the computational benefit of the reduction in the number of heads of key and values, but we don't sacrifice too much on the quality side.
And now the last part of the model. As you can see here, the feedforward in the llama model has been converted into, has its activation function changed with the zwiglu function. Let's have a look at how it works. So the zwiglu function was analyzed in this famous paper from Noam Shazir, who is also one of the authors of the attention model, who is also one of the authors of the multi-query attention that we saw before.
So let's have a look at this paper. So the author compared the performance of the transformer model by using different activation functions in the feedforward layer of the transformer architecture. And the one we are interested in is this zwiglu here, which is basically the swish function with beta equal to one calculated in the X multiplied by a W matrix, which is a parameter matrix, which is then multiplied with the X multiplied by V, V is also another parameter matrix, and W2, which is another parameter matrix.
So compare this with the original feedforward network and here we have three parameter matrices, while in the original feedforward network, we only had two. So to make the comparison fair, the author reduced the number of the size of these matrices to have two such that the model, model's total number of parameters remains the same with the vanilla transformer.
In the vanilla transformer, we had this feedforward network, which was the relu function. So this max zero, et cetera, is the relu function. And we only had the two parameter matrices. Actually, some successor version of the transformer didn't have the bias. So this is, I took this formula from the paper, but there are many implementations without the bias actually.
And while in Lama, we use this computation for the feedforward network. And this is the code I took from the repository from Lama. And as you can see, it's just what the model says. It's the silu function. Why the silu function? Because it's the swish function with beta equal to one.
And when the swish function that has this expression, we give beta equal to one, it's called the sigmoid linear unit that has this graph and it's called silu. So the silu function evaluated in the w1 of x, then multiplied by w3, which is then we apply it to w2.
So we have three matrices. And these three matrices are basically linear layers. Now they use the parallelized version of this linear layer, but it's a linear layer. And if we look at the graph of this silu function, we can see that it's kind of like a relu, but in this here before the zero, we don't cancel out immediately the activation.
We keep a little tail here so that even values that are very close to zero from the negative side are not automatically canceled out by the function. So let's see how does it perform. So this is wiglu function actually performs very well. Here they evaluate the log complexity, perplexity of the model when we use this particular function.
And we can see that the perplexity here is the lowest. The perplexity basically means how unsure is the model about its choices. And the wiglu function is performing well. Then they also run the comparison on many benchmarks. And we see that this wiglu function is performing quite well on a lot of them.
So why is this wiglu activation function working so well? If we look at the conclusion of this paper, we see that we offer no explanation as to why this architecture seems to work. We attribute their success as all else to divine benevolence. Actually, this is kind of funny, but it's also kind of true.
Because in most of the deep learning research, we do not know why things work in the way they do. Because imagine you have a model of 70 billion parameters. How can you prove what is happening to each one of them after you modify one activation function? It's not easy to come up with a model that can explain why the model is reacting in a particular way.
What usually we do, we can either simplify the model, so we can work with a very small model, and then make some assumptions on why things work the way they do. Or we can just do it on a practical level. So we take a model, we modify it a little bit, we do some ablation study, and we check which one is performing better.
And this also happens in a lot of areas of machine learning. For example, we do a lot of grid search to find the right parameters for a model, because we cannot know beforehand which one will work well, or which one to increase, or which one to decrease. Because it depends on a lot of factors, not only on the algorithm used, but also on the data, also on the particular computations used, also on the normalization used.
So there is a lot of factors, there is no formula for everything to explain everything. So this is why the research needs to do a lot of study on the variants of models, to come up with something that works maybe in one domain and doesn't work well in other domains.
So in this case, we use the Zwiglu, mostly because in practice it works well with this kind of models. Thank you guys for watching this long video. I hope that you learned in a deeper level what happens in Lama, and why it is different from a standard transformer model.
I know that the video has been quite long, and I know that it has been hard on some parts to follow, so I actually kind of suggest to re-watch it multiple times, especially the parts that you are less familiar with, and to integrate this video with my previous video about the transformer.
So you can, I will put the chapters so you can easily find the part that you want, but this is what you need to do. You need to watch multiple times the same concept to actually master it. And I hope to make another video in which we code the Lama model from zero, so we can put all this theory into practice.
But as you know, I am doing this on my free time, and my free time is not so much. So thank you guys for watching my video, and please subscribe to my channel, because this is the best motivation for me to keep posting amazing content on AI and machine learning.
Thank you for watching, and have an amazing rest of the day.