back to indexCoding LLaMA 2 from scratch in PyTorch - KV Cache, Grouped Query Attention, Rotary PE, RMSNorm
Chapters
0:0 Introduction
1:20 LLaMA Architecture
3:14 Embeddings
5:22 Coding the Transformer
19:55 Rotary Positional Embedding
63:50 RMS Normalization
71:13 Encoder Layer
76:50 Self Attention with KV Cache
89:12 Grouped Query Attention
94:14 Coding the Self Attention
121:40 Feed Forward Layer with SwiGLU
128:50 Model weights loading
141:26 Inference strategies
145:15 Greedy Strategy
147:28 Beam Search
151:13 Temperature
152:52 Random Sampling
154:27 Top K
157:3 Top P
158:59 Coding the Inference
00:00:02.640 |
In this video we will be coding Lama 2 from scratch. 00:00:05.640 |
And just like the previous video in which I was coding the transformer model from zero, 00:00:11.640 |
while coding I will also explain all the aspects of Lama, so each building block of the Lama architecture. 00:00:19.400 |
And I will also explain the math behind the rotary positional encoding. 00:00:23.360 |
I will also explain grouped query attention, the KV cache. 00:00:27.280 |
So we will not only have a theoretical view of these concepts, but also a practical one. 00:00:33.120 |
If you are not familiar with the transformer model, I highly recommend you watch my previous video on the transformer model. 00:00:38.640 |
And then if you want, you can also watch the previous video on how to code a transformer model from zero, 00:00:43.440 |
because it will really help you a lot in understanding what we are doing in this current video. 00:00:49.120 |
If you already watched my previous video on the architecture of Lama, in which I explained the concepts, 00:00:55.080 |
it will be also really helpful if you didn't, I will try to explain all the concepts, 00:01:01.880 |
but not as much in detail as in the last video. 00:01:06.960 |
So please, if you have time, if you want, please watch my previous video on the Lama architecture, 00:01:11.520 |
and then watch this video, because this will really give you a deeper understanding of what is happening. 00:01:21.280 |
Here we have a comparison between the architecture of the standard transformer, 00:01:25.040 |
as introduced in the "Attention is all you need" paper, and the architecture of Lama. 00:01:30.320 |
The first thing we notice is that the transformer was an encoder-decoder model. 00:01:36.000 |
And in the previous video, we actually trained it on a translation task, 00:01:40.120 |
on how to translate, for example, from English to Italian, while Lama is a large language model. 00:01:46.160 |
So the goal of a large language model is actually to work with what is called the next token prediction task. 00:01:53.720 |
So given a prompt, the model tries to come up with the next token that completes this prompt in the most coherent way. 00:02:05.720 |
And we keep asking the model for the successive tokens based on the previous tokens. 00:02:14.560 |
So each output depends on the previous tokens, which is also called the prompt. 00:02:20.800 |
Contrary to what I have done in my previous video on coding the transformer model, 00:02:25.840 |
in this video, we will not start by coding the single building blocks of Lama and then come up with a bigger picture. 00:02:35.840 |
So we will first make the skeleton of the architecture and then we will build each block. 00:02:42.240 |
I find that this is a better way for explaining Lama also because it's the model is simpler, 00:02:48.040 |
even if the single building blocks are much more complex. 00:02:51.320 |
And this is why it's better to first look at how they interact with each other and then zoom in into their inner workings. 00:03:06.200 |
So this block here, so we are given an input and we want to convert it into embeddings. 00:03:14.760 |
These are my slides from my previous video on the transformer. 00:03:19.560 |
And as you can see, we start with an input sentence. 00:03:23.200 |
So we are given, for example, the sentence, your cat is a lovely cat. 00:03:26.640 |
We tokenize it, so we split into single tokens. 00:03:29.480 |
We match, we map each token into its position in the vocabulary. 00:03:33.880 |
The vocabulary is the list of all the words that our model can recognize. 00:03:39.680 |
These tokens actually most of the time are not single words. 00:03:44.240 |
What I mean is that the model doesn't just split the word by whitespace into single words. 00:03:50.200 |
Usually the most commonly used tokenizer is the BPE tokenizer, which means byte pair encoding tokenizer, 00:03:56.960 |
in which the single tokens can also be a sequence of letters that are not necessarily mapped to a single word. 00:04:04.440 |
It may be a part of a word or maybe it may be a whitespace. 00:04:08.840 |
It may be multiple words or it may be a single digit, etc. 00:04:13.440 |
And the embedding is a mapping between the number. 00:04:17.840 |
So the input IDs that represent the position of the token inside of the vocabulary to a vector. 00:04:24.400 |
This vector in the original transformer was of size 512, while in Lama, in the base model, 00:04:36.600 |
The dimension is 4096, which means this vector will contain 4096 numbers. 00:04:43.200 |
Each of these vectors represents somehow the meaning of the word, 00:04:48.400 |
because each of these vectors are actually parameter vectors that are trained along the model 00:04:54.400 |
and somehow they capture the meaning of each word. 00:04:57.680 |
So for example, if we take the word "cat" and "dog", 00:05:01.560 |
they will have an embedding that is more similar compared to "cat" and "tree", 00:05:06.880 |
if we compare the distance of these two vectors, so we could say the Euclidean distance of these two vectors. 00:05:14.960 |
And this is why they are called embedding, this is why we say they capture the meaning of the word. 00:05:24.600 |
I will open Visual Studio Code and first of all, we make sure that we have the necessary libraries. 00:05:32.800 |
I will also share the repository of this project. 00:05:35.800 |
So the only thing we need is Torch, Sentence Piece, which is the tokenizer that we use in Lama, and TKODM. 00:05:43.200 |
The second thing is from the official repository of Lama, you should download the download script. 00:05:47.440 |
It's this file, download.sh, that allows you to download the weights of the Lama model. 00:05:52.400 |
In my case, I have downloaded the Lama 2 7 billion, along with the tokenizer. 00:06:01.120 |
And I will not even be able to use the GPU because my GPU is not powerful enough. 00:06:08.400 |
I think most of you guys will do the same because it's actually a little big for a normal computer unless you have a powerful GPU. 00:06:17.920 |
Let's create a new file, model.py, and let's start our journey. 00:06:26.080 |
Import the necessary things. So we import Torch, import and also Torch.nn. 00:06:40.880 |
Okay, these are the basic things that we always import. I remember we also need math and then also data classes. 00:07:00.240 |
Most of the code is actually based on the original Lama code, so you don't be surprised if you see a lot of similarities. 00:07:06.480 |
But I simplified a lot of parts to remove things that we don't need, especially, for example, the parallelization. 00:07:12.320 |
And I also tried to add a lot of comments to show you all the shapes changed in each tensor. 00:07:22.800 |
So the first thing I want to create is the class that represents the parameters of the model. 00:07:57.600 |
Here we already see that we have two type of heads. 00:08:06.320 |
And here we have the number of heads for the k and the v. 00:08:13.600 |
Because we will see later, with the grouped query attention, 00:08:17.040 |
we don't have to necessarily have the same number of heads for the query key and values like in the original transformer. 00:09:03.680 |
These two parameters indicate the dimension, the hidden dimension of the ffnlayer. 00:09:09.360 |
The basic idea is that they try to, when they introduce the grouped query attention, 00:09:16.240 |
Because with the grouped query attention, we reduce the number of heads of the k and v. 00:09:20.160 |
But they incremented the number of parameters of the feedforward layer. 00:09:23.680 |
So as the number of the total parameters of the model remains the same. 00:09:27.440 |
This allows to compare the full, the base transformer. 00:09:32.480 |
So with all the heads for the query, the key and values. 00:09:37.600 |
Which has a reduced number of heads for the k and v. 00:09:40.160 |
But this is just a decision, an architectural decision. 00:10:14.800 |
Also here we have two parameters that we will later use for the kv cache. 00:10:18.400 |
And I will explain you later what is it and how it works. 00:10:20.720 |
Let's start, as I said before, let's start with implementing the skeleton of the entire model. 00:10:29.440 |
And while implementing each single part, I will also review the background. 00:10:35.680 |
This is the main class that will represent the entire model. 00:10:48.400 |
So all this model here, except for the softmax. 00:11:08.560 |
we make sure that we have set the vocabulary size. 00:11:49.040 |
Just like in the transformer, the base transformer, if you remember correctly. 00:11:52.560 |
This block here and this block here were repeated one after another many times. 00:11:59.600 |
And the output of the last layer is then sent to this rms norm, then to the linear, etc. 00:12:14.880 |
So I'm using the same names as used in the transformer in the code from the Lama repository. 00:12:28.400 |
Because when we will load the weights of the model, the names must match. 00:12:32.400 |
Otherwise, the torch doesn't know where to load the weights in. 00:12:36.240 |
And this is the reason I'm trying to keep the same name. 00:12:40.080 |
I only changed some names to make them more clear. 00:13:24.720 |
And the normalization is the rms normalization. 00:13:30.960 |
We need to tell him the size of the features. 00:13:35.040 |
And the EPS is a very small number that is needed for the normalization calculation. 00:13:41.120 |
So that we never divide it by zero, basically. 00:13:58.160 |
Okay, then we need to pre-compute the frequencies of the rotary positional encodings. 00:14:10.640 |
I created this method and then we go to implement it and I will show you how it works. 00:15:11.920 |
The last layer output is sent to the normalization, then to the output. 00:15:15.840 |
So this logic will be more clear in the forward method. 00:15:22.720 |
Okay, here you will see one thing that is different from the previous transformers, 00:15:37.840 |
is that the sequence length that we want is always one. 00:15:40.880 |
And this is because we are using the KV cache. 00:15:51.360 |
Here, when we give the input, we want to give the prompt 00:15:55.680 |
and the model will give us a softmax of the next token. 00:15:59.200 |
But with the KV cache, we don't need to give all the previous tokens 00:16:03.360 |
because maybe we already computed them in the previous iteration. 00:16:06.960 |
So we only need to give the latest token and then the model will output the next token. 00:16:11.440 |
While the intermediate tokens, so the cache of the previous tokens, 00:16:14.640 |
will be kept by the model in its cache because here we have a KV cache. 00:16:19.520 |
So for now, just remember that the input here that we will get is one token at a time. 00:16:32.800 |
Okay, so batch size, sequence length is tokens.shape. 00:16:42.560 |
And we make sure that the sequence length is actually one. 00:16:45.840 |
And second thing is this model is only good for inferencing, not for training. 00:16:59.120 |
Because for training, of course, we need to not have the KV cache 00:17:02.160 |
and we need to be able to process multiple tokens. 00:17:06.000 |
But our goal is actually to use the pre-trained LAMA weights. 00:17:30.160 |
So the dimension of the embeddings, which is 4096 for the base model. 00:17:35.520 |
But depending on the model size, it can be different. 00:17:48.000 |
Okay, I promise I will explain this line here and this line here 00:18:16.720 |
Let me finish it and I will explain everything together. 00:18:19.360 |
So this is basically, we pre-compute something about the positional encoding 00:18:26.960 |
But let's finish writing it and then I explain this method and this one. 00:18:35.040 |
So suppose with this one, we retrieve something that is needed 00:18:39.440 |
for computing the positional encoding, which then we feed to the next layers. 00:18:43.120 |
So we consecutively apply all the encoder blocks. 00:19:01.360 |
And finally, we apply the normalization, just like here. 00:19:09.920 |
So we apply these blocks one after another many times. 00:19:14.800 |
And then we calculate the output using the linear layer. 00:19:26.720 |
So we take the input, we convert it into embeddings. 00:19:31.920 |
We give this input embeddings with something about the positional encodings 00:19:38.400 |
We take the output of the last one and give it to the RMS norm. 00:19:41.440 |
We take the output of the RMS norm and give it to the linear layer. 00:19:44.320 |
And then during the inference, we will apply the softmax. 00:19:47.680 |
Now, let's concentrate on the positional encodings. 00:19:51.760 |
Let's first review how they worked in the original transformer. 00:19:54.800 |
As you remember, in the original transformer, 00:19:57.520 |
so the transformer in the attention is all you need. 00:19:59.920 |
We first take the sentence, we convert into embedding vectors. 00:20:07.200 |
We then add another vector which has the same size, so 512, 00:20:11.440 |
that represents the position of that token inside the sentence. 00:20:15.760 |
So every sentence, every token in the first position of a sentence 00:20:22.480 |
Every token in the second position of a sentence 00:20:29.200 |
And every token in the third position of a sentence 00:20:34.640 |
These vectors are pre-computed because they only depend on the position, 00:20:42.080 |
And this is why they are called absolute positional encoding, 00:20:44.640 |
because they only strictly depend on the position of the word 00:20:49.040 |
While in the contrary, in the rotary positional embeddings, 00:20:55.440 |
First of all, the rotary positional encodings or embeddings, 00:20:59.440 |
they are computed right before the calculation of the attention. 00:21:03.440 |
And they are only applied to the Q and the K matrices, 00:21:09.840 |
The first thing we need to understand is the difference between 00:21:13.120 |
absolute positional encodings and relative ones. 00:21:15.280 |
The absolute positional encodings are the one we just saw, 00:21:24.800 |
they come into play during the calculation of the attention. 00:21:28.240 |
And the calculation of the attention is basically done using the dot product. 00:21:33.600 |
Because it's query multiplied by the transpose of the keys, 00:21:40.640 |
And that we want with the relative positional encodings, 00:21:43.600 |
we change this dot product so that we introduce a new vector 00:21:47.280 |
that indicates the distance between the two tokens involved in the dot product. 00:21:57.760 |
So query multiplied by the transpose of the keys, 00:22:10.960 |
we have another vector here that represents the distance between these two tokens. 00:22:15.200 |
And we compute the attention mechanism like this. 00:22:22.560 |
are something in between the absolute and the relative. 00:22:26.080 |
Absolute because each token will get its own embedding, 00:22:35.360 |
will be evaluated using the relative distance between two tokens. 00:22:40.100 |
The rotary positional embeddings were introduced in the paper from this company, JUE. 00:22:54.400 |
they wanted to find an inner product that works like this. 00:23:03.760 |
The inner product can be thought of as a generalization of the dot product. 00:23:08.720 |
So it's an operation that has some properties that reflect what is the dot product. 00:23:14.960 |
So the authors of this paper wanted to find an inner product 00:23:29.680 |
So this inner product only depends on the embedding of the two tokens involved, 00:23:36.560 |
and the relative distance of these two tokens. 00:23:42.320 |
For example, if the first token is in position two, 00:23:58.000 |
And they wanted to find a dot product that has this property, 00:24:01.680 |
that only depends on the embedding of the first token, 00:24:08.720 |
Then they saw that if this function G is built in this way, 00:24:15.920 |
That is, we take the first token, so the query for example, 00:24:22.880 |
This is actually done also in the vanilla transformer, 00:24:27.920 |
We convert it into a complex number in this form, 00:24:35.360 |
we transform into a complex number into this form, 00:24:41.840 |
This inner product will basically depend only on the distance 00:24:56.080 |
which is an inner product, behaves like this. 00:24:59.840 |
So it only depends on the embeddings of the vector 00:25:03.840 |
And if we, for example, this formulation here, 00:25:11.200 |
so we think of embedding with only two dimensions, 00:25:19.920 |
So each complex number, thanks to the Euler's formula, 00:25:28.160 |
And this matrix here reminds us of the rotation matrix. 00:25:39.520 |
and if we multiply this vector v zero by this matrix here, 00:25:46.480 |
the resulting vector will be rotated by the angle theta. 00:25:51.040 |
So this is why they're called rotary positional embeddings, 00:25:53.680 |
because the matrix here represents a rotation of the vector. 00:26:03.280 |
we have to think that they will map it into a vector space, 00:26:10.320 |
that is a multiple of a base angle, so theta, 00:26:19.120 |
So that two tokens that occupy similar positions 00:26:30.160 |
And this is the idea behind the rotary positional embeddings. 00:26:34.080 |
But how do we actually compute them in the code, 00:26:40.000 |
Well, to compute them, we need to build a matrix like this. 00:26:44.000 |
And as you can see, this matrix is actually full of zeros. 00:26:49.120 |
And so when we calculate the embedding in this way, 00:26:52.080 |
we will do, if we do a matrix multiplication, 00:26:54.400 |
we will be doing a lot of operations that are useless, 00:26:59.680 |
So the authors of the paper proposed another form 00:27:09.360 |
to which we want to apply the positional encodings. 00:27:14.400 |
So the first dimension, the second dimension, 00:27:18.480 |
So if this is, for example, the vanilla transformer, 00:27:29.040 |
plus another vector that is actually based on this vector, 00:27:34.160 |
but with the positions and the designs changed. 00:27:38.800 |
we have the second dimension with its sign change. 00:27:41.680 |
The second position, we have actually the first dimension. 00:27:44.560 |
In the third position, we have the fourth dimension, 00:27:48.400 |
So it depends only on the embedding of the word, 00:27:51.760 |
but it changes with the signs and position change. 00:27:58.720 |
with another matrix that you can see here, this vector. 00:28:07.920 |
And now what we can pre-compute is this matrix here, 00:28:14.000 |
because it doesn't depend on the token to which we apply it to, 00:28:17.040 |
and this matrix here, because it doesn't depend 00:28:20.480 |
And they depend on m, so it's the position of the word, 00:28:27.520 |
Theta is a series of numbers defined like this. 00:29:12.000 |
This theta parameter, 10,000, comes from the paper. 00:29:17.200 |
We first need to make sure that the dimension of the word 00:29:25.280 |
to which we are applying the embedding is actually even, 00:29:30.000 |
that this rotary positional encoding cannot be applied 00:30:13.120 |
And the shape of this theta will be head dimension divided by 2. 00:30:24.080 |
Because we will apply these embeddings to each head, 00:30:30.000 |
but after we have split them into multi-head, 00:30:35.600 |
we check the size of the dimension of each head 00:30:43.680 |
Because in the paper they also divide it by 2. 00:31:07.360 |
to the power of minus 2 multiplied by i minus 1 00:31:13.520 |
divided by dimension for i equal to 1, 2, etc. 00:31:31.760 |
So i, here it starts from 1, we will start from 0, 00:31:58.080 |
Well, because this is to the power of minus 2. 00:32:01.600 |
So something to the power of a negative number is 1 over 00:32:05.920 |
that something to the power of the positive exponent. 00:32:09.680 |
And then this will result in a matrix with shape, 00:32:32.960 |
So the series of theta that goes from theta 1 00:32:41.040 |
Because the m's, the possible positions of a token can be many. 00:32:48.640 |
the maximum sequence length that we can afford 00:32:52.880 |
multiplied by 2, because we have also the prompt, 00:32:57.600 |
So we say, OK, let's pre-compute all the possible theta and m 00:33:02.560 |
for all the possible positions that our model will see. 00:33:05.840 |
And all the possible positions is given by this parameter, 00:33:20.480 |
And the shape is sequence length, which is m. 00:33:24.880 |
Now we need to multiply m by all the sequence of thetas. 00:33:43.200 |
we need m1 theta 1, m1 theta 2, m1 theta d divided by 2. 00:33:49.360 |
Then we need m2 theta 1, m2 theta 2, m2 theta 3. 00:33:58.400 |
basically means multiply all the elements of the first vector 00:34:19.120 |
here we have a frequency is equal to torch outer product m and theta. 00:34:29.680 |
OK, so what we are doing is we are doing the outer product within m, 00:34:38.400 |
This will basically take the first element of the first vector 00:34:42.080 |
and multiply with all the elements of the second vector. 00:34:45.280 |
Then take the second element of the first vector 00:34:47.520 |
and multiply it with all the elements of the second vector, etc, etc. 00:34:50.720 |
So if we start with a shape, let me say shape of m is sequence length. 00:34:58.960 |
Let's say outer product with head dimension divided by 2. 00:35:08.560 |
This will result in a tensor of sequence length by head dimension divided by 2. 00:35:16.000 |
So for each position, we will have all the theta. 00:35:18.160 |
Then for the second position, we will have all the theta. 00:35:20.400 |
For the third position, we will have all the theta, and so on. 00:35:23.840 |
Now, we want to write these numbers into a complex form, and I will show you why. 00:35:44.640 |
I multiplied by m multiplied by theta, where r is equal to 1, as follows. 00:36:05.280 |
Let me also write the shape, and then I'll explain to you how it works. 00:36:30.880 |
I could also, you know, not explain all the proofs. 00:36:34.320 |
So I know the next few minutes will be a little boring 00:36:37.280 |
because I will be explaining all the math behind it. 00:36:39.280 |
But of course, I don't think you, just like me, 00:36:42.640 |
you like to watch just some code and say, "Okay, this is how it's done." 00:36:47.280 |
No, I like to actually give a motivation behind every operation we do, 00:36:51.840 |
and that's, I think, one of the reasons you are watching this video 00:36:54.080 |
and not just reading the code from the beta repository. 00:37:00.880 |
The first thing we need to review is how complex numbers work. 00:37:06.720 |
Okay, a complex number is a number in the form a plus i multiplied by b, 00:37:14.240 |
where a is called the real part and b is called the imaginary part. 00:37:19.040 |
And i is a number such that i to the power of 2 is equal to minus 1. 00:37:25.680 |
So the complex numbers were introduced to represent all the numbers 00:37:30.080 |
that involve somehow the square root of a negative number. 00:37:33.200 |
As you know from school, the square root of a negative number cannot be calculated, 00:37:37.280 |
but so that's why we introduce this constant i, 00:37:39.920 |
which is the negative number, which is the square root of minus 1. 00:37:43.520 |
And so we can represent the square root of negative numbers. 00:37:46.160 |
And they can also be helpful in vector calculations, and we will see how. 00:37:57.280 |
The Euler's formula says that e to the power of i multiplied by x 00:38:02.480 |
is equal to cosine of x plus i multiplied by sine of x. 00:38:11.600 |
So it allows us to represent a complex number in the exponential form 00:38:16.960 |
into a sum of two trigonometric functions, the cosine and the sine. 00:38:24.160 |
Because our goal is to calculate these matrices here, 00:38:31.920 |
the cosine of m theta and the sine of m theta. 00:38:35.520 |
And the first thing we did is we calculated all the theta one, 00:38:42.160 |
then we calculated all the possible combinations of positions and thetas. 00:38:47.360 |
So what we did is we calculated a vector that represents the theta. 00:38:53.760 |
So theta 1, theta 2, up to theta d divided by 2. 00:39:10.800 |
Then we calculated the product of each of them for all the possible thetas. 00:39:19.360 |
So for example, we created a new matrix that has m1 theta 1, m1 theta 2, m1 theta 3, 00:39:39.360 |
And then m2 theta 1, m2 theta 2, m2 theta 3, etc, etc. 00:39:59.120 |
They're just real numbers because theta is a real number, 00:40:01.760 |
m is a real number, but they are not complex numbers. 00:40:06.160 |
So what we do with the last operation here, this one here, 00:40:09.680 |
we convert each of these numbers into polar, into its polar form. 00:40:14.560 |
A number in polar form is a number that can be written as 00:40:22.480 |
which can be written as r cosine of theta plus i sine of theta. 00:40:29.360 |
Why? Because it can be represented in the graphical, let's say, graphical plane xy. 00:40:38.640 |
As you know, complex numbers can be represented into the 2D plane xy, 00:40:45.120 |
where the real part is on the x and the imaginary part is on the y. 00:40:50.640 |
So we are actually representing a vector of size, let's say, r with an inclination of theta. 00:40:58.640 |
Because, as you know, the projection of this vector on the real part is 00:41:03.200 |
r cos theta plus i, the projection on the y-axis is sine of theta. 00:41:11.200 |
And here I forgot r, yeah, I've forgotten r here, r sine of theta. 00:41:18.240 |
So this is another way of representing complex numbers. 00:41:21.760 |
And what we are doing is we are calculating this matrix 00:41:25.120 |
and then converting all these numbers into their complex form. 00:41:28.800 |
So we are converting it into another matrix that has r equal to 1. 00:41:34.000 |
And this number here, for example, this item here will become 00:41:40.480 |
cosine of m1 theta 1 plus i sine of m1 theta 1. 00:41:57.920 |
This has become another complex number that is the cosine of m1 theta 2 plus i sine of m1 theta 2. 00:42:12.560 |
Because we are not increasing the numbers, the total numbers, 00:42:17.040 |
this shape of the tensor also doesn't change. 00:42:21.200 |
So instead of having m theta 1, it becomes cosine of m theta 1 plus i m theta 1. 00:42:36.320 |
Now, the point is, imagine we are given a vector, 00:42:40.560 |
because we want to apply these positional encodings to a vector. 00:42:45.600 |
Because the vector will be given us as a list of dimensions, 00:42:52.080 |
Just like in the original transformer, we have a vector of size 512. 00:42:56.560 |
In this case, it will be much smaller because it's the dimension of each head. 00:43:00.640 |
And as you remember, each head doesn't watch the full dimension of the embedding vector, 00:43:07.840 |
So, but for us, okay, imagine it's only one head. 00:43:10.480 |
So if it's only one head, we will watch the full dimension. 00:43:15.840 |
Just suppose that we only are working with one head. 00:43:18.240 |
So imagine we have a token with its full dimensions. 00:43:25.360 |
And in the case of the vanilla transformer, 512. 00:43:42.720 |
because we want to do the calculation and not go crazy. 00:43:46.880 |
Otherwise, 4096 is a little difficult to prove. 00:43:50.800 |
I want to make a list of operations on this vector until we arrive to this form here. 00:44:01.440 |
Suppose our embedding vector is only made of four dimensions. 00:44:10.880 |
Okay, the first thing we do is I will do some transformations 00:44:21.680 |
So for now, just follow the transformations I'm doing. 00:44:27.760 |
I want to group successive tokens, successive dimensions. 00:44:35.920 |
So X1 and X2 become another dimension in this tensor. 00:44:41.040 |
And X3 and X4 become, oops, very badly written. 00:44:52.480 |
The total number of items is still four, but I added another dimension. 00:45:16.080 |
I consider this first number of this part to be the real part of the complex number. 00:45:22.640 |
And this one to be the imaginary part of the complex number. 00:45:28.080 |
So I do another transformation that we will call two, 00:45:45.520 |
This vector has less items because now two numbers became one complex number. 00:45:54.160 |
Now I multiply this element wise with the vector that we pre-computed before. 00:46:01.360 |
As you remember before, we pre-computed this one. 00:46:05.440 |
Cosine of M1 theta 1 plus I of M1 theta 1, cosine of M1 theta 2 plus I. 00:46:12.560 |
Because they suppose this position, this token here, 00:46:16.960 |
suppose his position is M1, because we need also the M. 00:46:21.520 |
So suppose this token here, his position is M1. 00:46:25.600 |
So we take all this row here, M1, and this will become our new matrix here. 00:46:37.680 |
So four dimensions means we have a theta 1 and theta 2. 00:46:44.960 |
So element wise with the cosine of M1 theta 1 plus I sine of M1 theta 1. 00:47:07.840 |
And then we have cosine of M1 theta 2 plus I of sine of M1 theta 2. 00:47:23.520 |
Now we have an element wise product between the first item of this matrix 00:47:35.200 |
And then we have the product of two complex numbers. 00:47:39.360 |
This complex number here and this complex number here. 00:47:42.960 |
So let's see how to compute the product of two complex numbers. 00:47:47.520 |
Because I don't want to write very long expressions, I will call this one F1. 00:47:54.320 |
So F1 is the cosine of M1 theta 1 and F2 is the sine of M1 theta 1. 00:48:04.240 |
And for the same reason I will call this one F3 and F4. 00:48:08.720 |
Now let's compute the product of the first item of this vector 00:48:33.840 |
This is equal to X1 F1 plus IX1 F2 plus IX2 F1. 00:49:00.800 |
Then we have this product IX2 multiplied by IX2. 00:49:20.160 |
So all the terms that don't have I, X1 F1 minus X2 F2 plus I that multiplies X1 F2 plus X2 F1. 00:49:45.760 |
Okay, this is how to compute the product of two complex numbers. 00:49:51.120 |
So the first number here in the resulting matrix from this element-wise multiplication will be 00:49:58.400 |
X1 F1 minus X2 F2 plus I of X1 F2 plus X2 F1. 00:50:25.680 |
The second element, we don't need to do this multiplication 00:50:29.120 |
because they have the similar structure as the first one. 00:50:31.600 |
So we just change the X1 with X3, X2 with X4. 00:50:43.120 |
So the resulting matrix will be X3 F3 minus X4 F4 plus I X3 F4 plus X4 F3. 00:51:15.920 |
So this complex number, we can split the real part and the imaginary part. 00:51:19.920 |
And this we will call it transformation number three. 00:51:23.680 |
So we can split it in a tensor of two dimensions. 00:51:30.400 |
One is the real part and one is the complex part. 00:51:33.200 |
Where is X1 F1 minus X2 F2 then X1 F2 plus X2 F1. 00:51:56.240 |
This is the first tensor. The second tensor will be X3 F3 minus X4 F4. 00:52:20.960 |
I'm really sorry for the bad handwriting, but it's my touchpad is not so good. 00:52:26.320 |
Then we do another transformation in which we flatten all these values. 00:52:35.600 |
This is the second, the third and the fourth. 00:52:37.440 |
So we remove this dimension, the inner dimension. 00:52:40.720 |
So we flatten this matrix and it will become X1 F1 minus X2 F2. 00:53:11.200 |
Then we have X3 F3 minus X4 F4 then we have X3 F4 plus X4 F3. 00:53:32.240 |
Let's compare this resulting matrix with what is in the paper. 00:53:46.160 |
OK, the resulting matrix is exactly the same as what is in the paper. 00:53:53.040 |
So X1 multiplied by F1, which is the cosine, as you remember, M1 theta 1. 00:54:06.160 |
So minus X2 multiplied by F2, which is the sine, as you can see here, sine of M1 theta 1. 00:54:13.120 |
Here it's not M1 because it's for the generic M, but we set M equal to M1. 00:54:23.200 |
So X1 with X1 here, because here is the sum, so the order doesn't matter. 00:54:29.520 |
So X1 F2, so X1 multiplied by the sine plus X2 F1, X2 F1. 00:54:44.800 |
So X3 multiplied by the cosine minus X4 sine minus X4 sine. 00:54:57.360 |
And plus X4 F3, X4 F3, F3 is the cosine of theta 2. 00:55:04.160 |
Also in this case, because we have the sum here inside, the order doesn't matter. 00:55:10.160 |
So as you can see, we started with a vector of dimension 4, but it could be of dimension N. 00:55:20.480 |
We then multiplied with the matrix that we pre-computed here. 00:55:23.760 |
Then we did some other transformation, and the end result is exactly as doing this operation. 00:55:32.720 |
Because this is actually what we need to apply the embedding vector to this vector here. 00:55:38.560 |
So to this token, how to apply the embeddings, 00:55:40.960 |
the rotary position embeddings through this series of transformations. 00:55:44.320 |
So I could have also written the code and not tell you anything, 00:55:50.720 |
So that you know that what I'm doing is actually described in the paper, 00:55:56.480 |
and we are actually doing it according to the paper. 00:55:59.280 |
There is also a visualization in the paper that is really helpful. 00:56:02.560 |
So what we did here, for example, is we transform the embedding vector into, 00:56:09.200 |
split it into a new tensor, which has half dimension. 00:56:18.160 |
So the two consecutive dimension x1 and x2 and x3 and x4. 00:56:22.160 |
Then we multiply, we transform it with a complex number. 00:56:27.120 |
We multiply it with M theta that we pre-computed. 00:56:30.960 |
And this visualization of why we do this is present in the paper. 00:56:41.760 |
So here they say, if you have a word with n dimensions, 00:56:51.360 |
Then, of course, you will have theta of d half theta. 00:56:55.440 |
Because we have theta 1, theta 2, up to theta d half. 00:56:59.840 |
We group successive dimensions into a new complex number, 00:57:06.880 |
that if we project it on the complex plane, it will result into this vector. 00:57:13.760 |
And then we multiply it with the complex number M theta 1. 00:57:18.320 |
This will result in the number being rotated by the angle indicated by M theta 1. 00:57:25.360 |
And this is the encoded number, is the encoded token. 00:57:29.840 |
And this is exactly what we are doing with our matrix transformations 00:57:41.200 |
X is the token to which we want to apply the Rotary Embeddings. 00:57:49.200 |
FreqsComplex is the output of this function, but only for the position of this token. 00:58:02.880 |
Because this will have all the theta for all the possible positions, 00:58:06.320 |
but we only need the positions for this particular token. 00:58:12.320 |
The first thing we do is the transformation number 1, I think I call it. 00:58:26.960 |
So the first thing we do is we call the transformation number 1. 00:58:30.960 |
And number 2. So the first thing we do is we transform 00:58:34.320 |
the two consecutive dimensions into a new tensor. 00:58:38.320 |
And then we visualize it as a complex number. 00:59:03.680 |
Okay, this operation here is basically saying take two consecutive dimensions and group them. 00:59:15.920 |
And then we transform this intermediate tensor into a complex tensor 00:59:19.920 |
by using the view as complex operation from Torch. 00:59:26.800 |
So we are starting from B, sequence length, H, head dimension. 00:59:33.680 |
Because I saw before this X is actually not the original vector, 00:59:38.160 |
but it's already the one divided with its head dimension. 00:59:49.360 |
then this head dimension is actually the full dimension of the token. 01:00:01.360 |
But this tensor has two dimensions less than this one. 01:00:09.120 |
So take the XComplex and we add the two dimensions that it's missing. 01:00:25.360 |
Here we are going from here to divide by two. 01:00:36.080 |
Because every two consecutive pairs are becoming one complex number. 01:00:40.160 |
And here we go from sequence length to head dimension divide by two. 01:00:46.560 |
We are mapping it to one because this is the batch dimension sequence length, 01:00:53.920 |
then the head dimension one, and then head dimension divide by two. 01:01:06.080 |
which will result in a rotation as we saw in the figure before. 01:01:08.880 |
So that's why I call it X rotated is equal to X complex 01:01:13.120 |
multiplied by the complex number of the frequencies. 01:01:52.080 |
We transform the complex number into a tensor 01:01:55.520 |
in which the first item is the real part of the complex number 01:01:58.400 |
and then the complex part, the imaginary part, and then we flatten it. 01:02:14.640 |
This operation view as real will transform the tensor like this. 01:02:40.960 |
in which we transform the complex number into a tensor of two dimensions. 01:02:45.280 |
Because that's why you can see this additional dimension here. 01:02:50.000 |
You can just say to flatten it with the shape of the original... 01:03:34.800 |
or a list of tokens, because we have the batch dimensions, 01:03:41.040 |
doing all these transformations that we have done here 01:03:46.080 |
And they are all equivalent to doing this operation as written on the paper. 01:03:50.000 |
Now we need to go forward with our transformer by implementing the rest. 01:03:58.000 |
The next thing that we can implement is this RMS norm 01:04:02.000 |
because it's present at the output of the transformer 01:04:11.360 |
We can see that we have the normalization, the RMS normalization here 01:04:24.640 |
If you want to have a deep understanding of how normalization works, 01:04:30.160 |
I actually described why we need normalization, 01:04:32.720 |
how it was historically done and how it works, 01:04:44.000 |
But if you want to have a better understanding, 01:04:47.760 |
So as you remember, in the original transformer, 01:05:06.400 |
we computed two statistics, one for each item, 01:05:27.680 |
And this formula comes from probability statistics. 01:05:31.520 |
So as you know, if you have any random variable 01:05:50.320 |
We then multiply this with the gamma parameter 01:05:56.800 |
But this was done in the layer normalization. 01:06:25.760 |
can be obtained without recentering the values. 01:06:28.720 |
So without recentering them around the mean of zero. 01:06:35.040 |
However, the variance in the layer normalization 01:06:39.440 |
because if you remember the formula of the variance 01:06:48.800 |
So to compute the variance, we needed the mean, 01:07:07.360 |
That's why they introduced these statistics here, 01:07:13.360 |
And in practice gives the same normalization effect 01:07:25.760 |
between layer normalization and RMS normalization 01:07:30.400 |
And it looks like that recentering was not necessary 01:08:29.280 |
It's used here as the added to the denominator. 01:09:12.820 |
So we return x multiplied by torch dot r sqrt. 01:09:18.160 |
r sqrt stands for the one over the square root. 01:10:36.960 |
this is r sqrt is equal to one over sqrt of x, 01:10:44.400 |
And the dimensions here are multiplied by b sequence length one, 01:10:54.080 |
which results in b sequence length dimensions. 01:10:58.880 |
So what we are doing is exactly just this formula here. 01:11:23.360 |
Here we have the encoder block is all this block here 01:11:38.560 |
another skip connection and a feed forward layer here. 01:11:42.160 |
I think the easiest one to start with is the feed forward, 01:12:52.080 |
What is the head dimension is the dimension of the vector 01:12:56.720 |
So 4,096 divided by, here is the divide by 32, 01:13:09.600 |
So each head will see 4,096 divided by 32 items 01:13:39.760 |
Then we have the normalization before the self-attention. 01:13:57.120 |
And this is the motivation behind this argument norm abs. 01:14:13.680 |
It's after the attention, not after the feed forward. 01:14:52.320 |
StartPause indicates the position of the token. 01:14:54.960 |
I kept the same variable number as in the original code. 01:15:01.200 |
we will be dealing with only one token at a time. 01:15:03.760 |
So StartPause indicates the position of the token 01:15:52.080 |
And to the attention, we also give the frequencies. 01:16:02.080 |
they come into play when we calculate the attention. 01:16:04.720 |
And these operations involve tensors of size B, 01:16:17.920 |
which results in B sequence length dimension. 01:16:50.320 |
Now we need to build the self-attention and the feedforward. 01:17:24.240 |
We have an input, which is sequenced by the model. 01:17:27.920 |
each token modeled by a vector of size T model. 01:17:42.240 |
which has this dimension of sequence by D model. 01:17:44.640 |
And we then split them into the number of heads 01:17:50.400 |
such that each vector that represents the token 01:18:02.320 |
So if the token was 512 in size, for example, 01:18:07.280 |
the first head will watch 128 dimensions of this vector. 01:18:12.880 |
The second head will watch the next 108 dimensions. 01:18:16.880 |
The next head will watch the next 108 dimensions, 01:18:26.640 |
This results in head 1, head 2, head 3 and head 4. 01:18:36.640 |
And this is the output of the multi-head attention. 01:18:44.320 |
because it's the same input that acts as a query, 01:18:53.280 |
and the key and the values come from another place, 01:19:09.280 |
or you want to translate from one language to another. 01:19:28.480 |
Okay, in LLAMA, we need to talk about a lot of things 01:19:34.640 |
We need to review how the self-attention works in LLAMA, 01:19:47.040 |
Otherwise, it will be very hard to follow the code. 01:19:57.440 |
so that has been trained on this particular line. 01:20:01.520 |
So the line is love that can quickly seize the gentle heart. 01:20:18.800 |
And we have a model that has been trained on this line, 01:20:23.280 |
love that can quickly seize the gentle heart. 01:20:25.120 |
Now, a model that has been trained on this particular line 01:20:32.080 |
should have an input that is built in this way. 01:20:36.400 |
and then the tokens that represented the sentence, 01:20:43.600 |
Because the transformer is a sequence-to-sequence model, 01:20:54.320 |
will be mapped to the first token of the output. 01:20:58.080 |
will be mapped to the second token of the output. 01:21:04.720 |
will be mapped to the third token of the output. 01:21:17.360 |
to predict this particular token can, for example, 01:21:22.160 |
the model doesn't only watch the same token in the input, 01:21:26.480 |
so that, but also watch all the previous tokens. 01:21:37.760 |
And the self-attention mechanism with its causal mask 01:21:54.240 |
we need to give the previous output token as input also. 01:21:58.000 |
So we always append the last token of the output 01:22:02.080 |
to the input to predict the successive tokens. 01:22:18.320 |
we need to append the previous output to the input 01:22:32.640 |
and the model will produce a sequence of four tokens as output. 01:22:37.600 |
But this is not really convenient when we do the inferencing 01:22:40.720 |
because the model is doing a lot of dot products 01:22:53.760 |
we need to access all the previous context here. 01:23:05.760 |
However, we can't just tell the transformer model 01:23:10.640 |
We need to change the calculations in such a way 01:23:13.520 |
that we only receive at the output of the transformer 01:23:17.840 |
so that all the other tokens are not even calculated. 01:23:33.600 |
we are only interested in the last token output by the model 01:23:38.640 |
However, the model needs to access all the previous tokens 01:23:44.800 |
because the model needs to access all the prompt 01:23:57.040 |
Suppose we do the same job that we did before. 01:24:07.600 |
So it will be multiplied by the transposed of the keys. 01:24:10.400 |
This will produce this matrix here, which is one by one. 01:24:29.200 |
in which the only token we give is the start of sentence. 01:24:32.320 |
Then we take this token output, the token at the output. 01:24:39.840 |
because this has to be mapped to the linear layer, etc, etc. 01:24:55.920 |
We multiply it by the transposed of the keys. 01:25:09.280 |
Then we append the output of the previous as the input 01:25:14.160 |
and we multiply it by the transposed of the keys. 01:25:21.440 |
We then append the output of the last one at the Q. 01:25:25.280 |
We multiply it by the transposed of the keys. 01:25:28.720 |
We multiply it by the V and we get this sequence as output. 01:25:35.120 |
We are doing a lot of computations that we don't need. 01:25:37.680 |
First of all, these dot products that we are computing here 01:25:45.520 |
So Q multiplied by the transposed of the keys 01:25:48.320 |
will result in a lot of dot products that result in this matrix. 01:25:51.600 |
These dot products that you see here highlighted in violet 01:25:55.520 |
have been already computed at the previous steps 01:25:59.760 |
but these have already been computed at the previous step. 01:26:02.240 |
Plus, not only they have been computed already, 01:26:05.120 |
we don't need them because we only are interested in 01:26:08.720 |
what the latest token that we added as the input, 01:26:12.880 |
what is this tokens dot product with all the other tokens 01:26:19.120 |
because this tokens dot product with all the other tokens 01:26:27.760 |
So if there is a way to not do all these computations again 01:26:33.120 |
and also to not output all the previous tokens 01:26:46.160 |
we always take the last token and we use it as input. 01:26:56.480 |
But because the query needs to access all the previous tokens, 01:27:03.600 |
So we append the last input to the keys and the values 01:27:14.800 |
For example, this is our first step of inferencing. 01:27:28.480 |
This token, in the previous case, was appended to the queries. 01:27:33.520 |
So in the next step, it became a matrix of dimension 2 by 4096. 01:27:39.680 |
But in our case, at the time step 2, we don't append it. 01:27:43.360 |
We only append it to the end of the keys and the values. 01:27:52.080 |
we will see that this row here is the only one we are interested in. 01:27:55.520 |
So the one that was not violet in the previous diagram. 01:28:00.960 |
it will result in only the last token, the one we are interested in. 01:28:14.880 |
But the number of dot products that we are doing 01:28:19.680 |
We don't need to do all those dot products that we did before. 01:28:35.520 |
And so that's why it's much faster to do inferencing with the KV cache. 01:28:47.280 |
And every time we add the token to the queue grows, right? 01:28:50.720 |
But all these previous values with the KV cache, we are not computing it again. 01:29:01.600 |
If this mechanism is not clear, please watch my previous video about LAMA, 01:29:04.880 |
in which I describe it in much more detail and also with much more visualizations. 01:29:15.440 |
There is another thing actually I want to show you before we go to build it, 01:29:26.400 |
because it's actually the successive version of the multi-query attention. 01:29:30.240 |
It's something in between the multi-head attention and the multi-query attention. 01:29:33.520 |
But actually, the real name is grouped query attention in the paper. 01:29:38.560 |
Now, the reason we introduced the grouped query attention 01:29:44.400 |
is first of all, we had the multi-query attention. 01:29:47.680 |
The multi-query attention basically were introduced to solve one problem. 01:29:51.840 |
That is, we first had the multi-head attention. 01:29:55.920 |
We introduced the KV cache with the multi-head attention. 01:30:02.240 |
The problem was that with the multi-head attention, 01:30:06.400 |
With the multi-head with the KV cache, we do less dot products. 01:30:13.120 |
But it also resulted in a new bottleneck for the algorithm. 01:30:16.560 |
So the bottleneck was not longer the number of computations, 01:30:19.680 |
but how many memory access we were performing to access these tensors. 01:30:32.800 |
than it is at moving tensors around in its memory. 01:30:38.480 |
we not only need to consider how many operations we are doing, 01:30:46.880 |
So it's not a good idea to keep copying tensor from one place to another, 01:30:51.280 |
because the GPU is much slower at copying memory from one place to another 01:30:57.040 |
And this can be visualized on the datasheet of the GPU. 01:31:02.560 |
You can see, for example, that computing operations 01:31:04.640 |
is 19.5 Tera floating point operations per second. 01:31:08.640 |
And while the memory bandwidth, so how fast it can move memory, 01:31:14.000 |
So we need to optimize algorithms also for managing 01:31:20.480 |
how many tensors we access and how we move them around the memory. 01:31:24.480 |
This is why we introduced the multi-query attention. 01:31:28.000 |
The multi-query attention basically means that 01:31:34.080 |
but we only have one head for the key and the values. 01:31:38.000 |
This resulted in a new algorithm that was much more efficient 01:31:47.120 |
Because the KVCache, yeah, it reduced the number of dot products, 01:31:50.480 |
but it had a new bottleneck, that is the number of memory access. 01:31:53.600 |
With this algorithm, we also may optimize the memory access, 01:31:59.440 |
because we are reducing the number of heads for the key and the values. 01:32:02.960 |
So we are reducing the number of parameters in the model. 01:32:06.400 |
And this way, the model, because we are reducing 01:32:11.440 |
the number of parameters involved in the attention mechanism, 01:32:17.680 |
But we saw that practically it degraded the quality not so much. 01:32:26.800 |
So they show that the quality degradation was very little, 01:32:33.920 |
but the performance gains were very important. 01:32:41.440 |
to 5 microseconds or 6 microseconds per token, 01:32:45.600 |
Now, let's introduce the grouped query attention 01:32:59.520 |
n heads for the keys and n heads for the values. 01:33:02.720 |
In the multi-query attention, we have n heads for the keys, 01:33:07.520 |
but only one head for the keys and the values. 01:33:10.000 |
In the grouped multi-query attention or the grouped query attention, 01:33:13.920 |
we have less number of heads for the keys and values. 01:33:20.720 |
So every two heads for the queries, in this case, for example, 01:33:24.160 |
we will have one head for the keys and the values. 01:33:28.960 |
And this is a good balance between quality and speeds, 01:33:33.120 |
because, of course, the fastest one is this one, 01:33:37.280 |
But, of course, the best one from a quality point of view is this one, 01:33:42.560 |
but this is a good compromise between the two. 01:33:44.720 |
So you don't lose quality, but at the same time, 01:33:47.520 |
you also optimize the speed compared to the multi-head attention. 01:33:50.880 |
So now that we have reviewed all this concept, let's go build it. 01:33:54.960 |
So please, again, if you didn't understand very much in detail, 01:34:02.960 |
in which I explain all this part much better. 01:34:04.960 |
Otherwise, if I have to repeat the same content of the previous video, 01:34:08.960 |
this would be the current video would become 10 hours. 01:34:36.480 |
Compared to the original code from Facebook, from Meta, 01:35:03.280 |
And KVHeads indicates the number of heads for the keys and the values, 01:35:17.200 |
because they can be different than the number of heads for the queries. 01:35:49.600 |
This value here represents the ratio between the number of heads for the query 01:35:54.560 |
and the number of heads for the keys and the values. 01:35:57.520 |
We will use it later when we calculate the attention. 01:36:22.800 |
And then we have a self.headDimension, which is… 01:36:54.720 |
This indicates the part of the embedding that will be visualized by each head. 01:36:59.920 |
Because, as you know, the embedding is split into multiple heads. 01:37:02.960 |
So each head will watch the full sentence, but a part of the embedding of each word. 01:37:19.520 |
WQ, WK, WV, and WO, just like in the normal vanilla transformer. 01:38:32.560 |
I just now created one for the keys and one for the values. 01:38:50.560 |
Okay, finally, we implement the forward method, which is the salient part here. 01:39:24.640 |
To simplify the code for you, I will write the… 01:39:30.000 |
For each operation, I will write the dimensions of the tensor that is involved in the operation, 01:39:35.040 |
and also the resulting tensor from each operation. 01:39:38.000 |
The start position indicates just the position of the token inside of the sentence. 01:39:48.000 |
And these are the frequencies that we have computed. 01:40:21.360 |
Then what we do is we multiply, just like in the original transformer, 01:40:28.640 |
We multiply it by then the WQ, WK, and WK matrix. 01:40:39.680 |
This means going from B, one dimension, to B, one head dimension. 01:40:49.120 |
So, the number of heads for the query multiplied by the dimension, 01:40:58.080 |
So, the number of heads multiply the head dimension, as you can see from here. 01:41:15.040 |
In this case, however, we may change the shape of the… 01:41:28.560 |
Because the number of heads for the kv may be smaller than q. 01:41:34.080 |
So, this matrix may have a last dimension that is smaller than xq. 01:41:48.640 |
Apply the WQ, WK, and WV matrix to queries, keys, and values, 01:42:02.320 |
which are the same, because it's a self-attention. 01:42:07.440 |
We then divide them into their corresponding number of heads. 01:42:30.000 |
So, we divide b1, h, q multiplied by head dimension into b1, head, h, q, and head dimension. 01:42:46.160 |
So, we divide them into the h heads for the query. 01:42:52.720 |
And then we do the same for the key and the values. 01:43:49.360 |
now, we have multiplied, okay, we have the x input, 01:43:59.040 |
As you remember, we take the input, we multiply it by WQ, WK, and WV. 01:44:06.880 |
We then divide them into the number of heads. 01:44:09.440 |
But in the case of grouped query attention, they may be different. 01:44:12.160 |
So, this may be four heads, and this may be two heads, and this may be two heads. 01:44:18.960 |
The next thing we are going to do, and this is present in the here, 01:44:24.400 |
we need to apply the rotary positional encodings to the query 01:44:37.680 |
And this is how we apply the positional encodings. 01:44:49.440 |
This will not change the size of the vectors. 01:44:53.760 |
Because at the end, we have the same shape as the original input vector. 01:45:32.640 |
As we can see here, every time we have an input-output token, 01:45:40.160 |
so for example, the attention 2 here, it supposes the token number 2, 01:45:44.320 |
we append it at the end of the keys and the values. 01:45:51.200 |
So, what we do here, we keep a cache of the keys and the values, 01:45:56.800 |
because they will be used for the next iterations. 01:45:59.840 |
Because at every iteration, in X, we only receive the latest token 01:46:07.840 |
We append it to the K and the V, and then we compute the attention 01:46:14.080 |
between all the K, all the V, but only the single token as query. 01:46:38.080 |
This should be 1, because sequence length is actually 1, always. 01:46:44.080 |
But I try to keep this code the same as the one from Lama, from Meta. 01:46:51.840 |
This is, basically, it means that if we have one token from many batches, 01:47:01.920 |
I mean, we have one token for every batch, we replace them, 01:47:10.720 |
So, we replace the entry for this particular position for every batch. 01:47:25.860 |
Now, we replace it only for this position here. 01:47:31.360 |
But when we compute the attention using the KVCache, let's go watch again, 01:47:36.320 |
we need to calculate the dot product between the only one token but all the keys. 01:47:46.960 |
And then we will need to multiply with all the values, 01:47:49.840 |
and this will result in only one token as output. 01:47:52.320 |
So, we need to extract from this cache all the tokens as keys 01:47:56.800 |
and all the tokens as values up to this position here. 01:48:21.040 |
So, starting from 0 up to startPos plus sequenceLength, 01:48:46.800 |
Now, what happens is that, let me write also some sizes here. 01:49:01.760 |
because the sequenceLength of the input is always 1, we know that. 01:49:05.280 |
But the sequenceLength of the cache means all the cached keys and values, 01:49:16.320 |
So, this sequenceLength is actually equal to startPosition. 01:49:24.800 |
My next dimension is the number of heads for the K and V, 01:49:34.640 |
Now, the number of heads for the keys and values 01:49:41.200 |
may not correspond to the number of heads of the queries. 01:49:48.000 |
In the original code from Lama, what they did was basically, 01:49:58.720 |
we have that the number of heads for the keys and the values 01:50:01.840 |
is not the same as the number of heads for the queries. 01:50:10.160 |
The other way is to just copy this single head into multiple heads, 01:50:21.520 |
and then we just compute it just like a multi-head. 01:50:24.320 |
This is not an optimized solution, but it's the one used by the code by Lama. 01:50:31.680 |
because I don't have any way of testing other codes, 01:50:35.760 |
because the only model that supports the grouped query attention 01:50:40.480 |
is the biggest one from Lama, so with 70 billion parameters, 01:50:44.240 |
but my computer will never be able to load that model. 01:50:49.600 |
so that's why I also didn't optimize the code 01:50:52.240 |
for actually computing the grouped query attention, 01:50:54.160 |
but I will just replicate this single head multiple times, 01:51:38.000 |
just repeats the keys until we reach the number of, 01:51:46.720 |
It's the ratio of the number of heads of the queries 01:51:51.600 |
So, if the number of heads of the keys is four, 01:51:55.040 |
and the number of heads for the queries is eight, 01:51:57.200 |
that means we need to repeat twice each head. 01:51:59.840 |
So, let's build also this method, since we are here. 01:52:37.680 |
So, the first thing we do is we add a new dimension, 01:52:58.640 |
part_sequence_length, number_of_heads, then nothing, 01:53:08.880 |
and then this will add this new dimension in this position. 01:53:51.760 |
We repeat all the sequence this dimension number of times, 01:54:07.840 |
And this is how we repeat the keys and also the values. 01:54:23.600 |
Okay, now we just proceed just like with the standard, 01:54:29.120 |
the standard calculation for the multi-head attention. 01:54:34.400 |
That is, we first move the head dimension before the sequence dimension, 01:54:42.080 |
because each head will watch all the sequence, 01:54:55.680 |
because 1 is the sequence length of the queries, 01:54:58.000 |
the number of heads of the queries, and head dimension, 01:55:05.680 |
batch head sequence length and head dimension. 01:55:22.000 |
Then we do the standard formula for queries multiplied by the transpose of the keys, 01:55:34.160 |
divided by the square root of the dimension of each head. 01:55:37.040 |
So, xq, so the queries multiplied by the transpose of the keys, 01:55:48.960 |
all of this divided by the square root of the dimension of each head. 01:56:03.360 |
and this one will result in a shape of queries, 01:56:44.960 |
So, the formula is queries multiplied by the transpose of the keys, 01:56:50.960 |
and then we do the softmax, then the output is multiplied by the values. 01:57:28.160 |
and then we multiply it by the output matrix, 01:57:32.800 |
but before we remove all the heads, so we concatenate again. 01:57:39.040 |
So, here we take the output of all the heads, 01:57:42.000 |
then we concatenate them together, and then we multiply it by the wo matrix. 01:58:17.200 |
b1 dim, this one is bhq one head dimension into b1 hq head dimension, 01:58:38.480 |
and then we remove the dimension for the head, so b1 dimension. 01:58:52.720 |
Here, I think I made some mistake, because self, that's why it's colored differently. 01:59:02.080 |
When we calculated the self-attention, because we are inferencing, 01:59:05.760 |
so this code will only work for inferencing, we can use the kvcache. 01:59:09.840 |
The kvcache allow us to save a number of dot products that we don't need. 01:59:15.120 |
Why? Because every time we are in the original transformer, 01:59:18.880 |
we were computing a lot of dot products for tokens, 01:59:24.800 |
In this case, we simplified the mechanism to output only one token. 01:59:29.600 |
As you can see, the output of the self-attention is b, so batch, 01:59:33.440 |
one token only with its embedding size, which is 4096. 01:59:40.160 |
So, we are only outputting one token, not many tokens. 01:59:44.880 |
We input only one token, and we output one token. 01:59:47.680 |
But because we need to relate that single token with all the previous tokens, 01:59:55.680 |
Every time we have a token, we put it into the cache, like here, 02:00:00.400 |
then we retrieve all the previous saved tokens from the cache, 02:00:04.480 |
and then we calculate the attention between all the previous tokens, 02:00:07.600 |
so the keys and the values, and the single token as input of, as queries. 02:00:20.880 |
And the grouped query attention is the fact that we have a different number of heads 02:00:29.680 |
we do have a different number of heads for the keys and queries. 02:00:37.680 |
But we just repeat the one that we are missing to calculate the attention. 02:00:42.080 |
So the attention is calculated just like the previous transformer, 02:00:48.560 |
but by repeating the missing keys and values heads, 02:00:52.880 |
instead of actually optimizing the algorithm. 02:00:55.360 |
This has also been done by Meta in its official implementation, 02:01:01.040 |
The biggest reason is because I cannot test any other modification. 02:01:05.120 |
I cannot test another algorithm that actually tries to optimize this calculation. 02:01:10.800 |
So if I find another implementation that I know is working, 02:01:16.640 |
Otherwise, I will try to run it on Colab and see if I can come up with a better solution. 02:01:23.440 |
But at least we got the concept of the grouped query attention. 02:01:29.440 |
and it's something that is in between the multi-query attention and the multi-head attention. 02:01:34.880 |
That doesn't sacrifice quality, but improves speed. 02:01:38.080 |
Now, the last thing that we didn't implement is the feedforward layer. 02:01:43.120 |
For the feedforward layer, the only thing that we need to review 02:01:45.680 |
is the ZWIGGLU activation function that we can see here. 02:01:48.320 |
And this activation function has been changed compared to the previous 02:01:53.360 |
activation function used in the vanilla transformer, which was the RELU function. 02:01:57.680 |
And the only reason we replaced it is because this one performs better. 02:02:02.480 |
And as I showed in my previous video, we cannot prove why it works better. 02:02:08.400 |
Because in such a big model with 70 billion parameters, 02:02:11.520 |
it's difficult to explain why a little modification works better than another. 02:02:15.840 |
We just know that some things work better in practice for that kind of model 02:02:24.480 |
So as you can see here in the conclusion of the paper, 02:02:27.520 |
they say that we offer no explanation as to why this architecture seems to work. 02:02:31.280 |
We attribute their success as all else to divine benevolence. 02:02:34.720 |
So it means that when you have such a big model 02:02:36.960 |
and you change a little thing and it works better, 02:02:39.120 |
you cannot always come up with a pattern to describe why it is working better. 02:02:43.120 |
You just take it for granted that it works better 02:02:46.560 |
and you use it because it works better in practice. 02:02:49.520 |
So to implement the ZWIGGLU function, we need to apply... 02:02:53.760 |
This is the formula from the original transformer. 02:03:00.880 |
So this is the RELU function of the first linear layer and the second linear layer. 02:03:05.520 |
In LAMA, we use the ZWIGGLU function which involves the three matrices here. 02:03:10.800 |
Because they incremented the number of parameters here 02:03:13.600 |
and also they were experimenting with the grouped query attention, 02:03:17.840 |
the architecture of LAMA has some more parameters 02:03:23.760 |
to adjust the number of parameters of this feedforward layer. 02:03:30.720 |
And this is actually used in deep learning research. 02:03:37.360 |
and this reduces the number of parameters or increases the number of parameters, 02:03:42.880 |
they adjust the numbers of parameters of the feedforward layer 02:03:46.720 |
so that when they make comparison between two models, 02:03:51.920 |
So I will also, of course, use the same structure 02:03:56.160 |
because otherwise I cannot load the weight from the pre-trained model. 02:04:07.680 |
Then they do the two-third of this dimension. 02:04:10.880 |
And then they also have a multiplier if it's specified. 02:04:36.800 |
By using this modification to calculating the hidden dimension like this, 02:04:41.600 |
it may not be the case that this hidden dimension is a multiple of this number here. 02:04:48.720 |
So maybe they want the size of the hidden dimension to be multiple of this number here. 02:04:54.480 |
So maybe they want the size of the hidden dimension to be multiple of this number here. 02:04:57.280 |
So maybe they want the size of the hidden dimension to be multiple of this number here. 02:04:59.760 |
So maybe they want the size of the hidden dimension to be multiple of this number here. 02:05:02.560 |
So maybe they want the size of the hidden dimension to be multiple of this number here. 02:05:03.440 |
So maybe they want the size of the hidden layer to be a multiple of 256. 02:05:08.880 |
But by calculating it like this, it may not be. 02:05:11.680 |
So what they do is they make it round up to the next multiple of the multiple of parameter. 02:05:38.640 |
It's easier to show with an example than to actually write it. 02:05:42.400 |
So suppose you have the hidden size is equal to, let's say, 7. 02:05:48.800 |
But you want it to multiple of is equal to 5. 02:05:52.240 |
So you want the hidden size to be a multiple of 5. 02:05:56.240 |
Well, what we do is, basically, we do hidden plus 4 in this case. 02:06:18.320 |
It will result in the first multiple that is bigger or equal to this number here. 02:06:25.680 |
And then we have these matrices for the Zwiglu function. 02:06:31.440 |
We just follow the formula for the Zwiglu function, which is here. 02:06:45.520 |
Because the Zwish with the beta is equal to 1 is actually the Sillu function, 02:06:52.800 |
And then we multiply it with another parameter matrix here. 02:06:56.800 |
And then we apply it to another linear layer, w2. 02:06:59.920 |
So in total, we have three matrices, w1, we call it w2, and w3. 02:07:50.400 |
The first thing we do is we calculate the Zwish function. 02:08:00.400 |
Then we calculate, so we are calculating, let me show you. 02:08:12.320 |
We are calculating this one, xw, Zwish of xw. 02:08:20.640 |
Then we multiply them together, just like in the formula. 02:08:34.720 |
And then we apply the last linear layer, which is w2. 02:08:40.240 |
Which results in a multiplication by the w2 matrix, by the way. 02:08:53.200 |
Now that we have all the building blocks, we need to go to the inferencing. 02:09:04.000 |
So inference.py, the first code we will be, first we will build a code to load the model. 02:09:10.960 |
And then we will build a code to inference the model. 02:09:14.640 |
I will actually also show all the inference techniques that are out there, 02:09:20.560 |
So let's start by building first the code for loading the model. 02:09:44.640 |
And then we need the sentence piece to load the tokenizer. 02:09:49.040 |
Because the sentence piece is the tokenizer that has been used, 02:10:00.100 |
From model import model-args and the transformer class. 02:10:05.920 |
We define the class Lama, which is our model. 02:10:11.840 |
It takes a transformer, a tokenizer, which is a sentence piece processor. 02:10:56.020 |
And we call it build, just like in the original code from Lama. 02:10:59.060 |
In which we pass the directory where the checkpoints are saved. 02:11:05.540 |
In this case, the directory name is lama27b, in my case. 02:11:08.740 |
But it depends on which size of the model you have downloaded. 02:11:12.260 |
Then the tokenizer path, which is the path to the tokenizer. 02:11:17.460 |
This is the file of the tokenizer that I downloaded. 02:11:19.780 |
Then we have a load model layer, max sequence length. 02:11:42.100 |
This is only for displaying how much time it takes to load the model. 02:11:45.620 |
If we want to load the model, we will also load the checkpoints. 02:11:59.780 |
The glob method allows you to find all the files that match this filter. 02:12:13.780 |
Okay, we see that we are loading checkpoint this one. 02:12:53.140 |
We can show how much time it takes to load the model. 02:13:01.060 |
In my computer, usually it takes 10 to 20 seconds. 02:13:23.620 |
So we can also show how much time it takes to load all the parameters of the model. 02:13:29.380 |
Then we load the parameters, so the JSON file. 02:14:04.580 |
Maximum sequence length is the one we have specified. 02:14:12.580 |
And then we have the max patch size is the max patch size. 02:14:27.060 |
And then all the parameters loaded from the JSON file. 02:14:35.620 |
Then we, by using the tokenizer, we can populate the vocab size of the model args. 02:14:50.740 |
The vocabulary size is actually the number of tokens inside the tokenizer. 02:14:55.060 |
Now this is also the default tensor for PyTorch. 02:15:12.500 |
So whenever PyTorch wants to create a new tensor, what kind of type it should use, 02:15:17.140 |
it's defined, this is by meta, so they want for CUDA to use this type that I show you here. 02:15:28.980 |
This changes the precision that the tensor supports. 02:16:07.620 |
actually the checkpoint is a list of key and values. 02:16:15.540 |
So the weight, for example, of a linear layer, 02:16:17.780 |
or the bias of a linear layer, or something like this. 02:16:20.420 |
And the names that we have used for the variable names and the matrices here, 02:16:29.460 |
for example, wqwk, match actually the name that are present in the checkpoint here, 02:16:36.740 |
So to make sure that I have used the right names, 02:16:39.300 |
I will load the checkpoint with strict equal true. 02:16:45.140 |
Strict equal true means that if there is at least one name that doesn't match, 02:16:50.340 |
So if load model, model.loadState ticked, strict equal true. 02:17:03.460 |
So if there is at least one name in the loaded file that doesn't match the name 02:17:08.340 |
in the classes that I have created here in the model, it will throw an error. 02:17:11.940 |
But I know that there is one key that we don't need, 02:17:15.620 |
which are the frequencies for the rotary positional embeddings, 02:17:19.540 |
which we actually are computing every time we create the tensor. 02:17:23.140 |
So we are creating them here by using this function. 02:17:26.660 |
So we don't need to load them from the model. 02:17:29.540 |
So we can remove it from the model, from the checkpoint. 02:17:33.780 |
So because the checkpoint is a dictionary, we can just remove this. 02:17:40.100 |
And then we can print how much time it took to load the model. 02:18:09.380 |
Now, before we proceed further, let me test if the model can be successfully loaded. 02:18:30.340 |
Then I don't want to use CUDA because my GPU doesn't support it. 02:18:37.220 |
Then device is equal to storage.cuda.is_available and allow_cuda else cpu. 02:18:47.620 |
Next time if you want to load the model with CUDA, just set this variable to true. 02:18:52.900 |
But in my case, I will always leave it to false because I don't want to load CUDA. 02:19:37.940 |
Let's run it and hopefully it will not crash. 02:20:00.100 |
There is always a lot of typos when you write code. 02:20:41.940 |
It means that at least it's doing something and it's not crashing, which is always a good 02:20:50.420 |
It means that at least it's doing something and it's not crashing, which is always a good 02:20:55.940 |
So our next step is actually to build the inferencing code. 02:20:59.460 |
So what we want to do is actually we want to be able to give some prompts to the model 02:21:13.380 |
And here we pass, for example, the size of the prompts. 02:21:17.940 |
And then we want to, you know, we want to inference the model. 02:21:26.100 |
So before we start inferencing the model, we need to build the code for inferencing 02:21:30.500 |
the model, because we need to find a strategy for selecting the next token, etc, etc. 02:21:34.980 |
So let's review how the inferencing works and what are the various strategies for inferencing. 02:21:39.700 |
Okay, so when we are dealing with the next token prediction task, when we want to inference, 02:21:45.460 |
we usually give the prompt and then we want to predict the tokens. 02:21:52.260 |
And every time we give one more token, the model will output one more token as output. 02:21:58.340 |
But with the KVCache, actually, we always give one token at a time. 02:22:02.660 |
The KVCache will keep the cache for the keys and the values and with only output one token. 02:22:08.820 |
Okay, the point is, we need to find strategies for selecting this token. 02:22:14.020 |
Among all the tokens that we have in the vocabulary. 02:22:17.060 |
And this is the job of the logits and the softmax. 02:22:22.180 |
Now imagine I give you the following task as human. 02:22:29.700 |
I think nuclear power is and then you have to choose a word. 02:22:33.860 |
Now you as human may have thought of the possible next tokens, which may be clean, dangerous, 02:22:40.100 |
cheap, expensive, safe, difficult, or something else. 02:22:43.540 |
The choice of the next token in your head depends on your education, 02:22:47.780 |
on your experience with nuclear power, and your opinion on the matter. 02:22:51.460 |
Large language models also face the same problem. 02:22:55.220 |
When we give them a prompt, then the model has to choose the next word. 02:22:59.300 |
The model, the uncertainty of the choice derives entirely from their training process 02:23:05.220 |
and the strategy that we use to select the next token. 02:23:10.820 |
For example, we have the greedy strategy, the beam search, temperature is a parameter, 02:23:17.060 |
In this video, we will review all these strategies and how they work. 02:23:21.140 |
But first, we need to understand what are the logits. 02:23:23.700 |
Let's look at the transformer model from Lama. 02:23:28.260 |
So the output of the self-attention is a sequence. 02:23:33.540 |
In the case of the KVCache is only one token. 02:23:38.980 |
So after normalization, we run it through a linear layer. 02:23:41.780 |
The linear layer will transform the embedding that is output from the self-attention here 02:23:47.060 |
into a list of numbers that represent the kind of the probability, 02:23:54.820 |
they are not really a probability, but we can think of it as a probability, 02:24:01.540 |
So if our vocabulary is made of, let's say, 100 tokens, 02:24:11.940 |
these 100 numbers will become the probability of that token being the next more probable token 02:24:22.500 |
So given an input, a prompt, the model comes up with probabilities. 02:24:27.620 |
Probabilities for which token to choose next. 02:24:30.660 |
And so what is the job of the linear layer and what is the job of the softmax? 02:24:35.700 |
The linear layer converts the embedding of a token into a list of numbers 02:24:41.060 |
such that each number represents a score that later with the softmax 02:24:46.020 |
represents the probability of that particular token in the vocabulary. 02:24:50.260 |
The softmax job is just to scale the logits in such a way that they sum up to one. 02:24:55.380 |
So that's why we can talk about probabilities with the softmax, but not with the logits. 02:25:00.100 |
So the output of the softmax is thus a probability distribution 02:25:06.580 |
That is, each word in the vocabulary will have a probability associated with it. 02:25:10.580 |
But now, given these words, each one with their probability, 02:25:20.420 |
The greedy strategy basically says we just select the token with the maximum probability. 02:25:25.540 |
So imagine we are inferencing and the time step is the first time step in the greedy strategy. 02:25:31.380 |
The prompt is Celia, you're breaking my heart. 02:25:35.620 |
OK, this is a line from a very famous song from Simone Ergampfunkel. 02:25:41.220 |
And the next word, for those who know, will be confidence. 02:25:50.500 |
Suppose the output of the softmax is this distribution here. 02:26:03.220 |
With a greedy strategy, we always choose the token with the maximum probability. 02:26:11.780 |
So the input at the next inference step becomes Celia, you're breaking my heart. 02:26:17.620 |
And then the model has to come up with the next word, which, if you know the song, is daily. 02:26:22.420 |
If we use the greedy strategy, we select the one with the highest probability. 02:26:27.860 |
So in this case, it's daily, and it's also the correct one. 02:26:33.380 |
At every step, we choose the token with the maximum probability, 02:26:36.900 |
which is then appended to the input to generate the next token, and so on. 02:26:40.900 |
But if the initial token happens to be the wrong one, 02:26:45.140 |
so not only the initial, but the initial two, three tokens happen to be the wrong ones, 02:26:49.300 |
it's very likely that all the next tokens will also be wrong, 02:26:52.740 |
because we are giving a wrong prompt to the model. 02:26:55.220 |
So imagine at the time step one, we don't choose confidence, 02:26:59.860 |
but somehow the model came up with a high score for liver. 02:27:03.540 |
So you're shaking my liver, but then the next word, 02:27:06.660 |
the model will not be able to come up with a reasonable next word, 02:27:10.420 |
because there is no song that says you're shaking my liver. 02:27:14.580 |
So if we make a mistake in the early stage of the greedy, 02:27:17.780 |
all the next token very probably will also be wrong. 02:27:23.060 |
And however, it performs poorly in practice, that it's very, it's not used so much. 02:27:30.980 |
In BeamSearch, we have a parameter, which is called K, 02:27:35.060 |
which means that at every step, we not only choose the top ones, 02:27:41.060 |
And we always keep the top two best performing tokens. 02:27:45.380 |
So in this case, for example, imagine we are time step one. 02:27:51.940 |
And the top two words are pizza and confidence. 02:27:55.620 |
Pizza somehow has a higher, has a higher probability, 02:28:00.260 |
because maybe the model has never seen this song before. 02:28:03.700 |
So it doesn't know that the next word is confidence. 02:28:07.060 |
So maybe the model outputs these probabilities. 02:28:10.180 |
But we choose the two top most, the two tokens with the highest probabilities. 02:28:22.020 |
one in case we choose the first one, so the first token, 02:28:27.780 |
And then we see what are the next possible choices if we use the first token. 02:28:32.500 |
And what are the next choices if we use the second token? 02:28:35.540 |
So we check the model output for the first prompt and for the second prompt. 02:28:39.940 |
And in case we use, for example, the first prompt, 02:28:47.380 |
And if we use the second prompt, the model will output these probabilities. 02:28:51.700 |
What we do then is we calculate the cumulative score for each possible path. 02:28:57.460 |
So for pizza, for example, the probability was 40%. 02:29:00.340 |
But after pizza, the model produced the probability for the margarita, 02:29:07.700 |
So for this path, pizza, margarita, it's 0.004. 02:29:14.580 |
Pizza, anchovies, it's going to be 0.2% or 0.002. 02:29:21.860 |
However, with confidence, we get a new next token that can be either daily or monthly. 02:29:29.380 |
With daily, we get a cumulative score of 0.16 and with monthly of 0.02. 02:29:38.260 |
even if at the time step one, pizza was the most probable word, 02:29:45.860 |
so we didn't kill it, just like we did it with greedy. 02:29:51.300 |
We can see that the confidence then produces a next token that is very probable, 02:30:00.260 |
And so it can come up with more specific choices for the next tokens 02:30:07.780 |
So we compute the cumulative score of all these paths, 02:30:14.180 |
and we keep the two paths that have the top choices. 02:30:19.700 |
because it's later we chose pizza at the beginning, 02:30:22.420 |
because somehow the model thought it was pizza, 02:30:25.860 |
The model was not so confident about the next words. 02:30:31.300 |
the model was very confident about the second score. 02:30:33.540 |
So we killed all this path here, and we kept this one 02:30:38.980 |
in which we just selected the path with the highest score, 02:30:42.180 |
and that's the output of our inferencing strategy with BeamSearch. 02:30:45.860 |
And repeat the steps of the last slide for all the successive tokens 02:30:53.620 |
And with BeamSearch, at every step we keep alive the top k paths, 02:31:04.820 |
because at every step we must explore k possible options, 02:31:08.100 |
but generally it performs better than the greedy strategy, 02:31:12.580 |
Another thing that is interesting in inferencing is the temperature, 02:31:20.740 |
we can make the model more confident or less confident. 02:31:31.620 |
so they are what will become the probabilities after we apply the Softmax. 02:31:36.100 |
So before we apply the Softmax, we can scale the logits, 02:31:48.020 |
so the Softmax probabilities are reasonable numbers. 02:31:53.700 |
And if we divide these logits before applying the Softmax 02:32:01.060 |
it's called, this number is called the temperature, 02:32:06.100 |
because it will make bigger probabilities bigger 02:32:10.900 |
So the gap between the low and high probability increases. 02:32:15.220 |
So for example, you can see that without applying any temperature, 02:32:32.980 |
So the gap between the low and high probability reduces. 02:32:35.860 |
The temperature is important if we want to increase 02:32:42.820 |
because it can be used in conjunction with other strategies, 02:32:47.780 |
or the top k or the top v that we will see later. 02:32:55.620 |
So as we saw, the logits are not a probability distribution, 02:32:59.060 |
but after we apply the softmax, they become a distribution. 02:33:12.580 |
we have one token that can be chosen with a 12% probability, 02:33:18.420 |
and one that can be chosen with 80% probability. 02:33:24.980 |
and 12% of the time we will choose this token, 02:33:27.380 |
and 7% of the time we will choose this token. 02:33:32.500 |
It means take a number from this distribution, 02:33:39.620 |
Now, there is a problem with this sampling strategy here, 02:33:47.380 |
it may happen that we choose tokens that are total crap. 02:33:56.020 |
for example, before, with the greedy strategy, 02:34:00.740 |
this token here, if we use a random sampling, 02:34:03.700 |
we will choose the word pizza with 40% probability, 02:34:11.620 |
it may happen that we will choose the word Pokemon 02:34:18.100 |
so the probability of us making a bad choice is low, 02:34:38.260 |
we sort them, and then we just keep the highest key, 02:34:48.260 |
and then we calculate the distribution for the rest. 02:34:51.220 |
So we apply the softmax only to the ones that survive. 02:34:58.420 |
Given the following, these two distributions, 02:35:00.740 |
the low probability tokens can still make their way 02:35:11.380 |
Imagine we have a distribution that is very flat. 02:35:32.740 |
But more or less, all the words have the same probability. 02:36:15.060 |
this token, this token, this token, this token. 02:36:21.780 |
So they will still make their way into our selection. 02:36:33.380 |
because sometimes the tokens that are in the top N tokens, 02:36:41.460 |
Also, sometimes the prompt can be quite ambiguous. 02:36:47.140 |
we may not know what is the next word to be chosen. 02:36:51.620 |
but we also don't want the very low probability tokens. 02:37:16.660 |
It means that if we have the previous distributions, 02:37:30.340 |
So this one is nearly 90% and the other one are 0.000%, 02:37:36.580 |
but this more or less all of them are like 0.2% 02:37:43.940 |
In the case, imagine P is equal to, let's say, 0.5. 02:37:51.540 |
such that the area under the curve is equal to 0.5. 02:37:57.140 |
But here, because this first token is already 0.9, 02:38:05.700 |
because this area under the curve is already 0.9. 02:38:12.340 |
So when the model, when the distribution is more flat, 02:38:22.740 |
But when we have a big mode, we select fewer tokens 02:38:27.300 |
and this way we avoid getting the low probability ones. 02:38:36.180 |
for selecting the token, we will implement it. 02:38:39.860 |
And in the case of Lama, also in the official code, 02:38:45.540 |
In my case, I think that the BeamSearch is a reasonable choice. 02:38:59.860 |
So we implement the method, let's call it TextCompilation, 02:39:18.900 |
And so 0.6 means that we want to make the model more confident. 02:39:33.300 |
such that their cumulative probability is at least 0.9. 02:39:56.020 |
Okay, so if we didn't specify the max generation length, 02:40:08.040 |
Just generate as much token as we can up to the sequence length. 02:40:14.100 |
And then we, first of all, convert each token of the prompt. 02:40:20.980 |
So each prompt, actually, into tokens using the tokenizer. 02:40:29.240 |
Then, as we saw before, we need to add the beginning of sentence 02:40:47.060 |
when we pass the input to the model for inferencing. 02:41:09.240 |
Because we specified the max batch also for the model when we built it for the KVCache, 02:41:20.020 |
so we need to make sure that the batch size of the prompts is not too large. 02:41:42.340 |
is the maximum prompt length that we have in the prompt. 02:42:16.020 |
I'm not writing any message, even if you should, but okay, 02:42:24.980 |
is how many tokens we want to get from the model. 02:42:45.620 |
Okay, now we create the list that will contain the generated token. 02:43:31.700 |
this means create a tensor of shape batch size by total length, 02:43:38.260 |
in which each item is actually the padding token. 02:43:40.820 |
And then we fill the initial tokens with the prompt tokens. 02:44:09.060 |
Okay, we also need this variable that tells if we reach the end of sentence 02:44:38.980 |
This indicates if the token in this position is a padding token or not, so true. 02:44:56.100 |
If the token is a prompt token, false, otherwise. 02:45:00.660 |
And then we can finally make the for loop to generate the tokens. 02:45:32.580 |
the logits come from the model, so set.model.forward. 02:45:56.100 |
And we also tell the model what is the position of this token, because for the KVCache. 02:46:06.500 |
As you can see, every time when we want to inference, we always select the last token. 02:46:31.460 |
But because we are using the KVCache, actually our model will only output one token at a time. 02:46:35.780 |
So, the next token will be selected according to our topP strategy. 02:46:58.420 |
If we didn't specify any temperature, we just use the greedy. 02:47:41.620 |
Now we have the next token according to this strategy or this greedy. 02:47:47.060 |
Then we only replace the token if it is a padding token. 02:47:54.260 |
So, the problem is, we already have some tokens that come from the prompt. 02:47:58.500 |
But we don't want to, but we still need to give the prompt to the model. 02:48:03.060 |
But we are only giving one token at a time to the model to build the initial cache. 02:48:07.060 |
So, we will give, the first prompt tokens will be given to the model, 02:48:11.060 |
not because we care about what the model will output for those tokens. 02:48:14.900 |
But only because we want the KV cache to be built for those positions. 02:48:21.780 |
And after we give the last token of the prompt, 02:48:25.140 |
then we care about what is the model outputting. 02:48:28.340 |
So, only replace the next token if it is a padding token. 02:48:34.820 |
The one that was not an initial prompt token. 02:48:38.020 |
Because here we build tokens full of paddings. 02:48:41.060 |
But then, we replace the prompt tokens, the padding tokens with the prompt tokens 02:48:52.900 |
All the others have to be inferred by the model. 02:49:24.740 |
If it's true, if the token is a prompt token. 02:49:26.740 |
So, if it is a prompt token, replace it with this one. 02:49:30.740 |
And if it's not a prompt token, just keep it the current one. 02:49:53.220 |
Since we do not care about what the model outputs for the initial prompt tokens, 02:50:04.180 |
we don't care if we find an end-of-sentence position for those tokens. 02:50:09.380 |
So, end-of-sentence is only reached if we find it for one of the tokens 02:50:14.820 |
not the one that we send to the model just to build a KV cache. 02:50:45.620 |
this basically means the end-of-sentence for a particular prompt is reached 02:51:15.140 |
And we actually found an end-of-sentence token from the model output. 02:51:22.580 |
If all of the prompts have reached the end-of-sentence token, 02:52:21.380 |
this means that if we found an end-of-sentence token for one of the prompts, 02:52:48.660 |
This is the output text and then we output the tokens and the text. 02:53:05.540 |
Hopefully, I didn't make too many typos and mistakes. 02:53:17.700 |
So, we have the logits that are the output of the model. 02:53:20.820 |
We transform them into probabilities by using the softmax. 02:53:24.180 |
But given these probabilities, we need to use the sample_top_p strategy 02:53:28.260 |
to select all the tokens such that their cumulative probability is equal to top_p, 02:53:38.500 |
Okay, the first thing we do is we sort these probabilities in descending order. 02:54:22.820 |
Then we create the mask that says which tokens we want to keep 02:54:28.260 |
So, mask is equal to probability_sum minus probability_sort more than p. 02:54:46.740 |
You can see here, for example, the cumulative probability. 02:54:52.500 |
So, the probabilities are this one, 44 percent, 40 percent, 6 percent, 4 percent, and 3 percent. 02:55:07.540 |
Then this one plus this one plus this one is 91 percent. 02:55:10.500 |
This one plus this one plus this one plus this one is 96 percent, etc, etc. 02:55:15.300 |
But imagine we have a 0.90 percent probability or 0.5 percent probability. 02:55:24.660 |
We need to keep up to this token here because this one is not enough. 02:55:33.300 |
So, the first number that is less than or equal to p. 02:55:42.900 |
And this is why we do this minus probability sort. 02:55:48.900 |
So, all the ones that we didn't select, we zero them out. 02:55:58.900 |
And then we redistribute the probabilities because, of course, 02:56:04.980 |
if we remove some items from here, they don't sum up to one anymore. 02:56:08.820 |
So, we need to redistribute the probabilities. 02:56:35.300 |
Okay, then the next token is basically, suppose we keep the first two tokens. 02:56:41.700 |
And then what we do is we want to sample from them. 02:56:44.500 |
So, the first token is 0.44 percent probability. 02:56:47.460 |
The second token is 0.40 percent probability. 02:56:50.420 |
But after we redistribute their probabilities, actually, 02:56:54.740 |
And this one will be a little higher than 40 percent. 02:57:00.820 |
It means that the first token will have a slightly better chances of being chosen. 02:57:04.660 |
And the second token will have slightly less chance of being selected. 02:57:07.860 |
And we want one sample because we want one token. 02:57:26.420 |
Because this indicates which index to select, 02:57:31.140 |
then we need to map that index to the actual number in the vocabulary. 02:57:36.500 |
But because we already changed the order of these numbers, 02:57:41.540 |
So, initially, the logits were built in such a way 02:57:45.700 |
that the first logit corresponded to the first number of the vocabulary. 02:57:49.220 |
The second logit corresponded to the second number of the vocabulary. 02:57:52.900 |
But because we sorted it by descending order, this order has been gone. 02:57:56.820 |
So, we don't know now, just given the token selected, 02:58:00.420 |
we don't know which number it maps back into the vocabulary. 02:58:04.580 |
That's why the sort method returns two arguments. 02:58:07.380 |
One is the sorted numbers and one is the indexes that it changed. 02:58:11.140 |
So, it will tell you for each position what was the original item in that position. 02:58:16.500 |
So, this is why we actually query using gather. 02:58:19.620 |
Gather allows us to retrieve from an element what was the original one, 02:58:33.380 |
And this will map back into the vocabulary directly. 02:58:43.460 |
So, now let's create some prompts and let's run the code. 02:58:47.540 |
I have some prompts here that I copied and pasted. 02:58:54.660 |
So, out_tokens, out_text, we want to generate maximum 64 tokens. 02:59:09.460 |
We assert that the len of the output text is actually equal equal to len of prompts. 02:59:18.900 |
So, for i in range, hopefully the model will work. 02:59:26.660 |
And then we print the output text for each prompt. 02:59:42.900 |
So, let's run the code and let's hope for the best. 02:59:46.580 |
Okay, self-attention is missing the required forward function. 03:00:11.700 |
Sum_received, this is wrong because it should be dimension, not div, but should be dim. 03:00:59.940 |
I just changed this tensor from capital T to small t. 03:01:13.780 |
Simply put, the theory of relativity states that time is relative to the observer. 03:01:29.140 |
Suppose the second prompt says if Google was an Italian company founded in Milan, 03:01:35.940 |
it would be listed on the Milan Stock Exchange, 03:01:38.500 |
as the Milan Stock Exchange is the largest in Italy. 03:01:41.060 |
But since Google is a US company, it is listed on the Nasdaq Stock Exchange. 03:01:44.740 |
So, it avoided actually answering the question. 03:01:48.820 |
So, this is how you copy it actually from the LAMA code. 03:01:52.180 |
So, they ask to translate from English to French. 03:01:54.420 |
And after cheese, we expect to find fromage, onion, etc. 03:02:02.660 |
And we can also see that the spaces have been kept. 03:02:06.420 |
So, these spaces were not added by me, but actually by the model. 03:02:09.700 |
So, it keeps the output aligned with what was the prompt. 03:02:16.340 |
So, tell me if the following person is actually Doraemon disguised as human. 03:02:20.100 |
So, the name is Umar Jameel, and the decision is 03:02:29.140 |
Actually, okay, this is the output of the model. 03:02:33.700 |
If I think I change the seed to some other number and run the model again, 03:02:38.500 |
the output will be totally different or maybe slightly different. 03:02:46.740 |
I tried to convey the idea of what is the architecture inside LAMA. 03:02:54.420 |
And even if I didn't build the training code, 03:02:56.980 |
because actually to build the training code is rather complicated, 03:03:00.420 |
we need a big corpus of text, we need to tokenize it, 03:03:06.900 |
But I hope to make another video in the future on how to train a language model, 03:03:13.460 |
maybe with a smaller dataset and with a lighter architecture. 03:03:16.420 |
And I tried to convey all the math behind all the choices, 03:03:21.300 |
and also how the inner workings of the KV cache and the grouped query attention. 03:03:28.980 |
If you have any questions, please write in the comments. 03:03:31.380 |
I also will share the repository with the code that I have previously built for this, 03:03:36.500 |
and which has much more comments than the one I have written here. 03:03:42.660 |
so everyone can understand step by step all the dimensions involved. 03:03:46.420 |
Here I tried to write the most important dimensions, 03:03:49.300 |
but because of time, I didn't write all of them. 03:03:54.580 |
It was a long journey, but I can assure you that you learned a lot. 03:03:59.060 |
And I hope you will visit again my channel for more videos about deep learning, 03:04:03.620 |
about PyTorch, about coding, and about everything that we love in AI.