back to index

Coding LLaMA 2 from scratch in PyTorch - KV Cache, Grouped Query Attention, Rotary PE, RMSNorm


Chapters

0:0 Introduction
1:20 LLaMA Architecture
3:14 Embeddings
5:22 Coding the Transformer
19:55 Rotary Positional Embedding
63:50 RMS Normalization
71:13 Encoder Layer
76:50 Self Attention with KV Cache
89:12 Grouped Query Attention
94:14 Coding the Self Attention
121:40 Feed Forward Layer with SwiGLU
128:50 Model weights loading
141:26 Inference strategies
145:15 Greedy Strategy
147:28 Beam Search
151:13 Temperature
152:52 Random Sampling
154:27 Top K
157:3 Top P
158:59 Coding the Inference

Whisper Transcript | Transcript Only Page

00:00:00.000 | Hello guys, welcome to my new coding video.
00:00:02.640 | In this video we will be coding Lama 2 from scratch.
00:00:05.640 | And just like the previous video in which I was coding the transformer model from zero,
00:00:11.640 | while coding I will also explain all the aspects of Lama, so each building block of the Lama architecture.
00:00:19.400 | And I will also explain the math behind the rotary positional encoding.
00:00:23.360 | I will also explain grouped query attention, the KV cache.
00:00:27.280 | So we will not only have a theoretical view of these concepts, but also a practical one.
00:00:33.120 | If you are not familiar with the transformer model, I highly recommend you watch my previous video on the transformer model.
00:00:38.640 | And then if you want, you can also watch the previous video on how to code a transformer model from zero,
00:00:43.440 | because it will really help you a lot in understanding what we are doing in this current video.
00:00:49.120 | If you already watched my previous video on the architecture of Lama, in which I explained the concepts,
00:00:55.080 | it will be also really helpful if you didn't, I will try to explain all the concepts,
00:01:01.880 | but not as much in detail as in the last video.
00:01:06.960 | So please, if you have time, if you want, please watch my previous video on the Lama architecture,
00:01:11.520 | and then watch this video, because this will really give you a deeper understanding of what is happening.
00:01:16.920 | And let's review the Lama architecture.
00:01:21.280 | Here we have a comparison between the architecture of the standard transformer,
00:01:25.040 | as introduced in the "Attention is all you need" paper, and the architecture of Lama.
00:01:30.320 | The first thing we notice is that the transformer was an encoder-decoder model.
00:01:36.000 | And in the previous video, we actually trained it on a translation task,
00:01:40.120 | on how to translate, for example, from English to Italian, while Lama is a large language model.
00:01:46.160 | So the goal of a large language model is actually to work with what is called the next token prediction task.
00:01:53.720 | So given a prompt, the model tries to come up with the next token that completes this prompt in the most coherent way.
00:02:02.080 | So in a way that it makes sense, the answer.
00:02:05.720 | And we keep asking the model for the successive tokens based on the previous tokens.
00:02:12.040 | So this is why it's called a causal model.
00:02:14.560 | So each output depends on the previous tokens, which is also called the prompt.
00:02:20.800 | Contrary to what I have done in my previous video on coding the transformer model,
00:02:25.840 | in this video, we will not start by coding the single building blocks of Lama and then come up with a bigger picture.
00:02:34.120 | But we will start from the bigger picture.
00:02:35.840 | So we will first make the skeleton of the architecture and then we will build each block.
00:02:42.240 | I find that this is a better way for explaining Lama also because it's the model is simpler,
00:02:48.040 | even if the single building blocks are much more complex.
00:02:51.320 | And this is why it's better to first look at how they interact with each other and then zoom in into their inner workings.
00:02:59.600 | Let's start our journey with the embeddings.
00:03:03.160 | So this block here, let me use the laser.
00:03:06.200 | So this block here, so we are given an input and we want to convert it into embeddings.
00:03:12.680 | Let's also review what are embeddings.
00:03:14.760 | These are my slides from my previous video on the transformer.
00:03:19.560 | And as you can see, we start with an input sentence.
00:03:23.200 | So we are given, for example, the sentence, your cat is a lovely cat.
00:03:26.640 | We tokenize it, so we split into single tokens.
00:03:29.480 | We match, we map each token into its position in the vocabulary.
00:03:33.880 | The vocabulary is the list of all the words that our model can recognize.
00:03:39.680 | These tokens actually most of the time are not single words.
00:03:44.240 | What I mean is that the model doesn't just split the word by whitespace into single words.
00:03:50.200 | Usually the most commonly used tokenizer is the BPE tokenizer, which means byte pair encoding tokenizer,
00:03:56.960 | in which the single tokens can also be a sequence of letters that are not necessarily mapped to a single word.
00:04:04.440 | It may be a part of a word or maybe it may be a whitespace.
00:04:08.840 | It may be multiple words or it may be a single digit, etc.
00:04:13.440 | And the embedding is a mapping between the number.
00:04:17.840 | So the input IDs that represent the position of the token inside of the vocabulary to a vector.
00:04:24.400 | This vector in the original transformer was of size 512, while in Lama, in the base model,
00:04:31.720 | so the 7 billion model, it's 4096.
00:04:36.600 | The dimension is 4096, which means this vector will contain 4096 numbers.
00:04:43.200 | Each of these vectors represents somehow the meaning of the word,
00:04:48.400 | because each of these vectors are actually parameter vectors that are trained along the model
00:04:54.400 | and somehow they capture the meaning of each word.
00:04:57.680 | So for example, if we take the word "cat" and "dog",
00:05:01.560 | they will have an embedding that is more similar compared to "cat" and "tree",
00:05:06.880 | if we compare the distance of these two vectors, so we could say the Euclidean distance of these two vectors.
00:05:14.960 | And this is why they are called embedding, this is why we say they capture the meaning of the word.
00:05:21.640 | So let's start coding the model.
00:05:24.600 | I will open Visual Studio Code and first of all, we make sure that we have the necessary libraries.
00:05:32.800 | I will also share the repository of this project.
00:05:35.800 | So the only thing we need is Torch, Sentence Piece, which is the tokenizer that we use in Lama, and TKODM.
00:05:43.200 | The second thing is from the official repository of Lama, you should download the download script.
00:05:47.440 | It's this file, download.sh, that allows you to download the weights of the Lama model.
00:05:52.400 | In my case, I have downloaded the Lama 2 7 billion, along with the tokenizer.
00:05:57.920 | And this is the smallest model, actually.
00:06:01.120 | And I will not even be able to use the GPU because my GPU is not powerful enough.
00:06:06.000 | And so I will run the model on the CPU.
00:06:08.400 | I think most of you guys will do the same because it's actually a little big for a normal computer unless you have a powerful GPU.
00:06:15.840 | So let's start coding it.
00:06:17.920 | Let's create a new file, model.py, and let's start our journey.
00:06:26.080 | Import the necessary things. So we import Torch, import and also Torch.nn.
00:06:40.880 | Okay, these are the basic things that we always import. I remember we also need math and then also data classes.
00:06:57.840 | And this is all the imports we need.
00:07:00.240 | Most of the code is actually based on the original Lama code, so you don't be surprised if you see a lot of similarities.
00:07:06.480 | But I simplified a lot of parts to remove things that we don't need, especially, for example, the parallelization.
00:07:12.320 | And I also tried to add a lot of comments to show you all the shapes changed in each tensor.
00:07:20.560 | And that's it. So let's start.
00:07:22.800 | So the first thing I want to create is the class that represents the parameters of the model.
00:07:29.360 | So...
00:07:57.600 | Here we already see that we have two type of heads.
00:08:00.160 | One is the number of heads for the queries.
00:08:02.960 | So number of heads for the queries.
00:08:06.320 | And here we have the number of heads for the k and the v.
00:08:11.520 | So the keys and the values.
00:08:13.600 | Because we will see later, with the grouped query attention,
00:08:17.040 | we don't have to necessarily have the same number of heads for the query key and values like in the original transformer.
00:08:22.720 | But we can have multiple number of heads.
00:08:25.440 | And we will see why and how they work.
00:08:36.160 | This will be set when we load the tokenizer.
00:09:03.680 | These two parameters indicate the dimension, the hidden dimension of the ffnlayer.
00:09:07.600 | So the feedforward layer.
00:09:09.360 | The basic idea is that they try to, when they introduce the grouped query attention,
00:09:13.680 | they try to keep the number of parameters.
00:09:16.240 | Because with the grouped query attention, we reduce the number of heads of the k and v.
00:09:20.160 | But they incremented the number of parameters of the feedforward layer.
00:09:23.680 | So as the number of the total parameters of the model remains the same.
00:09:27.440 | This allows to compare the full, the base transformer.
00:09:32.480 | So with all the heads for the query, the key and values.
00:09:35.520 | With the one they use in llama.
00:09:37.600 | Which has a reduced number of heads for the k and v.
00:09:40.160 | But this is just a decision, an architectural decision.
00:09:43.360 | Then we have some EPS.
00:09:46.560 | This is a number that is very small.
00:09:49.200 | And we will see why we need it.
00:10:01.600 | Oh my god.
00:10:02.160 | And these are all the parameters we need.
00:10:14.800 | Also here we have two parameters that we will later use for the kv cache.
00:10:18.400 | And I will explain you later what is it and how it works.
00:10:20.720 | Let's start, as I said before, let's start with implementing the skeleton of the entire model.
00:10:26.880 | And then we implement each single part.
00:10:29.440 | And while implementing each single part, I will also review the background.
00:10:33.760 | And how it works and the maths behind it.
00:10:35.680 | This is the main class that will represent the entire model.
00:10:44.640 | So all the model we can see here.
00:10:48.400 | So all this model here, except for the softmax.
00:11:08.560 | we make sure that we have set the vocabulary size.
00:11:18.720 | We save the values.
00:11:37.680 | This is the number of layers of the model.
00:11:41.920 | So this represents this block here.
00:11:46.640 | It's repeated many times, one after another.
00:11:49.040 | Just like in the transformer, the base transformer, if you remember correctly.
00:11:52.560 | This block here and this block here were repeated one after another many times.
00:11:56.080 | And here it's repeated 32 times.
00:11:59.600 | And the output of the last layer is then sent to this rms norm, then to the linear, etc.
00:12:14.880 | So I'm using the same names as used in the transformer in the code from the Lama repository.
00:12:28.400 | Because when we will load the weights of the model, the names must match.
00:12:32.400 | Otherwise, the torch doesn't know where to load the weights in.
00:12:36.240 | And this is the reason I'm trying to keep the same name.
00:12:40.080 | I only changed some names to make them more clear.
00:12:43.520 | But most of the other names are the same.
00:12:45.200 | This is the model list.
00:12:55.360 | So this is the list of the layers.
00:12:56.960 | We will create later the encoder block.
00:13:13.680 | Which is each of these blocks here.
00:13:17.760 | This is the encoder block.
00:13:19.120 | For now, we just create the skeleton.
00:13:21.760 | So we have a list of these blocks.
00:13:23.440 | Then we have a normalization.
00:13:24.720 | And the normalization is the rms normalization.
00:13:29.760 | We will implement it later.
00:13:30.960 | We need to tell him the size of the features.
00:13:35.040 | And the EPS is a very small number that is needed for the normalization calculation.
00:13:41.120 | So that we never divide it by zero, basically.
00:13:44.160 | And then we have the output layer.
00:13:58.160 | Okay, then we need to pre-compute the frequencies of the rotary positional encodings.
00:14:09.280 | So let's do it.
00:14:10.640 | I created this method and then we go to implement it and I will show you how it works.
00:14:38.160 | let me check.
00:14:55.520 | I think we have a parenthesis.
00:14:57.440 | Okay, this is the base transformer model.
00:15:00.880 | So first of all, we have n layers.
00:15:03.360 | We have, first of all, the input embeddings.
00:15:05.920 | So we convert the input into embeddings.
00:15:09.920 | Then we pass it through a list of layers.
00:15:11.920 | The last layer output is sent to the normalization, then to the output.
00:15:15.840 | So this logic will be more clear in the forward method.
00:15:22.720 | Okay, here you will see one thing that is different from the previous transformers,
00:15:37.840 | is that the sequence length that we want is always one.
00:15:40.880 | And this is because we are using the KV cache.
00:15:43.920 | And we will see why.
00:15:46.480 | So in the previous, let's review here.
00:15:51.360 | Here, when we give the input, we want to give the prompt
00:15:55.680 | and the model will give us a softmax of the next token.
00:15:59.200 | But with the KV cache, we don't need to give all the previous tokens
00:16:03.360 | because maybe we already computed them in the previous iteration.
00:16:06.960 | So we only need to give the latest token and then the model will output the next token.
00:16:11.440 | While the intermediate tokens, so the cache of the previous tokens,
00:16:14.640 | will be kept by the model in its cache because here we have a KV cache.
00:16:18.000 | But we will see this mechanism later.
00:16:19.520 | So for now, just remember that the input here that we will get is one token at a time.
00:16:24.640 | So we will get a batch with sequence length.
00:16:32.800 | Okay, so batch size, sequence length is tokens.shape.
00:16:42.560 | And we make sure that the sequence length is actually one.
00:16:45.840 | And second thing is this model is only good for inferencing, not for training.
00:16:59.120 | Because for training, of course, we need to not have the KV cache
00:17:02.160 | and we need to be able to process multiple tokens.
00:17:06.000 | But our goal is actually to use the pre-trained LAMA weights.
00:17:19.360 | So we convert the tokens into embeddings.
00:17:23.760 | As you can see, we add the dim dimension.
00:17:30.160 | So the dimension of the embeddings, which is 4096 for the base model.
00:17:35.520 | But depending on the model size, it can be different.
00:17:48.000 | Okay, I promise I will explain this line here and this line here
00:18:14.480 | much more in detail just in two minutes.
00:18:16.720 | Let me finish it and I will explain everything together.
00:18:19.360 | So this is basically, we pre-compute something about the positional encoding
00:18:24.400 | that we then give to the successive layers.
00:18:26.960 | But let's finish writing it and then I explain this method and this one.
00:18:33.680 | And everything will be clear.
00:18:35.040 | So suppose with this one, we retrieve something that is needed
00:18:39.440 | for computing the positional encoding, which then we feed to the next layers.
00:18:43.120 | So we consecutively apply all the encoder blocks.
00:19:01.360 | And finally, we apply the normalization, just like here.
00:19:09.920 | So we apply these blocks one after another many times.
00:19:12.960 | Then we apply the normalization.
00:19:14.800 | And then we calculate the output using the linear layer.
00:19:18.560 | And finally, we return the output.
00:19:23.440 | So this is the skeleton of the model.
00:19:26.720 | So we take the input, we convert it into embeddings.
00:19:29.600 | This part I will explain later.
00:19:31.920 | We give this input embeddings with something about the positional encodings
00:19:36.640 | to these blocks one after another.
00:19:38.400 | We take the output of the last one and give it to the RMS norm.
00:19:41.440 | We take the output of the RMS norm and give it to the linear layer.
00:19:44.320 | And then during the inference, we will apply the softmax.
00:19:47.680 | Now, let's concentrate on the positional encodings.
00:19:51.760 | Let's first review how they worked in the original transformer.
00:19:54.800 | As you remember, in the original transformer,
00:19:57.520 | so the transformer in the attention is all you need.
00:19:59.920 | We first take the sentence, we convert into embedding vectors.
00:20:04.320 | So vectors of size 512.
00:20:07.200 | We then add another vector which has the same size, so 512,
00:20:11.440 | that represents the position of that token inside the sentence.
00:20:15.760 | So every sentence, every token in the first position of a sentence
00:20:20.400 | will receive this vector.
00:20:22.480 | Every token in the second position of a sentence
00:20:25.600 | will have this vector added to it.
00:20:29.200 | And every token in the third position of a sentence
00:20:32.160 | will have this vector added to it.
00:20:34.640 | These vectors are pre-computed because they only depend on the position,
00:20:39.440 | not on the word they are applied to.
00:20:42.080 | And this is why they are called absolute positional encoding,
00:20:44.640 | because they only strictly depend on the position of the word
00:20:47.840 | inside of the sentence.
00:20:49.040 | While in the contrary, in the rotary positional embeddings,
00:20:53.360 | they are a little different.
00:20:54.320 | Let's go have a look.
00:20:55.440 | First of all, the rotary positional encodings or embeddings,
00:20:59.440 | they are computed right before the calculation of the attention.
00:21:03.440 | And they are only applied to the Q and the K matrices,
00:21:07.200 | not to the V.
00:21:07.920 | Let's see why.
00:21:09.840 | The first thing we need to understand is the difference between
00:21:13.120 | absolute positional encodings and relative ones.
00:21:15.280 | The absolute positional encodings are the one we just saw,
00:21:17.680 | so they deal with one token at a time,
00:21:20.240 | and each token gets its own embedding.
00:21:22.720 | While the relative positional embeddings,
00:21:24.800 | they come into play during the calculation of the attention.
00:21:28.240 | And the calculation of the attention is basically done using the dot product.
00:21:33.600 | Because it's query multiplied by the transpose of the keys,
00:21:36.560 | divided by the square root of the model.
00:21:38.640 | So there is a dot product in between.
00:21:40.640 | And that we want with the relative positional encodings,
00:21:43.600 | we change this dot product so that we introduce a new vector
00:21:47.280 | that indicates the distance between the two tokens involved in the dot product.
00:21:52.160 | So for example, in the original transformer,
00:21:54.720 | we have this formula here.
00:21:57.760 | So query multiplied by the transpose of the keys,
00:22:00.240 | divided by the square root of the model.
00:22:02.720 | While in the relative positional encodings,
00:22:04.640 | which are not the one used in Lama,
00:22:06.720 | so this is just an introduction.
00:22:08.400 | The relative positional encodings,
00:22:10.960 | we have another vector here that represents the distance between these two tokens.
00:22:15.200 | And we compute the attention mechanism like this.
00:22:18.240 | The rotary positional embeddings,
00:22:20.720 | the one that are used in Lama,
00:22:22.560 | are something in between the absolute and the relative.
00:22:26.080 | Absolute because each token will get its own embedding,
00:22:31.680 | but relative because the attention mechanism
00:22:35.360 | will be evaluated using the relative distance between two tokens.
00:22:39.600 | Let's see.
00:22:40.100 | The rotary positional embeddings were introduced in the paper from this company, JUE.
00:22:50.160 | And the authors of this paper,
00:22:54.400 | they wanted to find an inner product that works like this.
00:22:59.440 | So first of all, what is an inner product?
00:23:01.920 | We are all familiar with the dot product.
00:23:03.760 | The inner product can be thought of as a generalization of the dot product.
00:23:08.720 | So it's an operation that has some properties that reflect what is the dot product.
00:23:14.960 | So the authors of this paper wanted to find an inner product
00:23:19.520 | between the two vectors, query and key,
00:23:23.040 | such that they only depend on the...
00:23:27.840 | So this is the symbol for inner product.
00:23:29.680 | So this inner product only depends on the embedding of the two tokens involved,
00:23:34.240 | so XM and XN,
00:23:36.560 | and the relative distance of these two tokens.
00:23:40.000 | So the distance between them.
00:23:42.320 | For example, if the first token is in position two,
00:23:46.320 | and the second token is in position five,
00:23:49.120 | so M equal two, N equal five,
00:23:51.440 | the distance between two will be three,
00:23:53.440 | or minus three according to the order.
00:23:58.000 | And they wanted to find a dot product that has this property,
00:24:01.680 | that only depends on the embedding of the first token,
00:24:03.920 | on the embedding of the second token,
00:24:06.480 | and the relative distance between them.
00:24:08.720 | Then they saw that if this function G is built in this way,
00:24:13.760 | then we achieve that objective.
00:24:15.920 | That is, we take the first token, so the query for example,
00:24:20.000 | we multiply it by the W matrix.
00:24:22.880 | This is actually done also in the vanilla transformer,
00:24:24.960 | but okay, suppose there is no W matrix here.
00:24:27.920 | We convert it into a complex number in this form,
00:24:31.840 | we take the key vector,
00:24:35.360 | we transform into a complex number into this form,
00:24:38.640 | and we define the inner product in this way.
00:24:41.840 | This inner product will basically depend only on the distance
00:24:47.600 | between these two tokens.
00:24:49.440 | So they wanted to find an encoding mechanism
00:24:52.560 | such that the attention mechanism,
00:24:54.480 | which is based on a dot product,
00:24:56.080 | which is an inner product, behaves like this.
00:24:59.840 | So it only depends on the embeddings of the vector
00:25:02.160 | and the distance between them.
00:25:03.840 | And if we, for example, this formulation here,
00:25:08.640 | if we apply it on a vector of dimension two,
00:25:11.200 | so we think of embedding with only two dimensions,
00:25:14.640 | it becomes in this form here.
00:25:17.200 | This is due to the Euler's formula.
00:25:19.920 | So each complex number, thanks to the Euler's formula,
00:25:22.640 | can be written as the cosine plus a sine.
00:25:28.160 | And this matrix here reminds us of the rotation matrix.
00:25:35.600 | Let me give you an example.
00:25:37.040 | Suppose our original vector is here,
00:25:39.520 | and if we multiply this vector v zero by this matrix here,
00:25:46.480 | the resulting vector will be rotated by the angle theta.
00:25:51.040 | So this is why they're called rotary positional embeddings,
00:25:53.680 | because the matrix here represents a rotation of the vector.
00:25:58.320 | So when we want to visualize
00:26:01.440 | how the rotary positional embedding work,
00:26:03.280 | we have to think that they will map it into a vector space,
00:26:06.640 | and they will rotate each word to an angle
00:26:10.320 | that is a multiple of a base angle, so theta,
00:26:13.440 | and proportional to the theta angle,
00:26:16.800 | proportional according to its position.
00:26:19.120 | So that two tokens that occupy similar positions
00:26:23.680 | will have similar inclination,
00:26:25.920 | and the two tokens have different positions
00:26:28.560 | will have different inclinations.
00:26:30.160 | And this is the idea behind the rotary positional embeddings.
00:26:34.080 | But how do we actually compute them in the code,
00:26:38.640 | in the PyTorch?
00:26:40.000 | Well, to compute them, we need to build a matrix like this.
00:26:44.000 | And as you can see, this matrix is actually full of zeros.
00:26:49.120 | And so when we calculate the embedding in this way,
00:26:52.080 | we will do, if we do a matrix multiplication,
00:26:54.400 | we will be doing a lot of operations that are useless,
00:26:57.120 | because most of these items are zero.
00:26:59.680 | So the authors of the paper proposed another form
00:27:02.480 | that is more computationally efficient.
00:27:04.400 | And this form basically says that
00:27:06.960 | we take the embedding of the vector
00:27:09.360 | to which we want to apply the positional encodings.
00:27:12.000 | So for example, this one, this is a vector.
00:27:14.400 | So the first dimension, the second dimension,
00:27:16.160 | the third dimension, and the last dimension.
00:27:18.480 | So if this is, for example, the vanilla transformer,
00:27:20.560 | this would be XD, which should be 512.
00:27:24.320 | We multiply it element-wise by this matrix,
00:27:29.040 | plus another vector that is actually based on this vector,
00:27:34.160 | but with the positions and the designs changed.
00:27:37.040 | So this is actually in the first position,
00:27:38.800 | we have the second dimension with its sign change.
00:27:41.680 | The second position, we have actually the first dimension.
00:27:44.560 | In the third position, we have the fourth dimension,
00:27:47.120 | but with sign change, actually.
00:27:48.400 | So it depends only on the embedding of the word,
00:27:51.760 | but it changes with the signs and position change.
00:27:55.840 | And then we multiply this element-wise
00:27:58.720 | with another matrix that you can see here, this vector.
00:28:02.160 | Then this will be the encoding of the token
00:28:06.320 | we are talking about.
00:28:07.920 | And now what we can pre-compute is this matrix here,
00:28:14.000 | because it doesn't depend on the token to which we apply it to,
00:28:17.040 | and this matrix here, because it doesn't depend
00:28:19.040 | on the token we apply it to.
00:28:20.480 | And they depend on m, so it's the position of the word,
00:28:25.360 | and theta.
00:28:26.400 | What is theta?
00:28:27.520 | Theta is a series of numbers defined like this.
00:28:31.280 | And so let's first build the code
00:28:35.840 | to pre-compute this and this here.
00:28:39.200 | Let's do it.
00:28:40.400 | I will first write the code,
00:28:42.480 | and later I will show you how it works.
00:29:12.000 | This theta parameter, 10,000, comes from the paper.
00:29:14.640 | It's written here, 10,000.
00:29:17.200 | We first need to make sure that the dimension of the word
00:29:25.280 | to which we are applying the embedding is actually even,
00:29:28.240 | because in the paper it's written
00:29:30.000 | that this rotary positional encoding cannot be applied
00:29:33.520 | to an embedding which has an odd dimension.
00:29:38.320 | So it cannot be 513, it must be 512 or 514
00:29:42.400 | or any other even number.
00:29:43.760 | And this is as written in the paper.
00:29:54.240 | Even, okay.
00:30:02.960 | Now we build the theta parameters,
00:30:04.800 | which is a sequence.
00:30:13.120 | And the shape of this theta will be head dimension divided by 2.
00:30:24.080 | Because we will apply these embeddings to each head,
00:30:27.840 | so not right after the embedding,
00:30:30.000 | but after we have split them into multi-head,
00:30:32.880 | so each token, the token of each head,
00:30:35.600 | we check the size of the dimension of each head
00:30:40.880 | and we divide it by 2.
00:30:42.000 | Because, why divide it by 2?
00:30:43.680 | Because in the paper they also divide it by 2.
00:30:46.480 | So D divided by 2 here.
00:30:53.200 | Okay, so what's the formula here?
00:31:01.760 | The formula is theta of i is equal to 10,000
00:31:07.360 | to the power of minus 2 multiplied by i minus 1
00:31:13.520 | divided by dimension for i equal to 1, 2, etc.
00:31:20.400 | up to dimension divided by 2.
00:31:23.200 | So now we are computing this part here.
00:31:27.280 | So this part here, which is a series.
00:31:31.760 | So i, here it starts from 1, we will start from 0,
00:31:34.800 | so we don't have to do i minus 1.
00:31:36.880 | And theta is equal to 1 over the theta.
00:31:44.640 | So 10,000 to the power of theta numerator
00:31:49.440 | divided by head dimension.
00:31:56.080 | Why do we do 1 over theta?
00:31:58.080 | Well, because this is to the power of minus 2.
00:32:01.600 | So something to the power of a negative number is 1 over
00:32:05.920 | that something to the power of the positive exponent.
00:32:09.680 | And then this will result in a matrix with shape,
00:32:17.680 | head dimension divided by 2.
00:32:20.480 | So shape is head dimension divided by 2.
00:32:26.240 | Now we construct the positions.
00:32:28.000 | So what are the positions?
00:32:29.360 | Because we want to build these two matrices,
00:32:31.600 | they depend on theta.
00:32:32.960 | So the series of theta that goes from theta 1
00:32:35.520 | to theta dimension divided by 2.
00:32:37.840 | And that we already have.
00:32:39.440 | Now we need to build the m's.
00:32:41.040 | Because the m's, the possible positions of a token can be many.
00:32:44.720 | We basically give as input to this function
00:32:48.640 | the maximum sequence length that we can afford
00:32:52.880 | multiplied by 2, because we have also the prompt,
00:32:56.480 | which may be long.
00:32:57.600 | So we say, OK, let's pre-compute all the possible theta and m
00:33:02.560 | for all the possible positions that our model will see.
00:33:05.840 | And all the possible positions is given by this parameter,
00:33:09.040 | sequence length.
00:33:09.840 | So now construct the positions.
00:33:20.480 | And the shape is sequence length, which is m.
00:33:24.880 | Now we need to multiply m by all the sequence of thetas.
00:33:37.760 | But each m with all the thetas.
00:33:41.120 | So for example, if we have m equal to 1,
00:33:43.200 | we need m1 theta 1, m1 theta 2, m1 theta d divided by 2.
00:33:49.360 | Then we need m2 theta 1, m2 theta 2, m2 theta 3.
00:33:52.800 | So for that, we will use a outer product.
00:33:55.440 | The outer product, I will show you later,
00:33:58.400 | basically means multiply all the elements of the first vector
00:34:03.680 | with all the elements of the second vector,
00:34:05.680 | all the possible combinations.
00:34:08.080 | So for example,
00:34:19.120 | here we have a frequency is equal to torch outer product m and theta.
00:34:29.680 | OK, so what we are doing is we are doing the outer product within m,
00:34:36.080 | which is the positions, multiplied by theta.
00:34:38.400 | This will basically take the first element of the first vector
00:34:42.080 | and multiply with all the elements of the second vector.
00:34:45.280 | Then take the second element of the first vector
00:34:47.520 | and multiply it with all the elements of the second vector, etc, etc.
00:34:50.720 | So if we start with a shape, let me say shape of m is sequence length.
00:34:58.960 | Let's say outer product with head dimension divided by 2.
00:35:08.560 | This will result in a tensor of sequence length by head dimension divided by 2.
00:35:16.000 | So for each position, we will have all the theta.
00:35:18.160 | Then for the second position, we will have all the theta.
00:35:20.400 | For the third position, we will have all the theta, and so on.
00:35:23.840 | Now, we want to write these numbers into a complex form, and I will show you why.
00:35:44.640 | I multiplied by m multiplied by theta, where r is equal to 1, as follows.
00:35:56.720 | So we compute.
00:36:05.280 | Let me also write the shape, and then I'll explain to you how it works.
00:36:20.320 | This is here, too.
00:36:23.360 | Okay, let's write some formulas.
00:36:30.880 | I could also, you know, not explain all the proofs.
00:36:34.320 | So I know the next few minutes will be a little boring
00:36:37.280 | because I will be explaining all the math behind it.
00:36:39.280 | But of course, I don't think you, just like me,
00:36:42.640 | you like to watch just some code and say, "Okay, this is how it's done."
00:36:47.280 | No, I like to actually give a motivation behind every operation we do,
00:36:51.840 | and that's, I think, one of the reasons you are watching this video
00:36:54.080 | and not just reading the code from the beta repository.
00:36:57.200 | So let's do some math.
00:37:00.880 | The first thing we need to review is how complex numbers work.
00:37:06.720 | Okay, a complex number is a number in the form a plus i multiplied by b,
00:37:14.240 | where a is called the real part and b is called the imaginary part.
00:37:19.040 | And i is a number such that i to the power of 2 is equal to minus 1.
00:37:25.680 | So the complex numbers were introduced to represent all the numbers
00:37:30.080 | that involve somehow the square root of a negative number.
00:37:33.200 | As you know from school, the square root of a negative number cannot be calculated,
00:37:37.280 | but so that's why we introduce this constant i,
00:37:39.920 | which is the negative number, which is the square root of minus 1.
00:37:43.520 | And so we can represent the square root of negative numbers.
00:37:46.160 | And they can also be helpful in vector calculations, and we will see how.
00:37:52.400 | Because we have the Euler's formula.
00:37:57.280 | The Euler's formula says that e to the power of i multiplied by x
00:38:02.480 | is equal to cosine of x plus i multiplied by sine of x.
00:38:11.600 | So it allows us to represent a complex number in the exponential form
00:38:16.960 | into a sum of two trigonometric functions, the cosine and the sine.
00:38:22.000 | And this will be very helpful later.
00:38:24.160 | Because our goal is to calculate these matrices here,
00:38:31.920 | the cosine of m theta and the sine of m theta.
00:38:35.520 | And the first thing we did is we calculated all the theta one,
00:38:40.400 | then we calculated all the positions,
00:38:42.160 | then we calculated all the possible combinations of positions and thetas.
00:38:47.360 | So what we did is we calculated a vector that represents the theta.
00:38:53.760 | So theta 1, theta 2, up to theta d divided by 2.
00:39:02.320 | Then we calculated all the possible m's.
00:39:04.720 | m can be 1, can be 2, can be whatever.
00:39:09.360 | So sequence length, let's say.
00:39:10.800 | Then we calculated the product of each of them for all the possible thetas.
00:39:19.360 | So for example, we created a new matrix that has m1 theta 1, m1 theta 2, m1 theta 3,
00:39:30.880 | up to m1 theta d divided by 2.
00:39:39.360 | And then m2 theta 1, m2 theta 2, m2 theta 3, etc, etc.
00:39:49.920 | Until m2 theta d divided by 2.
00:39:54.640 | These numbers are still not complex numbers.
00:39:59.120 | They're just real numbers because theta is a real number,
00:40:01.760 | m is a real number, but they are not complex numbers.
00:40:04.000 | Then we convert them into complex numbers.
00:40:06.160 | So what we do with the last operation here, this one here,
00:40:09.680 | we convert each of these numbers into polar, into its polar form.
00:40:14.560 | A number in polar form is a number that can be written as
00:40:18.080 | r multiplied by e to the power of i theta,
00:40:22.480 | which can be written as r cosine of theta plus i sine of theta.
00:40:29.360 | Why? Because it can be represented in the graphical, let's say, graphical plane xy.
00:40:38.640 | As you know, complex numbers can be represented into the 2D plane xy,
00:40:45.120 | where the real part is on the x and the imaginary part is on the y.
00:40:50.640 | So we are actually representing a vector of size, let's say, r with an inclination of theta.
00:40:58.640 | Because, as you know, the projection of this vector on the real part is
00:41:03.200 | r cos theta plus i, the projection on the y-axis is sine of theta.
00:41:11.200 | And here I forgot r, yeah, I've forgotten r here, r sine of theta.
00:41:18.240 | So this is another way of representing complex numbers.
00:41:21.760 | And what we are doing is we are calculating this matrix
00:41:25.120 | and then converting all these numbers into their complex form.
00:41:28.800 | So we are converting it into another matrix that has r equal to 1.
00:41:34.000 | And this number here, for example, this item here will become
00:41:40.480 | cosine of m1 theta 1 plus i sine of m1 theta 1.
00:41:51.440 | This number here will become another number.
00:41:55.280 | So this is only one number.
00:41:57.920 | This has become another complex number that is the cosine of m1 theta 2 plus i sine of m1 theta 2.
00:42:11.040 | Etc, etc, for all that.
00:42:12.560 | Because we are not increasing the numbers, the total numbers,
00:42:17.040 | this shape of the tensor also doesn't change.
00:42:19.440 | It just becomes a more complex number.
00:42:21.200 | So instead of having m theta 1, it becomes cosine of m theta 1 plus i m theta 1.
00:42:27.440 | Why do we need this form here?
00:42:29.200 | Because we need sines and cosines.
00:42:33.520 | And later we will see how we will use them.
00:42:36.320 | Now, the point is, imagine we are given a vector,
00:42:40.560 | because we want to apply these positional encodings to a vector.
00:42:43.600 | So how to apply them?
00:42:45.600 | Because the vector will be given us as a list of dimensions,
00:42:50.080 | from x1 to the last dimension.
00:42:52.080 | Just like in the original transformer, we have a vector of size 512.
00:42:56.560 | In this case, it will be much smaller because it's the dimension of each head.
00:43:00.640 | And as you remember, each head doesn't watch the full dimension of the embedding vector,
00:43:06.080 | but a part of it.
00:43:07.840 | So, but for us, okay, imagine it's only one head.
00:43:10.480 | So if it's only one head, we will watch the full dimension.
00:43:13.040 | So for now, don't consider the multi-head.
00:43:15.840 | Just suppose that we only are working with one head.
00:43:18.240 | So imagine we have a token with its full dimensions.
00:43:21.520 | So in the case of Lama, 4096 dimensions.
00:43:25.360 | And in the case of the vanilla transformer, 512.
00:43:28.560 | How to apply it?
00:43:29.440 | Let's do some math.
00:43:31.200 | Actually, let's do some more math.
00:43:34.560 | So we are given,
00:43:35.920 | suppose a smaller embedding vector,
00:43:42.720 | because we want to do the calculation and not go crazy.
00:43:46.880 | Otherwise, 4096 is a little difficult to prove.
00:43:50.800 | I want to make a list of operations on this vector until we arrive to this form here.
00:43:59.840 | So let's start.
00:44:01.440 | Suppose our embedding vector is only made of four dimensions.
00:44:04.640 | X1, X2, X3, and X4.
00:44:10.880 | Okay, the first thing we do is I will do some transformations
00:44:19.360 | and I will later translate them into code.
00:44:21.680 | So for now, just follow the transformations I'm doing.
00:44:24.000 | This is the transformation number one.
00:44:27.760 | I want to group successive tokens, successive dimensions.
00:44:32.640 | So into another dimension.
00:44:35.920 | So X1 and X2 become another dimension in this tensor.
00:44:41.040 | And X3 and X4 become, oops, very badly written.
00:44:46.880 | X4 become another dimension in this tensor.
00:44:52.480 | The total number of items is still four, but I added another dimension.
00:44:57.360 | And this has size four by one, right?
00:45:04.240 | This one has two by two by one.
00:45:08.000 | So I split it into multiple tensors.
00:45:11.440 | And okay, now this next thing I do,
00:45:16.080 | I consider this first number of this part to be the real part of the complex number.
00:45:22.640 | And this one to be the imaginary part of the complex number.
00:45:26.000 | And the same for the second vector here.
00:45:28.080 | So I do another transformation that we will call two,
00:45:32.480 | in which X1 plus IX2, and then X3 plus IX2.
00:45:45.520 | This vector has less items because now two numbers became one complex number.
00:45:54.160 | Now I multiply this element wise with the vector that we pre-computed before.
00:46:01.360 | As you remember before, we pre-computed this one.
00:46:05.440 | Cosine of M1 theta 1 plus I of M1 theta 1, cosine of M1 theta 2 plus I.
00:46:12.560 | Because they suppose this position, this token here,
00:46:16.960 | suppose his position is M1, because we need also the M.
00:46:21.520 | So suppose this token here, his position is M1.
00:46:25.600 | So we take all this row here, M1, and this will become our new matrix here.
00:46:34.160 | We only have four dimensions.
00:46:37.680 | So four dimensions means we have a theta 1 and theta 2.
00:46:42.000 | Because D divided by 2 until D divided by 2.
00:46:44.960 | So element wise with the cosine of M1 theta 1 plus I sine of M1 theta 1.
00:47:07.840 | And then we have cosine of M1 theta 2 plus I of sine of M1 theta 2.
00:47:23.520 | Now we have an element wise product between the first item of this matrix
00:47:31.200 | and the first item of this matrix.
00:47:33.440 | Actually they are two vectors.
00:47:35.200 | And then we have the product of two complex numbers.
00:47:39.360 | This complex number here and this complex number here.
00:47:42.960 | So let's see how to compute the product of two complex numbers.
00:47:47.520 | Because I don't want to write very long expressions, I will call this one F1.
00:47:54.320 | So F1 is the cosine of M1 theta 1 and F2 is the sine of M1 theta 1.
00:48:04.240 | And for the same reason I will call this one F3 and F4.
00:48:08.720 | Now let's compute the product of the first item of this vector
00:48:17.840 | and the first item of this vector.
00:48:19.440 | So X1 plus IX2 multiplied by F1 plus IF2.
00:48:33.840 | This is equal to X1 F1 plus IX1 F2 plus IX2 F1.
00:49:00.800 | Then we have this product IX2 multiplied by IX2.
00:49:05.040 | But it will become I squared X2 F2.
00:49:08.720 | I squared we know it's equal to minus 1.
00:49:11.040 | So it will become minus X2 F2.
00:49:16.240 | This one can then be written as real part.
00:49:20.160 | So all the terms that don't have I, X1 F1 minus X2 F2 plus I that multiplies X1 F2 plus X2 F1.
00:49:45.760 | Okay, this is how to compute the product of two complex numbers.
00:49:50.160 | Let's do it.
00:49:51.120 | So the first number here in the resulting matrix from this element-wise multiplication will be
00:49:58.400 | X1 F1 minus X2 F2 plus I of X1 F2 plus X2 F1.
00:50:25.680 | The second element, we don't need to do this multiplication
00:50:29.120 | because they have the similar structure as the first one.
00:50:31.600 | So we just change the X1 with X3, X2 with X4.
00:50:35.280 | This is X4 and F1 with F3 and F2 with F4.
00:50:43.120 | So the resulting matrix will be X3 F3 minus X4 F4 plus I X3 F4 plus X4 F3.
00:51:11.840 | This one can then be split back.
00:51:15.920 | So this complex number, we can split the real part and the imaginary part.
00:51:19.920 | And this we will call it transformation number three.
00:51:23.680 | So we can split it in a tensor of two dimensions.
00:51:30.400 | One is the real part and one is the complex part.
00:51:33.200 | Where is X1 F1 minus X2 F2 then X1 F2 plus X2 F1.
00:51:56.240 | This is the first tensor. The second tensor will be X3 F3 minus X4 F4.
00:52:08.400 | And the second will be X3 F4 plus X4 F3.
00:52:20.960 | I'm really sorry for the bad handwriting, but it's my touchpad is not so good.
00:52:26.320 | Then we do another transformation in which we flatten all these values.
00:52:32.480 | So this will be the first item.
00:52:35.600 | This is the second, the third and the fourth.
00:52:37.440 | So we remove this dimension, the inner dimension.
00:52:40.720 | So we flatten this matrix and it will become X1 F1 minus X2 F2.
00:52:51.760 | The second items will be X1 F2 plus X2 F1.
00:53:11.200 | Then we have X3 F3 minus X4 F4 then we have X3 F4 plus X4 F3.
00:53:32.240 | Let's compare this resulting matrix with what is in the paper.
00:53:37.840 | So let's compare it with this one.
00:53:42.800 | Let me zoom a little bit.
00:53:46.160 | OK, the resulting matrix is exactly the same as what is in the paper.
00:53:53.040 | So X1 multiplied by F1, which is the cosine, as you remember, M1 theta 1.
00:54:00.160 | So X1 multiplied by F1 plus minus X2.
00:54:06.160 | So minus X2 multiplied by F2, which is the sine, as you can see here, sine of M1 theta 1.
00:54:13.120 | Here it's not M1 because it's for the generic M, but we set M equal to M1.
00:54:18.080 | And the second dimension is also correct.
00:54:21.600 | So it's X1 F2.
00:54:23.200 | So X1 with X1 here, because here is the sum, so the order doesn't matter.
00:54:29.520 | So X1 F2, so X1 multiplied by the sine plus X2 F1, X2 F1.
00:54:35.680 | And let me check if we can use...
00:54:39.840 | OK, the third dimension is X3 F3.
00:54:44.800 | So X3 multiplied by the cosine minus X4 sine minus X4 sine.
00:54:50.960 | Then we have X3 F4.
00:54:53.520 | So X3 F4, F4 is the sine of theta 2.
00:54:57.360 | And plus X4 F3, X4 F3, F3 is the cosine of theta 2.
00:55:04.160 | Also in this case, because we have the sum here inside, the order doesn't matter.
00:55:10.160 | So as you can see, we started with a vector of dimension 4, but it could be of dimension N.
00:55:17.280 | And we did some transformation.
00:55:20.480 | We then multiplied with the matrix that we pre-computed here.
00:55:23.760 | Then we did some other transformation, and the end result is exactly as doing this operation.
00:55:29.920 | So now let's translate this into code.
00:55:32.720 | Because this is actually what we need to apply the embedding vector to this vector here.
00:55:38.560 | So to this token, how to apply the embeddings,
00:55:40.960 | the rotary position embeddings through this series of transformations.
00:55:44.320 | So I could have also written the code and not tell you anything,
00:55:47.600 | but I like to give proof to what I do.
00:55:50.720 | So that you know that what I'm doing is actually described in the paper,
00:55:56.480 | and we are actually doing it according to the paper.
00:55:59.280 | There is also a visualization in the paper that is really helpful.
00:56:02.560 | So what we did here, for example, is we transform the embedding vector into,
00:56:09.200 | split it into a new tensor, which has half dimension.
00:56:15.040 | But by grouping two consecutive dimensions.
00:56:18.160 | So the two consecutive dimension x1 and x2 and x3 and x4.
00:56:22.160 | Then we multiply, we transform it with a complex number.
00:56:27.120 | We multiply it with M theta that we pre-computed.
00:56:30.960 | And this visualization of why we do this is present in the paper.
00:56:35.920 | It's in particular at this figure here.
00:56:41.760 | So here they say, if you have a word with n dimensions,
00:56:47.120 | you need, of course, d dimensions.
00:56:51.360 | Then, of course, you will have theta of d half theta.
00:56:55.440 | Because we have theta 1, theta 2, up to theta d half.
00:56:59.840 | We group successive dimensions into a new complex number,
00:57:06.880 | that if we project it on the complex plane, it will result into this vector.
00:57:11.200 | So the x1, x2 vector you can see here.
00:57:13.760 | And then we multiply it with the complex number M theta 1.
00:57:18.320 | This will result in the number being rotated by the angle indicated by M theta 1.
00:57:25.360 | And this is the encoded number, is the encoded token.
00:57:29.840 | And this is exactly what we are doing with our matrix transformations
00:57:33.520 | that I show you right now.
00:57:35.200 | Now let's translate this into code.
00:57:38.400 | Apply Rotary Embeddings.
00:57:41.200 | X is the token to which we want to apply the Rotary Embeddings.
00:57:49.200 | FreqsComplex is the output of this function, but only for the position of this token.
00:58:00.000 | So only for all the positions of this token.
00:58:02.880 | Because this will have all the theta for all the possible positions,
00:58:06.320 | but we only need the positions for this particular token.
00:58:09.600 | And then we need the device.
00:58:12.320 | The first thing we do is the transformation number 1, I think I call it.
00:58:22.800 | Yeah, this one here.
00:58:24.080 | And number 1 and number 2.
00:58:26.960 | So the first thing we do is we call the transformation number 1.
00:58:30.960 | And number 2. So the first thing we do is we transform
00:58:34.320 | the two consecutive dimensions into a new tensor.
00:58:38.320 | And then we visualize it as a complex number.
00:58:41.200 | These operations are supported by PyTorch.
00:58:44.480 | So we do them.
00:58:47.360 | So we create XComplex is equal to.
00:59:03.680 | Okay, this operation here is basically saying take two consecutive dimensions and group them.
00:59:15.920 | And then we transform this intermediate tensor into a complex tensor
00:59:19.920 | by using the view as complex operation from Torch.
00:59:23.600 | Let me write some comments.
00:59:26.800 | So we are starting from B, sequence length, H, head dimension.
00:59:33.680 | Because I saw before this X is actually not the original vector,
00:59:38.160 | but it's already the one divided with its head dimension.
00:59:42.640 | Because we will have a multi head attention.
00:59:47.600 | But if there is no multi head attention,
00:59:49.360 | then this head dimension is actually the full dimension of the token.
00:59:52.640 | So 4096.
00:59:53.920 | Then we have this tensor here.
01:00:01.360 | But this tensor has two dimensions less than this one.
01:00:04.480 | It doesn't have the batch dimension.
01:00:06.400 | And it doesn't have the head dimension.
01:00:08.160 | So we need to add it.
01:00:09.120 | So take the XComplex and we add the two dimensions that it's missing.
01:00:13.840 | And here we are doing.
01:00:23.440 | Okay, let me write all the transformations.
01:00:25.360 | Here we are going from here to divide by two.
01:00:35.280 | Why divide by two?
01:00:36.080 | Because every two consecutive pairs are becoming one complex number.
01:00:40.160 | And here we go from sequence length to head dimension divide by two.
01:00:46.560 | We are mapping it to one because this is the batch dimension sequence length,
01:00:53.920 | then the head dimension one, and then head dimension divide by two.
01:00:59.600 | Now we multiply them together.
01:01:02.160 | So we do this operation here.
01:01:03.920 | So element wise multiplication,
01:01:06.080 | which will result in a rotation as we saw in the figure before.
01:01:08.880 | So that's why I call it X rotated is equal to X complex
01:01:13.120 | multiplied by the complex number of the frequencies.
01:01:17.040 | In this case, we are doing sequence length.
01:01:22.720 | H dimension divide by two.
01:01:36.000 | Then we multiply it, we obtain this result.
01:01:45.600 | And then we first transform it into...
01:01:52.080 | We transform the complex number into a tensor
01:01:55.520 | in which the first item is the real part of the complex number
01:01:58.400 | and then the complex part, the imaginary part, and then we flatten it.
01:02:04.480 | So let's do it.
01:02:14.640 | This operation view as real will transform the tensor like this.
01:02:21.040 | So it's this transformation here
01:02:40.960 | in which we transform the complex number into a tensor of two dimensions.
01:02:45.280 | Because that's why you can see this additional dimension here.
01:02:48.080 | And then we flatten it.
01:02:50.000 | You can just say to flatten it with the shape of the original...
01:03:00.160 | with the original tensor we gave it.
01:03:05.360 | So become...
01:03:15.200 | And this is how we calculate the embedding.
01:03:30.240 | So given a tensor of representing a token
01:03:34.800 | or a list of tokens, because we have the batch dimensions,
01:03:37.920 | we can apply the embeddings like this
01:03:41.040 | doing all these transformations that we have done here
01:03:44.160 | and that are represented in this code.
01:03:46.080 | And they are all equivalent to doing this operation as written on the paper.
01:03:50.000 | Now we need to go forward with our transformer by implementing the rest.
01:03:58.000 | The next thing that we can implement is this RMS norm
01:04:02.000 | because it's present at the output of the transformer
01:04:06.880 | but also at the input.
01:04:08.480 | So let's go review again the architecture.
01:04:11.360 | We can see that we have the normalization, the RMS normalization here
01:04:17.360 | but we also have it here and here.
01:04:19.520 | So let's implement it.
01:04:20.720 | Let's also visualize how the RMS norm works.
01:04:24.640 | If you want to have a deep understanding of how normalization works,
01:04:28.400 | in my previous video about Llama,
01:04:30.160 | I actually described why we need normalization,
01:04:32.720 | how it was historically done and how it works,
01:04:35.280 | also at the autograd level.
01:04:37.840 | So I will not repeat the same lecture here.
01:04:41.280 | I will just briefly introduce how it works.
01:04:44.000 | But if you want to have a better understanding,
01:04:45.760 | please watch my previous video.
01:04:47.760 | So as you remember, in the original transformer,
01:04:50.000 | we used layer normalization.
01:04:51.680 | And layer normalization worked like this.
01:04:54.000 | We have an input where we have some items,
01:04:57.280 | suppose item 1, item 2, up to item 10.
01:05:00.160 | Each item has three features, so A1, A2, A3.
01:05:04.480 | What we did with layer normalization,
01:05:06.400 | we computed two statistics, one for each item,
01:05:10.080 | so mu and sigma, so the mean and the sigma.
01:05:13.680 | And we standardize each item,
01:05:16.320 | normalize each element of this input matrix,
01:05:19.760 | using this formula here,
01:05:21.760 | which transforms it into a distribution
01:05:24.960 | with zero mean and variance of one.
01:05:27.680 | And this formula comes from probability statistics.
01:05:31.520 | So as you know, if you have any random variable
01:05:34.240 | with its mu and sigma,
01:05:36.160 | if you do the variable minus its mean,
01:05:39.600 | divided by the standard deviation,
01:05:42.160 | so the square root of the variance,
01:05:43.760 | it will result into a Gaussian of mean zero
01:05:47.520 | and the standard and the variance of one.
01:05:50.320 | We then multiply this with the gamma parameter
01:05:54.080 | and we also add a beta parameter here.
01:05:56.800 | But this was done in the layer normalization.
01:05:58.960 | In LLAMA, we use RMS normalization
01:06:05.120 | and let's see the difference.
01:06:06.320 | In RMS normalization,
01:06:09.040 | the paper of the RMS normalization
01:06:11.120 | claims that we don't need to obtain
01:06:14.160 | the effect of layer normalization,
01:06:16.240 | we don't need to compute two statistics,
01:06:18.480 | that is the mean and the variance.
01:06:20.560 | And actually they claim that the normal,
01:06:22.800 | the effect given by layer normalization
01:06:25.760 | can be obtained without recentering the values.
01:06:28.720 | So without recentering them around the mean of zero.
01:06:32.480 | But just by scaling.
01:06:35.040 | However, the variance in the layer normalization
01:06:37.840 | was computed using the mean,
01:06:39.440 | because if you remember the formula of the variance
01:06:41.680 | is X minus the mean of the distribution
01:06:46.000 | to the power of two divided by N.
01:06:48.800 | So to compute the variance, we needed the mean,
01:06:53.600 | but we wanted to avoid computing the mean
01:06:55.920 | because we don't need it.
01:06:56.960 | This is not what the RMS paper claims.
01:06:59.600 | RMS paper claims that we don't need the mean
01:07:01.760 | and we don't need to recenter.
01:07:03.680 | So we need to compute a statistic
01:07:05.760 | that doesn't depend on the mean.
01:07:07.360 | That's why they introduced these statistics here,
01:07:10.160 | which is the root mean squared
01:07:11.520 | that doesn't depend on the mean.
01:07:13.360 | And in practice gives the same normalization effect
01:07:17.040 | as the layer normalization.
01:07:19.040 | And we also have a gamma parameter also here
01:07:21.840 | that is learnable and that's multiplied.
01:07:24.240 | So as you can see, the only difference
01:07:25.760 | between layer normalization and RMS normalization
01:07:28.640 | is that we don't recenter the values.
01:07:30.400 | And it looks like that recentering was not necessary
01:07:34.480 | as written in the paper,
01:07:35.440 | because they say in this paper,
01:07:36.640 | we hypothesize that the rescaling invariance
01:07:39.760 | is the reason for the success of layer norm
01:07:42.560 | rather than the recentering invariance.
01:07:45.280 | So they just rescale the values
01:07:48.320 | according to the RMS statistic.
01:07:50.560 | And this is what we will do in our code.
01:07:53.120 | So let's build this block.
01:07:57.120 | (keyboard clicking)
01:08:19.840 | So the APS value you can see here
01:08:22.160 | is used as a denominator.
01:08:27.600 | Let me go back here.
01:08:29.280 | It's used here as the added to the denominator.
01:08:32.160 | So to avoid a division by zero.
01:08:34.160 | And then we have the gamma parameter.
01:08:42.080 | (keyboard clicking)
01:08:52.640 | And this is it.
01:08:53.920 | Then we define the function norm.
01:08:56.480 | (keyboard clicking)
01:09:03.840 | Where x is batch sequence length dimension.
01:09:10.880 | Okay.
01:09:12.820 | So we return x multiplied by torch dot r sqrt.
01:09:18.160 | r sqrt stands for the one over the square root.
01:09:25.040 | (keyboard clicking)
01:09:44.960 | And that's it.
01:09:45.760 | (keyboard clicking)
01:10:04.240 | We multiply by gamma.
01:10:05.520 | (keyboard clicking)
01:10:16.480 | So we have, as you can see,
01:10:18.720 | weight is actually is a number,
01:10:20.480 | a list of ones with the dimension dim.
01:10:22.480 | So dim multiplied by b sequence length dim
01:10:28.800 | results in b sequence length dim.
01:10:32.720 | Where b is the batch dimension.
01:10:35.120 | And here what we are doing is,
01:10:36.960 | this is r sqrt is equal to one over sqrt of x,
01:10:43.440 | just as a reminder.
01:10:44.400 | And the dimensions here are multiplied by b sequence length one,
01:10:54.080 | which results in b sequence length dimensions.
01:10:58.880 | So what we are doing is exactly just this formula here.
01:11:03.360 | So just multiply it by one over rms
01:11:07.760 | and then multiply it with gamma here.
01:11:11.200 | Now that we have also built the rms norm,
01:11:15.760 | let's go check our next building block,
01:11:18.080 | which is this encoder block.
01:11:19.600 | So what is the encoder block?
01:11:20.880 | Let's go back to the transformer.
01:11:23.360 | Here we have the encoder block is all this block here
01:11:28.160 | that contains a normalization.
01:11:31.200 | It contains a self-attention here.
01:11:34.000 | It contains skip connections.
01:11:35.920 | You can see here another normalization,
01:11:38.560 | another skip connection and a feed forward layer here.
01:11:42.160 | I think the easiest one to start with is the feed forward,
01:11:49.040 | but we can also, we can also, okay,
01:11:52.160 | let's start first build the encoder block
01:11:54.800 | and then we will build the attention
01:11:57.760 | and finally the feed forward.
01:11:59.440 | So we first build the skeleton of this,
01:12:01.840 | then the attention and then this, let's go.
01:12:14.960 | (keyboard clicking)
01:12:38.960 | I received some parameters.
01:12:41.280 | (keyboard clicking)
01:12:52.080 | What is the head dimension is the dimension of the vector
01:12:55.360 | divided by the number of heads.
01:12:56.720 | So 4,096 divided by, here is the divide by 32,
01:13:01.920 | because as we can see here,
01:13:04.960 | we have the dimension of the vector,
01:13:06.640 | of the embedding vector is 4,096,
01:13:08.400 | but we have 32 heads.
01:13:09.600 | So each head will see 4,096 divided by 32 items
01:13:13.680 | from each token.
01:13:14.960 | (keyboard clicking)
01:13:24.880 | Then we have a self-attention block.
01:13:27.120 | I define it, but don't build it right now.
01:13:29.520 | Just define the skeleton.
01:13:30.960 | Then we have the feed forward.
01:13:37.040 | (keyboard clicking)
01:13:39.760 | Then we have the normalization before the self-attention.
01:13:44.720 | So self-attention, this is our RMS norm.
01:13:52.160 | (keyboard clicking)
01:13:57.120 | And this is the motivation behind this argument norm abs.
01:14:00.560 | (keyboard clicking)
01:14:04.160 | Then we have an after the feed forward.
01:14:06.480 | (keyboard clicking)
01:14:12.800 | Is it after?
01:14:13.680 | It's after the attention, not after the feed forward.
01:14:17.680 | (keyboard clicking)
01:14:19.200 | So before the feed forward block.
01:14:21.760 | (keyboard clicking)
01:14:35.840 | And then we have norm abs.
01:14:38.000 | (keyboard clicking)
01:14:39.600 | Okay, okay.
01:14:41.040 | Now let's implement the forward method.
01:14:43.440 | (keyboard clicking)
01:14:52.320 | StartPause indicates the position of the token.
01:14:54.960 | I kept the same variable number as in the original code.
01:14:58.240 | It's actually the position of the token.
01:14:59.920 | Because as you remember,
01:15:01.200 | we will be dealing with only one token at a time.
01:15:03.760 | So StartPause indicates the position of the token
01:15:06.720 | we are dealing with.
01:15:07.520 | (keyboard clicking)
01:15:11.200 | These are the pre-computed frequencies.
01:15:14.160 | (keyboard clicking)
01:15:25.280 | So we need to the skip connection.
01:15:27.840 | And yeah, okay.
01:15:29.920 | The hidden is equal to x plus the attention.
01:15:36.800 | So we calculated the attention of what?
01:15:39.840 | Of the normalized version of this input.
01:15:43.040 | So we first apply the normalization.
01:15:45.440 | (keyboard clicking)
01:15:49.840 | And then we calculate this attention.
01:15:52.080 | And to the attention, we also give the frequencies.
01:15:55.360 | Because as you remember,
01:15:56.640 | the rotary positional encodings are kind of,
01:16:02.080 | they come into play when we calculate the attention.
01:16:04.720 | And these operations involve tensors of size B,
01:16:08.960 | sequence length dimension,
01:16:10.320 | which is x plus the skip connection,
01:16:12.880 | plus the output of the attention,
01:16:15.040 | B sequence length dimension,
01:16:17.920 | which results in B sequence length dimension.
01:16:23.280 | Then we have another,
01:16:26.880 | we have the application of the feedforward
01:16:28.720 | with its skip connection.
01:16:30.080 | So out is equal to h plus.
01:16:36.240 | (keyboard clicking)
01:16:40.400 | And before we send it to the feedforward,
01:16:42.320 | before we applied the normalization.
01:16:45.920 | (keyboard clicking)
01:16:49.280 | And this is the output.
01:16:50.320 | Now we need to build the self-attention and the feedforward.
01:16:54.240 | Let's start with the harder part first.
01:16:56.640 | So the self-attention,
01:16:57.840 | because I think it's more interesting.
01:16:59.760 | Before we build the self-attention,
01:17:01.920 | let's review how self-attention worked
01:17:04.480 | in the original transformer
01:17:06.000 | and how it will work here.
01:17:07.920 | So, okay, this is the original paper
01:17:11.840 | from the original paper of the transformer.
01:17:14.560 | So attention is all you need.
01:17:16.000 | Let's review the self-attention mechanism
01:17:17.840 | in the original transformer.
01:17:19.040 | And then we will see how it works in Llama.
01:17:21.120 | In the attention is all you need.
01:17:24.240 | We have an input, which is sequenced by the model.
01:17:26.560 | So a sequence of tokens,
01:17:27.920 | each token modeled by a vector of size T model.
01:17:31.040 | We transform them into query key and values,
01:17:34.400 | which are the same input.
01:17:35.920 | We multiply by a W matrix,
01:17:38.560 | which is a parameter matrix,
01:17:40.080 | which results in a new matrix,
01:17:42.240 | which has this dimension of sequence by D model.
01:17:44.640 | And we then split them into the number of heads
01:17:49.280 | that we have,
01:17:50.400 | such that each vector that represents the token
01:17:54.000 | is split into, suppose we have four heads.
01:17:56.320 | So each vector, each head will see a part
01:18:00.400 | of the embedding of each token.
01:18:02.320 | So if the token was 512 in size, for example,
01:18:05.840 | the embedding vector,
01:18:07.280 | the first head will watch 128 dimensions of this vector.
01:18:12.880 | The second head will watch the next 108 dimensions.
01:18:16.880 | The next head will watch the next 108 dimensions,
01:18:19.760 | et cetera, et cetera, et cetera.
01:18:20.960 | We then calculate the attention
01:18:23.280 | between all these smaller matrices.
01:18:25.360 | So Q, K and V.
01:18:26.640 | This results in head 1, head 2, head 3 and head 4.
01:18:30.800 | We then concatenate them together.
01:18:33.120 | We multiply with the W matrix.
01:18:36.640 | And this is the output of the multi-head attention.
01:18:41.680 | In this case, it's called self-attention
01:18:44.320 | because it's the same input that acts as a query,
01:18:47.920 | as key and values.
01:18:49.680 | In case the query comes from one place
01:18:53.280 | and the key and the values come from another place,
01:18:55.360 | in that case, it's called cross-attention.
01:18:57.440 | And that kind of attention is used
01:18:59.200 | in multi-modal architectures, for example,
01:19:03.440 | when you want to combine, for example,
01:19:05.040 | pictures with captions or music with text,
01:19:09.280 | or you want to translate from one language to another.
01:19:11.840 | So you have kind of multi-modality
01:19:13.600 | and you want to connect the two together.
01:19:15.520 | But in our case, we are modeling a language.
01:19:18.560 | So self-attention is what we need.
01:19:21.920 | Actually, attention is all we need.
01:19:23.440 | So let's watch how it works in LLAMA.
01:19:28.480 | Okay, in LLAMA, we need to talk about a lot of things
01:19:32.880 | before we build the self-attention.
01:19:34.640 | We need to review how the self-attention works in LLAMA,
01:19:37.600 | how is the key, what is the KV cache,
01:19:39.840 | what is the grouped query attention,
01:19:41.680 | and actually how the inference works.
01:19:44.640 | So we need to review all this stuff
01:19:45.840 | before we proceed with the code.
01:19:47.040 | Otherwise, it will be very hard to follow the code.
01:19:50.000 | So let's first talk about the inferencing.
01:19:54.240 | Given, suppose we have a model,
01:19:57.440 | so that has been trained on this particular line.
01:20:01.520 | So the line is love that can quickly seize the gentle heart.
01:20:05.840 | And this is a line from Dante Alighieri.
01:20:09.120 | You can see this from the epistle
01:20:11.120 | from the Inferno, Fifth Canto.
01:20:14.000 | It's not the first line actually,
01:20:15.920 | but this is Paolo and Francesca, by the way.
01:20:18.800 | And we have a model that has been trained on this line,
01:20:23.280 | love that can quickly seize the gentle heart.
01:20:25.120 | Now, a model that has been trained on this particular line
01:20:30.480 | using the next token prediction
01:20:32.080 | should have an input that is built in this way.
01:20:34.320 | So the start of sentence,
01:20:36.400 | and then the tokens that represented the sentence,
01:20:40.000 | then the target should be the same sentence
01:20:42.320 | with the end of sentence.
01:20:43.600 | Because the transformer is a sequence-to-sequence model,
01:20:47.440 | it maps one input sequence
01:20:49.280 | into an output sequence of the same size.
01:20:51.680 | This means that the first token
01:20:54.320 | will be mapped to the first token of the output.
01:20:56.720 | The second token of the input
01:20:58.080 | will be mapped to the second token of the output.
01:21:02.080 | The third token of the input
01:21:04.720 | will be mapped to the third token of the output.
01:21:08.560 | But it's not a one-to-one correspondence.
01:21:12.160 | Because of the mask, of the causal mask
01:21:14.880 | that we apply during the self-attention,
01:21:17.360 | to predict this particular token can, for example,
01:21:22.160 | the model doesn't only watch the same token in the input,
01:21:26.480 | so that, but also watch all the previous tokens.
01:21:29.520 | So the model to predict can
01:21:32.480 | needs to access not only that,
01:21:34.800 | but also SOS, love, that.
01:21:37.760 | And the self-attention mechanism with its causal mask
01:21:40.880 | will access all the previous tokens,
01:21:43.280 | but not the next ones.
01:21:44.640 | This means that when we do the inferencing,
01:21:47.600 | we should do it like this.
01:21:48.720 | We start with the start of sentence
01:21:50.320 | and the model will output the first word.
01:21:52.640 | To output the next token,
01:21:54.240 | we need to give the previous output token as input also.
01:21:58.000 | So we always append the last token of the output
01:22:02.080 | to the input to predict the successive tokens.
01:22:04.960 | So for example, to output that,
01:22:06.640 | we need to give SOS, love.
01:22:08.560 | To output the next token,
01:22:09.760 | we take this that and we put it in the input
01:22:13.440 | so that we can get the next word.
01:22:15.520 | To output the next token,
01:22:18.320 | we need to append the previous output to the input
01:22:21.440 | to get the new output quickly.
01:22:23.200 | Now, when we do this job,
01:22:26.000 | the model is actually,
01:22:28.480 | we are giving this input,
01:22:30.880 | which is a sequence of four tokens,
01:22:32.640 | and the model will produce a sequence of four tokens as output.
01:22:37.600 | But this is not really convenient when we do the inferencing
01:22:40.720 | because the model is doing a lot of dot products
01:22:44.080 | that are not necessary,
01:22:46.240 | that have already been built.
01:22:47.760 | For example, what I want to say is that
01:22:51.360 | in order to get this last token quickly,
01:22:53.760 | we need to access all the previous context here.
01:22:58.720 | But we don't need to output love that can
01:23:01.600 | because we don't care.
01:23:02.560 | We already have these tokens.
01:23:03.920 | We only care about the last one.
01:23:05.760 | However, we can't just tell the transformer model
01:23:08.640 | to not output the previous tokens.
01:23:10.640 | We need to change the calculations in such a way
01:23:13.520 | that we only receive at the output of the transformer
01:23:16.800 | only one token
01:23:17.840 | so that all the other tokens are not even calculated.
01:23:20.480 | And this will make the inferencing fast.
01:23:22.960 | And this is the job of the KVCache.
01:23:25.200 | Let me show you with some diagrams.
01:23:26.960 | As you can see, at every step of the token,
01:23:33.600 | we are only interested in the last token output by the model
01:23:36.720 | because we already have the previous ones.
01:23:38.640 | However, the model needs to access all the previous tokens
01:23:42.240 | to decide which token to output
01:23:44.800 | because the model needs to access all the prompt
01:23:48.320 | to output the next token.
01:23:50.400 | And we do this using the KVCache
01:23:52.480 | to reduce the amount of computation.
01:23:54.320 | So let's do with some examples.
01:23:57.040 | Suppose we do the same job that we did before.
01:23:59.680 | So the inferencing of that model.
01:24:02.560 | We give the first token, so the SOS.
01:24:04.800 | This will be multiplied.
01:24:06.560 | This is the self-attention.
01:24:07.600 | So it will be multiplied by the transposed of the keys.
01:24:10.400 | This will produce this matrix here, which is one by one.
01:24:13.200 | So you can check the dimensions.
01:24:14.480 | One by 4,096 multiplied by 4,006 by one
01:24:19.120 | will output a matrix that is one by one.
01:24:21.360 | This will be multiplied by the values
01:24:25.040 | and this will result in the output token.
01:24:27.120 | So this is the inferencing step one
01:24:29.200 | in which the only token we give is the start of sentence.
01:24:32.320 | Then we take this token output, the token at the output.
01:24:37.040 | This is actually not the token
01:24:39.840 | because this has to be mapped to the linear layer, etc, etc.
01:24:42.480 | But suppose this is already the token.
01:24:44.400 | And we append it to the input.
01:24:48.880 | So it becomes the second input of the input.
01:24:52.880 | So this is SOS and this is the last output.
01:24:55.920 | We multiply it by the transposed of the keys.
01:24:58.640 | We get this matrix here.
01:25:00.960 | We multiply it by the values
01:25:02.480 | and we get two output tokens as output
01:25:05.200 | because it's a sequence to sequence model.
01:25:09.280 | Then we append the output of the previous as the input
01:25:14.160 | and we multiply it by the transposed of the keys.
01:25:16.800 | We get this matrix here.
01:25:18.160 | We then multiply it by the Vs
01:25:19.600 | and we get three tokens as output.
01:25:21.440 | We then append the output of the last one at the Q.
01:25:25.280 | We multiply it by the transposed of the keys.
01:25:27.440 | We get this matrix here.
01:25:28.720 | We multiply it by the V and we get this sequence as output.
01:25:32.000 | But we see some problems.
01:25:33.600 | And the one that I told you before.
01:25:35.120 | We are doing a lot of computations that we don't need.
01:25:37.680 | First of all, these dot products that we are computing here
01:25:44.080 | because this is the self-attention.
01:25:45.520 | So Q multiplied by the transposed of the keys
01:25:48.320 | will result in a lot of dot products that result in this matrix.
01:25:51.600 | These dot products that you see here highlighted in violet
01:25:55.520 | have been already computed at the previous steps
01:25:57.920 | because we are at the step number four
01:25:59.760 | but these have already been computed at the previous step.
01:26:02.240 | Plus, not only they have been computed already,
01:26:05.120 | we don't need them because we only are interested in
01:26:08.720 | what the latest token that we added as the input,
01:26:11.680 | so to the prompt,
01:26:12.880 | what is this tokens dot product with all the other tokens
01:26:19.120 | because this tokens dot product with all the other tokens
01:26:22.080 | will result in the output of the last token,
01:26:25.920 | the one we are interested in.
01:26:27.760 | So if there is a way to not do all these computations again
01:26:33.120 | and also to not output all the previous tokens
01:26:36.160 | that we actually don't need
01:26:37.440 | because we always access the latest token,
01:26:40.320 | yes, we just use the KVCache.
01:26:42.720 | In the KVCache, what we do is
01:26:46.160 | we always take the last token and we use it as input.
01:26:52.080 | So we don't append it to the query.
01:26:54.160 | We just use it directly as query.
01:26:56.480 | But because the query needs to access all the previous tokens,
01:27:01.600 | we keep the keys and the values.
01:27:03.600 | So we append the last input to the keys and the values
01:27:07.280 | but we don't append it to the queries.
01:27:09.360 | We replace it entirely with the queries.
01:27:12.800 | Let's see with an example.
01:27:14.800 | For example, this is our first step of inferencing.
01:27:18.320 | So this is just the start of sentence token.
01:27:21.280 | So we just have one token.
01:27:22.720 | We multiply it by the transpose of the keys.
01:27:25.520 | It will result in one by one.
01:27:26.960 | So we only have one token as output.
01:27:28.480 | This token, in the previous case, was appended to the queries.
01:27:33.520 | So in the next step, it became a matrix of dimension 2 by 4096.
01:27:39.680 | But in our case, at the time step 2, we don't append it.
01:27:43.360 | We only append it to the end of the keys and the values.
01:27:46.560 | And we only keep the queries here.
01:27:50.240 | If we do this product again now,
01:27:52.080 | we will see that this row here is the only one we are interested in.
01:27:55.520 | So the one that was not violet in the previous diagram.
01:27:58.720 | And if we do this dot product,
01:28:00.960 | it will result in only the last token, the one we are interested in.
01:28:04.560 | And every time we keep doing this job,
01:28:06.400 | we will see the key and the values grow.
01:28:10.160 | The queries will be always the last token.
01:28:14.880 | But the number of dot products that we are doing
01:28:17.920 | during the inferencing is much less.
01:28:19.680 | We don't need to do all those dot products that we did before.
01:28:22.560 | So compare this is time step 4.
01:28:24.400 | This is 4 dot products.
01:28:27.120 | Compare it with the previous time step 4.
01:28:29.360 | So here we have 16 dot products.
01:28:33.200 | So we reduce it by a factor of 4.
01:28:35.520 | And so that's why it's much faster to do inferencing with the KV cache.
01:28:42.160 | And let's review again.
01:28:44.480 | So here, as you can see, the matrix QK.
01:28:47.280 | And every time we add the token to the queue grows, right?
01:28:50.720 | But all these previous values with the KV cache, we are not computing it again.
01:28:54.640 | So that's why this is much faster.
01:28:56.240 | We only compute the one we need.
01:28:58.240 | And we only get one token as output.
01:29:01.600 | If this mechanism is not clear, please watch my previous video about LAMA,
01:29:04.880 | in which I describe it in much more detail and also with much more visualizations.
01:29:10.960 | Now let's go build this.
01:29:15.440 | There is another thing actually I want to show you before we go to build it,
01:29:18.800 | which is the grouped query attention.
01:29:21.040 | This one here.
01:29:23.120 | So I call it grouped multi-query attention,
01:29:26.400 | because it's actually the successive version of the multi-query attention.
01:29:30.240 | It's something in between the multi-head attention and the multi-query attention.
01:29:33.520 | But actually, the real name is grouped query attention in the paper.
01:29:36.880 | Also, it's called grouped query attention.
01:29:38.560 | Now, the reason we introduced the grouped query attention
01:29:44.400 | is first of all, we had the multi-query attention.
01:29:47.680 | The multi-query attention basically were introduced to solve one problem.
01:29:51.840 | That is, we first had the multi-head attention.
01:29:55.920 | We introduced the KV cache with the multi-head attention.
01:29:59.920 | Just the one we just saw.
01:30:02.240 | The problem was that with the multi-head attention,
01:30:04.480 | we were doing too many dot products.
01:30:06.400 | With the multi-head with the KV cache, we do less dot products.
01:30:09.600 | This resulted in a lot less computation.
01:30:13.120 | But it also resulted in a new bottleneck for the algorithm.
01:30:16.560 | So the bottleneck was not longer the number of computations,
01:30:19.680 | but how many memory access we were performing to access these tensors.
01:30:24.560 | Because in the GPU,
01:30:26.480 | the GPU is much faster at doing computations
01:30:32.800 | than it is at moving tensors around in its memory.
01:30:36.320 | So when we optimize an algorithm,
01:30:38.480 | we not only need to consider how many operations we are doing,
01:30:42.560 | but also how many tensors we are accessing,
01:30:44.960 | and where are these tensors located.
01:30:46.880 | So it's not a good idea to keep copying tensor from one place to another,
01:30:51.280 | because the GPU is much slower at copying memory from one place to another
01:30:54.960 | than it is at computing operations.
01:30:57.040 | And this can be visualized on the datasheet of the GPU.
01:31:02.560 | You can see, for example, that computing operations
01:31:04.640 | is 19.5 Tera floating point operations per second.
01:31:08.640 | And while the memory bandwidth, so how fast it can move memory,
01:31:12.320 | is 40 times slower.
01:31:14.000 | So we need to optimize algorithms also for managing
01:31:20.480 | how many tensors we access and how we move them around the memory.
01:31:24.480 | This is why we introduced the multi-query attention.
01:31:28.000 | The multi-query attention basically means that
01:31:29.920 | we have many heads for the queries,
01:31:34.080 | but we only have one head for the key and the values.
01:31:38.000 | This resulted in a new algorithm that was much more efficient
01:31:42.960 | than the algorithm just with the KVCache.
01:31:47.120 | Because the KVCache, yeah, it reduced the number of dot products,
01:31:50.480 | but it had a new bottleneck, that is the number of memory access.
01:31:53.600 | With this algorithm, we also may optimize the memory access,
01:31:56.560 | but we lose some quality,
01:31:59.440 | because we are reducing the number of heads for the key and the values.
01:32:02.960 | So we are reducing the number of parameters in the model.
01:32:06.400 | And this way, the model, because we are reducing
01:32:11.440 | the number of parameters involved in the attention mechanism,
01:32:14.240 | of course the model will degrade in quality.
01:32:17.680 | But we saw that practically it degraded the quality not so much.
01:32:22.560 | So actually the quality was not bad.
01:32:24.560 | And this was in this paper.
01:32:26.800 | So they show that the quality degradation was very little,
01:32:30.400 | so from 26.7 to 26.5,
01:32:33.920 | but the performance gains were very important.
01:32:37.200 | We went from 48 microseconds per token
01:32:41.440 | to 5 microseconds or 6 microseconds per token,
01:32:44.560 | so a lot faster.
01:32:45.600 | Now, let's introduce the grouped query attention
01:32:51.520 | or the grouped multi-query attention.
01:32:52.960 | In the multi-head attention,
01:32:55.360 | we had n heads for the queries,
01:32:59.520 | n heads for the keys and n heads for the values.
01:33:02.720 | In the multi-query attention, we have n heads for the keys,
01:33:07.520 | but only one head for the keys and the values.
01:33:10.000 | In the grouped multi-query attention or the grouped query attention,
01:33:13.920 | we have less number of heads for the keys and values.
01:33:20.720 | So every two heads for the queries, in this case, for example,
01:33:24.160 | we will have one head for the keys and the values.
01:33:28.960 | And this is a good balance between quality and speeds,
01:33:33.120 | because, of course, the fastest one is this one,
01:33:35.920 | because you have less heads.
01:33:37.280 | But, of course, the best one from a quality point of view is this one,
01:33:42.560 | but this is a good compromise between the two.
01:33:44.720 | So you don't lose quality, but at the same time,
01:33:47.520 | you also optimize the speed compared to the multi-head attention.
01:33:50.880 | So now that we have reviewed all this concept, let's go build it.
01:33:54.960 | So please, again, if you didn't understand very much in detail,
01:33:59.280 | this part is better.
01:34:00.640 | You go to review my other video about Llama,
01:34:02.960 | in which I explain all this part much better.
01:34:04.960 | Otherwise, if I have to repeat the same content of the previous video,
01:34:08.960 | this would be the current video would become 10 hours.
01:34:12.400 | So let's go build it.
01:34:21.680 | Okay, we need to save some things.
01:34:36.480 | Compared to the original code from Facebook, from Meta,
01:34:42.000 | I actually removed the parallelization.
01:34:44.320 | First of all, because I cannot test it.
01:34:46.080 | I don't have multiple GPUs.
01:34:47.520 | I don't have a very powerful GPU, actually.
01:34:49.520 | And so I simplified the code a lot.
01:35:03.280 | And KVHeads indicates the number of heads for the keys and the values,
01:35:17.200 | because they can be different than the number of heads for the queries.
01:35:21.680 | And this is why we also have an headsQueue.
01:35:49.600 | This value here represents the ratio between the number of heads for the query
01:35:54.560 | and the number of heads for the keys and the values.
01:35:57.520 | We will use it later when we calculate the attention.
01:35:59.760 | So let me write some comments.
01:36:02.000 | So this is…
01:36:21.360 | So…
01:36:22.800 | And then we have a self.headDimension, which is…
01:36:49.120 | Which is…
01:36:54.720 | This indicates the part of the embedding that will be visualized by each head.
01:36:59.920 | Because, as you know, the embedding is split into multiple heads.
01:37:02.960 | So each head will watch the full sentence, but a part of the embedding of each word.
01:37:17.760 | Then we have the W matrices.
01:37:19.520 | WQ, WK, WV, and WO, just like in the normal vanilla transformer.
01:37:24.880 | And they don't have any bias.
01:37:41.680 | Oops, why did I write true?
01:38:11.520 | And then we create a cache.
01:38:30.800 | We will see later how it's used.
01:38:32.560 | I just now created one for the keys and one for the values.
01:38:39.600 | So…
01:38:50.560 | Okay, finally, we implement the forward method, which is the salient part here.
01:39:19.520 | So, self x is…
01:39:24.640 | To simplify the code for you, I will write the…
01:39:30.000 | For each operation, I will write the dimensions of the tensor that is involved in the operation,
01:39:35.040 | and also the resulting tensor from each operation.
01:39:38.000 | The start position indicates just the position of the token inside of the sentence.
01:39:48.000 | And these are the frequencies that we have computed.
01:39:50.640 | Okay, let's start by extracting pipe size.
01:40:02.400 | B, sequence length, and dimension.
01:40:16.320 | But the sequence length, we know it's one.
01:40:18.320 | So, dimension, yeah.
01:40:21.360 | Then what we do is we multiply, just like in the original transformer,
01:40:26.960 | we take the query, the key, and values.
01:40:28.640 | We multiply it by then the WQ, WK, and WK matrix.
01:40:32.320 | So, xq is equal to self.wq.
01:40:39.680 | This means going from B, one dimension, to B, one head dimension.
01:40:49.120 | So, the number of heads for the query multiplied by the dimension,
01:40:52.640 | because we are…
01:40:53.280 | In this case, we are…
01:40:56.560 | This is actually equal to dim.
01:40:58.080 | So, the number of heads multiply the head dimension, as you can see from here.
01:41:05.600 | So, we are not changing the shape.
01:41:15.040 | In this case, however, we may change the shape of the…
01:41:28.560 | Because the number of heads for the kv may be smaller than q.
01:41:34.080 | So, this matrix may have a last dimension that is smaller than xq.
01:41:38.960 | And the same is for xv.
01:41:42.640 | So, here, let me write some comment.
01:41:48.640 | Apply the WQ, WK, and WV matrix to queries, keys, and values,
01:42:02.320 | which are the same, because it's a self-attention.
01:42:04.320 | So, the query, key, and value is always x.
01:42:07.440 | We then divide them into their corresponding number of heads.
01:42:13.840 | So, xq is equal to xq.q.
01:42:18.480 | Batch size, we keep it like this.
01:42:21.200 | Sequence length is one.
01:42:30.000 | So, we divide b1, h, q multiplied by head dimension into b1, head, h, q, and head dimension.
01:42:46.160 | So, we divide them into the h heads for the query.
01:42:52.720 | And then we do the same for the key and the values.
01:43:00.240 | And the same for the b.
01:43:49.360 | now, we have multiplied, okay, we have the x input,
01:43:54.000 | we multiply it by the WQ, WK, and WK, y.
01:43:57.280 | Let's go check the code here.
01:43:59.040 | As you remember, we take the input, we multiply it by WQ, WK, and WV.
01:44:04.560 | This will result in these matrices here.
01:44:06.880 | We then divide them into the number of heads.
01:44:09.440 | But in the case of grouped query attention, they may be different.
01:44:12.160 | So, this may be four heads, and this may be two heads, and this may be two heads.
01:44:16.800 | So, they are not the same number.
01:44:18.960 | The next thing we are going to do, and this is present in the here,
01:44:24.400 | we need to apply the rotary positional encodings to the query
01:44:30.320 | and the keys, but not the values.
01:44:32.240 | Let's do it.
01:44:37.680 | And this is how we apply the positional encodings.
01:44:49.440 | This will not change the size of the vectors.
01:44:51.680 | You can see that here.
01:44:53.760 | Because at the end, we have the same shape as the original input vector.
01:45:11.360 | Okay.
01:45:26.480 | Now, now comes the KVCache part.
01:45:30.960 | Let's watch again the slides.
01:45:32.640 | As we can see here, every time we have an input-output token,
01:45:40.160 | so for example, the attention 2 here, it supposes the token number 2,
01:45:44.320 | we append it at the end of the keys and the values.
01:45:48.320 | And this is exactly what we are going to do.
01:45:51.200 | So, what we do here, we keep a cache of the keys and the values,
01:45:56.800 | because they will be used for the next iterations.
01:45:59.840 | Because at every iteration, in X, we only receive the latest token
01:46:04.640 | that was output from the previous iteration.
01:46:07.840 | We append it to the K and the V, and then we compute the attention
01:46:14.080 | between all the K, all the V, but only the single token as query.
01:46:18.800 | So, let's do it.
01:46:22.320 | So, first, replace.
01:46:36.640 | This is the position of the token.
01:46:38.080 | This should be 1, because sequence length is actually 1, always.
01:46:44.080 | But I try to keep this code the same as the one from Lama, from Meta.
01:46:51.840 | This is, basically, it means that if we have one token from many batches,
01:47:01.920 | I mean, we have one token for every batch, we replace them,
01:47:05.600 | because we can process multiple batches.
01:47:10.720 | So, we replace the entry for this particular position for every batch.
01:47:15.520 | Okay.
01:47:25.860 | Now, we replace it only for this position here.
01:47:31.360 | But when we compute the attention using the KVCache, let's go watch again,
01:47:36.320 | we need to calculate the dot product between the only one token but all the keys.
01:47:46.960 | And then we will need to multiply with all the values,
01:47:49.840 | and this will result in only one token as output.
01:47:52.320 | So, we need to extract from this cache all the tokens as keys
01:47:56.800 | and all the tokens as values up to this position here.
01:48:01.200 | The one we are passing.
01:48:02.480 | So, keys is equal to all.
01:48:21.040 | So, starting from 0 up to startPos plus sequenceLength,
01:48:27.840 | and the values are length.
01:48:46.800 | Now, what happens is that, let me write also some sizes here.
01:48:57.600 | We have b, sequenceLength of K and V,
01:49:01.760 | because the sequenceLength of the input is always 1, we know that.
01:49:05.280 | But the sequenceLength of the cache means all the cached keys and values,
01:49:14.000 | which are up to startPosition.
01:49:16.320 | So, this sequenceLength is actually equal to startPosition.
01:49:21.760 | And actually, startPosition plus 1.
01:49:24.800 | My next dimension is the number of heads for the K and V,
01:49:32.880 | and then the dimension of each head.
01:49:34.640 | Now, the number of heads for the keys and values
01:49:41.200 | may not correspond to the number of heads of the queries.
01:49:46.400 | So, how do we compute?
01:49:48.000 | In the original code from Lama, what they did was basically,
01:49:51.120 | let's go check the code for here.
01:49:54.640 | So, in the grouped query, attention,
01:49:58.720 | we have that the number of heads for the keys and the values
01:50:01.840 | is not the same as the number of heads for the queries.
01:50:04.240 | So, there are two ways.
01:50:06.080 | One is to make an optimized algorithm
01:50:08.080 | that actually takes this into consideration.
01:50:10.160 | The other way is to just copy this single head into multiple heads,
01:50:19.040 | such that we arrive to this situation here,
01:50:21.520 | and then we just compute it just like a multi-head.
01:50:24.320 | This is not an optimized solution, but it's the one used by the code by Lama.
01:50:28.720 | And it's also the one I will be sticking to,
01:50:31.680 | because I don't have any way of testing other codes,
01:50:35.760 | because the only model that supports the grouped query attention
01:50:40.480 | is the biggest one from Lama, so with 70 billion parameters,
01:50:44.240 | but my computer will never be able to load that model.
01:50:47.600 | And so, I don't have any way of testing it,
01:50:49.600 | so that's why I also didn't optimize the code
01:50:52.240 | for actually computing the grouped query attention,
01:50:54.160 | but I will just replicate this single head multiple times,
01:50:57.680 | such that we arrive to this situation here.
01:51:01.520 | So, I will also repeat.
01:51:16.320 | okay, this function here, repeat_kv,
01:51:38.000 | just repeats the keys until we reach the number of,
01:51:43.280 | for this number of times, so nrep.
01:51:45.920 | What is this?
01:51:46.720 | It's the ratio of the number of heads of the queries
01:51:49.600 | by the number of heads of the keys.
01:51:51.600 | So, if the number of heads of the keys is four,
01:51:55.040 | and the number of heads for the queries is eight,
01:51:57.200 | that means we need to repeat twice each head.
01:51:59.840 | So, let's build also this method, since we are here.
01:52:17.200 | okay, we don't need to repeat it,
01:52:31.040 | so there is only one repetition.
01:52:32.480 | We just return the basic tensor.
01:52:36.080 | Otherwise, we repeat it n times.
01:52:37.680 | So, the first thing we do is we add a new dimension,
01:52:57.360 | and we can do like this,
01:52:58.640 | part_sequence_length, number_of_heads, then nothing,
01:53:08.880 | and then this will add this new dimension in this position.
01:53:14.640 | Then we expand it.
01:53:33.360 | then we reshape it.
01:53:46.880 | Basically, we introduce a new dimension.
01:53:51.760 | We repeat all the sequence this dimension number of times,
01:53:56.240 | along this dimension n-wrap number of times,
01:54:00.560 | and then we just flatten it.
01:54:02.640 | So, we remove again this dimension.
01:54:07.840 | And this is how we repeat the keys and also the values.
01:54:14.240 | Now we can repeat.
01:54:23.600 | Okay, now we just proceed just like with the standard,
01:54:29.120 | the standard calculation for the multi-head attention.
01:54:34.400 | That is, we first move the head dimension before the sequence dimension,
01:54:42.080 | because each head will watch all the sequence,
01:54:44.720 | but a part of the embedding of each token.
01:54:47.440 | So, what we are doing is batch 1,
01:54:55.680 | because 1 is the sequence length of the queries,
01:54:58.000 | the number of heads of the queries, and head dimension,
01:55:05.680 | batch head sequence length and head dimension.
01:55:10.720 | We do the same for the keys and the values.
01:55:22.000 | Then we do the standard formula for queries multiplied by the transpose of the keys,
01:55:34.160 | divided by the square root of the dimension of each head.
01:55:37.040 | So, xq, so the queries multiplied by the transpose of the keys,
01:55:48.960 | all of this divided by the square root of the dimension of each head.
01:55:57.200 | Then we apply the softmax,
01:56:03.360 | and this one will result in a shape of queries,
01:56:15.520 | one head dimension multiplied by qv.
01:56:34.720 | The softmax doesn't change the dimension.
01:56:38.800 | Then we multiply it by the values.
01:56:44.960 | So, the formula is queries multiplied by the transpose of the keys,
01:56:50.960 | and then we do the softmax, then the output is multiplied by the values.
01:57:00.400 | So, this will result in b,
01:57:28.160 | and then we multiply it by the output matrix,
01:57:32.800 | but before we remove all the heads, so we concatenate again.
01:57:36.160 | This is what we did also here.
01:57:39.040 | So, here we take the output of all the heads,
01:57:42.000 | then we concatenate them together, and then we multiply it by the wo matrix.
01:58:00.480 | and this will result in a b1 dim,
01:58:17.200 | b1 dim, this one is bhq one head dimension into b1 hq head dimension,
01:58:34.720 | because of the transposition,
01:58:38.480 | and then we remove the dimension for the head, so b1 dimension.
01:58:47.200 | And this is our self-attention with kvcache.
01:58:51.280 | So, let's review what we have done.
01:58:52.720 | Here, I think I made some mistake, because self, that's why it's colored differently.
01:59:00.240 | Okay, let's review what we have done.
01:59:02.080 | When we calculated the self-attention, because we are inferencing,
01:59:05.760 | so this code will only work for inferencing, we can use the kvcache.
01:59:09.840 | The kvcache allow us to save a number of dot products that we don't need.
01:59:15.120 | Why? Because every time we are in the original transformer,
01:59:18.880 | we were computing a lot of dot products for tokens,
01:59:22.480 | output tokens that we don't care about.
01:59:24.800 | In this case, we simplified the mechanism to output only one token.
01:59:29.600 | As you can see, the output of the self-attention is b, so batch,
01:59:33.440 | one token only with its embedding size, which is 4096.
01:59:40.160 | So, we are only outputting one token, not many tokens.
01:59:44.880 | We input only one token, and we output one token.
01:59:47.680 | But because we need to relate that single token with all the previous tokens,
01:59:53.040 | we keep a cache of the keys and the values.
01:59:55.680 | Every time we have a token, we put it into the cache, like here,
02:00:00.400 | then we retrieve all the previous saved tokens from the cache,
02:00:04.480 | and then we calculate the attention between all the previous tokens,
02:00:07.600 | so the keys and the values, and the single token as input of, as queries.
02:00:12.960 | The output is the only token we care about.
02:00:16.480 | This is the idea behind kvcache.
02:00:20.880 | And the grouped query attention is the fact that we have a different number of heads
02:00:26.240 | for the keys and values, but in our case,
02:00:29.680 | we do have a different number of heads for the keys and queries.
02:00:37.680 | But we just repeat the one that we are missing to calculate the attention.
02:00:42.080 | So the attention is calculated just like the previous transformer,
02:00:46.480 | like a normal multi-head attention,
02:00:48.560 | but by repeating the missing keys and values heads,
02:00:52.880 | instead of actually optimizing the algorithm.
02:00:55.360 | This has also been done by Meta in its official implementation,
02:00:59.920 | and I also did it here.
02:01:01.040 | The biggest reason is because I cannot test any other modification.
02:01:05.120 | I cannot test another algorithm that actually tries to optimize this calculation.
02:01:10.800 | So if I find another implementation that I know is working,
02:01:15.440 | I will share it with you guys.
02:01:16.640 | Otherwise, I will try to run it on Colab and see if I can come up with a better solution.
02:01:21.840 | But for now, we just repeat it.
02:01:23.440 | But at least we got the concept of the grouped query attention.
02:01:27.200 | That is, we have less number of heads,
02:01:29.440 | and it's something that is in between the multi-query attention and the multi-head attention.
02:01:34.880 | That doesn't sacrifice quality, but improves speed.
02:01:38.080 | Now, the last thing that we didn't implement is the feedforward layer.
02:01:43.120 | For the feedforward layer, the only thing that we need to review
02:01:45.680 | is the ZWIGGLU activation function that we can see here.
02:01:48.320 | And this activation function has been changed compared to the previous
02:01:53.360 | activation function used in the vanilla transformer, which was the RELU function.
02:01:57.680 | And the only reason we replaced it is because this one performs better.
02:02:02.480 | And as I showed in my previous video, we cannot prove why it works better.
02:02:08.400 | Because in such a big model with 70 billion parameters,
02:02:11.520 | it's difficult to explain why a little modification works better than another.
02:02:15.840 | We just know that some things work better in practice for that kind of model
02:02:19.440 | or for that kind of application.
02:02:21.040 | And this is actually not my opinion.
02:02:23.120 | This is actually written in the paper.
02:02:24.480 | So as you can see here in the conclusion of the paper,
02:02:27.520 | they say that we offer no explanation as to why this architecture seems to work.
02:02:31.280 | We attribute their success as all else to divine benevolence.
02:02:34.720 | So it means that when you have such a big model
02:02:36.960 | and you change a little thing and it works better,
02:02:39.120 | you cannot always come up with a pattern to describe why it is working better.
02:02:43.120 | You just take it for granted that it works better
02:02:46.560 | and you use it because it works better in practice.
02:02:49.520 | So to implement the ZWIGGLU function, we need to apply...
02:02:53.760 | This is the formula from the original transformer.
02:02:57.040 | So we have two matrices here.
02:03:00.880 | So this is the RELU function of the first linear layer and the second linear layer.
02:03:05.520 | In LAMA, we use the ZWIGGLU function which involves the three matrices here.
02:03:10.800 | Because they incremented the number of parameters here
02:03:13.600 | and also they were experimenting with the grouped query attention,
02:03:17.840 | the architecture of LAMA has some more parameters
02:03:23.760 | to adjust the number of parameters of this feedforward layer.
02:03:27.920 | So as it respects some constraints.
02:03:30.720 | And this is actually used in deep learning research.
02:03:35.360 | Whenever we modify the transformer model
02:03:37.360 | and this reduces the number of parameters or increases the number of parameters,
02:03:41.280 | the first thing the researchers do,
02:03:42.880 | they adjust the numbers of parameters of the feedforward layer
02:03:46.720 | so that when they make comparison between two models,
02:03:49.760 | they have the same number of parameters.
02:03:51.920 | So I will also, of course, use the same structure
02:03:56.160 | because otherwise I cannot load the weight from the pre-trained model.
02:04:00.560 | So let's do it.
02:04:02.240 | The hidden size is calculated like this.
02:04:06.240 | So four times the dimension.
02:04:07.680 | Then they do the two-third of this dimension.
02:04:10.880 | And then they also have a multiplier if it's specified.
02:04:21.760 | Then they say round the hidden... oops.
02:04:36.800 | By using this modification to calculating the hidden dimension like this,
02:04:41.600 | it may not be the case that this hidden dimension is a multiple of this number here.
02:04:48.720 | So maybe they want the size of the hidden dimension to be multiple of this number here.
02:04:54.480 | So maybe they want the size of the hidden dimension to be multiple of this number here.
02:04:57.280 | So maybe they want the size of the hidden dimension to be multiple of this number here.
02:04:59.760 | So maybe they want the size of the hidden dimension to be multiple of this number here.
02:05:02.560 | So maybe they want the size of the hidden dimension to be multiple of this number here.
02:05:03.440 | So maybe they want the size of the hidden layer to be a multiple of 256.
02:05:08.880 | But by calculating it like this, it may not be.
02:05:11.680 | So what they do is they make it round up to the next multiple of the multiple of parameter.
02:05:30.960 | So this is a way to do it.
02:05:36.720 | Okay, let me give you an example.
02:05:38.640 | It's easier to show with an example than to actually write it.
02:05:42.400 | So suppose you have the hidden size is equal to, let's say, 7.
02:05:48.800 | But you want it to multiple of is equal to 5.
02:05:52.240 | So you want the hidden size to be a multiple of 5.
02:05:54.800 | So how do we do?
02:05:56.240 | Well, what we do is, basically, we do hidden plus 4 in this case.
02:06:01.840 | So we do 7 plus 4, which is 11.
02:06:04.880 | We divide it by 5, which is equal to 2.
02:06:10.800 | And then we multiply this 2 by 5.
02:06:15.200 | So it will result in 2 by 5 is equal to 10.
02:06:18.320 | It will result in the first multiple that is bigger or equal to this number here.
02:06:23.280 | That's the idea.
02:06:25.680 | And then we have these matrices for the Zwiglu function.
02:06:30.720 | It's very easy.
02:06:31.440 | We just follow the formula for the Zwiglu function, which is here.
02:06:38.080 | So w, the Zwish of, what is Zwish?
02:06:41.920 | The Zwish is the Sillu function.
02:06:45.520 | Because the Zwish with the beta is equal to 1 is actually the Sillu function,
02:06:50.480 | which has this graph here.
02:06:52.800 | And then we multiply it with another parameter matrix here.
02:06:56.800 | And then we apply it to another linear layer, w2.
02:06:59.920 | So in total, we have three matrices, w1, we call it w2, and w3.
02:07:13.920 | And they don't have bias.
02:07:18.880 | Oops.
02:07:19.920 | This is the hidden dimension.
02:07:31.040 | Okay, now we implement the forward method.
02:07:50.400 | The first thing we do is we calculate the Zwish function.
02:08:00.400 | Then we calculate, so we are calculating, let me show you.
02:08:12.320 | We are calculating this one, xw, Zwish of xw.
02:08:18.800 | Then we calculate this xv.
02:08:20.640 | Then we multiply them together, just like in the formula.
02:08:32.000 | So Zwish multiplied by xv.
02:08:34.720 | And then we apply the last linear layer, which is w2.
02:08:40.240 | Which results in a multiplication by the w2 matrix, by the way.
02:08:47.840 | And then we return x.
02:08:48.800 | And this is the field forward layer.
02:08:53.200 | Now that we have all the building blocks, we need to go to the inferencing.
02:09:00.560 | Let's start building the inference code.
02:09:04.000 | So inference.py, the first code we will be, first we will build a code to load the model.
02:09:10.960 | And then we will build a code to inference the model.
02:09:14.640 | I will actually also show all the inference techniques that are out there,
02:09:18.640 | and which one we will apply and why.
02:09:20.560 | So let's start by building first the code for loading the model.
02:09:25.440 | So first we import the stuff we need.
02:09:42.240 | We need the JSON to load the parameters.
02:09:44.640 | And then we need the sentence piece to load the tokenizer.
02:09:49.040 | Because the sentence piece is the tokenizer that has been used,
02:09:52.320 | and it's a library from Google.
02:09:53.600 | Okay.
02:10:00.100 | From model import model-args and the transformer class.
02:10:05.920 | We define the class Lama, which is our model.
02:10:11.840 | It takes a transformer, a tokenizer, which is a sentence piece processor.
02:10:18.640 | And then the model arguments.
02:10:32.640 | Oops.
02:10:39.060 | Oops.
02:10:39.560 | Args.
02:10:44.120 | Model-args.
02:10:46.980 | Yeah.
02:10:47.700 | Model-args.
02:10:48.200 | Okay.
02:10:50.680 | Now we build a static method.
02:10:54.740 | Static method.
02:10:56.020 | And we call it build, just like in the original code from Lama.
02:10:59.060 | In which we pass the directory where the checkpoints are saved.
02:11:05.540 | In this case, the directory name is lama27b, in my case.
02:11:08.740 | But it depends on which size of the model you have downloaded.
02:11:12.260 | Then the tokenizer path, which is the path to the tokenizer.
02:11:17.460 | This is the file of the tokenizer that I downloaded.
02:11:19.780 | Then we have a load model layer, max sequence length.
02:11:32.260 | Max patch size.
02:11:33.620 | And we have device.
02:11:37.460 | Okay.
02:11:42.100 | This is only for displaying how much time it takes to load the model.
02:11:45.620 | If we want to load the model, we will also load the checkpoints.
02:11:53.140 | So checkpoints is equal to sorted.
02:11:59.780 | The glob method allows you to find all the files that match this filter.
02:12:13.780 | Okay, we see that we are loading checkpoint this one.
02:12:35.700 | And then we actually load it.
02:12:43.540 | And we save it on the CPU.
02:12:53.140 | We can show how much time it takes to load the model.
02:13:01.060 | In my computer, usually it takes 10 to 20 seconds.
02:13:13.140 | Then previous time we rewrite it.
02:13:23.620 | So we can also show how much time it takes to load all the parameters of the model.
02:13:29.380 | Then we load the parameters, so the JSON file.
02:13:39.460 | We read it, open it as read-only file.
02:13:54.500 | And okay, then we build the arguments.
02:14:04.580 | Maximum sequence length is the one we have specified.
02:14:12.580 | And then we have the max patch size is the max patch size.
02:14:24.180 | The device is the one we have specified.
02:14:27.060 | And then all the parameters loaded from the JSON file.
02:14:30.820 | Then we loaded the tokenizer.
02:14:35.620 | Then we, by using the tokenizer, we can populate the vocab size of the model args.
02:14:50.740 | The vocabulary size is actually the number of tokens inside the tokenizer.
02:14:55.060 | Now this is also the default tensor for PyTorch.
02:15:12.500 | So whenever PyTorch wants to create a new tensor, what kind of type it should use,
02:15:17.140 | it's defined, this is by meta, so they want for CUDA to use this type that I show you here.
02:15:22.980 | Default tensor type torch.cuda half tensor.
02:15:28.980 | This changes the precision that the tensor supports.
02:15:34.100 | So how much space it occupies in memory.
02:15:39.060 | Otherwise, then we created the actual model.
02:15:52.020 | Okay, when we load a checkpoint,
02:16:07.620 | actually the checkpoint is a list of key and values.
02:16:12.020 | Each key is a matrix in the model.
02:16:15.540 | So the weight, for example, of a linear layer,
02:16:17.780 | or the bias of a linear layer, or something like this.
02:16:20.420 | And the names that we have used for the variable names and the matrices here,
02:16:29.460 | for example, wqwk, match actually the name that are present in the checkpoint here,
02:16:35.140 | except for one name.
02:16:36.740 | So to make sure that I have used the right names,
02:16:39.300 | I will load the checkpoint with strict equal true.
02:16:45.140 | Strict equal true means that if there is at least one name that doesn't match,
02:16:49.300 | it will throw an error.
02:16:50.340 | So if load model, model.loadState ticked, strict equal true.
02:17:03.460 | So if there is at least one name in the loaded file that doesn't match the name
02:17:08.340 | in the classes that I have created here in the model, it will throw an error.
02:17:11.940 | But I know that there is one key that we don't need,
02:17:15.620 | which are the frequencies for the rotary positional embeddings,
02:17:19.540 | which we actually are computing every time we create the tensor.
02:17:23.140 | So we are creating them here by using this function.
02:17:26.660 | So we don't need to load them from the model.
02:17:29.540 | So we can remove it from the model, from the checkpoint.
02:17:33.780 | So because the checkpoint is a dictionary, we can just remove this.
02:17:37.460 | It's called rope.freqs.
02:17:40.100 | And then we can print how much time it took to load the model.
02:17:57.860 | And then we return llama.
02:18:02.820 | Model tokenizer.
02:18:06.260 | And model args.
02:18:09.380 | Now, before we proceed further, let me test if the model can be successfully loaded.
02:18:15.140 | So let's do it.
02:18:18.660 | If name...
02:18:24.340 | First, I will set the manual seed to zero.
02:18:27.540 | So later we use it for inferencing.
02:18:30.340 | Then I don't want to use CUDA because my GPU doesn't support it.
02:18:35.540 | So I say allow_cuda = 4.
02:18:37.220 | Then device is equal to storage.cuda.is_available and allow_cuda else cpu.
02:18:47.620 | Next time if you want to load the model with CUDA, just set this variable to true.
02:18:52.900 | But in my case, I will always leave it to false because I don't want to load CUDA.
02:19:06.500 | Sequence length, I set it to 1024.
02:19:25.060 | Max batch size, let's say 3.
02:19:31.300 | And device now.
02:19:37.940 | Let's run it and hopefully it will not crash.
02:19:48.580 | Wow, already.
02:19:51.300 | Not tensore, but tensor.
02:19:55.220 | So let's run it again.
02:20:00.100 | There is always a lot of typos when you write code.
02:20:02.580 | Another problem here.
02:20:05.700 | Ah, not storage, but tensor.
02:20:11.220 | This should be tensor.
02:20:13.460 | bfloat16 tensor.
02:20:17.460 | Yeah, let's try again.
02:20:19.300 | Hidden... hidden what?
02:20:25.540 | Hidden dimension, of course.
02:20:27.060 | And let's try again.
02:20:39.060 | Yeah, all okay.
02:20:41.140 | Okay, wonderful.
02:20:41.940 | It means that at least it's doing something and it's not crashing, which is always a good
02:20:46.900 | news.
02:20:47.460 | So let's run it again.
02:20:49.540 | Okay, wonderful.
02:20:50.420 | It means that at least it's doing something and it's not crashing, which is always a good
02:20:55.620 | news.
02:20:55.940 | So our next step is actually to build the inferencing code.
02:20:59.460 | So what we want to do is actually we want to be able to give some prompts to the model
02:21:05.620 | and then check the output for this prompt.
02:21:07.540 | So let's define some prompts.
02:21:09.460 | We will define some prompts here.
02:21:13.380 | And here we pass, for example, the size of the prompts.
02:21:17.940 | And then we want to, you know, we want to inference the model.
02:21:26.100 | So before we start inferencing the model, we need to build the code for inferencing
02:21:30.500 | the model, because we need to find a strategy for selecting the next token, etc, etc.
02:21:34.980 | So let's review how the inferencing works and what are the various strategies for inferencing.
02:21:39.700 | Okay, so when we are dealing with the next token prediction task, when we want to inference,
02:21:45.460 | we usually give the prompt and then we want to predict the tokens.
02:21:49.860 | But we give one token at a time.
02:21:52.260 | And every time we give one more token, the model will output one more token as output.
02:21:56.820 | And we only keep the last one.
02:21:58.340 | But with the KVCache, actually, we always give one token at a time.
02:22:02.660 | The KVCache will keep the cache for the keys and the values and with only output one token.
02:22:08.820 | Okay, the point is, we need to find strategies for selecting this token.
02:22:14.020 | Among all the tokens that we have in the vocabulary.
02:22:17.060 | And this is the job of the logits and the softmax.
02:22:20.340 | So let's review how they work.
02:22:22.180 | Now imagine I give you the following task as human.
02:22:27.140 | So complete the following sentence.
02:22:29.700 | I think nuclear power is and then you have to choose a word.
02:22:33.860 | Now you as human may have thought of the possible next tokens, which may be clean, dangerous,
02:22:40.100 | cheap, expensive, safe, difficult, or something else.
02:22:43.540 | The choice of the next token in your head depends on your education,
02:22:47.780 | on your experience with nuclear power, and your opinion on the matter.
02:22:51.460 | Large language models also face the same problem.
02:22:55.220 | When we give them a prompt, then the model has to choose the next word.
02:22:59.300 | The model, the uncertainty of the choice derives entirely from their training process
02:23:05.220 | and the strategy that we use to select the next token.
02:23:09.380 | There are many strategies.
02:23:10.820 | For example, we have the greedy strategy, the beam search, temperature is a parameter,
02:23:15.140 | random sampling, top k, top p.
02:23:17.060 | In this video, we will review all these strategies and how they work.
02:23:21.140 | But first, we need to understand what are the logits.
02:23:23.700 | Let's look at the transformer model from Lama.
02:23:28.260 | So the output of the self-attention is a sequence.
02:23:33.540 | In the case of the KVCache is only one token.
02:23:36.900 | We then run it through a linear layer.
02:23:38.980 | So after normalization, we run it through a linear layer.
02:23:41.780 | The linear layer will transform the embedding that is output from the self-attention here
02:23:47.060 | into a list of numbers that represent the kind of the probability,
02:23:54.820 | they are not really a probability, but we can think of it as a probability,
02:23:58.500 | of that token in the vocabulary.
02:24:01.540 | So if our vocabulary is made of, let's say, 100 tokens,
02:24:06.100 | this linear layer will output 100 numbers.
02:24:08.820 | And after we apply the softmax,
02:24:11.940 | these 100 numbers will become the probability of that token being the next more probable token
02:24:19.860 | for the prompt given to the input.
02:24:22.500 | So given an input, a prompt, the model comes up with probabilities.
02:24:27.620 | Probabilities for which token to choose next.
02:24:30.660 | And so what is the job of the linear layer and what is the job of the softmax?
02:24:35.700 | The linear layer converts the embedding of a token into a list of numbers
02:24:41.060 | such that each number represents a score that later with the softmax
02:24:46.020 | represents the probability of that particular token in the vocabulary.
02:24:50.260 | The softmax job is just to scale the logits in such a way that they sum up to one.
02:24:55.380 | So that's why we can talk about probabilities with the softmax, but not with the logits.
02:25:00.100 | So the output of the softmax is thus a probability distribution
02:25:04.340 | over all the words in the vocabulary.
02:25:06.580 | That is, each word in the vocabulary will have a probability associated with it.
02:25:10.580 | But now, given these words, each one with their probability,
02:25:14.340 | how do we choose the next token?
02:25:16.260 | There are many strategies.
02:25:17.300 | The easiest one is the greedy.
02:25:20.420 | The greedy strategy basically says we just select the token with the maximum probability.
02:25:25.540 | So imagine we are inferencing and the time step is the first time step in the greedy strategy.
02:25:31.380 | The prompt is Celia, you're breaking my heart.
02:25:34.420 | You're shaking my.
02:25:35.620 | OK, this is a line from a very famous song from Simone Ergampfunkel.
02:25:41.220 | And the next word, for those who know, will be confidence.
02:25:47.220 | So Celia, you're breaking my heart.
02:25:48.980 | You are shaking my confidence.
02:25:50.500 | Suppose the output of the softmax is this distribution here.
02:25:55.620 | So we have 40% probability for this word.
02:25:58.660 | 20% for this word.
02:26:00.260 | 15% for this word.
02:26:01.700 | And 10% for this word.
02:26:03.220 | With a greedy strategy, we always choose the token with the maximum probability.
02:26:09.300 | Then we append it to the input.
02:26:11.780 | So the input at the next inference step becomes Celia, you're breaking my heart.
02:26:16.260 | You're shaking my confidence.
02:26:17.620 | And then the model has to come up with the next word, which, if you know the song, is daily.
02:26:22.420 | If we use the greedy strategy, we select the one with the highest probability.
02:26:27.860 | So in this case, it's daily, and it's also the correct one.
02:26:30.980 | So this is how the greedy strategy works.
02:26:33.380 | At every step, we choose the token with the maximum probability,
02:26:36.900 | which is then appended to the input to generate the next token, and so on.
02:26:40.900 | But if the initial token happens to be the wrong one,
02:26:45.140 | so not only the initial, but the initial two, three tokens happen to be the wrong ones,
02:26:49.300 | it's very likely that all the next tokens will also be wrong,
02:26:52.740 | because we are giving a wrong prompt to the model.
02:26:55.220 | So imagine at the time step one, we don't choose confidence,
02:26:59.860 | but somehow the model came up with a high score for liver.
02:27:03.540 | So you're shaking my liver, but then the next word,
02:27:06.660 | the model will not be able to come up with a reasonable next word,
02:27:10.420 | because there is no song that says you're shaking my liver.
02:27:14.580 | So if we make a mistake in the early stage of the greedy,
02:27:17.780 | all the next token very probably will also be wrong.
02:27:20.500 | But it's very easy to implement.
02:27:23.060 | And however, it performs poorly in practice, that it's very, it's not used so much.
02:27:27.940 | Another strategy is the BeamSearch.
02:27:30.980 | In BeamSearch, we have a parameter, which is called K,
02:27:35.060 | which means that at every step, we not only choose the top ones,
02:27:38.660 | but the top K at every step.
02:27:41.060 | And we always keep the top two best performing tokens.
02:27:45.380 | So in this case, for example, imagine we are time step one.
02:27:49.300 | So Celia, you're breaking my heart.
02:27:50.980 | You are shaking mine.
02:27:51.940 | And the top two words are pizza and confidence.
02:27:55.620 | Pizza somehow has a higher, has a higher probability,
02:28:00.260 | because maybe the model has never seen this song before.
02:28:03.700 | So it doesn't know that the next word is confidence.
02:28:07.060 | So maybe the model outputs these probabilities.
02:28:10.180 | But we choose the two top most, the two tokens with the highest probabilities.
02:28:17.460 | At the next time step, we make two prompts,
02:28:22.020 | one in case we choose the first one, so the first token,
02:28:25.700 | and one in case we choose the second token.
02:28:27.780 | And then we see what are the next possible choices if we use the first token.
02:28:32.500 | And what are the next choices if we use the second token?
02:28:35.540 | So we check the model output for the first prompt and for the second prompt.
02:28:39.940 | And in case we use, for example, the first prompt,
02:28:43.380 | the model will output these probabilities.
02:28:47.380 | And if we use the second prompt, the model will output these probabilities.
02:28:51.700 | What we do then is we calculate the cumulative score for each possible path.
02:28:57.460 | So for pizza, for example, the probability was 40%.
02:29:00.340 | But after pizza, the model produced the probability for the margarita,
02:29:06.100 | for example, is 0.01%.
02:29:07.700 | So for this path, pizza, margarita, it's 0.004.
02:29:12.660 | The probability is 0.4%.
02:29:14.580 | Pizza, anchovies, it's going to be 0.2% or 0.002.
02:29:21.860 | However, with confidence, we get a new next token that can be either daily or monthly.
02:29:29.380 | With daily, we get a cumulative score of 0.16 and with monthly of 0.02.
02:29:35.620 | So as we can see, at the time step two,
02:29:38.260 | even if at the time step one, pizza was the most probable word,
02:29:43.060 | because we kept the second choice alive,
02:29:45.860 | so we didn't kill it, just like we did it with greedy.
02:29:48.820 | Let me use the laser.
02:29:51.300 | We can see that the confidence then produces a next token that is very probable,
02:29:57.380 | because now the model has more prompt.
02:30:00.260 | And so it can come up with more specific choices for the next tokens
02:30:05.220 | with a very high confidence.
02:30:07.780 | So we compute the cumulative score of all these paths,
02:30:14.180 | and we keep the two paths that have the top choices.
02:30:16.820 | So now the pizza path has been killed,
02:30:19.700 | because it's later we chose pizza at the beginning,
02:30:22.420 | because somehow the model thought it was pizza,
02:30:24.420 | but then it couldn't find it.
02:30:25.860 | The model was not so confident about the next words.
02:30:28.420 | But in the case of this token here,
02:30:31.300 | the model was very confident about the second score.
02:30:33.540 | So we killed all this path here, and we kept this one
02:30:36.980 | until we arrived to the last token,
02:30:38.980 | in which we just selected the path with the highest score,
02:30:42.180 | and that's the output of our inferencing strategy with BeamSearch.
02:30:45.860 | And repeat the steps of the last slide for all the successive tokens
02:30:52.180 | until we arrive to the last one.
02:30:53.620 | And with BeamSearch, at every step we keep alive the top k paths,
02:31:00.980 | and all the others are killed.
02:31:02.420 | It increases inferencing time,
02:31:04.820 | because at every step we must explore k possible options,
02:31:08.100 | but generally it performs better than the greedy strategy,
02:31:11.060 | for the reason that I have just shown.
02:31:12.580 | Another thing that is interesting in inferencing is the temperature,
02:31:18.180 | because the idea of the temperature is that
02:31:20.740 | we can make the model more confident or less confident.
02:31:26.100 | So for example, when we compute the logits,
02:31:29.540 | which are not the probabilities,
02:31:31.620 | so they are what will become the probabilities after we apply the Softmax.
02:31:36.100 | So before we apply the Softmax, we can scale the logits,
02:31:39.860 | so that if we use, for example, like this.
02:31:43.860 | So for example, we have these logits here.
02:31:45.700 | I choose the negative numbers,
02:31:48.020 | so the Softmax probabilities are reasonable numbers.
02:31:50.580 | And so these are the logits.
02:31:53.700 | And if we divide these logits before applying the Softmax
02:31:58.260 | by a number that is low, so low temperature,
02:32:01.060 | it's called, this number is called the temperature,
02:32:03.220 | it will make the model more confident,
02:32:06.100 | because it will make bigger probabilities bigger
02:32:09.060 | and smaller probabilities smaller.
02:32:10.900 | So the gap between the low and high probability increases.
02:32:15.220 | So for example, you can see that without applying any temperature,
02:32:18.900 | the highest logit gets 80% probability.
02:32:22.980 | But applying a 0.4 temperature,
02:32:25.380 | the highest logit becomes 98% probability.
02:32:29.300 | And if we apply a high temperature,
02:32:31.140 | it makes the model less confident.
02:32:32.980 | So the gap between the low and high probability reduces.
02:32:35.860 | The temperature is important if we want to increase
02:32:40.740 | the confidence of the model or not,
02:32:42.820 | because it can be used in conjunction with other strategies,
02:32:46.260 | like, for example, the greedy,
02:32:47.780 | or the top k or the top v that we will see later.
02:32:50.580 | Another strategy is the random sampling.
02:32:55.620 | So as we saw, the logits are not a probability distribution,
02:32:59.060 | but after we apply the softmax, they become a distribution.
02:33:02.100 | So what we do, because it's a distribution,
02:33:05.140 | we can also sample from this distribution.
02:33:07.700 | For example, in this distribution here,
02:33:10.260 | that comes from these logits here,
02:33:12.580 | we have one token that can be chosen with a 12% probability,
02:33:16.020 | one can be chosen with 7% probability,
02:33:18.420 | and one that can be chosen with 80% probability.
02:33:21.060 | If we flip a coin, by 80% of the time,
02:33:23.460 | we will choose this token,
02:33:24.980 | and 12% of the time we will choose this token,
02:33:27.380 | and 7% of the time we will choose this token.
02:33:29.460 | So this means sample from this distribution.
02:33:32.500 | It means take a number from this distribution,
02:33:36.100 | according to its weight, to its probability.
02:33:39.620 | Now, there is a problem with this sampling strategy here,
02:33:44.820 | that with very little probability,
02:33:47.380 | it may happen that we choose tokens that are total crap.
02:33:51.220 | For example, in this scenario here,
02:33:56.020 | for example, before, with the greedy strategy,
02:33:59.380 | or with Bream search, for example,
02:34:00.740 | this token here, if we use a random sampling,
02:34:03.700 | we will choose the word pizza with 40% probability,
02:34:07.220 | the word confidence with 20% probability,
02:34:09.940 | but with a very little probability,
02:34:11.620 | it may happen that we will choose the word Pokemon
02:34:14.180 | with 10% probability.
02:34:15.540 | Of course, the probability is low,
02:34:18.100 | so the probability of us making a bad choice is low,
02:34:21.540 | but there is this probability.
02:34:24.020 | So this is a problem with random sampling.
02:34:26.020 | The next strategy is TopKey.
02:34:29.540 | In TopKey, what we do is,
02:34:31.780 | to avoid selecting the crappy tokens,
02:34:34.420 | we just remove them.
02:34:35.700 | So we take all the logits,
02:34:38.260 | we sort them, and then we just keep the highest key,
02:34:42.180 | so that the crappy one,
02:34:43.620 | we just remove them from this distribution,
02:34:48.260 | and then we calculate the distribution for the rest.
02:34:51.220 | So we apply the softmax only to the ones that survive.
02:34:55.060 | The problem is also here.
02:34:58.420 | Given the following, these two distributions,
02:35:00.740 | the low probability tokens can still make their way
02:35:04.020 | into the TopKey,
02:35:05.220 | because it all depends on the distribution
02:35:07.780 | to which we apply the TopKey.
02:35:09.140 | Let me give you a graphical example.
02:35:11.380 | Imagine we have a distribution that is very flat.
02:35:15.940 | Suppose this distribution here,
02:35:20.420 | so some words, this is our vocabulary.
02:35:25.540 | This is the probability of each word.
02:35:28.500 | So the word number one, word number two,
02:35:30.260 | word number three, word number four,
02:35:31.620 | et cetera, et cetera, et cetera.
02:35:32.740 | But more or less, all the words have the same probability.
02:35:36.740 | So imagine we take the top 10 words,
02:35:39.220 | so it will select all these tokens, right?
02:35:41.860 | Okay.
02:35:43.380 | So it will select the token number one,
02:35:44.820 | token number two, token number three,
02:35:46.020 | token number four,
02:35:46.980 | up to whatever token here is here.
02:35:49.940 | Imagine we have another distribution
02:35:51.700 | that is made like this.
02:35:52.900 | So we still have a vocabulary.
02:35:54.420 | Vocabulary.
02:35:56.500 | We still have a probability distribution.
02:35:58.340 | And the distribution is made like this.
02:36:02.500 | So because it's sorted,
02:36:03.620 | we have a distribution that is very skewed.
02:36:09.060 | Because we still keep the top 10,
02:36:12.500 | as you can see, we will select this token,
02:36:15.060 | this token, this token, this token, this token.
02:36:17.140 | But these tokens here are very crappy
02:36:19.620 | compared to this one here.
02:36:21.780 | So they will still make their way into our selection.
02:36:25.460 | And this is not something that we want.
02:36:27.300 | We want to avoid selecting crappy tokens,
02:36:29.620 | but we still want to have some randomness.
02:36:31.780 | So we don't want to be totally greedy
02:36:33.380 | because sometimes the tokens that are in the top N tokens,
02:36:38.980 | maybe they are all reasonable.
02:36:41.460 | Also, sometimes the prompt can be quite ambiguous.
02:36:45.060 | So we don't know which, even as humans,
02:36:47.140 | we may not know what is the next word to be chosen.
02:36:49.540 | So we want some randomness,
02:36:51.620 | but we also don't want the very low probability tokens.
02:36:54.980 | But with this top case strategy,
02:36:56.820 | the low probability tokens
02:36:59.140 | can still make their way into our selection.
02:37:01.140 | And this problem is solved with top P.
02:37:04.580 | With top P, we only keep the tokens
02:37:07.940 | with the highest probability,
02:37:09.620 | such that the cumulative probability
02:37:12.020 | is greater than or equal to the parameter P.
02:37:15.540 | What does this mean?
02:37:16.660 | It means that if we have the previous distributions,
02:37:21.060 | so one that is quite flat, for example,
02:37:24.020 | and the one that has a mode,
02:37:28.900 | so for example, this one.
02:37:30.340 | So this one is nearly 90% and the other one are 0.000%,
02:37:36.580 | but this more or less all of them are like 0.2%
02:37:40.580 | and then they go down.
02:37:43.940 | In the case, imagine P is equal to, let's say, 0.5.
02:37:48.340 | In this case, we will select all the tokens
02:37:51.540 | such that the area under the curve is equal to 0.5.
02:37:57.140 | But here, because this first token is already 0.9,
02:38:01.540 | we will actually only select one token
02:38:03.700 | and all the crappy ones will not be selected
02:38:05.700 | because this area under the curve is already 0.9.
02:38:09.060 | And this is the idea behind the top P.
02:38:12.340 | So when the model, when the distribution is more flat,
02:38:17.220 | we select more tokens
02:38:18.580 | because it means that we are more uncertain
02:38:21.460 | about which token to choose.
02:38:22.740 | But when we have a big mode, we select fewer tokens
02:38:27.300 | and this way we avoid getting the low probability ones.
02:38:33.540 | So now that we reviewed all the strategies
02:38:36.180 | for selecting the token, we will implement it.
02:38:39.860 | And in the case of Lama, also in the official code,
02:38:43.140 | they actually implement the top P strategy.
02:38:45.540 | In my case, I think that the BeamSearch is a reasonable choice.
02:38:51.540 | So in another video, maybe I will make
02:38:54.180 | how to implement the BeamSearch.
02:38:55.940 | But for now, I will implement the top P.
02:38:57.620 | So let's go build it.
02:38:59.860 | So we implement the method, let's call it TextCompilation,
02:39:03.620 | which is the same name that's used
02:39:05.860 | in the original code from Lama.
02:39:07.940 | Given prompts a temperature that is 0.6.
02:39:18.900 | And so 0.6 means that we want to make the model more confident.
02:39:25.060 | Top P means that we want all the tokens
02:39:33.300 | such that their cumulative probability is at least 0.9.
02:39:38.020 | So 90%.
02:39:38.500 | Okay, I think here should be lowercase.
02:39:56.020 | Okay, so if we didn't specify the max generation length,
02:40:01.940 | then we just generate the maximum token.
02:40:04.980 | Args.
02:40:08.040 | Just generate as much token as we can up to the sequence length.
02:40:14.100 | And then we, first of all, convert each token of the prompt.
02:40:20.980 | So each prompt, actually, into tokens using the tokenizer.
02:40:28.740 | Mm-hmm.
02:40:29.240 | Then, as we saw before, we need to add the beginning of sentence
02:40:47.060 | when we pass the input to the model for inferencing.
02:40:56.260 | Mm-hmm.
02:40:56.760 | But not the end of sentence.
02:41:00.020 | Okay.
02:41:09.240 | Because we specified the max batch also for the model when we built it for the KVCache,
02:41:20.020 | so we need to make sure that the batch size of the prompts is not too large.
02:41:23.860 | Mm-hmm.
02:41:27.860 | And then max prompt length
02:41:42.340 | is the maximum prompt length that we have in the prompt.
02:41:47.700 | Mm-hmm.
02:42:16.020 | I'm not writing any message, even if you should, but okay,
02:42:19.220 | for us it's just basically debugging.
02:42:21.380 | Then the total length
02:42:24.980 | is how many tokens we want to get from the model.
02:42:40.420 | Mm-hmm.
02:42:45.620 | Okay, now we create the list that will contain the generated token.
02:43:31.700 | this means create a tensor of shape batch size by total length,
02:43:38.260 | in which each item is actually the padding token.
02:43:40.820 | And then we fill the initial tokens with the prompt tokens.
02:43:57.380 | Mm-hmm.
02:44:09.060 | Okay, we also need this variable that tells if we reach the end of sentence
02:44:27.460 | in any of the prompts.
02:44:38.980 | This indicates if the token in this position is a padding token or not, so true.
02:44:56.100 | If the token is a prompt token, false, otherwise.
02:45:00.660 | And then we can finally make the for loop to generate the tokens.
02:45:12.900 | DamageRTQDM, range from one.
02:45:18.580 | Okay, now we generate one token at a time.
02:45:32.580 | the logits come from the model, so set.model.forward.
02:45:42.180 | We need to pass one token at a time.
02:45:44.100 | So, which token?
02:45:46.500 | The one currently we want to output.
02:45:49.140 | So, current minus one.
02:45:52.580 | Pause.
02:45:54.980 | So, only one token.
02:45:56.100 | And we also tell the model what is the position of this token, because for the KVCache.
02:46:00.900 | And if we use the temperature, we apply it.
02:46:06.500 | As you can see, every time when we want to inference, we always select the last token.
02:46:31.460 | But because we are using the KVCache, actually our model will only output one token at a time.
02:46:35.780 | So, the next token will be selected according to our topP strategy.
02:46:48.820 | So, we have the probabilities.
02:46:50.180 | Now we apply the topP.
02:46:52.580 | I just define it here.
02:46:54.580 | So, sample topP, and then we implement it.
02:46:58.420 | If we didn't specify any temperature, we just use the greedy.
02:47:25.380 | okay.
02:47:41.620 | Now we have the next token according to this strategy or this greedy.
02:47:47.060 | Then we only replace the token if it is a padding token.
02:47:54.260 | So, the problem is, we already have some tokens that come from the prompt.
02:47:58.500 | But we don't want to, but we still need to give the prompt to the model.
02:48:03.060 | But we are only giving one token at a time to the model to build the initial cache.
02:48:07.060 | So, we will give, the first prompt tokens will be given to the model,
02:48:11.060 | not because we care about what the model will output for those tokens.
02:48:14.900 | But only because we want the KV cache to be built for those positions.
02:48:21.780 | And after we give the last token of the prompt,
02:48:25.140 | then we care about what is the model outputting.
02:48:28.340 | So, only replace the next token if it is a padding token.
02:48:33.300 | And which one is a padding token?
02:48:34.820 | The one that was not an initial prompt token.
02:48:38.020 | Because here we build tokens full of paddings.
02:48:41.060 | But then, we replace the prompt tokens, the padding tokens with the prompt tokens
02:48:47.860 | for the one with the initial tokens.
02:48:52.900 | All the others have to be inferred by the model.
02:48:55.220 | So, token...
02:49:09.060 | This means, basically, check this mask.
02:49:23.780 | What is this mask?
02:49:24.740 | If it's true, if the token is a prompt token.
02:49:26.740 | So, if it is a prompt token, replace it with this one.
02:49:30.740 | And if it's not a prompt token, just keep it the current one.
02:49:44.020 | Okay, then...
02:49:53.220 | Since we do not care about what the model outputs for the initial prompt tokens,
02:50:02.020 | but only for the last prompt token,
02:50:04.180 | we don't care if we find an end-of-sentence position for those tokens.
02:50:09.380 | So, end-of-sentence is only reached if we find it for one of the tokens
02:50:13.220 | that we actually want to inference,
02:50:14.820 | not the one that we send to the model just to build a KV cache.
02:50:40.820 | okay,
02:50:45.620 | this basically means the end-of-sentence for a particular prompt is reached
02:51:06.900 | only if it was a padding token.
02:51:10.180 | So, only if it was a padding token.
02:51:11.860 | So, it was not a prompt token.
02:51:14.180 | This means not.
02:51:15.140 | And we actually found an end-of-sentence token from the model output.
02:51:22.580 | If all of the prompts have reached the end-of-sentence token,
02:51:26.900 | then we stop this for loop.
02:51:28.180 | We don't need to inference anymore.
02:51:29.540 | Now, we prepare the output.
02:52:21.380 | this means that if we found an end-of-sentence token for one of the prompts,
02:52:36.820 | we just cut the prompt output there.
02:52:39.380 | The model output at that particular token.
02:52:42.100 | We don't care about what it outputs next.
02:52:48.660 | This is the output text and then we output the tokens and the text.
02:53:05.540 | Hopefully, I didn't make too many typos and mistakes.
02:53:13.540 | So, now we need to build the sample_top_p.
02:53:17.700 | So, we have the logits that are the output of the model.
02:53:20.820 | We transform them into probabilities by using the softmax.
02:53:24.180 | But given these probabilities, we need to use the sample_top_p strategy
02:53:28.260 | to select all the tokens such that their cumulative probability is equal to top_p,
02:53:34.900 | which in our case is 0.9, so 90 percent.
02:53:38.500 | Okay, the first thing we do is we sort these probabilities in descending order.
02:54:04.100 | we then calculate the cumulative sum.
02:54:22.820 | Then we create the mask that says which tokens we want to keep
02:54:26.740 | and which one we don't want to keep.
02:54:28.260 | So, mask is equal to probability_sum minus probability_sort more than p.
02:54:36.420 | Why do we do a minus probability_sort?
02:54:39.300 | Because we want to shift.
02:54:40.740 | Let me show you on the slides here.
02:54:46.740 | You can see here, for example, the cumulative probability.
02:54:52.500 | So, the probabilities are this one, 44 percent, 40 percent, 6 percent, 4 percent, and 3 percent.
02:54:59.780 | Then we calculated the cumulative.
02:55:01.300 | That means up to here it's 44 percent.
02:55:05.140 | Then this one plus this one is 85 percent.
02:55:07.540 | Then this one plus this one plus this one is 91 percent.
02:55:10.500 | This one plus this one plus this one plus this one is 96 percent, etc, etc.
02:55:15.300 | But imagine we have a 0.90 percent probability or 0.5 percent probability.
02:55:24.660 | We need to keep up to this token here because this one is not enough.
02:55:31.140 | It's zero point.
02:55:31.780 | So, we need to up to this one.
02:55:33.300 | So, the first number that is less than or equal to p.
02:55:36.820 | And it's this in case it's this one.
02:55:39.700 | So, that's why we shift it.
02:55:40.820 | We want also this token inclusive.
02:55:42.900 | And this is why we do this minus probability sort.
02:55:48.900 | So, all the ones that we didn't select, we zero them out.
02:55:53.860 | Zero.
02:55:58.900 | And then we redistribute the probabilities because, of course,
02:56:04.980 | if we remove some items from here, they don't sum up to one anymore.
02:56:08.820 | So, we need to redistribute the probabilities.
02:56:10.900 | And this is very easy.
02:56:35.300 | Okay, then the next token is basically, suppose we keep the first two tokens.
02:56:41.700 | And then what we do is we want to sample from them.
02:56:44.500 | So, the first token is 0.44 percent probability.
02:56:47.460 | The second token is 0.40 percent probability.
02:56:50.420 | But after we redistribute their probabilities, actually,
02:56:53.380 | this one will be a little higher.
02:56:54.740 | And this one will be a little higher than 40 percent.
02:56:56.900 | And then we sample.
02:57:00.820 | It means that the first token will have a slightly better chances of being chosen.
02:57:04.660 | And the second token will have slightly less chance of being selected.
02:57:07.860 | And we want one sample because we want one token.
02:57:21.540 | And it's not next token.
02:57:24.500 | And then next token.
02:57:26.420 | Because this indicates which index to select,
02:57:31.140 | then we need to map that index to the actual number in the vocabulary.
02:57:36.500 | But because we already changed the order of these numbers,
02:57:40.180 | because we sorted it.
02:57:41.540 | So, initially, the logits were built in such a way
02:57:45.700 | that the first logit corresponded to the first number of the vocabulary.
02:57:49.220 | The second logit corresponded to the second number of the vocabulary.
02:57:52.900 | But because we sorted it by descending order, this order has been gone.
02:57:56.820 | So, we don't know now, just given the token selected,
02:58:00.420 | we don't know which number it maps back into the vocabulary.
02:58:04.580 | That's why the sort method returns two arguments.
02:58:07.380 | One is the sorted numbers and one is the indexes that it changed.
02:58:11.140 | So, it will tell you for each position what was the original item in that position.
02:58:16.500 | So, this is why we actually query using gather.
02:58:19.620 | Gather allows us to retrieve from an element what was the original one,
02:58:29.380 | given this index here.
02:58:30.740 | And then we return the next token.
02:58:33.380 | And this will map back into the vocabulary directly.
02:58:39.700 | And this should be it.
02:58:43.460 | So, now let's create some prompts and let's run the code.
02:58:47.540 | I have some prompts here that I copied and pasted.
02:58:50.580 | So, now let's build the inference code.
02:58:54.660 | So, out_tokens, out_text, we want to generate maximum 64 tokens.
02:59:09.460 | We assert that the len of the output text is actually equal equal to len of prompts.
02:59:17.540 | It should be.
02:59:18.900 | So, for i in range, hopefully the model will work.
02:59:26.660 | And then we print the output text for each prompt.
02:59:42.900 | So, let's run the code and let's hope for the best.
02:59:46.580 | Okay, self-attention is missing the required forward function.
02:59:52.260 | Let's see why.
02:59:58.820 | Oops, it's forward.
03:00:04.820 | It should be forward.
03:00:06.340 | Let's run again.
03:00:11.700 | Sum_received, this is wrong because it should be dimension, not div, but should be dim.
03:00:23.460 | Let's run again.
03:00:26.660 | For bfloat, 16.
03:00:34.020 | So, let's see why.
03:00:41.140 | eos token, let me check.
03:00:45.300 | Okay, now it's training.
03:00:59.940 | I just changed this tensor from capital T to small t.
03:01:04.260 | I will investigate why.
03:01:09.860 | Wow, we have an output.
03:01:11.140 | So, let's check.
03:01:12.420 | First of all, let's check the prompt.
03:01:13.780 | Simply put, the theory of relativity states that time is relative to the observer.
03:01:22.020 | Mass is relative to the observer.
03:01:23.620 | Speed is relative to the observer.
03:01:25.220 | Energy is relative to the observer.
03:01:27.380 | So, it looks like it's not bad.
03:01:29.140 | Suppose the second prompt says if Google was an Italian company founded in Milan,
03:01:35.940 | it would be listed on the Milan Stock Exchange,
03:01:38.500 | as the Milan Stock Exchange is the largest in Italy.
03:01:41.060 | But since Google is a US company, it is listed on the Nasdaq Stock Exchange.
03:01:44.740 | So, it avoided actually answering the question.
03:01:46.740 | Let's try the few-shot prompt.
03:01:48.820 | So, this is how you copy it actually from the LAMA code.
03:01:52.180 | So, they ask to translate from English to French.
03:01:54.420 | And after cheese, we expect to find fromage, onion, etc.
03:02:01.220 | So, it looks correct.
03:02:02.660 | And we can also see that the spaces have been kept.
03:02:06.420 | So, these spaces were not added by me, but actually by the model.
03:02:09.700 | So, it keeps the output aligned with what was the prompt.
03:02:12.900 | And then I created a zero-shot prompt.
03:02:16.340 | So, tell me if the following person is actually Doraemon disguised as human.
03:02:20.100 | So, the name is Umar Jameel, and the decision is
03:02:22.180 | he's a hero in every sense of the word.
03:02:25.300 | He's a hero in every sense of the word.
03:02:26.900 | I'm very happy, LAMA.
03:02:29.140 | Actually, okay, this is the output of the model.
03:02:32.420 | With manual seed zero.
03:02:33.700 | If I think I change the seed to some other number and run the model again,
03:02:38.500 | the output will be totally different or maybe slightly different.
03:02:42.260 | I hope not, but it may be different.
03:02:44.420 | Anyway, thanks for watching my video, guys.
03:02:46.740 | I tried to convey the idea of what is the architecture inside LAMA.
03:02:54.420 | And even if I didn't build the training code,
03:02:56.980 | because actually to build the training code is rather complicated,
03:03:00.420 | we need a big corpus of text, we need to tokenize it,
03:03:05.220 | and it's going to take a long time.
03:03:06.900 | But I hope to make another video in the future on how to train a language model,
03:03:13.460 | maybe with a smaller dataset and with a lighter architecture.
03:03:16.420 | And I tried to convey all the math behind all the choices,
03:03:21.300 | and also how the inner workings of the KV cache and the grouped query attention.
03:03:28.980 | If you have any questions, please write in the comments.
03:03:31.380 | I also will share the repository with the code that I have previously built for this,
03:03:36.500 | and which has much more comments than the one I have written here.
03:03:40.500 | It's much more in detail,
03:03:42.660 | so everyone can understand step by step all the dimensions involved.
03:03:46.420 | Here I tried to write the most important dimensions,
03:03:49.300 | but because of time, I didn't write all of them.
03:03:51.380 | So thank you again, guys, for watching.
03:03:54.580 | It was a long journey, but I can assure you that you learned a lot.
03:03:58.260 | Hopefully.
03:03:59.060 | And I hope you will visit again my channel for more videos about deep learning,
03:04:03.620 | about PyTorch, about coding, and about everything that we love in AI.
03:04:08.580 | Thank you for watching, guys.