back to indexLLaMA explained: KV-Cache, Rotary Positional Embedding, RMS Norm, Grouped Query Attention, SwiGLU
Chapters
0:0 Introduction
2:20 Transformer vs LLaMA
5:20 LLaMA 1
6:22 LLaMA 2
6:59 Input Embeddings
8:52 Normalization & RMSNorm
24:31 Rotary Positional Embeddings
37:19 Review of Self-Attention
40:22 KV Cache
54:0 Grouped Multi-Query Attention
64:7 SwiGLU Activation function
00:00:00.000 |
Hello guys! Welcome to my new video about Lama. In this video we will be seeing what 00:00:05.400 |
is Lama, how it is made, how it is structurally different from the transformer and we will 00:00:12.560 |
be building each block that makes up Lama. So I will not only explain you concept-wise 00:00:18.400 |
what is each block doing but we will also explore it from the mathematical point of 00:00:22.480 |
view and also from the coding point of view, so that we can unify theory with practice. 00:00:29.440 |
I can guarantee that if you watch this video you will have a deep understanding of what 00:00:35.380 |
makes Lama the model it is. So you will not only understand how the blocks interact with 00:00:42.960 |
each other but how they function and why we needed these blocks in the first place. In 00:00:49.440 |
this video we will be reviewing a lot of topics, so we will start from the architectural differences 00:00:55.360 |
between the vanilla transformer and the Lama model. We will be watching what is the new 00:00:59.600 |
normalization, the RMS normalization, rotary positional embedding, KV cache, multi-query 00:01:04.240 |
attention, grouped multi-query attention, the ZWIGLU activation function for the feed-forward 00:01:08.400 |
layer. But of course I take for granted that you have some background knowledge. First 00:01:14.280 |
of all I highly recommend that you watch my previous video about the transformer because 00:01:18.440 |
you need to know how the transformer works. And in my previous video I also explored the 00:01:23.040 |
concept of training and inferencing a transformer model. It's about 45 minutes and I think it's 00:01:29.200 |
worth a watch because it will really give you a deep understanding of the transformer. 00:01:33.560 |
After you have that knowledge you can watch this video. Anyway, for those who have already 00:01:38.120 |
watched the video but forgot some things, I will review most of the concepts as we proceed 00:01:44.520 |
through the topics. I also take for granted that you have some basic linear algebra knowledge, 00:01:51.240 |
so matrix multiplication, dot product, basic stuff anyway. And also, because we will be 00:01:56.600 |
using the rotary positional embeddings, some knowledge about the complex numbers, even 00:02:00.440 |
if it's not fundamental. So if you don't remember the complex numbers or how they work 00:02:06.040 |
or the LRS formula, it doesn't matter. You will understand the concept, not the math. 00:02:10.680 |
It's not really fundamental. Sometimes I will be reviewing topics that maybe you are already 00:02:16.040 |
familiar with, so feel free to skip those parts. Let's start our journey by reviewing 00:02:21.800 |
the architectural differences between the vanilla transformer and Lama. This picture 00:02:27.400 |
was built by me on the right side because I couldn't find the architectural picture on 00:02:33.400 |
the paper. So let's review the differences. As you remember, in the vanilla transformer 00:02:39.320 |
we have an encoder part and a decoder part. And let me highlight it. So this is the 00:02:46.040 |
encoder and the right side here is the decoder. While in Lama, we only have an encoder. First 00:02:53.720 |
of all, because the Lama is a large language model, it has been trained on the next prediction 00:02:59.240 |
token task. So basically, we only need the self-attention to predict the next token. 00:03:04.120 |
And we will see all these concepts. So we will see what is the next prediction task, 00:03:07.400 |
how it works, and how this new self-attention works. The second difference that we can see 00:03:13.080 |
from these pictures is that we have here, at the beginning, we have the embedding and 00:03:19.080 |
also we had the embedding here on the original transformer. But right after the embedding, 00:03:23.640 |
we don't have the positional encoding, but we have this RMS norm. And actually, all the 00:03:28.600 |
norms have been moved before the blocks. So before we had the multi-head attention, 00:03:34.280 |
and then we had the add-end norm, which is this plus sign here. So it's a concatenation 00:03:40.120 |
of a skip connection and the output of the multi-head attention, and the normalization. 00:03:44.840 |
And we also have this normalization here, here, here. So after every block. But here in Lama, 00:03:50.360 |
we have it before every block. And we will review what is the normalization and why 00:03:55.080 |
it works like the way it is. Right after the normalization, we have this query, 00:04:01.400 |
key, and values input for the self-attention. One thing we can notice is that the positional 00:04:07.960 |
encodings are not anymore the positional encodings of the transformer, but they have become the 00:04:13.000 |
rotary positional encodings, and they are only applied to the query and the keys, but not the 00:04:17.960 |
values. And we will see also why. Another thing is the self-attention is now the self-attention 00:04:25.320 |
with KV cache. We will see what is the KV cache and how it works. And also we have this grouped 00:04:31.560 |
multi-query attention. Another thing that changed is this feed-forward layer. In the original 00:04:38.440 |
feed-forward layer of the vanilla transformer, we had the relu activation function for the 00:04:45.640 |
feed-forward block. But in Lama, we are using the zwiglu function, and we will see why. 00:04:51.160 |
This nx means that this block here in the dashed lines is repeated n times one after another, 00:04:59.800 |
such that the output of the last layer is then fed to this rms norm, then to the linear layer, 00:05:05.400 |
and then to the softmax. And we will build each of these blocks from the bottom. So I will show 00:05:11.960 |
you exactly what these blocks do, how they work, how they interact with each other, what is the 00:05:16.680 |
math behind, what is the problem they were trying to solve. So we will have a deep knowledge of 00:05:21.560 |
these models. Let's start our journey with reviewing the models introduced by Lama. 00:05:27.880 |
So Lama1 came out in February 2023, and they had four dimensions for this model. One model was 00:05:37.160 |
with 6.7 billion parameters, 13, 32, 65. And then we have these numbers. What do they mean? 00:05:43.880 |
The dimension here indicates the size of the embedding vector. So as you can see here, 00:05:50.760 |
we have these input embeddings that we will review later. This is basically, they convert 00:05:55.960 |
each token into a vector of size indicated by this dimension. Then we have the number of heads. 00:06:01.560 |
So how many heads the attention has, the number of layers. If you remember from the original 00:06:08.680 |
transformer, the dimension was 512. The number of heads was eight. The number of layers, I think, 00:06:14.600 |
was six. And then we have the number of tokens each model was trained upon. So 1 trillion and 00:06:22.280 |
1.4 trillion. With Lama2, most of the numbers have doubled. So the context length is basically 00:06:29.160 |
the sequence length. So what is the longest sequence the model can be fed? And then the 00:06:37.400 |
number of tokens upon which the model have been trained is also doubled. So from 1 to 2 trillion 00:06:42.120 |
for each size of the model, while the parameters more or less remain the same. Then we have this 00:06:47.560 |
column here, GQA, that indicates that these two sizes of the model, so the 34 billion and 70 00:06:54.040 |
billion, they use the grouped query attention. And we will see how it works. Let's start by 00:07:00.200 |
reviewing what is the embeddings layer here. And for that, I will use the slides from my 00:07:05.480 |
previous video. If you remember my previous video, we introduced the embedding like this. 00:07:10.040 |
So we have a sentence that is made of six words. What we do is we tokenize the sentence, 00:07:15.880 |
so it converts into tokens. The tokenization usually is done not by space, but by the BPE 00:07:22.440 |
tokenizer. So actually, each word will be split into subwords also. But for clarity, for simplicity, 00:07:28.760 |
we just tokenize our sentence by using the whitespace as separator. So each token is 00:07:35.880 |
separated by whitespace from other tokens. And each token is then mapped into its position 00:07:42.680 |
into the vocabulary. So the vocabulary is the list of the words that our model recognizes. 00:07:51.480 |
They don't have to be words, of course. They could be anything. They are just tokens. 00:07:56.920 |
So each token occupies a position in this vocabulary, and the input IDs indicate the number 00:08:02.520 |
occupied by each token in the vocabulary. Then we map each input ID into a vector of size 512 00:08:11.320 |
in the original transformer. But in Lama, it becomes 4096. And these embeddings are vectors 00:08:19.880 |
that are learnable. So they are parameters for the model. And while the model will be trained, 00:08:24.920 |
this embedding will change in such a way that they will capture the meaning of the word they 00:08:29.720 |
are mapping. So we hope that, for example, the word "cat" and "dog" will have similar embedding, 00:08:37.320 |
because kind of they map similar-- at least they are in the same semantic group. And also, 00:08:43.880 |
the word "house" and "building," they will be very close to each other if we check the two vectors. 00:08:49.880 |
And this is the idea behind the embedding. Now let's check what is normalization. Because this 00:08:57.480 |
is the layer right after the embeddings. And for that, let's introduce some review of the 00:09:03.880 |
neural networks and how they work. So suppose we have a feed-forward neural network with an input, 00:09:10.760 |
a hidden layer made of neurons, another hidden layer made of another five neurons, 00:09:18.680 |
which then maps to an output. We usually have a target. And comparing the output with the target, 00:09:24.680 |
we produce a loss. The loss is then propagated back to the two hidden layers by means of back 00:09:31.400 |
propagation. So what we do is we calculate the gradient of the loss with respect to each weight 00:09:37.240 |
of these two hidden layers. And we modify these weights of the hidden layer accordingly, 00:09:43.800 |
also according to the learning rate that we have set. To check why we need to normalize and what is 00:09:50.760 |
the need of normalization, I will make a simplification of the neural network. So let's 00:09:56.680 |
suppose our neural network is actually a factory, a factory that makes phones. So to make a phone, 00:10:02.600 |
we start with some raw material that are given to a hardware team that will take the raw material 00:10:08.760 |
and produce some hardware. For example, they may select the Bluetooth device, they may select the 00:10:14.760 |
display, they may select the microphone, the camera, etc, etc. And they make up the hardware 00:10:20.920 |
of this phone. The hardware team then gives this prototype to the software team, which then creates 00:10:26.840 |
the software for this hardware. And then the output of the software team is the complete phone 00:10:32.120 |
with hardware and software and is given as the output. The output is then compared with what was 00:10:38.760 |
the original design of the phone. And then we compute a loss. So what is the difference between 00:10:44.440 |
the target we had for our phone and what we actually produced? So suppose the loss is our CEO. 00:10:50.680 |
And the loss is quite big, suppose. So our CEO will talk with the hardware team and with the 00:10:57.880 |
software team and will tell them to adjust their strategy so as to go closer to the target next 00:11:03.720 |
time. So suppose that the hardware was too expensive. So the CEO will tell the hardware 00:11:08.600 |
team to use maybe a smaller display, to use a cheaper camera, to change the Bluetooth to a 00:11:14.200 |
lower range one, or to change the Wi-Fi to a low energy one, to change the battery, etc, etc. 00:11:19.320 |
And we'll also talk with the software team to adjust their strategy and maybe tell the software 00:11:25.160 |
team to concentrate less on refactoring, to concentrate less on training, to hire more 00:11:32.120 |
interns and not care too much about the employees because the costs are too high, etc. And he will 00:11:39.880 |
adjust the strategy of the software and the hardware team. So the next time we start with the 00:11:45.560 |
raw material again. So let's go back. We start with the raw material again. And the hardware 00:11:53.800 |
team, according to the new strategy set by the CEO, will produce a new hardware. Now the problem 00:12:00.520 |
arises. The software team now will receive a hardware that the software team has never seen 00:12:05.800 |
before because the display has been changed, the Bluetooth has been changed, the Wi-Fi has been 00:12:11.160 |
changed, everything has been changed. So the software team needs to redo a lot of work and 00:12:17.480 |
especially they need to adjust their strategy a lot because they are dealing with something they 00:12:23.400 |
have never seen before. So the output of the software team will be much different compared 00:12:28.840 |
to what they previously output. And maybe it will be even further from the target because 00:12:35.960 |
the software team was not ready to make all these adjustments. So maybe they wasted a lot of time, 00:12:40.280 |
so maybe they wasted a lot of resources, so they maybe could not even reach the target, 00:12:45.320 |
even get closer to the target. So this time maybe the loss is even higher. So as you can see, 00:12:50.920 |
the problem arises by the fact that the loss function modifies the weights of the hardware 00:12:57.000 |
team and the software team. But then the software team at the next iteration receives an input that 00:13:04.520 |
it has never seen before and this input makes it produce an output that is much divergent compared 00:13:12.520 |
to the one it used to produce before. This will make the model oscillate kind of in the loss 00:13:18.760 |
and will make the training very slower. Now let's look what happens at the math level to understand 00:13:25.160 |
how the normalization works. So let's review some maths. Suppose that we have a linear layer 00:13:31.320 |
defined as nn.linear with three input features and five output features with bias. This is the 00:13:39.000 |
linear layer as defined in PyTorch. The linear layer will create two matrices, one called W, 00:13:45.960 |
the weight, and one called B, the bias. Suppose we have an input of shape 10 rows by 3 columns, 00:13:53.240 |
the output of this linear layer with this input x will be 10 rows by 5 columns. But how does this 00:14:00.920 |
happen mathematically? Let's review it. So imagine we have our input which is 10 by 3, 00:14:06.440 |
which means that we have 10 items and each item has 10 features. The W matrix created by the 00:14:13.400 |
linear layer will be 5 by 3, so the output features by the 3 input features. And we can think of each 00:14:21.880 |
of this row as one neuron, each of them having three weights, one weight for each of the input 00:14:29.560 |
features of the x input. Then we have the bias vector and the bias vector is one weight for each 00:14:39.240 |
neuron because the bias is one for every neuron. And this will produce an output which is 10 by 5, 00:14:46.920 |
which means we have 10 items with 5 features. Let's try to understand what is the flow of 00:14:54.360 |
information in these matrices. The flow of information is governed by this expression, 00:15:01.800 |
so the output is equal to the x multiplied by the transpose of the W matrix plus B. 00:15:09.880 |
So let's suppose we have this input x and we have one item and the item 1 has three features, 00:15:17.640 |
A1, A2 and A3. The transpose of Wt is this matrix here, so in which we swap the row with the 00:15:25.240 |
columns because according to the formula we need to make the transpose of that matrix. 00:15:29.240 |
So we have neuron 1 with the three weights W1, W2, W3. We multiply the two and we obtain this 00:15:36.120 |
matrix, so x multiplied by the transpose of W produces this matrix here, in which this row 1 00:15:43.640 |
is the dot product of this row vector with this column vector. Then we add the B row vector. 00:15:55.160 |
As you can see, to add two matrices they need to have the same dimension, but in PyTorch, 00:16:01.960 |
because of broadcasting, this row will be added to this row here and then to independently to 00:16:08.360 |
this row and to this row etc etc because of the broadcasting. And then we will have this output. 00:16:15.000 |
And the first item here will be Z1. What is Z1? Well, Z1 is equal to R1 plus B1. But what is R1? 00:16:25.480 |
R1 is the dot product of this column with this row or this row with this column. So it's this 00:16:31.320 |
expression here. So the output of the neuron 1 for the item 1 only depends on the features of 00:16:38.040 |
the item 1. Usually after this output we also apply a non-linearity like the ReLU function, 00:16:44.280 |
which and the argument of the ReLU function is referred to as the activation of the neuron 1. 00:16:51.000 |
Now, as we can see, the output of the neuron 1 only depends on the input features of each item. 00:16:59.480 |
So the output of a neuron for a data item depends on the features of the input data 00:17:04.200 |
item and the neuron's parameter. We can think of the input to a neuron as the output of a previous 00:17:09.960 |
layer. So, for example, that input that we saw before, the X, it may as well be the output of 00:17:14.840 |
the previous layer. If the previous layer, after its weight are updated because of the gradient 00:17:21.320 |
descent, changes drastically the output, like we did before, for example, because the CEO realigned 00:17:27.640 |
the strategy of the hardware team, so the previous layer, the hardware team, will produce an output 00:17:31.720 |
that is drastically different compared to what it used to produce, the next layer will have its 00:17:38.040 |
output changed also drastically. So, because it will be forced to readjust its weight drastically 00:17:45.320 |
at the next step of the gradient descent. So what we don't like is the fact that the weight, 00:17:50.760 |
the output of the previous layer changes too much, so that the next layer also has to change 00:17:56.440 |
its output a lot, because it's to adhere to the strategy defined by the loss function. 00:18:03.720 |
So this phenomenon, by which the distribution of the internal nodes of a neuron change, is referred 00:18:09.960 |
to as internal covariate shift. And we want to avoid it, because it makes training the network 00:18:15.240 |
slower, as the neurons are forced to readjust drastically their weights in one direction 00:18:20.280 |
or another, because of drastic changes in the output of the previous layers. 00:18:25.000 |
So what do we do? We do layer normalization, at least in the vanilla transformer. So let's 00:18:30.440 |
review how the layer normalization works. Imagine we still have our input x defined with 10 rows by 00:18:37.640 |
3 columns, and for each of these items, independently, we calculate two statistics. 00:18:46.760 |
One is the mu, so the mean, and one is the sigma, so the variance. And then we normalize the values 00:18:55.960 |
in this matrix according to this formula. So we take basically x minus its mu, so each item minus 00:19:04.120 |
the mu, divided by the square root of the variance plus epsilon, where epsilon is a very small number, 00:19:11.160 |
so that we never divide by zero in this way, even if the variance is very small. 00:19:15.160 |
And each of these numbers is then multiplied with the two parameters, 00:19:20.840 |
one is gamma, and one is beta. They are both learnable by the model, and they are useful, 00:19:26.760 |
because the model can adjust this gamma and beta to amplify the values that it needs. 00:19:32.040 |
So before we had layer normalization, we used to normalize with batch normalization, 00:19:40.600 |
and with batch normalization, the only difference is that instead of calculating the statistics by 00:19:45.400 |
rows, we calculated them by columns. So the feature 1, feature 2, and feature 3. 00:19:51.000 |
With layer normalization, we do it by row. So each row will have its own mu and sigma. 00:19:57.000 |
So by using the layer normalization, basically, we transform the initial distribution of features, 00:20:03.320 |
no matter what they are, into normalized numbers that are distributed with 0 mean and 1 variance. 00:20:10.840 |
So this formula actually comes from probability statistics, and if you remember, 00:20:14.360 |
let me use the pen, okay, if you remember, basically, if we have a variable x, which is 00:20:23.160 |
distributed like a normal variable with a mean, let's say 5, and a variance of 36, 00:20:31.000 |
if we do x minus its mean, so 5 divided by the square root of the variance, so 36, 00:20:41.800 |
this one, this variable here, let's call it z, will be distributed like n, 0, 1. 00:20:50.600 |
So it will become a standard Gaussian, and this is exactly what we are doing here. So we are 00:20:56.600 |
transforming them into standard Gaussians, so that this value, most of the times will be close to 0, 00:21:03.000 |
I mean, will be distributed around 0. Now let's talk about root-mean-square 00:21:09.080 |
normalization, the one used by Lama. The root-mean-square normalization was introduced 00:21:17.240 |
in this paper, root-mean-square layer normalization, from these two researchers, 00:21:23.320 |
and let's read the paper together. A well-known explanation of the success of layer norm is its 00:21:29.800 |
re-centering and re-scaling invariance property. So what do they mean? What is the re-centering 00:21:36.040 |
and the re-scaling invariance? The fact that the features, no matter what they are, they will be 00:21:41.080 |
re-centered around the zero mean, and re-scaled to have a variance of 1. The former enables the 00:21:47.800 |
model to be insensitive to shift noises on both input and weights, and the latter keeps the output 00:21:54.120 |
representations intact when both input and weight are randomly scaled. In this paper, we hypothesize 00:22:00.520 |
that the re-scaling invariance is the reason for success of layer norm, rather than the re-centering 00:22:06.680 |
invariance. So what they claim in this paper is that, basically, the success of layer norm is not 00:22:14.360 |
because of the re-centering and the re-scaling, but mostly because of the re-scaling, so this 00:22:21.240 |
division by the variance, basically, so to have a variance of 1. And what they do is, basically, 00:22:27.960 |
they said, okay, can we find another statistic that doesn't depend on the mean because we believe 00:22:33.720 |
that it's not necessary? Well, yes. They use this root-mean-square statistic, so this statistic 00:22:41.480 |
defined here, the statistic defined here, and as you can see from the expression of this statistic, 00:22:51.960 |
we don't use the mean to calculate it anymore, because the previous statistics here, so the 00:22:56.840 |
variance, to be calculated you need the mean, because if you remember, the variance to be 00:23:03.320 |
calculated needs the mean, so the variance is equal to the summation of x minus mu to the power of 2 00:23:11.240 |
divided by n. So we need the mean to calculate the variance. So what the authors wanted to do 00:23:18.040 |
in this paper, they said, okay, because we don't need to re-center, because we believe, we 00:23:23.400 |
hypothesize that the re-centering is not needed to obtain the effect of the layer normalization, 00:23:28.680 |
we want to find a statistic that doesn't depend on the mean, and the RMS statistic doesn't depend 00:23:33.400 |
on the mean. So they do exactly the same thing that they did in the layer normalization, so they 00:23:39.800 |
calculate the RMS statistic by rows, so one for each row, and then they normalize according to 00:23:46.840 |
this formula here, so they just divide by the statistic, RMS statistic, and then multiply by 00:23:52.120 |
this gamma parameter, which is learnable. Now, why root-mean-square normalization? Well, 00:24:00.200 |
it requires less computation compared to layer normalization, because we are not computing 00:24:05.240 |
two statistics, so we are not computing the mean and the sigma, we are only computing one, 00:24:10.360 |
so it gives you a computational advantage. And it works well in practice, so actually what the 00:24:17.480 |
authors of the paper hypothesized is actually true, we only need the invariance to obtain the 00:24:23.720 |
effect made by the layer normalization, we don't need the re-centering. At least, this is what 00:24:29.080 |
happens with Lama. The next topic we will be talking about is the positional encodings, 00:24:34.200 |
but before we introduce the rotary positional encodings, let's review the positional encodings 00:24:39.320 |
in the vanilla transformer. As you remember, after we transform our tokens into embeddings, 00:24:45.320 |
so vectors of size 512, in the vanilla transformer, then we sum another vector to these 00:24:51.320 |
embeddings, that indicate the position of each token inside the sentence, 00:24:59.000 |
and these positional embeddings are fixed, so they are not learned by the model, they are computed 00:25:04.600 |
once and then they are reused for every sentence during training and inference, and each word gets 00:25:12.280 |
his own vector of size 512. We have a new kind of positional encoding called rotary positional 00:25:18.760 |
encoding, so absolute positional encodings are fixed vectors that are added to the embedding 00:25:24.680 |
of a token to represent its absolute position in the sentence, so the token number 1 gets its own 00:25:30.600 |
vector, the token number 2 gets its own vector, the token number 3 gets its own vector, so the 00:25:36.600 |
absolute positional encoding deals with one token at a time. You can think of it as the pair latitude 00:25:42.680 |
and longitude on a map, each point on the earth will have its own unique latitude and longitude, 00:25:49.000 |
so that's an absolute indication of the position of each point on the earth, and this is the same 00:25:54.520 |
what happens with absolute positional encoding in the vanilla transformer. We have one vector 00:25:59.000 |
that represents exactly that position, which is added to that particular token in that position. 00:26:04.280 |
With relative positional encodings, on the other hand, it deals with two tokens at a time, 00:26:10.760 |
and it is involved when we calculate the attention. Since the attention mechanism 00:26:15.320 |
captures the intensity of how much two words are related to each other, relative positional 00:26:20.600 |
encodings tell the attention mechanism the distance between the two words involved in 00:26:25.880 |
this attention mechanism. So, given two tokens, we create a vector that represents their distance. 00:26:33.000 |
This is why it's called relative, because it's relative to the distance between two tokens. 00:26:38.440 |
Relative positional encodings were first introduced in the following paper 00:26:42.520 |
from Google, and you can notice that Vasvani, I think, is the same author of the transformer model. 00:26:50.200 |
So, now, with absolute positional encoding, so from the attention is all you need, 00:26:56.520 |
when we calculate the dot product in the attention mechanism, so if you remember the 00:27:04.200 |
attention mechanism, the formula, let me write it, the attention is equal to the query multiplied by 00:27:18.040 |
the transpose of the key divided by the square root of d model, d model, all of this, then we 00:27:28.040 |
do the softmax, and then we multiply it by v, etc., etc., but we only concentrate on the q 00:27:33.160 |
multiplied by the k transposed in this case, and this is what we see here. So, when we calculate 00:27:41.080 |
this dot product, the attention mechanism is calculating the dot product between two tokens, 00:27:47.800 |
that already have the absolute position encoded into them, because we already added the absolute 00:27:54.440 |
positional encoding to each token. So, in this attention mechanism from the vanilla transformer, 00:27:59.480 |
we have two tokens and the attention mechanism, while in relative positional encodings, 00:28:03.880 |
we have three vectors. We have the token one, the token two, and then we have this vector here, 00:28:15.960 |
we have this vector here, that represents the distance between these two tokens, 00:28:21.480 |
and so we have three vectors involved in this attention mechanism, and we want the attention 00:28:28.200 |
mechanism to actually match this token differently based on this vector here. So, this vector will 00:28:35.160 |
indicate to the attention mechanism, so to the dot product, how to relate these two words that 00:28:40.840 |
are at this particular distance. With rotary positional embeddings, we do a similar job, 00:28:48.120 |
and they were introduced with this paper, so Reformer, and they are from a Chinese company. 00:28:55.160 |
So, the dot product used in the attention mechanism is a type of inner product. So, 00:29:01.720 |
if you remember from linear algebra, the dot product is a kind of operation that has some 00:29:07.480 |
properties, and these properties are the kind of properties that every inner product must have. 00:29:13.400 |
So, the inner product can be thought of as a generalization of the dot product. 00:29:17.400 |
What the authors of the paper wanted to do is, can we find an inner product over the two-vector 00:29:25.640 |
query and key used in the attention mechanism that only depends on the two vectors themselves 00:29:32.280 |
and the relative distance of the token they represent. That is, given two vectors, query 00:29:39.480 |
and key, that only contain the embedding of the word that they represent, and their position 00:29:46.120 |
inside of the sentence, so this m is actually an absolute number, so it's a scalar, it represents 00:29:52.440 |
the position of the word inside of the sentence, and this n represents the position of the second 00:29:57.240 |
word inside of the sentence. What they wanted to say is, can we find an inner product, so this 00:30:03.080 |
particular parenthesis we see here is an inner product between these two vectors, 00:30:09.000 |
that behaves like this function g, that only depends on the embedding of xn, so the first 00:30:17.080 |
token, of xn, the second token, and the relative distance between them, and no other information. 00:30:24.760 |
So this function will be given only the embedding of the first token, the embedding of the second 00:30:29.640 |
token, and a number that represents the relative position of these two tokens, relative distance 00:30:35.960 |
of these two tokens. Yes, we can find such a function, and the function is the one defined 00:30:43.320 |
here. So we can define a function g, like the following, that only needs, only depends on the 00:30:49.640 |
two embedding vectors q and k, and the relative distance. And this function is defined in the 00:30:56.840 |
complex number space, and it can be converted by using the Euler formula into this form. 00:31:03.160 |
And another thing to notice is that this function here, the one we are watching, is defined for 00:31:09.720 |
vectors of dimension 2. Of course later we will see what happens when the dimension is bigger. 00:31:17.560 |
And when we convert this expression here, which is in the complex number space, 00:31:22.680 |
into it's matrix form, through the Euler's formula, we can recognize this matrix here 00:31:29.160 |
as the rotation matrix. So this matrix here basically represents the rotation of a vector. 00:31:34.760 |
For example, this one here, so this product here, will be a vector, and this rotation matrix will 00:31:42.920 |
rotate this vector into the space by the amount described by m theta, so the angle m theta. 00:31:50.840 |
Let's see an example. So imagine we have a vector v0, and we want to rotate it by theta, 00:31:58.840 |
by an angle theta here, to arrive to the vector v prime. So what we do is, we multiply the vector 00:32:05.320 |
v0 with this matrix, exactly this one, in which the values are calculated like this, 00:32:11.000 |
cosine of theta, minus sine of theta, sine of theta, and cosine of theta. 00:32:15.080 |
And the resulting vector will be the same vector, so the same length, but rotated by this angle. 00:32:22.040 |
And this is why they are called rotary positional embeddings, 00:32:29.160 |
Now, when the vector is not two-dimensional, but we have n dimensions, for example in the 00:32:36.680 |
original transformer model our embedding size is 512, and in Lama it's 4096, we need to use this 00:32:44.680 |
form. Now, I want you to notice not what are the numbers in this matrix, but the fact that this 00:32:51.480 |
matrix is sparse, so it is not convenient to use it to compute the positional embeddings, 00:32:57.080 |
because if we multiply by this embedding, our tensorflow, our gpu, our computer will do a lot 00:33:02.680 |
of operations that are useless, because we already know that most of the products will be zero. 00:33:07.160 |
So, is there a better way, a more computationally efficient way to do this computation? 00:33:12.200 |
Well, there is, this form here. So, given a token with the embedding vector x, 00:33:19.160 |
and the position m of the token inside the sentence, this is how we compute the position 00:33:24.760 |
embedding for the token. We take the dimensions of the token, we multiply by this matrix here, 00:33:31.640 |
computed like the following, where the theta are fixed, m is the position of the token, x1, x2, 00:33:38.680 |
x3 are the dimensions of the embedding, so the first dimension of the embedding, the second 00:33:42.360 |
dimension of the embedding, etc., plus, minus the second embedding, this vector computed like 00:33:50.120 |
with the following positions, so minus x2, which is the negative value of the second 00:33:57.400 |
dimension of the embedding of the vector x, multiplied by this matrix here. 00:34:03.240 |
So, there is nothing we have to learn in this matrix, everything is fixed, because if we watch 00:34:07.240 |
the previous slide, we can see that this theta actually is computed like this, one for each 00:34:13.240 |
dimension, and so there is nothing to learn. So, basically, they are just like the absolute 00:34:20.040 |
positional encoding, so we compute them once, and then we can reuse them for all the sentences that 00:34:27.080 |
we will train the model upon. Another interesting property of the 00:34:30.760 |
rotary positional embeddings is the long-term decay. So, what the authors did, they calculated 00:34:37.000 |
an upper bound for the inner product that we saw before, so the g function, by varying the distance 00:34:42.760 |
between the two tokens, and then they proved that no matter what are the two tokens, there is an 00:34:48.840 |
upper bound that decreases as the distance between the two tokens grow. And if you remember that the 00:34:57.560 |
inner product or the dot product that we are computing is for the calculation of the attention, 00:35:02.760 |
this dot product represents the intensity of relationship between the two tokens for which 00:35:07.640 |
we are computing the attention. And what these rotary positional embeddings do, they will 00:35:12.920 |
basically decay this relationship, the strength of this relationship between the two tokens, 00:35:19.560 |
if the two tokens that we are matching are distant from each other. And this is actually 00:35:28.120 |
what we want. So, we want two words that are very far from each other to have a less strong 00:35:33.240 |
relationship, and two words that are close to each other to have a stronger relationship. 00:35:37.480 |
And this is a desired property that we want from these rotary positional embeddings. 00:35:42.200 |
Now, the rotary positional embeddings are only applied to the query and the keys, 00:35:48.200 |
but not to the values. Let's see why. Well, the first consideration is that they basically, 00:35:54.440 |
they come into play when we are calculating the attention. So, when we calculate the attention, 00:35:59.080 |
it's the attention mechanism that will change the score. So, as you remember, 00:36:05.400 |
the attention mechanism is kind of a score that tells how much strong is the relationship between 00:36:10.920 |
two tokens. So, this relationship will be stronger or less stronger or will change according to also 00:36:18.920 |
the position of these two tokens inside of the sentence and the relative distance between these 00:36:24.760 |
two tokens. Another thing is that the rotation, rotary positional embeddings are applied after 00:36:30.040 |
the vector Q and K have been multiplied by the W matrix in the attention mechanism, 00:36:35.480 |
while in the vanilla transformer, they are applied before. So, in the vanilla transformer, 00:36:40.200 |
the position embeddings are applied right after we transform the tokens into embeddings. 00:36:47.240 |
But in the rotary positional embeddings, so in Lama, we don't do this. We basically, 00:36:53.240 |
right after we multiply by the W matrix in the attention mechanism. So, the W matrix, 00:36:59.080 |
if you remember, is the matrix of parameters that each head has, each attention head has. 00:37:07.160 |
And so, in the Lama, basically, we apply the rotary position encoding after we multiply the 00:37:16.200 |
vectors Q and K by the W matrix. Now comes the interesting part, in which we will watch how the 00:37:21.880 |
self-attention works in Lama. But before we can talk about the self-attention as used in Lama, 00:37:27.560 |
we need to review, at least briefly, the self-attention in the vanilla transformer. 00:37:33.000 |
So, if you remember the self-attention in the vanilla transformer, we start with the matrix Q, 00:37:38.760 |
which is a matrix of sequence by the model, which means that we have on the rows, the tokens, 00:37:45.480 |
and on the columns, the dimensions of the embedding vector. So, we can think of it like the following. 00:37:51.400 |
Let me. Okay. So, we can think of it like having six rows, one, and each of these rows is a vector 00:38:01.240 |
of dimension 512 that represents the embedding of that token. And now, let me delete. 00:38:08.040 |
And then, we multiply according to this formula. So, Q multiplied by the transpose of the K. So, 00:38:16.280 |
transpose of the K divided by the square root of 512, which is the dimension of the embedding vector, 00:38:22.040 |
where the K is equal to Q and V is also equal to Q, because this is a self-attention. So, 00:38:29.000 |
the three matrices are actually the same sequence. Then, we apply the softmax and we obtain this 00:38:36.120 |
matrix. So, we had the matrix that was 6 by 512 multiplied by another one that is 512 by 6. We 00:38:43.400 |
will obtain a matrix that is 6 by 6, where each item in this matrix represents the dot product 00:38:50.440 |
of the first token with itself, then the first token with the second token, the first token with 00:38:56.760 |
the third token, the first token with the fourth token, etc. So, this matrix captures the intensity 00:39:03.720 |
of relationship between two tokens. Then, the output of this softmax is multiplied by the V 00:39:13.640 |
matrix to obtain the attention sequence. So, the output of the self-attention is another matrix 00:39:21.080 |
that has the same dimensions as the initial matrix. So, it will produce a sequence where 00:39:28.280 |
the embeddings now not only capture the meaning of each token, not only they capture the position 00:39:34.600 |
of each token, but they also capture kind of the relationship between that token and every other 00:39:40.760 |
token. If you didn't understand this concept, please go back and watch my previous video about 00:39:45.800 |
the transformer where I explain it very carefully and in much more detail. Now, let's have a look 00:39:51.320 |
at the multi-head attention very briefly. So, the multi-head attention basically means that we have 00:39:57.720 |
an input sequence, we take it, we copy it into Q, K, and V, so they are the same matrix, we multiply 00:40:05.400 |
by parameter matrices, and then we split into multiple smaller matrices, one for each head, 00:40:12.200 |
and we calculate the attention between these heads. So, head 1, head 2, head 3, head 4. 00:40:17.160 |
Then, we concatenate the output of these heads, we multiply by the output matrix W_O, 00:40:22.920 |
and finally we have the output of the multi-head attention. Let's look at what is the first KV 00:40:28.520 |
cache. So, before we introduce the KV cache, we need to understand how Lama was trained, 00:40:35.320 |
and we need to understand what is the next token prediction task. So, Lama, just like most of the 00:40:42.120 |
large language models, have been trained on the next token prediction task, which means that given 00:40:47.880 |
a sequence, it will try to predict what is the next token, the most likely next token, to continue 00:40:54.760 |
the prompt. So, for example, if we tell him a poem, for example, without the last word, 00:41:01.880 |
probably it will come up with the last word that is missing from that poem. In this case, 00:41:07.640 |
I will be using one very famous passage from Dante Alighieri, and I will not use the Italian 00:41:13.640 |
translation, but we will use the English translation here. So, I will only deal with 00:41:17.640 |
the first line you can see here, "Love that can quickly seize the gentle heart". 00:41:21.480 |
So, let's train Lama on this sentence. How does the training work? Well, we give the input to 00:41:29.480 |
the model, the input is built in such a way that we first prepare the start of sentence token, 00:41:35.400 |
and then the target is built such that we append an end of sentence token. Why? Because the 00:41:43.160 |
model, this transformer model, is a sequence-to-sequence model, which maps each 00:41:50.040 |
position in the input sequence into another position in the output sequence. So, basically, 00:41:58.280 |
the first token of the input sequence will be mapped to the first token of the output sequence, 00:42:04.040 |
and the second token of the input sequence will be mapped to the second token of the output 00:42:08.440 |
sequence, etc., etc., etc. This also means that if we give our model the input "sos", 00:42:14.920 |
it will produce the first token as output, so "love", then if we give the first two tokens, 00:42:21.320 |
it will produce the second token as output, so "love that", and if we give the first three tokens, 00:42:29.400 |
it will produce the output, the third token as output. Of course, the model will also produce 00:42:35.960 |
the output for the previous two tokens, but let's see it with an example. So, if you remember from 00:42:42.440 |
my previous video, also in which I do the inferencing, when we train the model, we only do 00:42:47.080 |
it in one step, so we give the input, we give the target, we calculate the loss, and we don't have 00:42:53.320 |
any for loop to train the model for one single sentence, but for the inference, we need to do it 00:43:01.240 |
token by token. So, in this inferencing, we start with a time step, time step one, in which we only 00:43:10.200 |
give the input "sos", so start of sentence, and the output is "love". Then, we take the output token 00:43:17.720 |
here, "love", and we append it to the input, and we give it again to the model, and the model will 00:43:23.960 |
produce the next token, "love that". Then, we take the last token output by the model, "that", 00:43:31.080 |
we append it again to the input, and the model will produce the next token. And then, we again 00:43:36.920 |
take the next token, so "can", we append it to the input, and we feed it again to the model, 00:43:43.160 |
and the model will output the next token quickly. And we do it for all the steps that are necessary 00:43:49.880 |
until we reach the end of sentence token. Then, that's when we know that the model has finished 00:43:56.120 |
outputting its output. Now, this is not how Lama was trained, actually, but this is a good example 00:44:04.840 |
to show you how the next token prediction task works. Now, there is a problem with this approach. 00:44:13.320 |
Let's see why. At every step of the inference, we are only interested in the last token output by 00:44:20.920 |
the model, because we already have the previous ones. However, the model needs to access all the 00:44:27.480 |
previous tokens to decide on which token to output, since they constitute its context, or the prompt. 00:44:33.320 |
So, what I mean by this is that to output, for example, the word "D", the model has to see all 00:44:40.280 |
the input here. We cannot just give the "Cs". The model needs to see all the input to output this 00:44:45.960 |
last token, "D". But, the point is, this is a sequence-to-sequence model, so it will produce 00:44:52.600 |
this sequence as output, even if we only care about the last token. So, there is a lot of 00:44:58.040 |
unnecessary computation we are doing to calculate these tokens, again, that we already actually have 00:45:03.480 |
from the previous time steps. So, let's find a way to not do this useless computation. 00:45:08.760 |
And this is what we do with the KVCache. So, the KVCache is a way to do less computation 00:45:16.120 |
on the tokens that we have already seen during inferencing. So, it's only applied during 00:45:22.680 |
inferencing in a transformer model, and it not only applies to the transformer like the one in 00:45:31.160 |
Lama, but to all transformer models, because all transformer models work in this way. This is a 00:45:36.520 |
description, it's a picture of how the self-attention works during the next token prediction 00:45:42.360 |
task. So, as you saw also in my previous slides, we have a query matrix here with N tokens, then we 00:45:50.120 |
have the transposed of the keys, so the query can be taught as rows of vectors, where the first 00:45:56.680 |
vector represents the first token, the second token, etc. Then the transposed of the keys is 00:46:01.240 |
the same tokens but transposed, so the rows become columns. This produces a matrix that is N by N, 00:46:08.360 |
so if the initial input matrix is 9, the output maximum will be 9 by 9. Then we multiply it by 00:46:15.560 |
the V matrix, and this will produce the attention. The attention is then fed to the linear layer of 00:46:23.080 |
the transformer, then the linear layer will produce the logits, and the logits are fed to 00:46:29.000 |
the softmax, and the softmax allow us to decide which is the token from our vocabulary. Again, 00:46:36.040 |
if you are not familiar with this, please watch my previous video of the transformer about the 00:46:40.840 |
inferencing of the transformer, and you will see this clearly. So, this is a description of what 00:46:48.120 |
happens at a general level in the self-attention. Now, let's watch it step by step. So, imagine at 00:46:54.840 |
inference step 1, we only have the first token. If you remember before, we were only using the 00:47:00.840 |
start of sentence token. So, we take the start of sentence token, we multiply it by itself, 00:47:06.120 |
so the transposed, it will produce a matrix that is 1 by 1, so this matrix is 1 by 4096, 00:47:12.680 |
multiplied by another matrix that is 4096 by 1, it will produce a 1 by 1 matrix. 00:47:17.640 |
Why 4096? Because the embedding vector in Lama is 4096. Then the output, so this 1 by 1, 00:47:25.560 |
is multiplied by the V, and it will produce the output token here, and this will be our first 00:47:32.040 |
token of the output. And then we take the output token, this one, and we append it to the input 00:47:40.040 |
at the next step. So, now we have two tokens as input. They are multiplied by itself, but with 00:47:46.600 |
the transposed version of itself, and it will produce a 2 by 2 matrix, which is then multiplied 00:47:52.280 |
by the V matrix, and it will produce two output tokens. But we are only interested in the last 00:47:57.240 |
token's output by the model, so this one, attention 2, which is then appended to the input 00:48:03.720 |
matrix at the time step 3. So, now we have three tokens in the time step 3, which are multiplied 00:48:10.200 |
by the transposed version of itself, and it will produce a 3 by 3 matrix, which is then multiplied 00:48:16.200 |
by the V matrix, and we have these three tokens as output. But we are only interested in the last 00:48:23.880 |
token output by the model, so we append it again as input to the Q matrix, which is now four tokens, 00:48:30.440 |
which is multiplied by the transposed version of itself, and it will produce a 4 by 4 matrix 00:48:37.080 |
as output, which is then multiplied by this matrix here, and it will produce this attention matrix 00:48:43.000 |
here. But we are only interested in the last attention, which will be then added again to the 00:48:48.520 |
input of the next step. But we notice already something. First of all, we already here in this 00:48:55.720 |
matrix, where we compute the dot product between this token and this, this token and this, this 00:49:01.240 |
token and this. So this matrix is all the dot products between these two matrices. We can see 00:49:07.000 |
something. The first thing is that we already computed these dot products in the previous step. 00:49:13.160 |
Can we cache them? So let's go back. As you can see, this matrix is growing. Two, three, 00:49:19.880 |
four. See, there is a lot of attention, because every time we are inferencing the transformer, 00:49:26.360 |
we are giving the transformer some input, so it's re-computing all these dot products, 00:49:32.280 |
which is inconvenient, because we actually already computed them in the previous time step. So 00:49:37.080 |
is there a way to not compute them again? Can we kind of cache them? Yes, we can. And then, 00:49:44.360 |
since the model is causal, we don't care about the attention of a token with its predecessors, 00:49:50.520 |
but only with a token before it. So as you remember, in the self-attention, we apply a mask, 00:49:56.600 |
right? So the mask is basically, we don't want the dot product of one word with the word that 00:50:02.280 |
comes after it, but only the one that comes before it. So basically, we don't want all the numbers 00:50:08.760 |
above the principal diagonal of this matrix. And that's why we applied the mask in the 00:50:15.400 |
self-attention. But okay, the point is, we don't need to compute all these dot products. The only 00:50:21.800 |
dot products that we are interested in is this last row. So because we added the token 4 as input 00:50:29.880 |
compared to the last time step, so we only have this new token, token 4, and we want this token 4 00:50:35.800 |
how it is interacting with all the other tokens. So basically, we are only interested in this last 00:50:42.440 |
row here. And also, as we only care about the attention of the last token, because we want 00:50:49.960 |
to select the word from the vocabulary, so we only care about the last row, we don't care about 00:50:54.680 |
producing these two, these three attention score here in the output sequence of the self-attention, 00:51:02.280 |
we only care about the last one. So is there a way to remove all these redundant calculations? 00:51:07.880 |
Yes, we can do it with the KV cache. Let's see how. So with the KV cache, basically, what we do 00:51:16.280 |
is we cache the query, sorry, the keys and the values. And every time we have a new token, 00:51:25.560 |
we append it to the key and the values, while the query is only the output of the previous step. 00:51:32.360 |
So at the beginning, we don't have any output from the previous step, so we only use the first token. 00:51:37.880 |
So the first, the time step one of the inference is the same as without the cache. So we have the 00:51:44.360 |
token one with itself, will produce a matrix one by one, multiplied with one token, and it will 00:51:50.040 |
produce one attention. However, at the time step two, we don't append it to the previous query, 00:51:59.240 |
we just replace the previous token with the new token we have here. However, we keep the cache 00:52:05.240 |
of the keys. So we keep the previous token in the keys, and we append the last output to the keys 00:52:11.880 |
here, and also to the values. And if you do this multiplication, it will produce a matrix that is 00:52:18.680 |
one by two, where the first item is the dot product of the token two with the token one and 00:52:24.840 |
the token two with the token two. This is actually what we want. And if we then multiply with the V 00:52:30.760 |
matrix, it will only produce one attention score, which is exactly the one we want. And we do again, 00:52:36.920 |
so we take this attention two, and this will become the input of the next inference step. 00:52:43.720 |
So this token three, we append it to the previously cached K matrix and also to the 00:52:49.560 |
previously cached V matrix. This multiplication will produce an output matrix that we can see 00:52:55.960 |
here. The multiplication of this output matrix with this V matrix will produce one token in the 00:53:03.000 |
output, which is this one, and we know which token to select using this one. Then we use it 00:53:08.680 |
as an input for the next inferencing step by appending it to the cached keys and appending 00:53:14.280 |
to the cached V matrix. We do this multiplication, and we will get this matrix, which is four, 00:53:23.160 |
one by four, which is the dot product of the token four with the token one, the token four 00:53:28.600 |
with the token two, token four with the token three, and the token four with itself. We multiply 00:53:33.560 |
by the V matrix, and this will only produce one attention, which is exactly what we want to select 00:53:39.080 |
the output token. This is the reason why it's called the KV cache, because we are keeping a 00:53:44.920 |
cache of the keys and the values. As you can see, the KV cache allow us to save a lot of computation 00:53:52.120 |
because we are not doing a lot of dot products that we used to do before, and this makes the 00:53:59.720 |
inferencing faster. The next layer that we will be talking about is the grouped multi-query 00:54:04.440 |
attention, but before we talk about the grouped multi-query attention, we need to introduce its 00:54:09.560 |
predecessor, the multi-query attention. Let's see. So let's start with the problem. The problem is 00:54:16.680 |
that the GPUs are too fast. If you watch this datasheet, this is from the A1 GPU from NVIDIA, 00:54:26.280 |
we can see that the GPU is very fast at computing, at performing calculations, 00:54:31.160 |
but not so much, not so fast at transferring data from its memory. That means, for example, 00:54:39.880 |
that the A100 can do 19.5 tera floating point operations per second by using a 32-bit precision, 00:54:50.840 |
while it can only transfer 1.9 thousand gigabytes per second. It's nearly 10 times 00:55:00.920 |
more slower at transferring data than it is at performing calculations, and this means that 00:55:10.840 |
sometimes the bottleneck is not how many operations we perform, but how much data transfer our 00:55:16.840 |
operations need, and that depends on the size and the quantity of the tensors involved in our 00:55:22.680 |
calculations. For example, if we compute the same operations on the same tensor n times, 00:55:29.240 |
it may be faster than computing the same operations on n different tokens, even if they 00:55:34.840 |
have the same size. This is because the GPU may need to move these tensors around. So this means 00:55:41.560 |
that our goal should not only be to optimize the number of operations we do with our algorithms, 00:55:48.040 |
but also minimize the memory access and the memory transfers that our algorithms perform, 00:55:53.960 |
because the memory access and memory transfer are more expensive in terms of time compared to the 00:56:02.440 |
computations. And this also happens with software when we do I/O, for example. If we copy, for 00:56:08.840 |
example, we do some multiplications in the CPU or we read some data from the hard disk, reading from 00:56:14.920 |
the hard disk is much more slower than doing a lot of computations on the CPU. And this is a problem. 00:56:21.880 |
Now, in this paper, we introduced the multi-query attention. This paper is from Noam Shazir, 00:56:27.960 |
who is also one of the authors of the attention paper. So attention is all you need. And in this 00:56:34.360 |
paper, he introduced the problem. He said, well, let's look at the multi-head attention. So the 00:56:41.720 |
batched multi-head attention. This is the multi-head attention as presented in the original 00:56:46.840 |
paper. Attention is all you need. Let's look at the algorithm and let's calculate the number of 00:56:51.880 |
arithmetic operations performed and also the total memory involved in these operations. 00:56:58.440 |
So he calculated that the number of arithmetic operations is performed in O(1), O(b) and 00:57:05.320 |
d^2, where b is the batch size, n is the sequence length, and d is the size of the embedding vector. 00:57:11.960 |
While the total memory involved in the operations, given by the sum of all the tensors involved in 00:57:19.160 |
the calculations, including the derived ones, is equal to O(b) and d^2 + b*h*n^2, where h is the 00:57:29.640 |
number of heads in this multi-head attention, plus d^2. Now, if we compute the ratio between 00:57:36.920 |
the total memory and the number of arithmetic operations, we get this expression here, 00:57:43.240 |
1/k + 1/b. In this case, the ratio is much smaller than 1, which means that the number of memory 00:57:50.920 |
accesses that we perform is much less than the number of arithmetic operations. So the memory 00:57:56.040 |
access in this case is not the bottleneck. So what I mean to say is that the bottleneck of 00:58:04.120 |
this algorithm is not the memory access, it is actually the number of computations. And as you 00:58:09.240 |
saw before, when we introduced the KV cache, the problem we were trying to solve is the number of 00:58:13.880 |
computations, but by introducing the KV cache, we created a new bottleneck, and it's not the 00:58:27.560 |
computation anymore. So this algorithm here is the multi-head self-attention, but using the KV cache, 00:58:35.000 |
and this reduces the number of operations performed. So if we look at the number of 00:58:40.120 |
arithmetic operations performed, it's bnd^2. The total memory involved in the operation is bn^2d 00:58:48.280 |
+ ndd^2, and the ratio between the two is this, O(n/d + 1/b), so the ratio between the total memory 00:58:58.840 |
and the number of arithmetic operations. This means that when n is very similar to d, 00:59:06.600 |
this ratio will become 1, or when b is very similar to 1, or in the limit of 1, so the batch size is 1, 00:59:13.960 |
this ratio will become 1. And this is a problem, because now when this condition is verified, 00:59:20.360 |
it's true, then the memory access becomes the bottleneck of the algorithm. And this also means 00:59:28.040 |
that either we keep the dimension of the embedding vector much bigger than the sequence length, 00:59:36.760 |
but if we increase the sequence length without making the dimension of the embedding vector 00:59:42.120 |
much bigger, the memory access will become the bottleneck. So what we can do is, we need to 00:59:49.800 |
find a better way. To solve the problem of the previous algorithm, in which the memory became 00:59:54.920 |
the bottleneck, we introduced the multi-query attention. So what the author did was to remove 01:00:01.080 |
the h dimension from the k and the v, while keeping it for the q. So it's still a multi-head 01:00:09.880 |
attention, but only with respect to q, that's why it's called multi-query attention. So we will 01:00:15.800 |
have multiple heads only for the q, but the k and v will be shared by all the heads. And if we use 01:00:22.680 |
this algorithm, the ratio becomes this, 1/d + n/dh + 1/b. So we compare it to the previous one, 01:00:33.960 |
in which was n/d, now it's n/dh. So we reduced the n/d factor, the ratio n/d by a factor of h, 01:00:46.040 |
because we removed the h number of heads for the k and v. So the gains, the performance gains are 01:00:53.400 |
important actually, because now it happens less, it is less likely that this ratio will become 1. 01:01:01.320 |
But of course, by removing the heads from the k and v, our model will also have less parameters, 01:01:10.840 |
it will also have less degrees of freedom and complexity, which may degrade the quality 01:01:17.320 |
of the model. And it actually does degrade the quality of the model, but only slightly, 01:01:22.120 |
and we will see. So if we compare, for example, the blue score on a translation task from English 01:01:27.320 |
to German, we can see that the multi-head attention, so the attention that was in the 01:01:32.040 |
original attention paper, has a blue score of 26.7, while the multi-query has a blue score of 26.5. 01:01:41.560 |
The author also compared it with the multi-head local and multi-query local, where local means 01:01:49.720 |
that they restrict the attention calculation only to the previous 31 positions of each token. 01:01:58.600 |
And we can see it here. But the performance gains by reducing the heads of the k and v is great, 01:02:07.160 |
because you can see the inference time, for example, on the original multi-head attention 01:02:11.560 |
and the multi-query attention. The inference time went from 1.7 microseconds plus 46 microseconds 01:02:20.200 |
for the decoder to 1.5 microseconds plus 3.8 microseconds for the decoder. So in total here, 01:02:28.040 |
more or less, we took 48 microseconds, while here we more or less take 6 microseconds for 01:02:36.360 |
the multi-query. So it's a great benefit from a performance point of view during the inferencing. 01:02:45.160 |
Let's talk about grouped multi-query attention, because now we just introduced the kvcache 01:02:51.560 |
and the multi-query attention. But the next step of the multi-query attention 01:02:56.520 |
is the grouped multi-query attention, which is the one that is used in llama. 01:03:00.360 |
So let's have a look at it. With multi-query, we only have multiple heads for the queries, 01:03:06.440 |
but only one head for the key and the values. With grouped multi-query attention, basically, 01:03:12.520 |
we divide the queries into groups. So for example, this is the group 1, this is the group 2, 01:03:18.760 |
group 3 and group 4. And for each group, we have one different head of k and v. 01:03:26.600 |
This is a good compromise between the multi-head, in which there is a one-to-one correspondence, 01:03:32.280 |
and the multi-query, where there is a n-to-one correspondence. 01:03:36.680 |
So in this case, we have still multiple heads for the keys and values, but they are 01:03:42.600 |
less numerically compared to the number of heads of the queries. 01:03:46.040 |
And this is a good compromise between the quality of the model and the speed of the model, 01:03:52.840 |
because anyway, here we benefit from the computational benefit of the reduction in 01:04:01.320 |
the number of heads of key and values, but we don't sacrifice too much on the quality side. 01:04:07.560 |
And now the last part of the model. As you can see here, the feedforward in the llama model 01:04:13.880 |
has been converted into, has its activation function changed with the zwiglu function. 01:04:20.760 |
Let's have a look at how it works. So the zwiglu function was analyzed 01:04:25.640 |
in this famous paper from Noam Shazir, who is also one of the authors of the attention model, 01:04:30.920 |
who is also one of the authors of the multi-query attention that we saw before. 01:04:36.200 |
So let's have a look at this paper. So the author compared the performance of 01:04:42.120 |
the transformer model by using different activation functions in the feedforward 01:04:46.680 |
layer of the transformer architecture. And the one we are interested in is this 01:04:51.800 |
zwiglu here, which is basically the swish function with beta equal to one calculated 01:04:58.600 |
in the X multiplied by a W matrix, which is a parameter matrix, which is then multiplied with 01:05:05.080 |
the X multiplied by V, V is also another parameter matrix, and W2, which is another parameter matrix. 01:05:11.480 |
So compare this with the original feedforward network and here we have three parameter matrices, 01:05:19.000 |
while in the original feedforward network, we only had two. So to make the comparison fair, 01:05:26.040 |
the author reduced the number of the size of these matrices to have two such that the model, 01:05:33.720 |
model's total number of parameters remains the same with the vanilla transformer. 01:05:38.120 |
In the vanilla transformer, we had this feedforward network, which was the relu function. 01:05:42.840 |
So this max zero, et cetera, is the relu function. And we only had the two parameter matrices. 01:05:48.760 |
Actually, some successor version of the transformer didn't have the bias. 01:05:54.200 |
So this is, I took this formula from the paper, but there are many implementations 01:05:57.640 |
without the bias actually. And while in Lama, we use this 01:06:02.840 |
computation for the feedforward network. And this is the code I took from the repository from Lama. 01:06:09.560 |
And as you can see, it's just what the model says. It's the silu function. Why the silu function? 01:06:15.400 |
Because it's the swish function with beta equal to one. And when the swish function 01:06:20.040 |
that has this expression, we give beta equal to one, it's called the sigmoid linear unit 01:06:26.360 |
that has this graph and it's called silu. So the silu function evaluated in the w1 of x, 01:06:35.080 |
then multiplied by w3, which is then we apply it to w2. So we have three matrices. 01:06:41.960 |
And these three matrices are basically linear layers. Now they use the parallelized version 01:06:47.720 |
of this linear layer, but it's a linear layer. And if we look at the graph of this silu function, 01:06:54.360 |
we can see that it's kind of like a relu, but in this here before the zero, we don't cancel out 01:07:04.760 |
immediately the activation. We keep a little tail here so that even values that are very 01:07:10.920 |
close to zero from the negative side are not automatically canceled out by the function. 01:07:16.840 |
So let's see how does it perform. So this is wiglu function actually performs very well. 01:07:22.200 |
Here they evaluate the log complexity, perplexity of the model when we use this 01:07:30.200 |
particular function. And we can see that the perplexity here is the lowest. The perplexity 01:07:36.360 |
basically means how unsure is the model about its choices. And the wiglu function is performing well. 01:07:45.880 |
Then they also run the comparison on many benchmarks. And we see that this wiglu function 01:07:53.160 |
is performing quite well on a lot of them. So why is this wiglu activation function working so well? 01:08:01.080 |
If we look at the conclusion of this paper, we see that we offer no explanation as to why this 01:08:07.000 |
architecture seems to work. We attribute their success as all else to divine benevolence. 01:08:13.640 |
Actually, this is kind of funny, but it's also kind of true. Because in most of the 01:08:18.600 |
deep learning research, we do not know why things work in the way they do. Because imagine you have 01:08:25.640 |
a model of 70 billion parameters. How can you prove what is happening to each one of them 01:08:32.440 |
after you modify one activation function? It's not easy to come up with a model that 01:08:38.840 |
can explain why the model is reacting in a particular way. What usually we do, we can 01:08:45.880 |
either simplify the model, so we can work with a very small model, and then make some assumptions 01:08:51.320 |
on why things work the way they do. Or we can just do it on a practical level. So we take a model, 01:08:58.440 |
we modify it a little bit, we do some ablation study, and we check which one is performing better. 01:09:03.560 |
And this also happens in a lot of areas of machine learning. For example, 01:09:07.560 |
we do a lot of grid search to find the right parameters for a model, because we cannot know 01:09:13.000 |
beforehand which one will work well, or which one to increase, or which one to decrease. Because it 01:09:17.960 |
depends on a lot of factors, not only on the algorithm used, but also on the data, also on 01:09:23.480 |
the particular computations used, also on the normalization used. So there is a lot of factors, 01:09:28.120 |
there is no formula for everything to explain everything. So this is why the research needs 01:09:35.240 |
to do a lot of study on the variants of models, to come up with something that works maybe in 01:09:40.920 |
one domain and doesn't work well in other domains. So in this case, we use the Zwiglu, 01:09:45.160 |
mostly because in practice it works well with this kind of models. 01:09:48.920 |
Thank you guys for watching this long video. I hope that you learned in a deeper level what 01:09:56.200 |
happens in Lama, and why it is different from a standard transformer model. I know that the 01:10:01.240 |
video has been quite long, and I know that it has been hard on some parts to follow, 01:10:05.880 |
so I actually kind of suggest to re-watch it multiple times, especially the parts that you 01:10:10.920 |
are less familiar with, and to integrate this video with my previous video about the transformer. 01:10:15.560 |
So you can, I will put the chapters so you can easily find the part that you want, but this is 01:10:22.360 |
what you need to do. You need to watch multiple times the same concept to actually master it. 01:10:27.960 |
And I hope to make another video in which we code the Lama model from zero, so we can 01:10:32.120 |
put all this theory into practice. But as you know, I am doing this on my free time, 01:10:39.320 |
and my free time is not so much. So thank you guys for watching my video, and please subscribe 01:10:44.200 |
to my channel, because this is the best motivation for me to keep posting amazing content on AI and 01:10:50.360 |
machine learning. Thank you for watching, and have an amazing rest of the day.