back to indexFlash Attention derived and coded from first principles with Triton (Python)
Chapters
0:0 Introduction
3:10 Multi-Head Attention
9:6 Why Flash Attention
12:50 Safe Softmax
27:3 Online Softmax
39:44 Online Softmax (Proof)
47:26 Block Matrix Multiplication
88:38 Flash Attention forward (by hand)
104:1 Flash Attention forward (paper)
110:53 Intro to CUDA with examples
146:28 Tensor Layouts
160:48 Intro to Triton with examples
174:26 Flash Attention forward (coding)
262:11 LogSumExp trick in Flash Attention 2
272:53 Derivatives, gradients, Jacobians
285:54 Autograd
300:0 Jacobian of the MatMul operation
316:14 Jacobian through the Softmax
347:33 Flash Attention backwards (paper)
373:11 Flash Attention backwards (coding)
441:10 Triton Autotuning
443:29 Triton tricks: software pipelining
453:38 Running the code
00:00:00.400 |
Hello guys, welcome back to my channel. Today we are going to explore FlashAttention. 00:00:04.640 |
Now, we are going to explore FlashAttention from first principle which means that not only 00:00:09.600 |
we will code FlashAttention, we will actually derive it. So we pretend that the paper, FlashAttention 00:00:14.640 |
paper, never existed and we look at the attention computation and we look at the problem it has and 00:00:20.240 |
we try to solve it step by step pretending that FlashAttention never existed. This will give us 00:00:25.120 |
a deep understanding of how it works and also we will combine theory with practice because we will 00:00:30.320 |
code it. Now, in order to code FlashAttention we will need to write a kernel for our GPU and in 00:00:37.120 |
our specific case I will be using an NVIDIA GPU so a CUDA kernel but instead of writing C++ code 00:00:42.560 |
we will use a Triton which is a way of converting Python directly into CUDA kernels that can run 00:00:50.000 |
directly on the GPU and Triton you can think of it as a compiler that takes in Python and 00:00:54.880 |
converts it into something that can run on the GPU. So let's look at the topics for today. First 00:01:00.880 |
of all I will give an introduction to multi-head attention because we need to look at what is 00:01:05.040 |
attention and how it's computed and what are the problems in computing this attention. 00:01:08.960 |
Then we will look at actually the most critical part of the attention computation is this Softmax 00:01:14.080 |
and how it impacts the computation and complexity. We will look at what is online Softmax. Then we 00:01:20.240 |
will explore what is the GPU because we are going to write a kernel that will run on the GPU so we 00:01:24.480 |
need to understand what is the difference for example the CPU and the GPU and what is a kernel 00:01:28.480 |
and how it differs from a normal program that you write for the CPU. We will look at how tensors 00:01:34.560 |
are laid out in memory so row major layout, column major layout or etc strides. We are going to look 00:01:41.760 |
at block matrix multiplication, Triton, software pipeline and all the optimization that Triton does 00:01:46.640 |
to our code. Finally we will be able to code the FlashAttention forward pass but of course we are 00:01:52.320 |
not satisfied only by coding the forward pass. We also want to code the backward pass but in order 00:01:57.040 |
to code the backward pass we also need to understand how autograd works and the gradient descent works 00:02:02.400 |
in the case of custom operations so we need to understand what are derivatives, what are 00:02:06.880 |
gradients, what are Jacobians and then we calculate the gradient of the common operations that we use 00:02:11.680 |
in FlashAttention and finally we will have enough knowledge to code the backward pass. For this 00:02:16.800 |
reason this video is going to be super long but I hope you don't mind because we are going to learn 00:02:21.360 |
a lot. Of course you may be wondering all of this requires a lot of knowledge that you may not have 00:02:27.280 |
but that's not a problem because that's my problem because in this video I will make sure that if you 00:02:31.680 |
only have high school calculus so you know what are derivatives you have basics of linear algebra 00:02:36.320 |
like you know what is matrix multiplication or what is the transpose of a matrix and you have 00:02:40.880 |
a basic knowledge of attention mechanism so like for example you have watched my previous video on 00:02:45.040 |
the attention is all you need paper and you have a lot of patience that should be enough to understand 00:02:50.880 |
all of this video because all the topics that I will introduce I will always introduce them in 00:02:55.360 |
such a way that I pretend that you don't know anything about the topic so we try to derive 00:03:00.640 |
everything from first principle everything from scratch. Okay now that we have seen the introduction 00:03:06.000 |
let's go see the first part of the video which is the multi-head attention. All right let's talk 00:03:12.160 |
about multi-head attention. Now I am using the slides from my previous video attention is all 00:03:17.040 |
you need so we can look at very fast at what multi-head attention is and how it works. I hope 00:03:23.120 |
you remember the formula softmax of the query multiplied by the transpose of the key divided by 00:03:27.440 |
dk all multiplied by b because we will be using that a lot throughout the video. Now multi-head 00:03:33.760 |
attention starts from an input sequence or two input sequence in case we are talking about cross 00:03:38.400 |
attention. In the simple case of self-attention we have one input sequence which is a sequence of in 00:03:44.240 |
the case of language model a sequence of tokens where we have sec number of tokens and each token 00:03:51.440 |
is represented by an embedding so a vector with d model dimensions. The first thing that we do 00:03:58.480 |
is we convert this input sequence into query key and values through three linear projections one 00:04:06.240 |
called wq one called wk one called wv which in pytorch are represented through linear layers 00:04:13.520 |
and these linear layers are of d model by d model so they do not change the shape of the input 00:04:19.520 |
tensor and then after we do this job of projecting them they become three different sequences one 00:04:29.200 |
called query one called key and one called value so here i'm calling them q prime k prime and v 00:04:34.960 |
prime then we divide them into smaller embeddings so each of this token which is made up of d model 00:04:43.920 |
dimensions we divide it into smaller tokens each one suppose we have four heads each one will have 00:04:51.680 |
d model divided by four dimensions so this one is a sequence of tokens where each token is not 00:04:58.400 |
the entire token but a part of the embedding of each token and this one is a another part of the 00:05:03.920 |
embedding of the tokens and this one is another part of the embedding of the token etc and we 00:05:08.720 |
do this job for the query key and value sequence then we compute the attention as follows so the 00:05:15.520 |
softmax of the query multiplied by the transpose of the key divided by the the square root of dk 00:05:20.960 |
where dk is the dimension of each head so how many dimensions each head is working with and then we 00:05:29.360 |
do the multiplication with v and this will give us the output of the attention mechanism for each 00:05:34.480 |
head and this job is done independently for each head this should be clear to you if it's not please 00:05:41.120 |
watch my previous video on the attention mechanism because we will be working with this scenario a 00:05:47.120 |
lot now then we take this o the output of each head and then we concatenate it back in order to 00:05:57.200 |
get the representation of each token as a full embedding so before we split this embedding into 00:06:05.360 |
smaller embeddings this one here is called the q1 q2 q3 q4 then after we compute the attention we 00:06:12.640 |
get back the output of each head and we concatenate it together to get back the full embedding 00:06:17.760 |
dimension which is this edge here we run it through another linear projection called wo which 00:06:23.120 |
will be the output of the multi head attention now flash attention is not concerned with all of 00:06:29.440 |
these operations actually flash attention is only concerned with the operation that require 00:06:34.240 |
optimization and the operations that require optimizations are this one so the softmax of 00:06:40.400 |
the query multiplied by the transpose of the key divided by the square root of dk multiplied by v 00:06:44.800 |
which means that the projection of the input sequence through wq wk and wv is not something 00:06:51.440 |
that flash attention is concerned about because that's a matrix multiplication so when you use a 00:06:56.960 |
linear layer it's just a matrix multiplication of the input with the weight matrix of the linear 00:07:01.760 |
layer and this kind of operation so the matrix multiplication is one of the most um optimized 00:07:08.560 |
operation that we have in the gpu because the manufacturer of the gpu usually also releases 00:07:14.480 |
um the necessary library for computing the the matrix multiplication so actually these are quite 00:07:20.400 |
fast and they do not require any optimization so flash attention will pretend that the query 00:07:25.520 |
is has already been passed through by wq and the key has already passed through wk and the v has 00:07:31.600 |
already passed from wb moreover flash attention will not be concerned with the projection with 00:07:37.840 |
wo because that's also a matrix multiplication because the wo is always represented in pytorch 00:07:42.960 |
as a linear layer so it's a matrix multiplication and matrix multiplication as we have seen are very 00:07:49.680 |
optimized so there is nothing to optimize there but what we need to optimize in terms of speed 00:07:55.840 |
is this operation here softmax of the query multiplied by the transpose of the keys divided 00:07:59.440 |
by the square root of vk multiplied by v all right guys so now we have rehearsed what is 00:08:05.440 |
multi-head attention i also want to give you a lot of visualization which is basically 00:08:10.240 |
here in the paper of the multi-head attention we can see that we have the input that is v 00:08:18.160 |
k and q so q k and v each of them runs through a linear layer which is the w q w k and w v 00:08:26.080 |
then we do the scaled dot product attention which is done independently for each head so each head 00:08:33.760 |
will do query multiplied by the transpose of the key divided by the square root of dk where each 00:08:39.040 |
query and each key is not the full embedding of each token but a part of the embedding of the 00:08:43.680 |
token because we split them into smaller embeddings and eventually we take all the output of each of 00:08:49.440 |
this head which are computed in parallel so that's why you see this dimension edge in the depth we 00:08:55.200 |
concatenate them and then we run them through w o what are we concerned with we are concerned 00:09:00.560 |
with optimizing this particular block here the scaled dot product attention so let's start our 00:09:05.840 |
journey one thing that is very important to understand is why do we even need a better 00:09:12.640 |
implementation of the attention mechanism and if you look at the flash attention paper you will 00:09:17.360 |
notice the following part this is the paper flash attention one and in the flash attention one paper 00:09:23.680 |
they describe the attention implement implementation as it's done naively when using pytorch so first 00:09:29.760 |
we do the multiplication of the query multiplied by the transpose of the keys then we apply the 00:09:35.680 |
softmax to the output of this operation and finally we multiply the output of the softmax 00:09:40.240 |
with the v matrix to obtain the output of the attention the way this implementation is done 00:09:46.160 |
by pytorch without any optimization is as follows so we load the first of all these tensors are 00:09:55.120 |
residing in the gpu the gpu is made up of two main memories one is called the hbm which is the dram 00:10:03.200 |
which is the the ram of the gpu which is the 40 gigabyte of the a100 for example so it's the 00:10:10.880 |
biggest memory that we have in the gpu and then there are there is the shared memory so the problem 00:10:19.360 |
of the gpu is that accessing this hbm so the global it's also called the global memory it's very very 00:10:25.680 |
slow compared to the shared memory however the shared memory it's much much smaller compared to 00:10:30.560 |
the hbm and what they claim in the flash attention paper is that the operation of the attention is 00:10:37.200 |
i/o bound meaning that if we keep accessing the global memory the overall operation of computing 00:10:47.120 |
the attention is not because computing all these operations it's slow but because we keep accessing 00:10:52.880 |
the global memory which is slow so we call this kind of operations i/o bound so the only way to 00:11:00.080 |
improve this situation is to compute the attention inside the shared memory of the gpu which is much 00:11:06.480 |
smaller which is much closer to the cores that actually do the computation so we will need to 00:11:13.120 |
kind of also split the attention computation into smaller blocks that can reside in the shared 00:11:19.600 |
memory and we will see later in how this is possible through block matrix multiplication 00:11:25.760 |
and this is in the paper here they call it the tiling and it's a very how to say use the technique 00:11:32.640 |
when doing when writing kernels for the gpu which are usually involve some kind of matrix multiplication 00:11:40.640 |
so now we know what problem the flash attention is trying to solve it's trying to make sure that 00:11:47.120 |
we do not need to access the hbm so the high bandwidth memory when computing the attention 00:11:53.680 |
but copying only a part of each matrix inside the local memory so the shared memory of the gpu that 00:12:01.440 |
is closer to the cores and computing a part of the output matrix there then copying that part 00:12:09.120 |
to the output in that is residing in the hbm and keep doing it for all the blocks in which we can 00:12:15.120 |
divide this query key and value matrices and later we will see how this blocked computation is done 00:12:21.920 |
but also we will see that the biggest problem in computing this block computation is the softmax 00:12:27.280 |
because the softmax needs to access all the row of the s matrix to apply the softmax because 00:12:34.720 |
the the softmax needs to have a normalization factor which is the sum of all the exponentials 00:12:42.400 |
of all the values to which it is applied row wise and we will see later how we will solve this 00:12:47.680 |
problem so let's move on all right guys um okay when i say guys i mean guys and girls because i 00:12:55.840 |
don't know in my usually i just say guys too you know but please girls don't feel excluded so we 00:13:03.280 |
saw that first of all flash attention is only concerned in optimizing this softmax of the 00:13:08.960 |
transpose of softmax of the query multiplied by three divided by the square root of dk multiplied 00:13:14.480 |
by b and we need to introduce a little bit of notation so that we don't get lost in the future 00:13:20.480 |
slides first of all this is the formulas i took from the flash attention paper but for now we 00:13:25.920 |
let's pretend flash attention never existed so we are trying to solve the problem step by step 00:13:30.080 |
now um we should treat this q as something that has as the sequence that is the output 00:13:37.760 |
of the input sequence that has already passed through wq the k as something that has already 00:13:43.280 |
passed through wk and v as something that has already passed through wv because we don't want 00:13:49.040 |
to optimize the matrix multiplication because it's already fast enough another thing is let's 00:13:55.040 |
talk about what are the dimensions of these matrices so we can then understand what is the 00:13:59.520 |
the dimensions of the output of this operation so we will see treat q as a sequence of tokens 00:14:07.280 |
with n tokens so n tokens where each token is d has d dimensions so lowercase d dimensions 00:14:16.160 |
why because usually we take the queries and then we split them into multiple heads so we have we 00:14:22.480 |
pretend we have already done this splitting so we pretend we are ready to cover input sequence we 00:14:27.040 |
already run it through wq and then we have already split it into multiple heads and each of this head 00:14:33.280 |
will do the following operation so the the one we already saw and so the usual formula query 00:14:40.400 |
multiply the transpose of the keys and each of this head will work with these dimensions 00:14:45.920 |
for the query for the key and for the value sequence so now let's look at the the dimensions 00:14:52.560 |
of the output so the first operation that we will do is the query multiplied by the transpose of the 00:14:56.800 |
keys where the transpose of the keys is a matrix that originally is n by d but become but with the 00:15:03.040 |
transpose will be d by n so d by n and the result will be a matrix that is n by n because in the 00:15:10.480 |
matrix multiplication the outer dimensions become the dimension of the output matrix 00:15:14.480 |
what do what is the next operation that we do we take the output of this operation so the query 00:15:20.400 |
multiply by transpose of the keys and we run it through a softmax operation and we will see what 00:15:24.320 |
is the softmax operation which preserves the shape of the input so it doesn't change the shape of the 00:15:30.720 |
input matrix it just changes the values of it and then we take the output of the softmax and 00:15:36.480 |
we multiply it by v which will change the which will change the of course the shape because the 00:15:43.840 |
p matrix is n by n so this one is n by n and v is n by d so this one the output will be n by d the 00:15:54.800 |
outer dimensions of this matrix multiplication now let's look at the details of each of these 00:16:00.080 |
operations so when we do query multiply by transpose of the keys we will get a matrix that 00:16:04.560 |
is n by n where each value in this matrix is a dot product of a row of q and a column of k 00:16:13.840 |
in particular the first element of this matrix will be the dot product of the first query with 00:16:19.840 |
the first key vector the second element will be the dot product of the first query with the second 00:16:26.080 |
key vector and the third element will be the first query with the third key etc etc and the 00:16:31.920 |
let's say the the last row of this matrix will be the dot product of the last query with the first 00:16:39.680 |
key then the last query with the second key the last query with the third key etc etc until the 00:16:44.960 |
last query with the last key you may also notice that here i have written query transpose the key 00:16:53.440 |
because when we what is q1 first of all q1 is the first row of the query matrix 00:17:00.640 |
so a little bit of background on matrix multiplication so we know that when we do 00:17:06.800 |
matrix multiplication each output element is one row of the first matrix with one column of the 00:17:12.560 |
second matrix but we are doing the product of the first matrix with the transpose of the second so 00:17:18.480 |
it will be the dot product of the one row of the query matrix with one row of the key matrix 00:17:24.960 |
because we are doing the multiplication with key k transposed when you take a vector from a matrix 00:17:34.240 |
the usual notation so the in in as in how to say in in in mathematics in a linear algebra 00:17:42.240 |
we always pretend that a vector is a column vector so we cannot just write q multiplied 00:17:48.240 |
by k because that would be mean that would mean we are doing the dot product of we are doing the 00:17:55.920 |
kind of the matrix multiplication of one column matrix with one column matrix that is not possible 00:18:02.480 |
because the shapes do not match so as a notation we write that we do the dot product of the first 00:18:08.160 |
matrix the transpose which is a column vector but we transpose it so it becomes a row vector 00:18:13.520 |
with the second vector this is just because of notation guys so you just need to pretend that 00:18:20.320 |
this is the first query with the first key then the first query with the second key the first 00:18:24.640 |
query with the third key etc etc etc so we are doing dot products of vectors then we apply this 00:18:31.840 |
softmax operation the softmax operation what it will do it will transform each of these dot products 00:18:38.960 |
which are scalars so the output of a dot product is a scalar and it will transform each of these 00:18:45.280 |
numbers in such a way that they become kind of a probability distribution row wise which means that 00:18:52.000 |
each of these numbers is between 0 and 1 and when we sum up these numbers together they are sum up 00:18:58.880 |
to 1 and this condition this property will be valid for each row so this row also will sum up to 1 00:19:05.760 |
this row will sum up to 1 and this row will sum up to 1 etc etc etc let's see what is the softmax 00:19:12.720 |
operation now given a vector so let's call it x which is made up of n dimensions the softmax is 00:19:21.760 |
defined as follows so it is the the softmax basically transforms this transforms this vector 00:19:29.680 |
into another vector with the same dimension where each item of the output vector is calculated as 00:19:35.280 |
follows so the height element of the output vector is the exponential of the element input element 00:19:42.400 |
divided by the summation of all the exponentials of all the dimensions of the vector 00:19:48.240 |
basically this is called the normalization factor to make it all these numbers between 0 and 1 we 00:19:55.600 |
usually normalize that's why it's called the normalization factor and we use the softmax 00:20:01.520 |
because we want each of these numbers to be positive we don't want the stuff the output 00:20:06.480 |
of this operation to be negative so that's why we use the exponential but there is a problem 00:20:11.840 |
the problem is imagine our input vector is made up of many numbers that are maybe large so for 00:20:18.000 |
example let's say x1 is equal to 100 x2 is equal to 200 x3 is equal to 300 which is can happen 00:20:26.160 |
if we do the exponential of these numbers so the exponential of 100 that is going to be a 00:20:31.600 |
huge number it's going to very close to infinity at least compared to what we can store in a 00:20:37.120 |
computer so the output of exponential of 100 may not fit into a floating point 32 or a floating 00:20:44.240 |
point 16 number or even an integer of 32 bit so we cannot compute it because it will overflow our 00:20:53.440 |
our variable our integer that is storing this value this output so we talk in this case about 00:21:00.080 |
numerical instability so every time you hear the term numerical instability in computer science 00:21:05.840 |
it means that the number cannot be represented within a fixed representation with the bits we 00:21:11.520 |
have available which are usually 32 bit or 16 bit we have also 64 bit but that will be too expensive 00:21:18.720 |
to use so let's try to find a solution to make this stuff here computable and numerically stable 00:21:25.440 |
in order to make this softmax operation numerically stable which means that we want 00:21:31.600 |
these numbers to not explode or to become too small that they are not representable 00:21:36.880 |
we need to find a solution and luckily it's quite easy so the softmax as we have seen before it is 00:21:42.960 |
the following formula so each number is exponentiated and then we divide it by this 00:21:47.040 |
normalization factor which is just the sum of the exponential of each input dimension of the 00:21:51.600 |
input vector if we multiply the numerator and the denominator of a fraction with a constant 00:21:58.000 |
with a number then the fraction will not change so that's what we are going to do we are multiplying 00:22:01.760 |
the numerator and the denominator with this factor c as long as c is not equal to zero of course 00:22:08.080 |
then we can take this c and by using the distributive property of the product with 00:22:16.320 |
respect to the sum we can bring this c inside of the summation as you can see here 00:22:20.320 |
then we can also write every number as the exponential of the log of itself because the 00:22:27.520 |
exponential and the log will cancel out and then we can by using the properties of the exponentials 00:22:35.520 |
we know that the product of two exponential is equal to the sum of the is equal to the 00:22:40.400 |
exponential of the sum of the arguments of each exponential and we do it on the numerator and in 00:22:46.000 |
the denominator then we just call this quantity minus log c equal to k or k is equal to minus k 00:22:54.960 |
is equal to log c so we can replace this quantity with k we can do that because this is a constant 00:23:01.600 |
that we have chosen and we just are assigning it to another constant so basically by doing 00:23:08.480 |
this derivation we can see that we can sneak in a value inside of this exponential that if chosen 00:23:15.440 |
carefully can reduce the argument of this exponential and we will choose this k equal 00:23:21.360 |
to the maximum element inside of the input vector that we are applying the softmax to 00:23:26.480 |
so that each of this argument will be either zero in case xi is equal to the maximum element 00:23:33.840 |
that we are processing of the vector or it will be less than zero and we know that the exponential 00:23:40.320 |
when it's equal to zero will be equal to the output of the exponential will be one so the 00:23:45.200 |
argument when it's zero it will be equal to one and when it's smaller than zero so it's in the 00:23:50.160 |
negative range it will be between zero and one so which is easily representable with floating point 00:23:56.080 |
32 for example so this exponential will not explode anymore so basically to apply the softmax 00:24:04.400 |
to a vector in a numerically safe way we need to find a k constant which is the maximum value of 00:24:12.800 |
this vector and when we apply it we need to subtract each element minus this constant that 00:24:18.400 |
we have chosen so let's look at the algorithm to compute the softmax so first of all given a vector 00:24:25.760 |
or given an n by n matrix because we want to apply the softmax to this matrix here which is n by n 00:24:32.400 |
we need to go through each row of this matrix and for each row we need to find the maximum value 00:24:38.800 |
among the elements which takes time complexity linear with respect to the size of the vector 00:24:44.720 |
to the size of the row to which we are applying the softmax then we need to compute the normalization 00:24:49.760 |
factor which is this stuff here and we we cannot compute it before the step number one because we 00:24:56.720 |
need to have the maximum element to compute this summation here and after we have calculated the 00:25:02.960 |
normalization factor we can then divide each element's exponential by the normalization factor 00:25:08.640 |
and we cannot do the step number three before calculating the normalization factor because 00:25:13.200 |
we need to divide each number by the normalization factor so if you like pseudocode for algorithms 00:25:20.080 |
this is an algorithm for computing the softmax that we have seen right now so first we find the 00:25:25.120 |
maximum of the row to which we are applying the softmax then we compute the normalization factor 00:25:31.200 |
and then we apply the softmax to each element which means that we calculate compute the 00:25:35.360 |
exponential of each element minus the maximum value of the vector divided by the normalization 00:25:41.840 |
factor now this pseudocode is an algorithm that is quite slow because look at a practical example 00:25:50.240 |
imagine we have this vector here first we need to do step one find the maximum value in this 00:25:55.520 |
vector which is number five and this takes linear time computation then we need to calculate the 00:26:00.960 |
normalization constant which is the sum of the exponential of each element minus the maximum 00:26:06.960 |
value so e to the power of 3 minus 5 plus e to the power of 2 minus 5 etc etc this we will call it 00:26:14.000 |
l and then each we need to go again through this vector again and take the exponential of each 00:26:20.160 |
element minus the maximum divided by the normalization factor so to apply the softmax to 00:26:27.040 |
an n by n matrix we need to go through each element of this matrix three times and these 00:26:33.760 |
operations must be done sequentially so we cannot start operation two until we have done operation 00:26:38.880 |
one and we cannot start operation three until we have done one and two so this is quite slow 00:26:45.440 |
only to apply an operation that doesn't even change the shape of the matrix it's just 00:26:51.280 |
uh normal uh normalizing the values so there must be a better way that that does not involve 00:26:57.760 |
three sequential operations in which we need to go through this matrix three times let's see 00:27:02.480 |
all right guys let's rehearse what is the problem that we are trying to solve the problem statement 00:27:08.640 |
is the following can we find a better way to compute the softmax that does not involve going 00:27:15.120 |
through the vector three times because let's look at the pseudocode of the algorithm for computing 00:27:20.480 |
the local the softmax that we have found so far imagine we have a vector made up of four elements 00:27:26.080 |
the first thing that we need to do is to compute the maximum element in this vector which means 00:27:30.960 |
going through this for loop here that allow us to compute the maximum element in this vector which 00:27:36.800 |
means that we start from the left side of the vector and iteratively go to the right side so 00:27:41.520 |
we start from the first element arrive to the end and we compare the previously found maximum with 00:27:47.680 |
the current element to find the global maximum basically this means that uh i i know that this 00:27:53.360 |
is very simple uh i'm probably sure that you don't need to this example but making this example will 00:27:58.720 |
help us understand what we will do next so please bear with me even if it's super simple what i'm 00:28:03.520 |
doing okay we at the beginning m0 is equal to minus infinity m1 is basically the for loop at 00:28:14.000 |
the iteration number one which means that we are m1 will be equal to the maximum of the previous 00:28:19.920 |
estimate of the m which is minus infinity with the current element which is three so it will become 00:28:26.960 |
equal to three then m2 will be equal to the maximum of the previously computed maximum so m1 00:28:33.840 |
so three with the current element which is two so it will be equal to three m3 will be equal to 00:28:40.400 |
the maximum of the previously computed maximum so three with the current three with the current 00:28:46.880 |
element which is five so it will be equal to five and m4 will be equal to the maximum of the 00:28:53.200 |
previously computed maximum and the current element so it will be equal to five so this allow 00:28:57.920 |
us to compute the maximum element so at the fourth iteration we will have the maximum the global 00:29:02.320 |
maximum independently of what is the input array um delete okay after we have computed the maximum 00:29:10.080 |
which we know is five we can compute the normalization factor so let's start with the 00:29:15.520 |
l0 l0 is equal to zero l1 will be equal to the exponential of l0 so actually sorry it will be 00:29:23.120 |
l0 plus the exponential of the current element so three minus the maximum element we have found in 00:29:29.680 |
the previous for loop so five then l2 will be equal to l1 plus the exponential of the 00:29:35.920 |
the current element so it's two minus the maximum then l3 will be equal to l2 plus the exponential 00:29:45.360 |
of the current element five minus five then l4 will be equal to the l3 plus exponential of one 00:29:54.080 |
minus five if you expand this l this will be basically equal to e to the power of three minus 00:30:02.160 |
five plus e to the power of two minus five plus e to the power of five minus five plus e to the 00:30:07.760 |
power of one minus one minus five after we have computed this normalization factor we can use it 00:30:16.320 |
to normalize the each element in the input vector which means that the x new x1 so x1 prime let's 00:30:23.040 |
see will be equal to e to the power of what's the first element three minus five divided by l that 00:30:34.240 |
we computed in the previous for loop so the l at the fourth iteration the new x2 so x2 prime will 00:30:43.360 |
be equal to the e to the power of two minus five divided by l4 and x3 prime will be equal to the e 00:30:50.640 |
to the power of five minus five divided by l4 etc etc for all the elements i know this is super 00:30:58.080 |
simple but it will help us later so in this for loop we have that we need to go through the vector 00:31:05.600 |
three times because first we need to compute this for loop then we need to compute this for loop 00:31:10.560 |
and then we need to compute another for loop we cannot do them not in this sequence because in 00:31:15.680 |
order to compute this for loop we need to have the maximum element because we need it here 00:31:19.680 |
and we cannot compute this for loop until we have computed the previous one because we need to have 00:31:24.240 |
the normalization factor however we are stubborn and let's try to fuse these two operations into 00:31:30.800 |
one for loop which means that we go through the array and simultaneously compute mi and in the 00:31:37.440 |
same iteration we also try to compute lj of course we will not be able to compute lj because we don't 00:31:43.440 |
have the global maximum because we didn't go through the old array yet however let's try to 00:31:50.080 |
use the locally and whatever estimate we have of the maximum so far so let's try to use instead of 00:31:56.240 |
mn let's try to use mi so the local maximum that we have computed so far so if we apply the softmax 00:32:02.720 |
in this way in this fused way to this vector we will have the following iterations so this is our 00:32:10.160 |
array or vector and the first step is mi so m1 will be equal to the previous maximum which is 00:32:17.440 |
minus infinity with the current element so the maximum minus infinity and the current element 00:32:22.480 |
is equal to 3 and l1 will be equal to the previous l so l0 which is starts from 0 plus e to the power 00:32:31.600 |
of the current element minus we should be using the global maximum but we don't have the global 00:32:36.240 |
maximum so let's use the whatever maximum we have so far so we can use 3 now at the second iteration 00:32:42.560 |
we are at this element of the vector and we compute the maximum so far so the maximum so far 00:32:48.480 |
is the previous maximum and the current element so the maximum of the previous maximum and the 00:32:52.640 |
current element which is the maximum between 3 and 2 which is 3 and the normalization factor 00:32:59.120 |
is the previous normalization factor plus exponential of 2 minus 3 which is the current 00:33:04.720 |
element minus whatever maximum we have so far now if our array were made only of these two elements 00:33:12.640 |
so 3 and 2 then whatever we have computed is actually correct because the maximum that we 00:33:19.120 |
have found is a 3 and it's actually the global maximum and the normalization factor that we 00:33:24.800 |
have computed is actually correct because each of the exponential has been computed with the global 00:33:29.680 |
maximum because the first element was computed using 3 as the with the argument minus 3 and 00:33:37.200 |
also the second element was computed with the argument with the argument having minus 3 in the 00:33:42.720 |
in the argument which is the global maximum of the vector however when we arrive at the third 00:33:48.560 |
iteration so let me delete this vector so let me arrive here at the third iteration the maximum 00:33:54.400 |
will change which will also cause our normalization factor to get to to be wrong because 00:34:00.640 |
we arrive at the element number 3 so the number 5 here and we compute the maximum 00:34:07.760 |
so the maximum is the comparison of the previous maximum and the current element so the new maximum 00:34:13.520 |
becomes 5 and the normalization factor is the previous normalization factor so l2 plus the 00:34:19.760 |
exponential of the current element minus the current estimate of the maximum which is 5 00:34:26.240 |
however if you look at this l3 this is wrong why because l3 is equal to if you expand this 00:34:34.240 |
summation it will be equal to e to the power of 3 minus 3 plus e to the power of 2 minus 3 00:34:42.720 |
plus e to the power of 5 minus 5 this exponential here is using 5 as the global maximum this 00:34:50.320 |
exponential here is using 3 as the global maximum and this one is using 3 as the global maximum 00:34:55.440 |
so the first two elements have been computed thinking that the global maximum is 3 but 00:35:00.560 |
actually we later we found a better global maximum which is 5 so which makes this normalization 00:35:05.920 |
factor wrong however can we fix at the third iteration whatever normalization we have computed 00:35:13.280 |
so far up to the second iteration actually we can because if we expand this so as we have 00:35:20.800 |
here we have expanded it what we need here is here to have a minus 5 because that's actually 00:35:27.360 |
the global maximum that we have found so far not the minus 3 that we had at the previous iteration 00:35:32.880 |
so and here we also need to fix this replace this minus 3 with minus 5 how can we do that well if 00:35:39.040 |
we multiply this one here and this one here with a correction factor that will sneak in a new maximum 00:35:47.520 |
inside of this exponential then we solve the problem and actually this correction factor 00:35:52.160 |
is very easy to calculate because at the third iteration if we multiply l2 so the previously 00:35:57.840 |
computed normalization factor with this factor here which is the exponential of the previous 00:36:03.040 |
estimate of the maximum minus the current estimate of the maximum so 5 we will see that 00:36:08.960 |
e by the properties of the exponentials this one here will become e to the power of 3 minus 3 plus 00:36:16.800 |
3 minus 5 so this minus 3 will cancel out with this 3 and also the second factor will have this 00:36:23.200 |
3 will cancel out with this minus 3 will cancel out with this 3 and they will become e to the 00:36:28.400 |
power of 3 minus 5 and 2 to the power of e to the power of 2 minus 5 which is actually correct 00:36:34.640 |
because at the third iteration we should be actually happy we should be using minus 5 as the 00:36:40.400 |
maximum of the array so far so basically what we have found is a way to fix whatever normalization 00:36:48.960 |
factor we have computed so far while iterating through the array when we found we when we find 00:36:55.680 |
a better maximum compared to what we have so far and when we don't need to fix anything then the 00:37:02.160 |
formula still stands because what we did here as a multiplication as a correction factor so this is 00:37:08.080 |
the correction factor this correction factor is nothing more than the previous maximum so the 00:37:15.840 |
previous estimate of the maximum minus the current estimates of the maximum at the current iteration 00:37:21.440 |
so the current max so this is basically m of i minus 1 and this is m of i so the current maximum 00:37:30.400 |
at the current iteration and let me delete it otherwise it remains forever in my slides 00:37:35.440 |
so basically when we arrive to the last element we will see that the maximum doesn't change because 00:37:42.480 |
we compare the previous maximum with the current element which is less than the previous maximum so 00:37:47.760 |
the maximum doesn't change and we don't need to fix anything because the the the previous l3 so 00:37:55.600 |
the previously computed normalization factor is correct because they have all been using the minus 00:38:01.200 |
5 so when we don't need to fix anything we just multiply by e to the power of the previous maximum 00:38:08.080 |
minus the current maximum which is e to the power of zero in this case so it's not fixing anything 00:38:13.440 |
so we have found a way to fix the previously computed normalization factor while going 00:38:20.240 |
through the array even if at the current iteration we don't have the global maximum yet so that every 00:38:27.200 |
time the maximum changes we can fix and every time it doesn't change we just multiply with e to the 00:38:32.400 |
power of zero which is like multiplying with one so the new algorithm that we have found for the 00:38:38.320 |
softmax is the following so we start with m0 equal to minus infinity we start with l0 equal to zero 00:38:44.560 |
we go through the array we compute the locally the local maximum so up so the maximum so far 00:38:52.640 |
from the zeroth element to the ith element so to the element at which we are doing the iteration 00:38:59.680 |
and the previously computed li can be fixed by using this correction factor which is e to the 00:39:06.160 |
power of the previous maximum minus the current maximum plus the exponential of the current 00:39:12.560 |
element minus the current estimate of the maximum in this way we go through the array only once 00:39:19.120 |
and we obtain two values the global maximum at at the end at the same time the 00:39:27.120 |
normalization factor and then we can use it to compute the softmax so we made three transformed 00:39:33.520 |
three passes through the array into two passes through the array and this is very important 00:39:40.080 |
and we will see how we actually use it to derive flash attention the example that i have given you 00:39:46.880 |
so far is not really a proof that our algorithm will work in every case because we made a very 00:39:52.800 |
simple example by using a vector made up of four elements but does our new algorithm work in every 00:40:00.400 |
single case with whatever the numbers are we need to prove that so we will prove that by induction 00:40:06.960 |
so what first of all what are we trying to prove we have fused the first two for loops into one 00:40:13.280 |
for loop as you can see here what we expect is that at the end of this for loop this mn so the 00:40:21.200 |
m at the last iteration will be actually the global maximum in the vector and this ln so the 00:40:27.840 |
l at the last iteration will be equal to the sum of all the exponential of all the elements 00:40:34.240 |
minus the maximum element of the vector so the global maximum of the vector and we need to prove 00:40:41.440 |
that because what i did before was an example and that was not really a rigorous proof and the way 00:40:47.520 |
we will prove it is by induction which is a typical way of proving this kind of theorems 00:40:53.120 |
now proof by induction basically works in the following way we need to prove that 00:40:59.200 |
our algorithm works for a base case for example with n equal to one and then we pretend we assume 00:41:08.800 |
that the algorithm works on n and we need to prove that it also works for n plus one if this holds 00:41:17.120 |
then we have proven our algorithm for every possible n because it will work for the base 00:41:22.720 |
case so for example n equal to one and then by using the induction step we say so this if it 00:41:29.200 |
works for n and then it also works for n plus one then it means that it will also work for two but 00:41:34.080 |
then if it works for two then it should also work for three because of the induction step that we 00:41:38.240 |
will prove and if it works for three then it will also work for four etc etc up to infinity so let's 00:41:44.560 |
prove it for the base case which is n equal to one it's very simple so at n equal to one this 00:41:51.680 |
for loop will only have one iteration so m m1 and l1 m1 will be the maximum of the previous m which 00:41:59.120 |
is minus infinity because we initialize m0 equal to minus infinity so it will be equal to the 00:42:07.600 |
maximum of the previous m and the current element which is x1 so it will be equal to x1 whatever x1 00:42:13.200 |
we is uh x1 usually will never be equal it cannot be equal to minus infinity um because it's a 00:42:20.960 |
number in a fixed representation so it cannot be minus infinity um so the the x the m1 at the end 00:42:29.040 |
so it will because we have only one element n equal to one this is m1 is also the last um m of 00:42:37.280 |
this it of this for loop it will be equal to the global maximum of the vector made up of only one 00:42:42.560 |
element and l1 will be equal to the previous l which we start from zero so l0 multiplied by a 00:42:50.480 |
correction factor which will be in this case e to the power of minus infinity because the correction 00:42:54.960 |
factor is the previous estimate of the max of the max minus the current estimate of the max but the 00:43:01.360 |
previous estimate of the max is minus infinity minus x1 it is equal to minus infinity so this 00:43:07.120 |
one will be this will be cancelled out and then plus e to the power of x1 minus the current maximum 00:43:14.640 |
which is x1 so m1 and if this one will be equal to the sum of all the elements of the vector which 00:43:24.000 |
is made up of only one element minus the maximum element in the array which is x1 so we have proven 00:43:31.280 |
that it works for n equal to one now we assume that it works for n does it also work for an array 00:43:39.200 |
of vector or with a vector of size n plus one so let's see what happens at the n plus one iteration 00:43:49.200 |
at the n plus one iteration we will be doing the maximum of the previous estimate of m which is 00:43:54.880 |
the m at the nth iteration and the current element so xn of plus one this by the properties of the 00:44:03.120 |
max function it will be actually equal to the maximum of the global vector up to n plus one 00:44:10.400 |
because the maximum will choose whatever is the maximum between the previous estimate and the 00:44:16.320 |
current estimate and ln plus one which is the normalization factor at the n plus one iteration 00:44:23.520 |
will be equal to the ln so the previous estimate not previous estimate but the previous 00:44:27.920 |
normalization factor at the nth iteration multiplied by the correction factor which is the 00:44:35.040 |
previous maximum minus the current maximum plus the exponential of x the current element 00:44:43.600 |
minus the current estimate of the maximum but ln we have we assume that this property so this 00:44:54.160 |
algorithm works up to n so ln is for sure equal to the sum of all the exponentials of the previous 00:45:03.600 |
of the vector up to n minus the local maximum of the vector up to the nth element 00:45:12.800 |
which is mn we multiply by the correction factor if there is something to correct which will be the 00:45:21.360 |
previous maximum minus the current maximum plus the exponential of the current element minus the 00:45:26.880 |
current estimate of the maximum now by the properties of the exponentials so we can bring 00:45:35.440 |
this one inside of the summation and we will see that this mn and this mn will cancel out because 00:45:42.240 |
it will be exponential of xj minus mn plus mn minus mn plus one so this mn and this mn will 00:45:49.520 |
cancel out and we obtain this one plus this factor here that remains unchanged however you can see 00:45:55.920 |
that this stuff here is exactly the argument of this summation for the at the iteration n plus one 00:46:04.800 |
so it is this one is e to the power of xj where j is going from one to n minus mn plus one plus 00:46:13.280 |
e to the power of xn plus one minus mn plus one so the j only appears here and it's equal maximum 00:46:21.360 |
to n and this is similar to being a j with n plus one so we can increase the index of this summation 00:46:29.200 |
by one and it will be the same and it will result in the same summation so we have proven 00:46:36.240 |
that also at the n plus one iteration we will have that the l will be equal to the sum of all 00:46:43.920 |
the elements of the array the exponential of all the elements of the array up to the n plus one 00:46:49.760 |
element minus the maximum up to the n plus one element so we have proven that if it works and 00:46:57.840 |
then it also works for n plus one this is enough to prove that it works for all size of arrays 00:47:04.880 |
don't worry if you didn't get the proof by induction it is if it's the first time you are 00:47:11.680 |
seeing this kind of proof it may take a little bit to to get it if you want to learn a little 00:47:17.520 |
bit more about proof by induction i recommend watching some other proof it's very simple it's 00:47:22.000 |
just you need to get into the right mindset anyway let's move forward all right let's talk about 00:47:29.680 |
block matrix multiplication i know that you want to jump to the code immediately and we will go 00:47:34.720 |
there we just need a little more theory actually so imagine we are doing a matrix multiplication 00:47:40.800 |
so we have a matrix a we want to multiply it with a matrix b and it will produce an output matrix 00:47:47.200 |
c imagine the dimensions of the first matrix are m by k the second matrix is a k by n it will 00:47:54.800 |
produce an output matrix that is m by n now imagine we want to parallelize the computation of this 00:48:01.600 |
output matrix i know that i didn't talk about gpus yet so we will not talk about gpus we will 00:48:08.800 |
talk about parallelization in the case of a multi-core cpu with which you are very probably 00:48:14.720 |
familiar with because right now in nowadays when you buy a computer you have a cpu and usually you 00:48:21.120 |
can buy a single core cpu or multi-core like a two core four core eight core etc etc each of the 00:48:27.760 |
these cores are actually kind of small cpus inside your cpu that can execute operations in parallel 00:48:33.520 |
how to parallelize the matrix multiplication imagine you have this matrix multiplication 00:48:39.520 |
to parallelize each of the output element in this c matrix is a dot product of a row of the 00:48:46.320 |
a matrix with a column of the b matrix for example this element on the top left is the dot product 00:48:53.280 |
of the first row of a and the first column of b this element on the top right of c is the dot 00:49:00.240 |
product of the first row of a and the last column of b this element on the bottom left is the dot 00:49:07.040 |
product of the last row of a and the first column of b etc etc for all the other elements now to 00:49:13.600 |
parallelize this computation we need as many cores as is as there are elements in c if we want to 00:49:20.080 |
parallelize it so if m and n are very small then maybe we have enough cores but imagine m and n are 00:49:28.080 |
quite big we imagine like 100 by 100 we don't have 10 000 cores right now in the cpus so how can we 00:49:36.720 |
parallelize a matrix operation by using less cores than there are elements in the matrix itself 00:49:44.640 |
that's when we talk about block matrix multiplication basically block matrix 00:49:50.160 |
multiplication means that you can divide the original matrix into smaller blocks of elements 00:49:57.360 |
and then the operations of matrix multiplication can be computed between these blocks for example 00:50:04.800 |
imagine we have a matrix that is 8 by 4 it means that it has 8 rows and 4 columns which means that 00:50:13.680 |
it has 32 elements and then we are multiplying it with another matrix that is 4 by 8 so it has 4 00:50:22.560 |
rows and 8 columns so it also has 32 elements the output matrix will should have 64 elements 00:50:31.680 |
we don't have 64 cores so how can we parallelize it imagine we only have 8 cores now with 8 cores 00:50:39.600 |
we can divide this original matrix a into 4 blocks where the first block is this top left block of 00:50:47.520 |
2 by no 4 by 2 elements so um let's say um 8 elements on the top left and then 8 elements 00:50:58.960 |
on the top right of this matrix then 8 elements on the bottom left and 8 elements in the bottom 00:51:04.240 |
right of this matrix these are 4 blocks then we divide also the b matrix into um 8 blocks 00:51:11.200 |
where each block is made up of 4 elements so this b11 is the top left 4 elements in the original 00:51:19.680 |
matrix this b4 is the top right 4 elements in the original matrix this b21 is the um 00:51:28.080 |
bottom left 4 elements in the original matrix etc etc etc how do we do this block matrix 00:51:33.600 |
multiplication we can watch these matrices as made only by their blocks so we can view this 00:51:40.720 |
matrix here as made up only by its blocks we can view this matrix here as made up only by its blocks 00:51:48.320 |
and the output of this multiplication will be a matrices that is computed in the same way as the 00:51:55.600 |
original matrix but where the output of each dot product will not be a single element of the output 00:52:02.880 |
matrix but it will be a block of elements of the output matrix for example the top left block here 00:52:10.720 |
is the dot product of the first row of this matrix with the first column of this matrix 00:52:18.480 |
and it will be computed as follows so it will be a11 multiplied by b11 plus a12 multiplied by b21 00:52:25.920 |
and this output will not be a single scalar but it will be uh well let me count it should be 00:52:33.600 |
eight elements so it should be four um made up it should be a block of four elements 00:52:40.320 |
or eight elements let me let me count actually so because we have eight blocks and it should be 00:52:47.600 |
made up of eight elements let's we can see that here um how to find the dimensions of this output 00:52:55.520 |
block well we can check what is a11 a11 is four by two so it's eight elements in a smaller matrix 00:53:05.920 |
made up of eight elements where the elements are distributed in four rows and two columns 00:53:10.960 |
we are multiplying it by b11 which is a smaller matrix compared to the original made up of 00:53:16.560 |
two by two elements so four elements so when we multiply four by two multiplied by two by two it 00:53:23.280 |
will produce a four by two output block matrix so block so if we do this computation here block by 00:53:32.240 |
block it will produce a block of output elements of the original matrix so not not a single scalar 00:53:38.400 |
but a block of outputs which makes it very easy to parallelize because if we have only eight cores 00:53:44.320 |
we can assign each output block to one core and each core will not produce one output element of 00:53:50.480 |
the original matrix but it will produce eight elements of the original matrix as a four by two 00:53:56.400 |
matrix so basically block matrix allow us to to do the matrix multiplication either by element 00:54:05.360 |
by element so like in the original matrix so each row with each column or blocks by blocks in the 00:54:10.320 |
same way like we do normal matrix multiplication because the the matrix multiplication that we are 00:54:15.280 |
doing between blocks is the same way as we do matrix multiplication with the original matrix 00:54:21.120 |
and it will produce not a scalar but a block and now let's see why this is very important for us 00:54:26.320 |
so why should we care about block matrix multiplication because we are trying to 00:54:33.920 |
compute the following operation so the query multiplied by the transpose of the keys 00:54:38.640 |
and then we will should apply the softmax of this operation and then we should 00:54:42.240 |
multiply the output of the softmax with v for now let's ignore the softmax let's pretend that 00:54:48.960 |
we are not going to apply any softmax so we take the output of the query multiplied by the transpose 00:54:54.400 |
of the keys and we just multiply it by v to obtain the output of that edge which is wrong of course 00:54:58.560 |
but it simplifies our tractation of what we are going to do next so for for this moment let's 00:55:04.800 |
pretend that we are not going to apply any softmax so we just do the query multiplied by 00:55:08.400 |
transpose of the keys and directly we multiply the result of this operation with v this will 00:55:12.720 |
result in a matrix that is n by d so n tokens each made up of an embedding of d dimensions 00:55:19.680 |
so lowercase d dimensions and we know that query key and values are themselves matrices of n by d 00:55:28.320 |
dimensions so the n tokens which made up of an embedding of d dimensions so imagine we have a 00:55:36.400 |
query matrix and the key and the value matrix that are 8 by 128 so we have 8 tokens each token is 00:55:44.000 |
made up of 128 dimensions we can divide as we have seen each when we compute a matrix multiplication 00:55:51.600 |
we can divide our matrix into blocks how we choose the blocks is up to us as long as the 00:55:59.760 |
operating the shapes of the blocks match when doing the matrix multiplication so for example 00:56:05.760 |
in the previous case we divided our matrix a into blocks such that the the shape of the block matrix 00:56:13.360 |
so the matrix that is made up only of the blocks is compatible with the block matrix b 00:56:20.080 |
so that this operation is possible so this is the only requirement that we need to be 00:56:25.040 |
aware when doing the block matrix multiplication the shapes of the blocked matrix so the matrix 00:56:30.000 |
that is made only of the blocks should match in the matrix multiplication for the rest it doesn't 00:56:35.280 |
matter how we divide it so imagine that we choose to divide this query matrix into blocks of rows 00:56:42.960 |
and we can do that we don't have to necessarily divide also the columns we can just divide the 00:56:47.200 |
rows so that each q is not a single row but it's a group of two rows so q1 is a group of the first 00:56:54.640 |
two rows of the q matrix of the q sequence q2 is the group of the second two rows of the q sequence 00:57:01.120 |
etc etc and we do the same also for v for k we don't do it because we are actually going to 00:57:07.760 |
multiply with k transposed so we do the subdivision directly on k transposed so we so we have the q 00:57:15.680 |
which has been divided into groups of rows and then we have a k transposed which is a matrix 00:57:21.920 |
that is 108 by 8 because it's the transpose of the keys which is 8 by 108 and we decide to divide 00:57:31.120 |
each of the column group of columns of k into a single block so the k1 is the first two columns 00:57:38.960 |
of k transposed k2 is the second group of two columns in k transposed etc etc until k4 which 00:57:46.400 |
is the last two columns in k transposed the first operation that we do is the multiplication query 00:57:51.840 |
multiplied by the transpose of the keys which basically means that we need to multiply each 00:57:56.640 |
query with all the keys then the second query with all the keys etc etc now each query is not a single 00:58:04.960 |
row of the q sequence it's a group of two rows of this q sequence and each k is not a single column 00:58:12.080 |
of k transposed it's a group of two columns of k transposed but doesn't matter because we have 00:58:17.200 |
seen that the matrix multiplication if we write the matrices as made up of blocks we just compute 00:58:22.880 |
it in the same way when we do a normal matrix multiplication so we are multiplying this matrix 00:58:28.720 |
by this matrix and for what we know this matrix here is made up of four rows with some dimensions 00:58:36.640 |
which is 128 dimensions and this one here is made up of how many rows 128 rows and four columns 00:58:46.960 |
i didn't draw the columns because it's too many to draw here but you need to pretend it's a lot of 00:58:53.760 |
dimensions one for each 128 for each vector and here you need to pretend that this is 128 00:59:00.080 |
rows when we do the matrix multiplication we apply the normal matrix multiplication procedure 00:59:05.760 |
which is each output element so this first of all the output shape of this matrix of this matrix 00:59:10.400 |
multiplication will be four by four because it's the outer dimensions of the two metrics that you 00:59:14.560 |
are multiplying the first element of the output will be the dot product of this vector here 00:59:21.280 |
with this vector here the second element so this one here will be the dot product of this 00:59:27.120 |
vector here with this vector here however this is not vector and this is not a vector so it's 00:59:33.840 |
actually a matrix multiplication in this case this element here is not a scalar it is a group of 00:59:40.320 |
elements of the output matrix because we are doing block matrix multiplication and how many elements 00:59:45.760 |
it will be well we know that the original q1 is a 2 by 128 the k1 is 108 by 2 so it will be a group 00:59:54.640 |
of 2 by 2 elements of the output matrix so we are doing the matrix multiplication of the q1 01:00:01.840 |
with k1 then q1 with k2 then q1 with k3 q1 with k4 etc etc for the first row and then the second 01:00:10.240 |
row will be q2 with all the k's and the q3 with all the k's and q4 with all the k's so as you can 01:00:16.000 |
see when we do matrix multiplication we don't even care if what is underlying is a block or a vector 01:00:22.880 |
or a scalar we just apply the same procedure first row of the black block matrix multiplication with 01:00:30.320 |
the first column of the matrix of the second matrix and then the first row with the second 01:00:36.400 |
column the first row with the third column etc etc let's then multiply because the formula says 01:00:43.920 |
that we need to multiply query with the transpose of the keys and then multiply by b all of these are 01:00:49.280 |
block matrices now as you can see from my using of colors every time i refer to the original matrix i 01:00:57.680 |
use the blue color and every time i refer to the block matrix i use the pink color so we need to 01:01:04.160 |
multiply the output of the query multiply by the transpose of the key then by v because we are 01:01:08.720 |
skipping for now the softmax and later we will see why so if we want to do this multiplication we 01:01:14.560 |
need to do the following so it will be uh this matrix is made up of blocks and block matrix 01:01:22.080 |
multiplication just ignores this fact and just does the matrix multiplication like it is a normal 01:01:26.560 |
matrix multiplication so we do the first row with the first column then the first row with the second 01:01:33.200 |
column then the third row the first row with the third column etc etc so the first block of row how 01:01:40.320 |
is going to be calculated this output in the output matrix of this matrix multiplication 01:01:46.080 |
well it will be the first row so the dot product of the first row the dot product because it's not 01:01:55.280 |
really a dot product it's the actually the matrix multiplication of the first row but in a dot 01:02:00.320 |
product way let's say with the first column which is made up of v1 v2 v3 and v4 so it will be this 01:02:09.360 |
element with v1 plus this element with v2 plus this element with v3 plus this element with v4 01:02:16.640 |
and this will be the first output element the second output block will be this row 01:02:26.080 |
with this column which will be this element with v1 this element plus this element with 01:02:33.360 |
v2 plus this element with v3 plus this element with v4 and this will produce the second output 01:02:39.280 |
block etc etc also for the third and the fourth block output let's look at what is each block 01:02:47.360 |
made up of so each block is made up of the um the first element so query one multiplied by key one 01:02:54.240 |
because um it's the result of the query multiplied by the keys with the v1 of the second matrix 01:03:01.600 |
plus the this element with this one plus this element with this one plus this element with 01:03:07.760 |
this one so the pseudocode for generating this output of this attention mechanism which is not 01:03:14.560 |
really attention mechanism because we skip the softmax but i just want you to get into the 01:03:19.120 |
habit of thinking in terms of blocks is the following so we take each query block 01:03:25.920 |
we go through each query and as you can see let's look at actually what this output is made up of 01:03:34.080 |
it is made up of the query one multiplied by key one and the result multiplied by v1 then the query 01:03:40.400 |
one with k2 then the result multiplied by b2 then the query one with k3 and the result multiplied 01:03:47.280 |
by v3 plus the query one with the k4 and result multiplied by v4 this is basically what we are 01:03:53.280 |
doing is the dot product of this row with this column made up of blocks so the the pseudocode 01:04:02.160 |
for generating this first row is the query is then query number one and then we iterate through the 01:04:09.280 |
keys and the values from one to four and we sum iteratively so for each block basically to generate 01:04:16.720 |
this output matrix and if you for each row we will see that it's a different query with all the keys 01:04:22.800 |
and values and then this will be the the query number three with all the keys and values and 01:04:27.600 |
this will be the query four with all the keys and values so to generate this output matrix we need 01:04:33.040 |
to do we iterate through the queries and this will be one row of this output matrix and then we need 01:04:39.120 |
to do this iterative sum of the query i that we are iterating through multiplied by the jth k 01:04:46.640 |
and v and we keep summing them iteratively and that would that will produce the output matrix 01:04:53.040 |
or you can see here i know that what i have done so far is not useless not useful for flash 01:04:59.280 |
attention but it's useful for us to get into the mindset of computing this product by blocks 01:05:04.800 |
because later we will use it also with the softmax all right guys i i know that we have 01:05:11.680 |
computed what we have computed so far is not really the softmax operation it's not sorry 01:05:16.240 |
they're really the attention mechanism because we have skipped the softmax so somehow we need to 01:05:21.680 |
restore it and the the following few i think 10 20 minutes we are going to be really really 01:05:28.320 |
challenging because i am going to do a lot of operations that will involve a lot of different 01:05:34.240 |
blocks and a lot of different matrix multiplication and the variants of the softmax so it may be 01:05:40.880 |
difficult to follow however don't give up you can watch this part twice three times and every time 01:05:47.600 |
you it will have a better understanding i also recommend watch it until we reach the flash 01:05:53.440 |
attention algorithm before we start restarting from to to go back to re-watch it because you 01:06:00.480 |
watch it we reach the flash attention algorithm and it will give you a better understanding of 01:06:05.920 |
what has happened so far and then you can re-watch it to deepen your understanding another thing that 01:06:11.760 |
i recommend is take pen and paper and write exactly the operations that you are seeing 01:06:17.440 |
and write the shapes of each of these blocks of these elements that are made in the that are part 01:06:24.400 |
in this matrix multiplications so that you better understand what is happening and you better 01:06:32.480 |
remember what when i refer to a particular element or a particular block okay after giving this small 01:06:39.760 |
motivational speech let's start so what we have done so far was query multiplied by the transpose 01:06:46.880 |
of the keys however each query is not a single row of the query sequence but it's a block of 01:06:54.000 |
queries it's a block of rows in our particular case this q1 is not one row of the query sequence 01:07:02.000 |
it's two rows of the query sequence because we have chosen as a block size a group of two rows 01:07:07.280 |
and this k transposed one is not one column of the k transposed matrix is two columns of the 01:07:16.240 |
k transposed matrix because we have chosen it like this and if you don't remember let's go back to 01:07:21.040 |
see it here we have chosen k1 is two columns and q1 is two rows of the query original matrix 01:07:30.000 |
and every time i use the blue color i am referring to the original shape and every time i'm using the 01:07:36.560 |
pink or violet whatever it is i am referring to the block matrix so it's a block of elements of 01:07:43.920 |
the original matrix okay now the first thing that we have done was a query multiplied by the 01:07:50.880 |
transpose of the keys and this produces a block matrix as output that we will call s where each 01:07:58.560 |
element sij so the s11 element of this matrix will be the query one with the k transposed one 01:08:06.720 |
this s12 will be query one with k transpose the two s13 will be query one with k transpose the 01:08:14.720 |
three etc etc for all the rows and for all the columns then we should be applying the softmax 01:08:20.400 |
because if you remember the formula is softmax of the query multiplied by the transpose of the keys 01:08:24.800 |
however i want to restore the softmax operation but with a twist which means that we will apply 01:08:31.520 |
the simplified version of the softmax and we will call it softmax star which is just the softmax 01:08:38.160 |
without the normalization so let me write it for you what it means let's do it with the same color 01:08:45.520 |
that i chose for the softmax which is orange so the softmax if you remember correctly if we remember 01:08:52.800 |
it's the softmax of a vector we apply it element wise so each element is modified according to the 01:09:03.760 |
following formula so the ith element of the output vector to which we are applying the softmax 01:09:08.640 |
is equal to the exponential of the ith element of the input vector minus the maximum element 01:09:20.080 |
in the input vector divided by a normalization factor that is calculated according to this 01:09:26.320 |
summation that is going from j equal to 1 up to n of the exponential of xi minus x max so basically 01:09:37.920 |
we are doing the exponential of each element minus this x max and why are if you remember correctly 01:09:44.320 |
why are we subtracting this x max to make this exponential numerically stable computable because 01:09:50.480 |
otherwise it will explode and because we are applying it to the numerator we also need to 01:09:54.880 |
apply to the denominator okay the softmax star operation is exactly like the softmax but without 01:10:01.520 |
the normalization part which means that it's just the numerator of the softmax so we will modify 01:10:06.640 |
each element of the vector to which we apply the softmax star according to this formula 01:10:11.680 |
let me move it more aligned like this so we just do element element wise operation that is the 01:10:19.040 |
exponential of each element minus the maximum of the vector to which we are applying softmax star 01:10:24.480 |
okay now why did i introduce this softmax star operation because we will be applying it to the 01:10:32.240 |
matrix that we have computed so far which is this s matrix so we apply the softmax star to each 01:10:39.200 |
element of this s matrix but each element of this s matrix is itself a matrix because it's a block 01:10:47.040 |
matrix and each element of this s matrix so for example the element s11 is a two by two matrix 01:10:53.600 |
because it is coming from the product of two matrices which are a group of rows and a group 01:10:58.880 |
of columns from the q and the k so for example this s11 is what is let's draw it actually this 01:11:07.040 |
s11 will be for example made up of four elements let's call it i don't know a of s11 01:11:15.120 |
uh let's let's choose better naming let's call it i don't know a 01:11:21.120 |
b c and d just the generic elements when we apply the softmax star to this s11 01:11:33.520 |
it will result so let's apply the softmax star softmax star it will result in a matrix 01:11:43.360 |
that is each element the exponential of each element minus the maximum for each row now we 01:11:52.480 |
don't know which is the maximum so let's choose one suppose that the maximum for this row is a 01:11:56.880 |
and the maximum for this row is d the first element of the output of this softmax star applied 01:12:03.760 |
to this block s11 will be the exponential of a minus a because that's what we chose as the maximum 01:12:13.440 |
for this row the second element will be the exponential of b minus a because it's the maximum 01:12:19.760 |
for that row then in the bottom row it will be the exponential of c minus d because that's the 01:12:28.160 |
maximum for the bottom row and this will be the exponential of d minus t and that's the exponential 01:12:34.400 |
that's how the softmax star will modify each block in this block matrix let me delete this stuff 01:12:41.920 |
otherwise it will remain in my slides forever and later i want to share the slides with you guys so 01:12:46.720 |
you can use my same slides so delete delete okay after we have applied the softmax to each of the 01:12:56.080 |
elements in this s matrix we will call it the p matrix and each element p11 will again be a block 01:13:03.920 |
of two by two elements so p11 will be the softmax so p11 will be the softmax star applied to s11 01:13:15.200 |
where s11 is what is a query one k transposed one and the p12 will be the softmax star applied to 01:13:21.760 |
s12 where s12 is what is a query one multiplied by k transposed two etc etc etc for all the elements 01:13:29.760 |
of s okay now that we have applied this softmax star operation the next operation that we should 01:13:36.480 |
be doing according to the formula of the attention is the softmax of the query multiplied by the 01:13:41.600 |
transpose of the keys then the result of the softmax multiplied by v i know that we didn't 01:13:46.880 |
apply the real softmax we apply the softmax star which is softmax without the normalization 01:13:52.560 |
later we will see how to compensate this lack of normalization because we will do it at the end 01:13:58.560 |
and it's something that we can do okay so we take this p matrix which is the result of the softmax 01:14:06.480 |
star applied to this s matrix and we multiply it by v what how do we do it well it's a block or 01:14:14.640 |
it's a matrix made up of blocks of matrices so p11 is actually not a scalar but it's a matrix 01:14:21.920 |
of two by two elements and we need to multiply it by v but we don't multiply with the original 01:14:27.440 |
sequence v but with the blocked sequence v just like before where each v is not one row of v but 01:14:34.320 |
it's a group of rows of v and how many rows is it is it is two rows of v for now please ignore 01:14:42.800 |
completely whatever i have written here because we will use it later so we need to do this product 01:14:47.440 |
of this matrix here which is made up of blocks remember with this matrix here which is made up 01:14:52.880 |
of blocks it is made up of four rows where each row is not really a row it is a block of rows 01:15:00.800 |
and this one it is made up of four by four elements where each element is not really a 01:15:04.720 |
scalar but it's a matrix so as you remember in the block matrix multiplication when the algorithm for 01:15:11.760 |
computing the matrix multiplication is the same as the normal matrix multiplication except that 01:15:16.240 |
we use blocks so what i am doing is guys the following operation so let's write it somewhere 01:15:22.960 |
let's say o is equal to p multiplied by v okay so the first output row a row because it's not 01:15:33.520 |
really a row but it's a block row will be computed as follows the first row of this block matrix 01:15:41.200 |
with the first with the first column of this v matrix and we are treating it like a block matrix 01:15:50.800 |
so it will be p11 multiplied by v1 plus p12 multiplied by v2 plus p13 multiplied by v3 plus 01:16:02.320 |
p14 multiplied by v4 this will produce the first output row of o but it's not really a row because 01:16:11.920 |
it's a made up of two rows so this stuff here is not one row it is two row and we can prove that 01:16:19.360 |
because what is p11 p11 is let's write it somewhere so p11 is a 2x2 matrix yeah 2x2 and 01:16:30.160 |
we are multiplying it with v1 which is a block of two rows of v so it is a two rows by 128 01:16:38.720 |
dimensions so it is equal to 2 by 128 so this stuff here is 2 by 128 so this block here the 01:16:50.080 |
output block that we're computing is a block of two rows of the output matrix that we are computing 01:16:55.120 |
i know this is really difficult to follow because we are involving blocks so we need 01:17:03.120 |
to visualize at the same time matrix as blocks and as the original matrix that's why i highly 01:17:09.360 |
recommend you to pause the video think it through write down whatever you need to write down because 01:17:16.240 |
it's not easy to follow it just by memorizing the shape so you you actually need to write down 01:17:21.600 |
things anyway we are computing the first output block of the output o matrix now if we if you 01:17:31.120 |
remember the output the output this output here should be the output of the output of the softmax 01:17:42.560 |
multiplied by v now this softmax has not been applied to the entire row of this matrix here 01:17:52.560 |
as matrix here basically to compute this softmax star what we did was to compute the softmax star 01:18:00.880 |
at each block independently from the other blocks which means that the maximum that we are using to 01:18:07.360 |
compute each softmax star is not the global maximum for the row of this s matrix but the 01:18:15.680 |
local maximum of each block and this is wrong actually because when we compute the softmax 01:18:22.640 |
we apply the softmax we should be using the global row i want to give you an example 01:18:30.080 |
without using blocks because otherwise i think it's not easy to follow so when we do the normal 01:18:36.320 |
attention so we have a query multiplied by the transpose of the keys this produces a matrix 01:18:43.120 |
that is n by n so sequence by sequence where each element of this matrix so let's say three four 01:18:50.640 |
five i don't know how many is one two three four five six yeah six two three four and five six 01:18:59.600 |
should be one two three four five six okay this one here should be the dot product of the first 01:19:05.520 |
query with the first um let me use because query one transpose the key one uh this is because 01:19:16.880 |
as i said before when we do the product of two vectors we always treat them as column vectors 01:19:23.120 |
so when you want to write the dot product you cannot multiply two column vectors you need to 01:19:27.840 |
multiply one row vector with one column vector that's why we transpose this one if it confuses 01:19:32.400 |
you you can also write q1 k1 that's totally fine it's just uh wrong from a notation point of view 01:19:38.880 |
anyway the first one will be the dot product of the query one with the k1 the second element will 01:19:45.760 |
be the dot product of the query one with the k2 the third will be the query one with the k3 etc etc 01:19:52.480 |
etc um so this is a q1 with k1 q1 with the k2 k2 and q1 with the k3 q1 with the k4 um anyway 01:20:11.120 |
when we do the softmax we actually calculate the maximum on this entire row however what we are 01:20:18.400 |
doing is we are actually doing a block matrix multiplication and as you remember um when we do 01:20:26.800 |
by blocks we are grouping together rows of queries and rows of keys and in this particular case we 01:20:35.920 |
are grouping two queries together to create one uh one group of queries and two keys together to 01:20:43.920 |
create one block of keys so we need another row of this one so it's the let me choose a query one 01:20:51.200 |
k or query 2k1 this should be query 2k1 query 2k2 query 2k3 query 2k4 01:21:04.000 |
query 2k5 and query 2k6 um when we each of this each of this block here is computing 01:21:16.160 |
this block here is computing two by two elements of the original matrix if we had never applied 01:21:23.600 |
the blocks so it is computing these two four elements here and if we apply the softmax 01:21:32.960 |
star to each of these blocks we are not using the maximum element in this row we are only using the 01:21:39.680 |
maximum element in each block which means that when we will use it in the downstream product with 01:21:46.320 |
vmatrix we will be summing values that are wrong because each of these values here will be based 01:21:54.240 |
on a maximum that is not the global maximum for this row it is the local maximum of this block 01:22:01.120 |
here and um and this block here will have the global the low it will use the local maximum of 01:22:08.320 |
this block here and this block here will use the local maximum of this block here etc etc etc so 01:22:15.520 |
what i'm trying to say is that when you sum p11 with v1 p11 may have some maximum local maximum 01:22:24.480 |
that is different than from the local maximum of p12 and p13 may have a different maximum 01:22:30.800 |
local maximum that of p1 p11 and p12 so we need to find a way to fix the maximum that was used 01:22:40.720 |
to compute the exponential here with the maximum found here in case the maximum here is higher than 01:22:49.600 |
the one local to p11 so if we have found for example here a maximum that is higher than the 01:22:55.760 |
maximum used here here then we need to fix this one and this one because that maximum in the 01:23:00.960 |
softmax should be the maximum for all the row not the one belonging to the each block and this leads 01:23:07.440 |
to our next step how to fix this first of all let me introduce a little pseudo code for computing 01:23:15.040 |
this output matrix here which is an output block matrix and later we will use this pseudo code to 01:23:23.120 |
adjust the error that we have made in some blocks in case the future blocks so the p13 has a better 01:23:30.640 |
maximum than p11 or p12 so to compute this output matrix o we go through so for example to compute 01:23:39.280 |
the first row we choose well p11 is what is is let's go back p11 is let me delete also this one 01:23:46.880 |
it's not needed anymore p11 is the softmax star of q1 k1 p12 is the softmax star of q1 k2 p13 is 01:23:58.960 |
the softmax star of q1 k3 p14 is the softmax star of q1 k4 which means that to compute this block 01:24:10.960 |
here here we first need to compute the p11 what is p11 well p11 is the softmax star of a block of q 01:24:19.520 |
and another block of k which in the case of the first row of the output matrix means that 01:24:26.320 |
it is the query 1 with the softmax star of the query 1 with q1 the softmax star of the query 1 01:24:34.160 |
with k2 the softmax star of the query 1 with k3 the softmax star of the query 1 with k4 which 01:24:39.920 |
means that we need to go we need to make a for loop through all the keys while keeping the query 01:24:46.400 |
fixed so to compute the first output row we need to do the softmax star to produce p11 we need to 01:24:53.680 |
do the softmax star of query 1 k1 and we sum it initially to zeros because we don't we need to 01:25:02.640 |
initialize our output somehow and we initialize it with zeros then we sum the next p12 which is 01:25:10.320 |
the query 1 with the k2 and then we sum the next p13 which is a query 1 with the k3 etc etc that's 01:25:16.960 |
why we have this inner loop here all right so however this output that we are computing is 01:25:23.280 |
wrong because i told you we have computed the softmax star using statistics the maximum value 01:25:30.000 |
that is belonging to each block and not the one that is the overall row of the original matrix 01:25:35.360 |
how to fix that we have a tool actually we have computed before an algorithm called the 01:25:41.520 |
online softmax i don't know if i referred to it before as the online softmax but it's called the 01:25:45.680 |
online softmax that allows to fix previous iterations when we are computing the current 01:25:52.400 |
iteration based how well let's review the online softmax we start imagine we are working with one 01:26:00.560 |
single vector so we are a vector made up of n elements what we do is we do a for loop where 01:26:07.600 |
we compute iteratively the maximum up to the height element and we fix the normalization factor 01:26:15.920 |
computed in previous iteration in case we found a better maximum at the current element 01:26:22.480 |
if this is not clear guys go back and watch the online software because this is very important 01:26:27.360 |
because this is what we are going to use to fix this p11 p12 blocks in case we found better 01:26:33.280 |
maximum in p13 or p14 etc so let's see how to apply this online softmax to this case here 01:26:42.240 |
so that we can compute so you may be wondering why are we going through all these troubles i mean 01:26:49.360 |
why the real reason is when first of all why did we introduce block matrix multiplication 01:26:55.920 |
because we want to compute matrix multiplication in parallel so you can think that each of this p11 01:27:02.000 |
because they are independent from each other and because each of them are using the maximum 01:27:06.720 |
belonging to each block they can be computed independently from each other then however we 01:27:11.760 |
need to somehow aggregate their value and to aggregate the value we need to fix the values 01:27:17.920 |
that have been calculated independently because we didn't when computing values independently we 01:27:22.480 |
don't have a global view we have a local view so we compute local blocks p11 p12 p13 etc etc 01:27:29.440 |
and then when we aggregate these values we need to fix them so that's why we are trying to come 01:27:36.240 |
up with this system of fixing values that have been calculated independently so how to fix this 01:27:44.000 |
let's look at the following algorithm first of all this o block here as i said before it is a block 01:27:52.400 |
of two rows where each row is made up of 128 dimensions and we have seen that before by 01:27:59.600 |
checking the dimensions of p11 and v1 the result of p11 v1 which means that for each output block 01:28:06.880 |
we need to take care of two maximums and two normalization factors so up to now i didn't use 01:28:13.840 |
the normalization factor we said that we are applying the softmax star which is the softmax 01:28:17.600 |
without the normalization but eventually we will need to compute this normalization so we want to 01:28:23.760 |
create an algorithm that fixes the maximum used to compute each of this p11 and also computes 01:28:30.640 |
simultaneously the normalization factor and at the end we will apply this normalization factor 01:28:36.240 |
and the way we will do it is as follows we start with initializing the maximum to minus infinity 01:28:43.040 |
one for each row that we are computing so our output block is made up of two rows so we need 01:28:47.840 |
one maximum for the top row and one maximum for the bottom row and also the normalization factor 01:28:53.600 |
which we initialize with zero because we didn't sum anything for now and the output we initialize 01:28:58.880 |
it with all zeros because we didn't sum anything to this output for now we compute the we uh to 01:29:07.040 |
compute the output row so this output block here so this output block here we need to go through 01:29:14.320 |
all the keys uh to produce this p11 p12 p13 p14 while the query is the query number one the query 01:29:21.840 |
block number one so the first step that we do is we compute the maximum of the first block p11 01:29:30.880 |
which is the row max so the maximum for each row of the block q1 k1 this is not p11 it's s1 sorry 01:29:42.400 |
guys this is s11 so we compute the maximum of this one and we call it actually s1 as you can see here 01:29:50.800 |
and then we can calculate p11 which is the softmax star which is the exponential of the 01:30:00.080 |
query multiple query one k1 so s1 minus the maximum in the local group s1 and we add it to our output 01:30:11.920 |
for now the output is initialized with zero so for now ignore this part here i will explain it later 01:30:19.440 |
so for now all one should be equal only to p11 v1 now at the step number two we may find in the 01:30:29.760 |
local group s12 so this one is s12 we may find a better maximum for the top row and the bottom row 01:30:39.040 |
and this maximum is the m2 which may be better than the previous maximum for each of these two 01:30:46.320 |
row but may also not be so we need to find a way to fix in case it's better and to not fix anything 01:30:52.720 |
in case it's not better and the way we do it is this so we compute the new maximum of the current 01:30:59.200 |
local row query two we calculate the p12 which is the softmax star of s2 which is s2 minus m2 which 01:31:11.200 |
is the local maximum and then we need to add it to the output however in this case we may have found 01:31:18.800 |
a better maximum so how to fix the o1 which only used the maximum that was local to s1 well we know 01:31:28.400 |
that we can fix that by using exponentials because each of this element of o1 is just an exponential 01:31:35.920 |
without the normalization because we are applying softmax star so how to fix an exponential with 01:31:41.920 |
another exponential so basically we are saying that we multiply o1 which is a matrix so let me 01:31:49.920 |
show you what is this matrix so o1 is a matrix made up of two rows so as you can see here i 01:31:58.160 |
have the shape of o1 it's a 2x128 matrix so this is the top row so o11 o12 blah blah until o1 128 01:32:11.680 |
then o21 o22 blah blah and o2128 we need to fix this value how we basically just using the 01:32:25.120 |
exponential that we have used in the online softmax that we have seen before so if we multiply this 01:32:31.120 |
matrix here by a diagonal matrix that is made as follows it's a diagonal matrix made up of two 01:32:39.280 |
elements because the exponential of m1 minus m2 will be a vector of two elements and exponential 01:32:46.160 |
of a element wise exponential is another vector of two elements and this basically means that 01:32:53.280 |
diagonal matrix where in the diagonal we have the elements of the vector to which we are applying 01:32:58.800 |
this diag operation which means that this value here will be the exponential of the first element 01:33:06.000 |
of m1 so let me show you how to write it exponential of m1 minus m2 minus m2 so the 01:33:21.840 |
first element so let's call it one here here is a zero here will be zero and let's delete this one 01:33:28.720 |
and we write another one here exponential m1 minus m2 but the second element of this vector 01:33:38.320 |
so basically the diag this notation here diag means basically just take the vector and distribute 01:33:44.560 |
it over a n by n matrix where n is the size of the vector to which is applied and all the other 01:33:50.720 |
elements of this matrix should be zeros this is what this diag means if we do this operation here 01:33:57.040 |
we will see that the output of this multiplication will fix each element of the top row using this 01:34:06.400 |
exponential and the bottom row with this exponential which will basically cancel out 01:34:12.880 |
this m1 that was computed in the previous iteration and introduce the m2 that we have computed in the 01:34:18.720 |
current iteration in each of these elements in this o block matrix okay so this output will be 01:34:32.240 |
this element will multiply by this one so it will fix o11 with this factor here and o21 will 01:34:41.520 |
not be fixed by will be multiplied by zero so it will not contribute to this first output element 01:34:47.920 |
so this element here will only depend on o11 fixed by the exponential of m1 minus m2 but the 01:34:56.400 |
first element of this vector and then o12 will also be fixed by um o12 will be fixed by this 01:35:04.240 |
exponential here but not by this one and all the dimensions of the first row will be fixed by this 01:35:08.800 |
exponential and all the dimensions of the second row here will be fixed by this exponential here 01:35:15.920 |
this this scalar here which is the second element of the vector exp of m1 minus m2 01:35:22.240 |
okay it was really challenging this one so so what we are doing is we compute p12 and we fix 01:35:32.400 |
all the elements in p1 by multiplying by this matrix here by multiplying by this factor here 01:35:39.520 |
matrix factor here and when we will compute step 3 we will fix step 2 etc etc etc now let's talk 01:35:48.240 |
about the normalization factor because for now we have been ignoring it the normalization factor 01:35:54.160 |
is something that we can compute while computing these maximums because it is provided in the 01:36:00.560 |
pseudocode of the online algorithm that we have seen before for the softmax so while computing 01:36:05.680 |
the maximum we can actually compute the normalization factor by fixing the normalization 01:36:11.040 |
factor of the previous iteration and this is exactly what we are doing here so at the first 01:36:16.000 |
iteration we compute the normalization factor using the local maximum and at the second iteration so 01:36:22.000 |
you can for now ignore uh this one because we are not fixing l0 with anything because l0 will be 0 01:36:27.840 |
so we are just basically um we are just computing this summation here so l0 will be 0 so this 01:36:35.520 |
factor here will be 0 um and when computing l2 so the normalization step at the second iteration 01:36:43.120 |
we will fix l1 with an exponential which guess what it's exactly the same exponential that fixes 01:36:50.800 |
the maximum uh the p11 so it is the previous estimation of the maximum minus the current 01:36:57.920 |
estimation of the maximum plus the new uh normalization factor using the local maximum 01:37:04.320 |
and we keep doing this job at the end we will obtain a correct output for this uh matrix for 01:37:13.760 |
for this block here but without the normalization how to apply the normalization well the 01:37:20.400 |
normalization is something that is we need to divide each element of this o by the normalization 01:37:27.200 |
factor but because we are keeping while iterating through these four loops we also calculate the 01:37:34.400 |
normalization factor we keep accumulating it until we reach the end of the iteration and then 01:37:40.240 |
we apply the normalization factor so we take the last output and we just divide it by l4 which is 01:37:46.240 |
the normalization factor calculated as the fourth iteration and that will fix the softmax all right 01:37:53.760 |
guys so now that we have derived the algorithm of how to compute this output of the attention 01:38:00.000 |
blockwise while also fixing the softmax which is done independently in each single block 01:38:05.600 |
we know that the normalization is done at the end i want to also prove it so what we done when 01:38:12.560 |
we introduced this algorithm that computes the softmax in an online way we proved by induction 01:38:19.120 |
that this algorithm is correct so at the end of this algorithm this l of the last iteration will 01:38:25.760 |
actually be the normalization factor that we can apply to get the softmax so we don't apply the 01:38:33.360 |
normalization while computing this output in an online way iteratively way by multiplying the 01:38:39.600 |
query with all the blocks of keys we apply it at the end of this four iteration and at the end of 01:38:47.840 |
this four iteration we will have the last output and we also know that the last l will contain 01:38:55.280 |
exact normalization factor that we need to apply to each row because this o of four is a block 01:39:02.160 |
of output rows which is if you remember from the attention mechanism each output the output 01:39:09.920 |
of the attention has the same shape as the input query vector which is a sequence of tokens so this 01:39:17.280 |
o is a sequence of tokens that we need to apply the normalization to and we know that the correct 01:39:23.600 |
factor is l4 so let's prove this simple formula l4 is a vector one for that contains as many 01:39:31.200 |
elements as there are rows in o4 so in this o block of rows suppose that it contains two rows 01:39:42.080 |
like in the algorithm that i have described so far in which we pretend that we are grouping 01:39:46.480 |
two rows of queries with two columns of keys together so the output o the block o will contain 01:39:54.960 |
two rows of the output so we will have two normalization factor in this l4 vector here 01:40:01.600 |
what we are doing with this formula is we are taking this l4 vector and we are creating a 01:40:08.480 |
diagonal matrix with it and then we are computing the inverse of this diagonal matrix so l4 is a 01:40:16.240 |
vector that contains two normalization factors so it's l i don't know let's call it l l4 element 1 01:40:24.720 |
and l4 element 2 this is our l4 vector then we have o4 o4 is a matrix as you can see from the 01:40:38.160 |
shape is a 2 by 128 matrix so o is let's copy it actually oh no it's not copied o4 is a matrix that 01:40:48.720 |
is two rows with 128 elements so the first row with 128 dimensions and the second row with 128 01:40:58.400 |
dimensions the first thing that we are doing with this l4 is we are converting it into a diagonal 01:41:04.080 |
matrix which will be a diagonal matrix 2 by 2 because it contains two elements so it will become 01:41:10.240 |
something like this so it will be l4 the first element of l4 0 and then 0 l4 the second element 01:41:19.520 |
of this vector then we are computing the inverse of this matrix the inverse of a diagonal matrix 01:41:26.640 |
is just the diagonal matrix with each element on the diagonal that becomes its reciprocal 01:41:33.760 |
this is from linear algebra it's not i'm making it i'm making this up so 01:41:38.320 |
the inverse of this matrix here is equal to the same diagonal matrix but where each element is 01:41:48.720 |
1 over l4 the first element of l4 0 0 and 1 over l4 the second element of l4 and then we are 01:42:00.000 |
multiplying this stuff here so let me delete some stuff so this stuff here is getting multiplied 01:42:06.880 |
by o which is a matrix that is a 2 by 128 so we are doing this multiplication now 01:42:15.440 |
multiply now the output of this so this is a 2 let me write it 2 by 2 multiplied by 2 by 128 01:42:27.920 |
will be a matrix that is 2 by 128 where the first dimension of the first row of the output of this 01:42:40.640 |
operation will be the dot product of this call this row here with the first column so basically 01:42:49.600 |
we are dividing this element here by l4 the first element of l4 the second output element here will 01:42:56.240 |
be the dot product of this row with the second column so we are only multiplying we are dividing 01:43:02.160 |
the the second element here of this input vector here by l4 the first element of l4 because the 01:43:09.680 |
all the elements of the second row will be multiplied by 0 so they will not contribute 01:43:13.440 |
to this output row while the second output row will be the dot this element here will be the 01:43:18.480 |
dot product of this row with the first column the first element here is multiplied by 0 so it will 01:43:25.280 |
not contribute to this output so it's only the second element the first row of the second this 01:43:30.800 |
first element of the second row of the input matrix here will be divided by l4 2 so basically 01:43:37.040 |
this will be applied will divide all the elements in the second row and this will divide all the 01:43:42.720 |
element in the first row in producing this one here which is exactly what we need to do when 01:43:47.040 |
we want to normalize we need to apply this normalization factor and this should help you 01:43:52.080 |
better visualize why this operation is normalizing the vectors of the output at the end and still 01:43:57.920 |
obtaining the same result now let's proceed further all right guys finally we are ready to 01:44:04.400 |
see the flash attention forward pass by also comparing it with what we have derived so far 01:44:11.120 |
so if you look at the flash attention paper first of all this is the flash attention 2 forward pass 01:44:16.720 |
and later i will explain what are the differences between the flash attention 1 and the flash 01:44:20.240 |
attention 2 i didn't want to jump directly to this forward pass because i believe that even if the 01:44:27.440 |
derivation like the derivation was a little uh difficult to follow i believe that it gave you 01:44:33.280 |
some intuition into what is happening so even if you understand 50 percent of it that's enough 01:44:37.760 |
because later we will also code it and you should reach like a 90 percent of understanding so every 01:44:43.200 |
time we introduce some new information it should improve your your understanding so basically in 01:44:49.600 |
flash attention what we are flash attention 2 especially we take our as input we have our 01:44:56.400 |
query key and values which are a sequence of tokens each token is made up of a vector 01:45:01.760 |
of d dimensions and the d lowercase d dimensions and we divide this query guess what into blocks 01:45:09.920 |
in how many blocks well depending on this parameter br which is the size of the 01:45:16.800 |
query block that we want to choose so how many rows of query we want to group together into one 01:45:22.400 |
block and we also do it with the k and v and we divided that into blocks of depending on this 01:45:32.080 |
parameter bc then we also initialize the output which is the output that we want to produce so 01:45:38.480 |
what is the flash attention computing well the flash attention is computing the following so 01:45:43.760 |
it's computing the softmax softmax of the query multiplied by the transpose of the keys divided 01:45:51.280 |
by the some normalization factor multiply that by b and so that's what it's going to compute 01:45:59.920 |
and it's going to compute it this way first of all there is an outer loop through the queries 01:46:05.520 |
which corresponds to the same pseudo code that we have seen before because we want to compute 01:46:11.520 |
each block of the output matrix in parallel with the with respect to the others so basically we 01:46:19.120 |
want to compute this output block and this block output block independently this output block here 01:46:26.480 |
depends on the query one and all the keys this output block here depends on the k query two and 01:46:33.040 |
all the keys this output block here depends on the query three and all the keys where query one 01:46:38.560 |
is not the first query but it's the first group of queries or first block of queries query two is 01:46:44.320 |
not the first query two is not the second row of the very metric but it's the second block of the 01:46:49.360 |
query matrix etc etc and so that's why we have this outer iteration among all the blocks because 01:46:59.760 |
we want to compute all those blocks of the output matrix in parallel but to compute each of this 01:47:05.360 |
output block we need to go to an iteration among all the keys that's why we have an inner loop on 01:47:11.520 |
the keys and we do exactly the same operation that we have done so far by hand so first we compute 01:47:18.560 |
the s matrix which is what the each block of query with the corresponding block of the keys 01:47:24.080 |
then we compute the local maximum to the current s block this is the local maximum and we 01:47:32.160 |
compare it with the maximum of the previous iteration because that's what we do in the 01:47:36.400 |
online softmax then we compute the p block which is the softmax star of the s block 01:47:45.200 |
minus the local maximum of the s block then we compute the normalization factor what is the 01:47:53.440 |
normalization factor it is the summation of all the exponential of the softmax star but 01:48:01.200 |
by fixing the normalization factor of the previous step and we know how to fix the 01:48:07.760 |
normalization factor because we just multiply by an exponential which is the previous maximum 01:48:13.680 |
minus the current maximum that's what this factor is and then we compute the output exactly using 01:48:19.520 |
the same correction factor that we have seen before which is the diagonal matrix made up of 01:48:24.720 |
the diagonal where on the diagonal you have the elements of this vector here which is the 01:48:30.800 |
exponential of the previous maximum minus the current maximum multiplied by the output of the 01:48:36.320 |
previous step because we want to fix the previous step because it was based on the previous p which 01:48:40.640 |
was using the maximum of the local previous p plus the current p v which is based on the current 01:48:47.920 |
local maximum and it will be fixed by the next iteration okay and at the end after we have 01:48:55.520 |
gone through all the case so we have computed all the output block but we didn't apply the 01:49:02.000 |
normalization factor and it's applied at the end because while going through each key we are 01:49:07.200 |
calculating the l normalization factor for the softmax because inside of this for loop we are 01:49:13.360 |
just computing the softmax star so we are not normalizing each value so at the end someone has 01:49:18.720 |
to normalize it and it will be this instruction here which is use the normalization factor that 01:49:28.640 |
we have computed over all the iterations and apply it to each element of O because the difference 01:49:34.240 |
between the softmax star and the actual softmax is just the division by the normalization factor 01:49:42.720 |
and this instruction here is actually dividing each O with the corresponding normalization 01:49:48.080 |
factor one for each row of the block each row in the output block that we are computing 01:49:54.800 |
later we will see also what do we do what is what does it what is this SRAM what is the HBM 01:50:03.040 |
for now i just want you to concentrate on the operations that we are doing and they are exactly 01:50:08.320 |
the same operations that we have done so far later we will see also why do we need to save this stuff 01:50:14.160 |
here and etc etc but for now you should have enough knowledge to be able to follow what is 01:50:20.880 |
written in the flash attention paper for with respect to the forward pass algorithm and what 01:50:27.360 |
we are doing basically is just block matrix multiplication and while computing this block 01:50:32.160 |
we fix the previous block by using tricks of the exponential all right now that we have seen 01:50:39.920 |
forward pass of the flash attention before we can implement it we still lack a little bit of 01:50:44.560 |
knowledge because we don't know anything about the GPUs and we don't know anything about CUDA 01:50:49.360 |
and we don't know anything about Triton so that's what we are going to see next all right guys it's 01:50:55.280 |
time for us to explore finally the GPU and the CUDA programming model well let's start by comparing 01:51:02.240 |
the CPU and the GPU and this will let us understand how CUDA works then so first of all what is the 01:51:07.440 |
CUDA and what is the GPU the GPU is the hardware unit that we are that we buy and CUDA is a software 01:51:13.600 |
stack made by made by NVIDIA to write software for this GPU that they sell AMD has its own software 01:51:21.760 |
stack and other manufacturer have their own in this particular video we will be seeing example 01:51:27.280 |
of CUDA kernels but the knowledge that you will get can apply also to other GPUs now the first 01:51:33.680 |
difference between a CPU and the GPU is its purpose the your computer is right now running 01:51:40.240 |
on a CPU and your operating system is interfacing with the CPU in using the the so-called scheduler 01:51:48.080 |
so right now probably you are running a browser you are also running some other software on your 01:51:52.240 |
computers on your computer and the scheduler is tasked with switching between them very fast on 01:51:58.720 |
your CPU in such a way that it looks like to you that the processes are running concurrently 01:52:03.200 |
this actually is a fake kind of parallelism unless your CPU also has multiple cores 01:52:10.560 |
which nowadays CPUs do have so a CPU usually has one or multiple cores but not so many of them so 01:52:17.120 |
usually have a dual core or quad core or eight core CPU and each of these cores can execute 01:52:23.680 |
instructions in parallel the CPU is tasked the the main purpose of the CPU is to execute 01:52:32.320 |
many different tasks and switching between them very fast so maybe you have a browser that is 01:52:38.240 |
running a small game and then you have another movie player but then you have a word processor 01:52:43.440 |
and then you maybe have some utility to manage your to download files etc so most of these 01:52:50.640 |
programs actually are not compute intensive are actually I/O bound meaning that most of the time 01:52:55.120 |
they are either waiting for the network or they are waiting for the disk and they are very different 01:53:00.640 |
from each other in the purpose so the browser is completely different from a movie player 01:53:04.640 |
and it's completely different from a word processor so the job of the CPU is to actually 01:53:10.720 |
reduce the latencies of processing all these operations and it's highly optimized to process 01:53:17.040 |
to optimize each of these execution units called the cores which means that each core has a part 01:53:24.000 |
that is tasked to understand first of all what is the next instruction to run or to predict the 01:53:31.200 |
branch of how the what the next operation may be based on the conditions that you are running for 01:53:37.920 |
example if you have a if condition the branch predictor can understand what is the more most 01:53:42.560 |
likely next instruction and can do some optimizations also the CPU is has a lot of 01:53:48.560 |
caches to reduce the latencies in the loading data from all the devices it can interface with it can 01:53:54.000 |
interface with the the RAM for sure but it can also interface with the disk it can also interface 01:53:59.680 |
with some peripherals like the printer like the mouse like the keyboard etc etc on the other hand 01:54:05.680 |
the GPU is not tasked to do many different things at the same time but it's tasked to do 01:54:10.960 |
one thing or a few things but on a massive amount of data so the operations that we do on the GPU 01:54:18.160 |
are requires a lot of computation and for that for this reason most of the area so the physical area 01:54:26.000 |
of the GPU is dedicated to compute units so this green stuff that you can see here and these are 01:54:32.080 |
called cores and you can see that the part that is dedicated to the control area so the part that is 01:54:38.720 |
tasked with understanding what is the next instruction to run or to do some optimization 01:54:44.000 |
in this the program is very little you may be thinking well does it make it does it make the 01:54:50.800 |
GPU less fast compared to the GPU to the CPU well not really because we have many more cores that 01:54:56.720 |
can compensate for these higher latencies okay i can give you a lot of knowledge about the GPU 01:55:04.800 |
from a theoretical point of view i think the best way to understand the CUDA programming model is 01:55:09.040 |
just to jump into the code so we don't get bored okay imagine we have a very simple task and we 01:55:15.280 |
have a vector we have two vectors a and b and we want to calculate the sum of these two vectors 01:55:22.640 |
into and save the result into another vector c where each item is the element wise sum of the 01:55:29.120 |
corresponding item of a and b how would you proceed with this task on the CPU well you would 01:55:36.560 |
do a for loop for example so for example you would make a for loop that starts from the first index 01:55:43.120 |
so the index is zero and c of zero is equal to a of zero plus b of zero then c of one is equal to 01:55:50.560 |
a of one plus b of one etc and you do a for loop on all the elements of this vector in the GPU we 01:55:59.600 |
want to do the same operation but in parallel because we have a lot of compute units called 01:56:04.480 |
cores and we want all of them to work in parallel so the first thing that we need to understand is 01:56:09.200 |
how to divide the work that we are going to do into sub units of work and dedicate each core 01:56:16.640 |
to one subunit one simple subdivision would be okay the first core should do this summation 01:56:22.640 |
the second core should do this summation the third core should do the summation etc etc so 01:56:28.720 |
imagine we have a eight element vector we need eight cores to do this element wise summation 01:56:36.640 |
we will call the course threads because it should also remind you of the multi-threading 01:56:42.000 |
that we already use in operating system so multiple threads work concurrently on the same 01:56:46.800 |
on the same or similar job in the GPU let's look at the code now the code that i am going to show 01:56:53.520 |
you is a CUDA kernel and it's written in c but you don't have to understand c and you don't have to 01:56:59.040 |
understand this code what i want you to understand is the intuition behind it because later we will 01:57:03.760 |
need this knowledge and convert it into triton which is python and you should already be familiar 01:57:08.000 |
with python so let's go to the code and i have a very simple vector addition we can see it here 01:57:16.240 |
okay first of all how to do a vector summation usually the gpu is interfaced with a cpu 01:57:26.000 |
and the cpu has to first of all tell the gpu what is the data it is going to work with so the cpu 01:57:33.600 |
needs to have these vectors it needs to transfer them to the gpu then the gpu needs to do this 01:57:39.920 |
vector summation then the cpu has to copy back the information from the output from the gpu to 01:57:45.920 |
the cpu and then make it available to the program this is what we are going to do here so we are 01:57:52.080 |
going to allocate a three vectors of size n one called a one called b and one is the output vector 01:58:00.560 |
we initialize their items randomly so the a of i is a random number between 0 and 100 excluded 01:58:09.440 |
then we allocate memory on the gpu to hold these vectors and then we copy them to the gpu so we 01:58:16.800 |
copy the a vector to the gpu and the b vector to the gpu of course we don't copy the result because 01:58:21.760 |
that's what we want the gpu to populate with the output so we just allocate it on the gpu what we 01:58:28.000 |
don't copy our output vector on the gpu because it's it's made of random values then what we do 01:58:36.400 |
is we launch the kernel the launching the kernel means that we launch a program that the gpu should 01:58:41.680 |
execute in parallel on multiple threads or multiple cores each of these threads should do a unit of 01:58:48.000 |
operation a unit of work that is independent from the others actually they can be dependent on the 01:58:52.800 |
others and but we will not be talking about synchronization so we launch this kernel and 01:58:58.720 |
what we are saying in this line is launch one block of threads and later we will see what are 01:59:04.960 |
blocks but you can think of you can ignore this one for now what we are saying here is launch n 01:59:10.320 |
threads so n parallel operations on with the following arguments so the output where we want 01:59:18.160 |
to save data the input array a and the b input b and the number of elements let's see what happens 01:59:25.760 |
inside of this method this method is following a particular syntax that is um how to say CUDA 01:59:33.360 |
specific so this global is actually added it's a like a superset of the c language where we have 01:59:39.120 |
some additional keywords that belong to CUDA so it's not really c it's CUDA c so it's a very 01:59:48.080 |
simple method as you can see and the first thing that we need to do is CUDA cannot know what each 01:59:55.760 |
thread should do it's we should tell each thread what to do so the mapping between the data and 02:00:02.400 |
the what each thread should do it's up to us as software engineer CUDA what we'll do is when we 02:00:08.880 |
ask it to launch n threads in parallel it will allocate n threads and assign a unique identifier 02:00:15.600 |
to each of these threads in our simple case we can see it like this so it will assign the first 02:00:23.120 |
thread the index zero so we are asking for example imagine we have a vector of eight elements it will 02:00:29.760 |
assign the first thread index zero here i call it one but it's it's wrong but we can write another 02:00:36.480 |
number here so this will be actually thread zero this will be thread one this will be thread two 02:00:41.120 |
thread three thread four thread five thread six and thread seven so let me delete this one 02:00:51.680 |
and what we are saying is that the item that each thread should process is equal to its thread 02:01:08.160 |
index so this is the thread zero so it should process the item with index zero this is the 02:01:14.240 |
thread one and it should process the item with index one this is the thread number two and it 02:01:19.040 |
should process the item with index two and this is what we are doing in this line of code we are 02:01:24.480 |
saying which item each thread should process which is exactly the thread identifier so the thread id 02:01:33.120 |
later we will see why why do we have this dot x but that's for later next thing that you should 02:01:40.080 |
see is okay we are doing the output of the height position is equal to the a vector as the height 02:01:48.240 |
position plus the b vector as the height position so it's a very simple summation element wise 02:01:52.480 |
you may have noticed this if statement why do we need an if statement if we already know that we 02:01:58.640 |
are going to launch eight threads and of course i will be between um we already know that we are 02:02:07.040 |
going to launch n threads so i should of course be less than n because each thread id will be 02:02:12.160 |
between zero and n minus one so why do we need this if condition this is needed because when you 02:02:18.640 |
CUDA when it launches a number of threads this number of threads is always a multiple of a 02:02:26.720 |
unit which is a 32 in the case of the CUDA so if we have like 34 elements in a vector 02:02:34.080 |
and we ask CUDA to launch 34 threads CUDA will not launch 34 exactly it will launch 64 threads 02:02:40.960 |
so multiple of 32 which is the warp size by the way um and uh what we need to do is we need to 02:02:49.840 |
ask these threads to only work for we only need to ask the threads that have a corresponding element 02:02:58.240 |
to work and all the others that don't have a corresponding element because the the vector 02:03:03.120 |
is not large enough for all of them to not do anything so do not enter this uh if statement 02:03:08.400 |
there is another thing that we should learn which is actually the threads 02:03:16.320 |
actually when we have a group of threads in in a CUDA programming model but i believe 02:03:21.360 |
also in other GPUs a group of threads of 32 threads is called a warp and this 32 threads 02:03:30.480 |
will share the same control unit so let's go back to the slide so as you so as you can see here we 02:03:39.600 |
have this yellow unit here in the GPU and a group of threads will share the same control unit which 02:03:45.280 |
means that what is this control unit it's a part of the hardware of the GPU that is tasked with 02:03:50.960 |
understanding what is the next instruction to run now if the group of threads is sharing the same 02:03:57.120 |
unit it means that this group of thread will always execute the same statement at any time 02:04:03.520 |
they will always work in synchrony will always work on the same instruction they it's it cannot 02:04:10.000 |
be like this thread is working on one instruction and this one is working on another instruction 02:04:14.480 |
what does this mean on a programming level it means that if when we launch a group of threads 02:04:21.200 |
of course CUDA will spawn more threads than we need if the if the number of elements of our 02:04:26.720 |
vector is not a multiple of 32 this means that when we did this thread they will first execute 02:04:33.520 |
this operation and each of them will have its own value of this thread id so they will execute the 02:04:40.640 |
same instruction but the data at each instruction may be different because each of them have their 02:04:45.840 |
own registers which means that they will always they will for example reach this statement here 02:04:52.560 |
and the first thread will have i equal to zero the second thread will have i equal to one etc 02:04:57.200 |
etc even if they are executing the same instruction this programming model is called the single 02:05:01.760 |
instruction multiple data CUDA likes to call it a single instruction multiple thread doesn't matter 02:05:07.120 |
for us it just means that they will always execute the same instruction but the value of the 02:05:12.480 |
variables may be different then after executing this statement they will reach this statement 02:05:19.040 |
here the if statement and of course some of them will evaluate this statement to true and some of 02:05:24.640 |
them will execute the statement to false which also means that some of them should enter this 02:05:30.480 |
if statement and some of them should not enter this if statement however because the control 02:05:35.920 |
unit is the same for all of them they will be forced to enter this if statement even if they 02:05:41.040 |
should not so how CUDA manages this control divergence it will basically make work like 02:05:47.200 |
this all the threads for which this if statement is equal to true will enter this if and will 02:05:52.960 |
execute the instructions inside of this if and all the threads that have this statement equal to false 02:06:00.800 |
so the condition of this if equal to false they will enter the for loop because they cannot not 02:06:06.080 |
enter it because they should be always executing the same instruction at any time but they will 02:06:10.720 |
just not do any operations inside of this for loop they will just sit idle this is um called the 02:06:18.000 |
control divergence and it can reduce the um the the throughput of your program so you want to 02:06:23.920 |
minimize it but you may be wondering why doesn't the gpu dedicate a control unit to each core so 02:06:31.360 |
that they can work independently from each other because the control unit is expensive to add in 02:06:35.600 |
the chip area of the gpu it's much more efficient to add more workers instead of adding a control 02:06:42.000 |
area control units for each worker so this is a design choice of the gpu and it works fine 02:06:48.560 |
okay now that we have seen how a kernel works let's move forward to another example 02:06:54.880 |
all right the next example that we are going to see is the following is the same as the as before 02:07:02.000 |
so we are going to do a vector addition but imagine that we have a very large vector so imagine that 02:07:07.360 |
we have a vector with 1 million elements of course we could do like before so we launch a kernel with 02:07:13.760 |
1 million threads the problem is CUDA will reject it because it's a i don't have 1 million threads 02:07:20.240 |
to run in parallel so how can we proceed in this case because usually we are working with very 02:07:25.040 |
big matrices or very big vectors so we need to process a massive amount of data so how to manage 02:07:31.680 |
a parallel um let's say parallel computation when we do not have enough uh computation cores 02:07:40.400 |
one way is to divide the input vector into blocks of elements for example we may decide 02:07:48.000 |
for example imagine our gpu only has 32 cores in total we may divide our input vector into 02:07:56.480 |
blocks of size 32 such that the first 32 element are the first block the next 32 element are the 02:08:02.800 |
second block the third 32 element the third block and the last 32 element are the last block 02:08:08.880 |
in this way we can ask the gpu to work on one block at a time so we can say okay work on the 02:08:16.560 |
first block and after it has processed the first block it can work on the second block and then 02:08:22.640 |
the third block and the fourth block this also allows the gpu itself to manage a subunit of 02:08:29.200 |
work because imagine now we have blocks of 32 elements but we have a gpu of 64 cores 02:08:36.320 |
the gpu we can also schedule two blocks at the same time because it has enough cores 02:08:42.480 |
so we need to give some granularity uh we need to reduce the ground increase the granularity of our 02:08:49.040 |
data to let the gpu decide how many blocks to schedule this is the reason we introduce blocks 02:08:56.080 |
inside of CUDA so let me make a concrete example but with a very simple assumption imagine our 02:09:02.400 |
gpu only has two cores or let's say four cores actually so we have n is equal to eight elements 02:09:10.480 |
eight and we have four cores in total so what we can do for example is to is divide this 02:09:20.320 |
vector into groups of either four cores or even less let's say two two elements at a time so this 02:09:27.120 |
is the block number one this is the block number two this is the block number three and this is 02:09:34.000 |
the block number four we can ask CUDA to launch a kernel that is made up of four blocks and where 02:09:43.120 |
each block is made up of two threads so when we launch the CUDA kernel we can show the code now 02:09:51.040 |
we ask the CUDA where is the instruction this first instruction tells CUDA how many blocks 02:10:00.240 |
we have and the second part of this in this symbols tells how many threads we have for each 02:10:09.920 |
block in our case we want n divided by the block size number of blocks where the block size in my 02:10:18.560 |
picture is two so how many blocks we will have we will have a number of blocks so the number of blocks 02:10:30.640 |
is n divided by two where two is the block size so this is the block size and this will be equal 02:10:40.800 |
to four blocks each of size equal to two and this is what we are doing here so we are saying that 02:10:48.880 |
the number of blocks is okay the ceiling because it may not be a multiple of the block size n 02:10:55.200 |
of n divided by the block size and this tells how many blocks we have and this is will be the 02:11:01.280 |
this will define our grid it means the grid is basically telling how many blocks we have 02:11:05.920 |
and then each block is made up of block size number of threads then the problem is how do 02:11:12.240 |
we assign the work to do to each of these threads when we launch a kernel like this 02:11:18.800 |
with this configuration so the number of blocks and the number of threads per block CUDA will do 02:11:24.720 |
the following job it will assign this block each block a index called the block id where the block 02:11:34.880 |
id of the first block is zero so let me write here so this will have the first block will have a 02:11:40.560 |
block id equal to zero and in each block it will assign a thread id and the thread id of the first 02:11:49.920 |
thread of each block will be the thread zero and the second thread will be the thread number one 02:11:55.440 |
the second block will have a block id block id equal to one and the first thread of this block 02:12:03.600 |
will be the thread number zero and the second thread of this block will be the thread number one 02:12:07.840 |
the third block will have a block id block id equal to two and the first thread will be the 02:12:16.160 |
thread number zero and the second thread will be thread number one etc until the last block which 02:12:21.760 |
will be equal to three this will be thread number zero and thread number one the problem is now 02:12:27.280 |
based only on the index of the block and the index of the thread how can we map it to what element of 02:12:34.480 |
the vector each thread should work with one simple assignment would be to just do well you can see 02:12:41.920 |
that in this case we need the this vector this thread here to work with element zero this one 02:12:49.200 |
should work with element one this one should work with element number two this one to the element 02:12:54.560 |
number three this one four this one five six and seven this five is so ugly so let me write it 02:13:05.280 |
again how can we find the mapping given only the block id and the thread id how can we find which 02:13:13.200 |
element it should correspond to well it's very simple formula so you can see that the element 02:13:18.080 |
let's call it the element id which in the code i call it i is equal to the block id 02:13:28.640 |
multiplied by the size of each block which is a block size let's call it 02:13:33.200 |
block size yeah i have it block size plus the thread id 02:13:40.400 |
because in the case of the first thread this will be equal to zero multiplied by two plus zero which 02:13:49.280 |
is zero in this case it will be equal to zero multiplied by two which is zero plus one and 02:13:54.640 |
it will be equal to one in this case it will be equal to one because block id is equal to one one 02:13:59.600 |
multiplied by two is equal to two plus zero is equal to two etc etc and you can see that this 02:14:04.400 |
formula works for all the threads so the mapping when we launch a CUDA kernel we are telling the 02:14:10.800 |
gpu how many blocks we want and how many threads there are in each block but CUDA has no notion of 02:14:17.760 |
how to map each CUDA has no way of knowing how to map each thread into the element it should work 02:14:28.720 |
with that's up to us and that's what we are doing here when we are creating this kernel here so we 02:14:36.560 |
are telling that each element each thread should work with the ith element of the vector where i 02:14:43.360 |
is calculated as follows the block id to which this thread belongs multiplied by the block size 02:14:48.960 |
so how many threads there are in each block plus the thread id and this will tell the ith element 02:14:56.400 |
this particular thread should work with by giving in let's go back to the slides by choosing the 02:15:06.240 |
block size equal to two and having four cores the gpu can choose to run one block or two block 02:15:14.000 |
concurrently if it has enough free cores so that's why we want to work with by block by block because 02:15:20.080 |
it allows the gpu to choose how it want to parallelize the operations if it has enough 02:15:25.120 |
cores and we don't need to have n cores for n element vector we can divide it into smaller 02:15:32.240 |
blocks and let the gpu manage the scheduling let's see one last example and then we move on 02:15:37.360 |
to triton imagine now we want to do a matrix addition instead of doing a vector addition 02:15:43.360 |
now in a matrix addition we have data that we can see on two axes one is the rows and one is 02:15:51.200 |
the columns it's usually we represent the vertical axis as the y-axis and the horizontal axis as the 02:16:00.080 |
x-axis by using the same blocked intuition that we used before so dividing the data input data 02:16:09.760 |
into blocks this is how we can divide the labor of our matrix addition into blocks for example 02:16:16.880 |
we can divide our rows into blocks and call this one the block zero and this one in the block one 02:16:23.120 |
and this one is the block two the same we can do on the x-axis so we can choose this one as the 02:16:30.080 |
block zero this one as the block one and this one as the block two on the x-axis with x is the column 02:16:36.320 |
axis and the y is the row axis we don't even have to choose the same block size for the rows and the 02:16:43.440 |
columns we can even choose the to group together three columns and two rows instead of doing two 02:16:50.400 |
and two in this case we need to find because as we said before when we launch a CUDA kernel CUDA 02:16:56.800 |
will just assign ids to the blocks and the threads in each block then it's up to us understanding 02:17:03.760 |
what to how to map the id of the block and its corresponding thread id into the data element 02:17:09.360 |
that this particular thread should work it should work with so in the case of matrix addition we 02:17:15.120 |
could say that each thread should work with one output element of the output matrix c so it will 02:17:21.600 |
become the sum of the a element plus the b element and it should map it to the c matrix output matrix 02:17:30.640 |
so how to do it imagine we have six rows and we have six columns one easy way would be to divide 02:17:38.800 |
these rows into three blocks each made up of two rows and each column into three blocks 02:17:45.920 |
each block made up of two columns CUDA will launch as many blocks as there are the combinations of 02:17:55.200 |
the rows and column blocks so in this case we have three blocks for the columns and three 02:18:00.720 |
blocks for the rows so it will launch nine blocks so this is the block number 00 because it's a 02:18:10.000 |
CUDA will identify the dimensions of the block based on the axis in which we have divided it 02:18:17.920 |
so we will call this the x dimension the columns and the rows we will call it the y dimension 02:18:22.960 |
so it will launch as many blocks as there are combinations of x and y in this case we have nine 02:18:28.240 |
so this will be the block 00 this will be the block 01 this will be the block 02:18:33.040 |
02 this one will be the block 10 11 and 12 etc etc inside of each block we will also divide the 02:18:42.480 |
threads into x threads and y threads along the two dimensions so this will be the thread 0 and 02:18:48.720 |
the thread 1 along the x axis in the x block and this will be the thread 0 and the thread 1 in the 02:18:56.080 |
y in the in the block 0 of the y axis and each block will have two threads and they will be 02:19:03.840 |
identified as thread 0 and the thread 1 so let's look at how the launch grid works in this case 02:19:10.560 |
so imagine we have a matrix with number num rows number of rows and num columns num calls number 02:19:19.360 |
of columns and we want to divide each row the rows into block size number of rows and the calls 02:19:27.600 |
block size number of columns we define basically the number of blocks that we need is this one so 02:19:35.440 |
this is just a fancy way of writing the ceiling of the num rows divided by the rows block size 02:19:41.440 |
and this is just a fancy way of writing the ceiling of the number of columns divided by the 02:19:46.800 |
calls block size this tells us how many blocks we will have on the rows and how many we will 02:19:50.800 |
have on the columns the grid you can see here which tells us how many blocks we have is a tuple 02:19:57.040 |
that accepts three values which tells how many blocks we want on the x dimension how many we 02:20:02.960 |
want on the y dimension and how many we want on the z dimension we are not going to use the z 02:20:08.480 |
dimension because we only have a matrix then inside of each block how many threads we want 02:20:14.240 |
for the x dimension and for the y dimension as the x dimension we have chosen the columns so we 02:20:20.720 |
are saying how many blocks we want the columns and how many blocks we want for the rows and then 02:20:25.440 |
inside of each block how many threads we want for the column block and how many threads we want for 02:20:30.720 |
the row block this will define our launch grid and what CUDA will do it will just launch this 02:20:37.280 |
following configuration so it will launch as many blocks as there are combinations of x and y's and 02:20:42.400 |
inside of each x and y it will assign a thread id in such a way that the thread zero on the x-axis 02:20:50.720 |
so there will be two threads on the x-axis and the two threads on the y-axis of each block 02:20:55.840 |
now let's try to understand how to map just based on the block id on the x-axis just based on the 02:21:02.800 |
block id on the y-axis and the thread id on the x and y-axis how to map it to the one element of 02:21:09.840 |
the output matrix let's look at the code so first we can use the following formula to identify which 02:21:17.680 |
row this element should work with which the which because each element of a matrix is identified by 02:21:26.160 |
two indices one is the row identifier and one is the column identifier the row identifier we can 02:21:31.680 |
look at it like the block id multiplied by the block size plus the thread id let's see why it 02:21:38.400 |
makes sense so in this case for example this thread will work with the row zero because the 02:21:44.640 |
block id is on the y-axis is zero and the thread id zero so it's a block id multiplied by the block 02:21:51.600 |
size so zero plus zero it will be zero so this element will be working with the row number zero 02:21:58.160 |
and which column it will be working with well it will be working with the block id zero multiplied 02:22:04.000 |
by the block size on the column which is again zero i mean this block size is two but multiplied 02:22:09.760 |
by zero it will be zero plus the thread zero so it will be zero this element here on the here 02:22:16.560 |
it will be the block id of the y-axis multiplied by the block size plus the thread so it will be 02:22:24.720 |
the element zero on the row and for the columns it will be the element one let's see another one 02:22:31.200 |
for example here uh for example this element here so this um how this thread will uh which element 02:22:40.000 |
it will work with well it will be the block size on the y-axis multiplied by the the block id on 02:22:45.760 |
the y-axis multiplied by the block size so it will be one multiplied by two so that will be our row 02:22:51.840 |
so the row number two uh which makes sense because it's the um this is the row zero this is the row 02:22:58.160 |
one and this is the row two and the column will be the block id on the x-axis which in this case 02:23:07.040 |
is equal to one multiplied by the block size which is equal to two so two plus one is equal to three 02:23:13.120 |
so this thread here will work with the element number two three and this formula now makes sense 02:23:20.400 |
so this is how we use the block id and the thread id inside of each block to map it to which element 02:23:26.720 |
this particular thread should work with so as i said before cuda has no notion of knowing which 02:23:33.200 |
element this particular thread should work with this is up to us just based on the block id and 02:23:38.640 |
the thread id that cuda assigns then we make sure that the row index is less than the number of row 02:23:45.440 |
and the column index is less the number of columns why because as i said before when we launch um 02:23:50.480 |
blocks and threads cuda will round up that number to a multiple of 32 in the case of the threads 02:23:56.960 |
so which means that some of this thread should not work with any data so we make sure that all 02:24:01.520 |
the threads that should not have the corresponding element to work with they should be just sit idle 02:24:06.960 |
inside of this if statement but the one that have it they should go enter and do some job so we 02:24:13.680 |
calculate the index of the element of the matrix that this particular thread should work with as 02:24:20.960 |
follows which is the row index multiplied by the number of columns plus the column index 02:24:26.080 |
this is just another way of writing a or for example this is just another way of writing 02:24:36.160 |
a of row index call index but the way we allocate arrays in c or c++ is a flattened array where all 02:24:49.680 |
the rows are one after another so we need to identify the element inside of the array based 02:24:55.360 |
on its row index and column index and this is the formula that we use to identify it 02:24:59.760 |
if you have never worked with um arrays in c++ or c then it doesn't matter because later we will see 02:25:08.800 |
tensor layouts and this will be much more clear but if you have already worked with then you 02:25:13.280 |
already know how to index an element inside of a multi-dimensional array in c++ and then we compute 02:25:20.400 |
the output as as usual so i know that this has been a lot of information so what should we 02:25:26.560 |
should we remember from this the first thing that we should remember is that we decide how to divide 02:25:31.040 |
the work on whatever matrix we are working with or whatever thread we are working whatever vector 02:25:36.720 |
we are working with we tell cuda how many blocks we want and we tell cuda how many threads we want 02:25:42.000 |
in each block based on the identifier of the block id and the thread id we should come up with a 02:25:48.080 |
strategy on how to map it to a subunit of work so which part of the matrix or which part of the 02:25:53.760 |
vector that particular thread should work with um now the next step for us is to understand the 02:26:00.880 |
tensor layouts because we are going to work with the tensors and we need to understand how the 02:26:06.240 |
tensors are layout in the memory of the gpu or in the cpu as well actually so we need to understand 02:26:13.120 |
what is the row column row major layout and the column major layout what is the stride etc 02:26:18.800 |
and convert all the knowledge that we have about cuda into triton so that we can then code with 02:26:24.400 |
triton our kernel so let's go all right guys finally it's time for us to explore tensor layouts 02:26:32.640 |
now why do we need to explore tensor layouts because before we we have seen some examples 02:26:38.480 |
of cuda kernels and when you give a matrix to cuda or to a cuda kernel or a vector to cuda 02:26:46.160 |
kernel cuda will not give you will not give you the entire matrix like like in python where you 02:26:51.440 |
can access each element by its index cuda will just give you a pointer a pointer to the starting 02:26:57.600 |
element of that particular matrix or the starting element of that particular vector then it's up to 02:27:03.440 |
you to calculate the memory address of all the remaining elements so suppose that we have a 02:27:09.120 |
simple vector in pytorch this simple vector could be the following which is a vector of shape 7 02:27:16.480 |
because it's a tensor with only one dimension with shape 7 which is the number of elements in the 02:27:21.200 |
first dimension for now ignore this property called the stride and later i will explain it 02:27:26.880 |
what is it how this tensor will be saved in the memory of the cpu or in the gpu it will be saved 02:27:34.080 |
as follows suppose that the starting address of the first element is the address 100 and suppose 02:27:40.560 |
that each element is made up of a floating point of 16 bit so it means that each element will occupy 02:27:46.640 |
two bytes so the start address of the second element will be the address 102 and the third 02:27:52.480 |
element will be 104 and the fourth element will be 106 etc etc etc so this is exactly what you get 02:28:00.720 |
when you in c you get you allocate a vector or a matrix with malloc so when you allocate in c a 02:28:08.240 |
vector or a memory with malloc c or the memory allocator will just allocate enough memory to 02:28:15.520 |
store all the elements and it will give you a pointer to the start address of this memory 02:28:20.400 |
then it's up to you to understand where each of these elements is stored in that block of memory 02:28:25.200 |
and this is to to do this we introduce a property called the stride the stride tells us how many 02:28:32.640 |
elements we need to skip to arrive to the next element in the particular dimension in this case 02:28:38.640 |
for example in the case of a vector we only have one dimension which is the x dimension 02:28:43.760 |
or the columns dimension you can think of it so this is the first column this is the second the 02:28:49.120 |
third the fourth fifth etc etc um so in order to arrive from one element to the next we just need 02:28:55.040 |
to skip one element so to go from here we need to just increase our pointer by one element and 02:29:00.320 |
then to go here we need to increase again pointer by one element etc this allow us to do a for loop 02:29:06.320 |
on this tensor let's look at a more complicated case like the matrix so the matrix is a two 02:29:13.040 |
dimensional and suppose we have the following matrix which is made up of six elements with 02:29:18.800 |
two rows and three columns so the shape of this tensor will be two by three because if we have 02:29:24.400 |
two rows and three columns how this matrix will be saved in the memory in the memory it will be 02:29:31.840 |
just a flattened matrix it means and this is called the row major layout but there is also 02:29:38.400 |
another one called column major layout that we will not be discussing so how it will be stored 02:29:44.720 |
in the memory is as follows it will be the first elements of the first row so the elements of the 02:29:50.000 |
first row followed immediately by the elements of the second row so that the memory address 02:29:56.400 |
imagine with this the memory address of the first element is 62 to go to the next element 02:30:01.200 |
we need to increase the memory address by the number of bytes that each element occupies which 02:30:06.080 |
is two bytes so the the address of the second element will be 64 the third element will be 66 02:30:12.960 |
and the next row will start immediately after the end of the first row let's introduce this 02:30:19.200 |
property stride so the stride is what the stride tells us how many elements you need to skip in 02:30:24.560 |
each dimension to arrive to the next element of that dimension for example imagine we want to 02:30:32.320 |
address we want to get the element so all the elements of the first row 02:30:36.640 |
so let's call this tensor here let's call it t so t of zero and this basically this indexing here 02:30:49.040 |
says give me all the elements of the first row so in the first row select the all only the first 02:30:54.400 |
row and give me all the elements of that row how to how does this indexing work well by starting 02:31:01.040 |
from the pointer to the first element it will select only the first row and then it will move 02:31:08.560 |
the index here one element after another so it will select the first one the second one the third 02:31:15.760 |
one how does it know that it needs to move one element by one element because in this dimension 02:31:20.560 |
the stride is one so the stride tells us how many elements you need to skip to arrive to the next 02:31:25.920 |
element in that dimension imagine now that we want to get the t of let's say zero and one well 02:31:36.240 |
in this case let's say t of one actually and all the elements of the first row it will first of all 02:31:44.560 |
it needs to skip some elements from the first dimension it needs to skip the element zero 02:31:49.680 |
because we don't we are not selecting it we only want to select the element one of the first 02:31:54.080 |
dimension which basically means the row with index one so because it will start from the first 02:32:00.640 |
pointer to the first element it will it needs to know how many elements to skip and how many 02:32:06.160 |
element to skip is given by the stride so the stride tells us how many elements you need to 02:32:10.480 |
skip to arrive to the next element of the first dimension so in this case it will take the pointer 02:32:15.520 |
to the first element skip three elements and it will be starting with the second row and then 02:32:20.640 |
inside this row it will go through the second in the the index of the second dimension in which 02:32:26.160 |
the stride is one so it will just go one after another and it will return only this part of the 02:32:31.040 |
memory so to rehearse the stride is just a a number that tells us how many elements you need to skip 02:32:40.880 |
in each dimension to arrive to the next index in that dimension so it means that to go from one 02:32:46.880 |
row to the other we need to skip three elements to go from one column to the other we need to skip 02:32:51.200 |
one element why is the stride useful well the stride is useful because it allows us to reshape 02:32:59.360 |
tensors very easily and without doing any computation let's see okay imagine we want to 02:33:06.080 |
reshape a matrix imagine initially the shape of this matrix is a two by three so we have a two 02:33:11.680 |
row by three columns and we have a stride calculated as follows means that to go from one row to the 02:33:16.800 |
other you need to skip three elements and to go from one column one row to the other you need to 02:33:20.960 |
skip three elements and to go from one column to the next you need to skip one element so you need 02:33:27.040 |
to jump by one element if we want to reshape it into this shape so three by two basically we want 02:33:33.600 |
to have three rows and two columns we can reshape it without actually changing its memory layout 02:33:45.760 |
just by changing the stride because look at this physical configuration of the tensor and we can 02:33:53.280 |
access this same tensor as this shape or as this shape exactly by using the same physical view 02:34:00.160 |
because to go from one row to the next here the stride is a three so we need to skip three 02:34:06.400 |
elements it means that the starting address the starting element of the second row is given by 02:34:12.720 |
the start pointer plus three elements so exactly here the second row will start and each element 02:34:19.760 |
of the second row is one after another because the stride of the second dimension is one so you can 02:34:26.000 |
see that to get the second row we can just start from here and then go one after another and get 02:34:32.080 |
all these elements which is exactly the second row suppose we want to obtain the second row of 02:34:37.120 |
this view here of this shape of this reshaped matrix how to do that let's look at the stride 02:34:43.360 |
the stride now is a two in the row it means that to go from one row to the next we need to skip 02:34:48.640 |
two elements so if we want to select this row here we go from the starting point of the memory 02:34:55.360 |
so this start pointer we skip the first two elements because the stride says that to go 02:35:01.920 |
from one row to the next you need to skip two elements so we arrive here and then we select 02:35:06.480 |
exactly two elements which are one after another because the stride in the second dimension is one 02:35:11.360 |
so the stride allow us to reshape the tensor without changing the physical layout on how it 02:35:21.440 |
is stored in the memory moreover the stride also allow us to get the transpose of a matrix without 02:35:29.280 |
changing the shape of how it is stored in the memory so without changing the arrangement of 02:35:33.600 |
the elements in the memory and this is very cool because we can view the same matrix as without 02:35:39.760 |
the transpose and also the transpose version of the matrix without changing anything in the memory 02:35:44.400 |
so it comes for free just by working with the index and the stride so to transpose the matrix 02:35:50.320 |
along two dimensions we just need to swap the stride along these two dimensions that we want 02:35:54.480 |
to transpose so in this case for example imagine we want to get the transpose of this matrix 02:35:59.440 |
we just need to swap the strides so if we want to get the second row of the transpose matrix 02:36:04.880 |
how to get that well you we always have the pointer to the first element where the tensor 02:36:11.280 |
is stored so at the beginning of where the tensor is stored in the memory and it says that in order 02:36:18.240 |
to go to from one row to the next we need to skip one element which is correct because as you can 02:36:25.680 |
see the second element is exactly the second element also in the memory so we just skip by one 02:36:31.280 |
and we get the starting point of the second row and then to go from one element to the next in 02:36:38.080 |
within the same row we need to skip three elements so the second element of the second row will be 02:36:44.480 |
after three elements after the first element of the second row so after two we need to skip three 02:36:50.880 |
elements so we skip this one we skip this one and we arrive to this one eight which is exactly the 02:36:55.600 |
second column of the second of the second row so basically the the stride as you can see allow us 02:37:02.160 |
to do two things one is it allow us to reshape the tensor without having to reallocate it in 02:37:08.880 |
another configuration in the memory secondly it allow us to transpose a matrix without having to 02:37:14.400 |
rearrange the elements in the memory which is great because moving memory around is expensive 02:37:19.040 |
and rearranging the memory is expensive so that it's great that this this stuff comes for free 02:37:25.360 |
basically another thing okay for example if you try to you know that in pytorch there are two 02:37:33.760 |
methods to reshape a tensor one is called the reshape method and one is called the view method 02:37:39.280 |
the after transposing a matrix by swiping the by swiping the stride of the two dimensions that you 02:37:47.760 |
want to transpose you cannot reshape for free the tensor anymore because um the tensor basically what 02:37:56.240 |
is the stride the stride how it is computed the stride is just the uh let me show you with a 02:38:02.880 |
concrete example the stride is just the product of all the shape uh after um in the future dimension 02:38:11.840 |
so the stride of the zeroth dimension is just the product of the elements in the shape of 02:38:18.400 |
the future dimension so the stride of zero is just the product of all the shape starting from 02:38:23.520 |
the index number one uh it's not easy to see with the 2d matrix because we don't have enough elements 02:38:28.880 |
so let's do it with a 3d matrix so this is a tensor with the three dimensions so it is a shape 02:38:35.360 |
of two four three which means that we have two matrices each matrix is made up of four rows and 02:38:40.960 |
each made and three columns the stride is calculated as follows so the zeroth dimension stride is just 02:38:49.280 |
the product of four by three and this three here comes the with the product of just a three with 02:38:55.360 |
its with one because we don't have any future dimension of the three so when we transpose 02:39:00.960 |
this stride property is lost and we cannot um after transposing this matrix by swapping the strides we 02:39:09.840 |
cannot do further reshaping operations so basically the the tensor is not log contiguous so 02:39:16.560 |
this is a very advanced okay property if you it doesn't matter if you know it or not but if you 02:39:22.080 |
are curious basically in pytorch you cannot um view a tensor after it has been transposed 02:39:29.600 |
because the pytorch to transpose a tensor will just swap the two strides but it loses the stride 02:39:35.440 |
property which is basically the stride will not be anymore the product of the future shapes so this 02:39:42.560 |
is not anymore two this should be two for example and this should be one but after transposing this 02:39:50.160 |
property is lost so you need to actually reallocate the tensor if you want to reshape it after it has 02:39:55.360 |
been transposed it doesn't matter if you remember this it's just a curiosity anyway so what is the 02:40:01.200 |
transposed what is the stride used for is the stride for the stride is used for two things 02:40:05.680 |
first of all it is used to understand how to index this tensor so just by having a pointer to the 02:40:12.400 |
first to the starting address of this tensor we can index this tensor however we like so we can 02:40:19.440 |
access any row any column moreover it allow us to reshape this tensor for free so without 02:40:26.880 |
rearranging the elements inside the memory and third it allow us to transpose the tensor however 02:40:33.280 |
we like just by swapping the strides of two uh the two dimensions that we want to transpose 02:40:38.400 |
now that we have seen also how the tensor is stored in the memory we can finally go to see 02:40:44.880 |
triton um and see some examples all right guys now that we have seen how uh tensors work 02:40:53.440 |
tensor layout works how CUDA works now we can see some examples of triton kernels to see how triton 02:40:59.040 |
differs from CUDA now if you go on the triton website you will find some tutorials like in 02:41:07.600 |
this section here and let's do let's work one tutorial together to understand how triton is 02:41:13.280 |
different from CUDA so if you go to the tutorial there are many examples so first of all the code 02:41:19.200 |
that i will be coding for flash attention is based on this tutorial here fused attention 02:41:23.040 |
that you can see here but with some modifications because i simplified the code a lot i removed for 02:41:28.160 |
example the fp8 implementation i also for example um this code here on the fused attention only 02:41:34.800 |
works in the backward pass only for the causal attention while my code will work for the causal 02:41:39.360 |
and non-causal attention uh the second another modification i did is instead of using the 02:41:44.720 |
exponential tool that they use here to make things faster drawing because the exponential tool is 02:41:49.360 |
implemented with a faster unit i i use the the original implementation of flash attention which 02:41:56.800 |
use the exponential with the base e etc so i simplified my code as much as possible to make 02:42:02.800 |
it simple to follow instead of making it optimized so for sure my code will be slower than the the 02:42:08.240 |
fused attention that you see here but mine should be more comprehensible more easy to follow anyway 02:42:14.880 |
let's go to the vector addition tutorial and if you go to the vector addition tutorial there are 02:42:20.080 |
some examples on how to do a vector addition with triton this should allow you to get into the 02:42:25.280 |
mindset of how to write kernels with triton instead of writing first the kernel and then calling it 02:42:31.920 |
let's do the opposite so let's see how to call this kernel and let's explore how it works so i 02:42:37.360 |
have already copied the tutorial vector addition from the website so let's look at first of all 02:42:42.960 |
what we want to achieve we have an input vector x and an input vector y and we want to compute 02:42:49.920 |
the vector addition which means that with the torch we want to do the following operation 02:42:54.160 |
and also we want to do the same operation also with the triton by calling this method add and 02:42:59.200 |
then we want to compare the two vectors output and they should be equal or at least the difference 02:43:04.480 |
should be very very small because of course there is always some rounding error in case you are 02:43:08.160 |
working with floating point numbers the size of this vector is 98 000 elements and we want to 02:43:15.680 |
work in a blocked way so as you remember before with the cuda you can do vector addition by 02:43:21.920 |
spawning a lot of number of threads each doing one operation but when the number of threads 02:43:27.360 |
that you have is not enough then you need to divide the input vector into blocks and this 02:43:32.000 |
is what we are going to do here so let's look at this add method so this add method basically 02:43:37.600 |
will first of all allocate the necessary memory for the output vector then it will compute the 02:43:45.360 |
launch grid the launch grid tells triton just like in cuda how many kernels we want to how 02:43:52.240 |
many blocks we want to launch how many blocks of threads we want to launch if you remember in the 02:43:58.960 |
cuda kernel we specify how many blocks we want and then how many threads we want for each block 02:44:05.360 |
in the case of triton we tell how many blocks we want and then we don't force how many threads to 02:44:15.360 |
launch it will be triton that will choose how many threads to launch we just tell what each 02:44:22.320 |
group of threads should do so in this case for example we divide our number of elements so n 02:44:29.600 |
so which is 98 000 into blocks of size block size which is initialized as 1024 this is basically 02:44:39.280 |
saying take them to calculate the grid size you do the ceiling division so basically this means 02:44:45.840 |
ceiling of seal of n elements divided by block size this is the meaning of this one so how many 02:44:56.160 |
blocks we want now what each block should do is inside of the kernel so let's go to the kernel 02:45:03.280 |
and when we launch the the kernel we we can specify the launch grid in this square parentheses and 02:45:09.120 |
then in the round parentheses we specify the arguments of this kernel so let's go to the kernel 02:45:14.960 |
we see that python triton will not give us access to the tensor x it will give us a pointer to the 02:45:25.040 |
first element of this tensor and this takes us back to the tensor layouts so the reason we studied 02:45:30.560 |
the tensor layouts and the strides and all the stuff is because triton this code this add kernel 02:45:37.840 |
will run on the gpu and the gpu cannot um does not index tensors like pytorch by using all the 02:45:46.400 |
dimension and with the broadcasting and all this fancy stuff the gpu will just give you the pointer 02:45:52.800 |
to the first element of this tensor in the memory and then it's up to you to compute all the indexes 02:45:58.960 |
of all the elements that you want to access so this x ptr is the pointer to the first element 02:46:05.280 |
of the x vector this y pointer is the first the pointer to the first element of the y 02:46:10.960 |
vector then we have the pointer to the output vector where we want to store the result of this 02:46:17.120 |
matrix addition we specify how many elements our vectors have and what is the block size so 02:46:23.120 |
how many items each block should process which may not correspond to how many threads each 02:46:30.480 |
each kernel will have you may be confused because okay in triton in coda we specified how many 02:46:39.360 |
threads each block should have so the granularity that we manage is the thread level 02:46:46.480 |
here we are saying it's a group of thread that should work with this quantity of data then it's 02:46:52.480 |
up to triton to optimize the number of threads that it will actually use actually there are 02:46:57.600 |
tricks there are ways to say how many threads we actually want by specifying the number of words 02:47:01.840 |
but we will see that later for now just remember that this thread this kernel here will process a 02:47:08.720 |
number of elements in the input vectors how many number how many elements block size number of 02:47:14.800 |
elements first of all we need to identify which block we are we are in coda we use the the variable 02:47:24.320 |
called the block id.x to identify the identifier of the block which tells us which group of elements 02:47:30.320 |
we should be working with in triton you do the same by using program id and in coda the block 02:47:38.640 |
id can be along the x y and z axis in triton these are called the dimension 0 1 and 2 here we have 02:47:47.200 |
one dimensional data so we only use one axis to specify the block index so we get the block index 02:47:54.160 |
which is the p id in this day in triton this is called the program id it's more intuitive to think 02:48:00.800 |
of it as the program like this is a kind of a program that is running in parallel with other 02:48:05.360 |
programs that will have different program id and based on the program id we can understand 02:48:11.120 |
what is the starting element this program should work with so this blue block of threads should 02:48:16.240 |
work with them and together that is just the p id multiplied by the block size so the p id 0 should 02:48:22.080 |
be working with the element that starts from the element 0 the p id 1 should start with the element 02:48:27.680 |
1024 and the p id 2 should start from the element 2048 so it should skip the first 2048 elements 02:48:34.960 |
and start with the element with index 2048 next we define how to load these elements 02:48:42.400 |
based on the pointer in which of the x and the y vector to do that we specify a list of offsets 02:48:53.280 |
with respect to the starting address that we want to load so because each program in triton works 02:48:59.680 |
with a group of data so not one single element but a block of elements we mean we need to understand 02:49:08.480 |
which elements to load so the offset of these elements in the case of the program id 0 it will 02:49:13.920 |
load the block start so 0 plus the elements from index 0 to 1024 excluded with the program element 02:49:27.920 |
1 this basically will result in a vector that is well the program start with p id equal to 1 will 02:49:35.600 |
be 1024 then 1025 1026 1027 etc etc until 2047 with the program number let's say 2 this this 02:49:50.960 |
offset will be the elements 2048 2049 blah blah blah until 3000 and something now we also as you 02:50:03.520 |
remember when we create when we launch a grid the number of threads is not always based on the number 02:50:13.520 |
of elements in the block or the number of elements in your vector it is always a multiple of a base 02:50:19.120 |
number which is usually 32 which means that the grid this program may have more threads that it 02:50:26.320 |
needs so some threads should not be doing anything so should not be loading any data and should not 02:50:31.920 |
be computing any summation so what we this is what we why we need this mask this means that 02:50:37.600 |
if all these offsets that we are loading it should be at most up to n elements because imagine you 02:50:45.440 |
have not 1000 2000 imagine you have a vector of 2060 elements which means that this offset for 02:50:56.240 |
the the third program of this kernel will load the offset that go from 2048 2049 blah blah 2060 and 02:51:07.360 |
then also 2061 2062 etc etc but we said that we only have a 2060 elements so all the 02:51:15.280 |
elements of 2061 62 etc until 3000 and something they don't exist so we need to tell somehow that 02:51:23.360 |
all the threads that are working with these elements should not load anything that's why 02:51:27.920 |
we need this mask this mask tells load among all the offsets that this block should work with 02:51:34.720 |
only those elements that actually exist for which this mask is true then we load the elements of 02:51:42.800 |
this current program which is a group of elements defined by these offsets and only the one that 02:51:50.960 |
for which this mask is true so only the one that actually exists all the others should be ignored 02:51:55.680 |
and we can also specify what it should load in case this the mask is false with another parameter 02:52:04.160 |
but we will not see that here we also load the group of elements of the y vector and then we 02:52:09.680 |
compute the output x plus y so if you remember previously in CUDA we we did something like this 02:52:16.800 |
like the output of i is equal to the x of i plus the y of i so we did it one element at a time 02:52:26.080 |
because each thread was working with one index here we are working with a group of elements so 02:52:31.040 |
this x is a group of elements is a block of elements at most of size block size actually 02:52:38.720 |
of size block size and it's this y is a group of elements from the y vector and we are computing 02:52:46.720 |
the output group by group so this this is summing a group of elements of x with the corresponding 02:52:54.800 |
group in y and writing it in output then we need to restore this output we need to store it in the 02:53:02.160 |
output tensor output ptr that you can see here which is a pointer to the first element of the 02:53:07.440 |
output vector and we say that where should we store this output vector which is of size shape 02:53:14.240 |
of this vector here is block size where should we save it well in the same offset to where which we 02:53:21.840 |
loaded x so if this program worked with the index 2048 2049 etc etc then all this output should be 02:53:30.720 |
written in the same offset 2048 2049 etc up to 3000 and something using the mask as well because 02:53:39.360 |
we don't want to write all the values of this block size because maybe we don't have enough 02:53:43.600 |
elements so only write the one that are actually present in the vector so the reason we need the 02:53:49.520 |
mask is because CUDA will launch a number of thread that is always a multiple of a base unit 02:53:53.920 |
that may not be a multiple of the vector size that we are working with so we need to find a way to 02:54:00.720 |
tell some threads to not do anything for those that the data is not available so let's rehearse 02:54:07.360 |
what you have seen so far in CUDA the program that we write is at the thread level so each thread 02:54:12.400 |
what it should do in triton it's this block of data we work with a block of threads what data 02:54:21.360 |
this block of thread should work with all right guys the final finally the moment has come so 02:54:29.920 |
we are going to code the flash attention for our pass right now in triton but let's rehearse the 02:54:36.480 |
algorithm so the goal of the attention mechanism in specifically in triton in flash attention is 02:54:43.200 |
to compute the attention output which is we want to compute the output of the following formula so 02:54:48.240 |
the query multiplied by the transpose of the key divided by the square root of the head dimension 02:54:52.000 |
all multiply we apply the softmax and then all multiply by b now we in this video we will be 02:55:01.680 |
coding the forward pass and also the backward pass but before coding the backward pass we need 02:55:07.200 |
to understand how the autograd works we need to understand what is the gradient what is the 02:55:12.080 |
jacobian how to derive the gradient of the softmax operation how to derive the gradient 02:55:16.960 |
of the matrix multiplication operation etc etc so that is going to be another part of the video 02:55:21.840 |
for now let's concentrate on the forward pass right now we have some tools so we know that 02:55:26.880 |
we have this thing called the gpu that can parallelize operation among multiple cores 02:55:31.200 |
we know that in cuda we can parallelize operations by telling by writing a program that is the 02:55:36.960 |
definition of what each thread should do or we can follow the triton programming mode which is 02:55:43.440 |
telling in python what each group of threads should do the mapping between the what each thread 02:55:51.600 |
should do and the which element that should try to work with is up to us to the programmers and 02:55:58.080 |
the same happens in triton we tell we how many blocks of threads we want how much data each 02:56:05.840 |
thread should block of thread should process so that's the block size that we saw in the vector 02:56:11.440 |
addition but then the mapping between the elements of the vector and the the identity of each group 02:56:20.160 |
of threads so the program id that we saw is up to us and the same will happen when we record 02:56:26.080 |
flash attention let's see what can we parallelize in this flash attention so first of all this code 02:56:33.360 |
that you see in the forward pass of the flash attention is takes as input query key and value 02:56:40.160 |
that is a vector that is a matrices of n by d however usually in a transformer network we don't 02:56:48.640 |
have only one sequence made up of d dimensions we have many sequences made up of d dimensions 02:56:55.520 |
and this d is the lowercase d which is the the number of dimensions dedicated for each head but 02:57:02.720 |
we don't have only one head we have multiple head so the algorithm that you see here is what each 02:57:09.360 |
head should work so each head of each batch should do moreover we have seen before when talking about 02:57:19.680 |
block matrix multiplication that we can parallelize the computation of the output because this output 02:57:25.600 |
block here depends on the query one and all the keys this one here depends on the query group 02:57:31.040 |
block of query two with all the keys and this one here is the query tree with all the keys etc 02:57:37.440 |
so because this one only depends on query the group the block query one and this one only 02:57:42.960 |
depends on the block query two they can work independently from each other by sharing of 02:57:47.920 |
course work the keys another thing that we need to understand about triton is the shared memory so 02:57:55.920 |
um the in the gpu we have the high bandwidth memory and which is the kind of the ram so the 02:58:04.720 |
when you buy an a100 they tell you that it has a 40 gigabyte that's the amount of memory in the 02:58:10.640 |
high bandwidth memory so the dram so let's look at actually the structure of the gpu 02:58:15.440 |
which is here we have this dram which is the big memory that we that the gpu has and then each 02:58:24.640 |
streaming multiprocessor so it's a let's call it a block of threads actually also have a shared 02:58:32.160 |
memory so inside of the gpu actually we have we have these streaming multiprocessors and 02:58:37.280 |
these streaming multiprocessors have a part of memory called the shared memory which is much 02:58:42.160 |
smaller than the dram like much much much smaller what changes between these two memories the access 02:58:49.360 |
to the dram is very slow and the access to the shared memory is very very very fast so one thing 02:58:56.080 |
that is different between cuda and triton is that whenever you load some information in cuda you are 02:59:02.080 |
loading that information directly from the global memory because when we launch a cuda kernel first 02:59:07.680 |
of all as you remember in my c++ code we first copy the tensors from or the vectors from the 02:59:14.480 |
cpu to the gpu and they reside in the global memory of the gpu then we load these elements 02:59:21.600 |
directly from the global memory but the access to the global memory usually it's much much much 02:59:27.120 |
slower so what happens with the flash attention is that the flash attention computation in its 02:59:32.160 |
the attention computation in its naive version the one that we can do with the torch is very 02:59:37.280 |
slow because the access to the global memory is very slow so we want to use as much as possible 02:59:43.120 |
the shared memory so we want to reuse the elements loaded from the global memory into the shared 02:59:48.480 |
memory so that we don't need to access the global memory every time to load elements from the 02:59:52.640 |
vectors or the matrices and this is what happens also in triton so in triton whenever you load 02:59:58.800 |
some data you are copying the information from the global memory to the shared memory 03:00:03.120 |
then whatever operations that you are doing is done on the shared memory and then when 03:00:08.240 |
you store the information you are copying the data from the shared memory to the global memory 03:00:12.400 |
this makes it much faster so we always work with the elements that have been loaded in the shared 03:00:18.640 |
memory and this shared memory basically it's shared for all the threads that belong to the same 03:00:25.440 |
block in triton we have an abstraction level that doesn't make us work directly with the threads 03:00:32.160 |
so we always work with a group of threads that belong to the same block that share this shared 03:00:36.560 |
memory so in triton we are copying information from the global memory to the shared memory we 03:00:41.280 |
do some operation with it and then we store back to the global memory and this is what we are going 03:00:45.360 |
to do with flash attention now let's review the algorithm of flash attention so in flash attention 03:00:50.880 |
we have to go an outer for loop that is among all the between all the keys and then an inner loop 03:00:57.760 |
that is sorry between all the query blocks and then an inner loop that is through all the key block 03:01:04.080 |
in the original flash attention algorithm the flash attention one the outer block was on the 03:01:10.880 |
keys and inner block was on the queries this made it less parallelizable why because the outer loop 03:01:17.840 |
is on the queries and we have seen before that the the output of this attention can be computed 03:01:24.720 |
independently for each block of queries so it's much easier to parallelize so this outer for loop 03:01:30.240 |
actually we don't have to run a for loop we just spawn many kernels each working with one iteration 03:01:35.760 |
of this outer for loop so each working with a different query block of this outer for loop 03:01:41.120 |
and the inner for loop is something that we have to iterate through so each triton kernel will 03:01:47.920 |
work with one query block and then iterate through all the key blocks 03:01:52.720 |
and inside of this key block we have already seen the operations that we are going to do which 03:01:59.200 |
the we explored before and at the end of this for loop we need to store back the output 03:02:05.840 |
in the high bandwidth memory and this is how it's gonna we are going to work another thing that we 03:02:13.760 |
should notice is that this query key value are n by d so as i said before but usually in in a 03:02:21.600 |
transformer model we don't have only one sequence we have many sequences so we can also parallelize 03:02:28.800 |
on the number of sequences that we have in the batch because each batch can work independently 03:02:33.120 |
from each other and inside each and each head each sequence has multiple heads so each head 03:02:40.720 |
also can work independently from each other because that we know from the attention is all 03:02:44.320 |
unit paper that's what's the meaning of head that's what's the meaning of multi-head attention 03:02:48.640 |
so that each head can compute the attention independently from each other so we will also 03:02:52.560 |
parallelize along the head dimension and moreover if you look at this definition of the query block 03:02:59.520 |
we can also split the query into blocks and each query block can work independently from the other 03:03:04.320 |
query blocks by in producing one output block this is how we are going to parallelize so we are going 03:03:10.240 |
to parallelize each sequence in the batch but inside of each sequence we are going to parallelize 03:03:15.600 |
each head and inside of each head we are going to parallelize each query block so how many programs 03:03:21.920 |
we we will have working in parallel at most it will be the sequence the number of batches so 03:03:28.080 |
the batch the number of sequences in the batch so the batch size it will be the batch size 03:03:37.600 |
multiplied by the number of blocks that we will divide the query sequence into 03:03:47.200 |
so let's call it the i don't know block size q 03:03:50.480 |
the block size q all right now that we have seen this one let's go actually code it so 03:04:02.640 |
i have already introduced a little bit the differences between my implementation of the 03:04:08.160 |
flash attention and the one that you can find on the triton documentation which is first of all i 03:04:12.480 |
don't work with fp8 because i believe this is unnecessary for our explanation it's of course 03:04:18.240 |
much faster because the recent gpus also support fp8 second difference is that in the um in the 03:04:25.920 |
flash attention on the triton website the backward pass is only implemented for the 03:04:31.840 |
causal attention but in my case i implement it for the causal and the non-causal attention even if 03:04:36.400 |
it's slower and later i actually i want to give you an exercise on how to improve it 03:04:41.680 |
and the third difference main difference is that i made make explicit use of the 03:04:48.080 |
softmax scale so i actually use the scale when needed another difference is that in the online 03:04:56.480 |
triton computation of the flash attention is this x is not really e to the power of x but it's 2 to 03:05:03.120 |
the power of x and then they compensate it with by by using the logarithm however because probably 03:05:10.160 |
the implementation of 2 to the power of x is faster than the e to the power of x but in my 03:05:15.600 |
case i retain the original exponential because i want to follow the original algorithm to make it 03:05:20.480 |
simpler to visualize the code along with the algorithm as in the flash attention paper 03:05:25.120 |
so i know i have created a lot of hype so let's do it let's start by creating a new 03:05:34.720 |
file let's call it a program.py just like before when i introduced triton i will start by coding 03:05:41.120 |
first the code that will use our kernel and then we code the kernel and we will only be coding the 03:05:45.920 |
forward pass of the kernel so let's start by importing what we need to import which is just 03:05:53.280 |
the torch and the triton and secondly let's start by let me check okay the copilot is already off 03:06:00.400 |
so i don't have to worry about that let's start to implement the code that will test our 03:06:05.840 |
implementation of the triton and compare it with the naive implementation of the attention mechanism 03:06:10.240 |
so we create our query key and value sequence for testing 03:06:17.760 |
which is if you remember it's a query is the batch size and it has the dimension batch size 03:06:24.160 |
because we have multiple sequences each sequence has a number of heads and it's made up of 03:06:30.080 |
sql and tokens and each token is identified by a head dim number of dimensions if you 03:06:37.440 |
and then this is because we have already split each token into smaller tokens each 03:06:43.440 |
each with its own head dimension if you remove the num heads dimension then you put back you 03:06:50.000 |
concatenate all the dimensions of this head dim we initialize the query key and the value sequence 03:06:57.040 |
by using a normal distribution this code i already took from the tutorial of triton so it's nothing 03:07:02.960 |
different and we require the gradient because we want to compute the gradient with respect 03:07:08.080 |
to query key and value and we will see later why because because we want to implement the back we 03:07:12.960 |
want to test also the backward pass even though we will not be coding it now so the first thing 03:07:18.160 |
that we do is we define our softmax scale which is as you remember the formula is a query multiplied 03:07:25.840 |
by the transpose of the keys and then divided by the square root of head dimension 03:07:33.040 |
so dk or dd head sometimes it's called and then we need to so we need to compute this 03:07:42.240 |
one we can already compute it it's this this is the one over the square root of the head dimension 03:07:46.880 |
and then we also define do and later we will see what is this but this is basically we will be 03:07:55.200 |
needed needed for the backward pass um don't worry if you don't understand what is do later we will 03:08:04.400 |
see it let's do the naive implementation of the attention which is very simple which is first we 03:08:10.960 |
define the mask and we use this mask only if the attention we are computing is causal so as you can 03:08:16.720 |
see we pass this parameter called the causal that tells if we want to compute the causal attention 03:08:22.480 |
or the not causal attention and the d type which is a float 16 because we want to work directly 03:08:28.160 |
with 16 bit floating point numbers we will not be working with fp8 just uh because we don't we 03:08:34.880 |
don't want to implement my implementation is actually not as fast as the one in the tutorial 03:08:39.840 |
of the triton website but i believe it's much more easier to comprehend so we define the mask 03:08:48.320 |
we compute the the product the query multiplied by the transpose of the key divided by the square 03:08:53.680 |
root of the head dimension so that's why we are multiplying by softmax scale if the attention 03:08:58.160 |
we are computing is causal then we use this mask that we have computed so we replace all the points 03:09:04.160 |
all the dot products where this mask is equal to zero with minus infinities and then the softmax 03:09:09.920 |
will replace this minus infinities with zeros because then we are applying the softmax and 03:09:14.800 |
the softmax is applied by rows just like the normal attention we compute okay the second thing 03:09:21.120 |
that we do is we want to um so the output is the product of the output of the softmax with the v 03:09:28.240 |
so this is the reference output on the naive implementation of um flash of the attention 03:09:34.640 |
mechanism then we want to compute we want to also derive the gradients of the output with respect to 03:09:42.320 |
the um inputs and in this case it's the the the v the k and the q later we will see what are we 03:09:51.600 |
doing here then we want also to we want to compare this reference implementation with our triton 03:09:58.480 |
implementation so let's do it so our triton implementation will be implemented as a class 03:10:03.680 |
called triton attention that we will call using this method called apply and later we will see 03:10:09.200 |
what is this method in which we pass the query key and value if we want to compute the causal 03:10:13.520 |
attention the softmax scale that it should be using and it should produce some output which 03:10:18.560 |
is the output of the output of the softmax multiplied by v then we can run also the 03:10:23.680 |
backward and this backward will be the the same backward that we will compute with the 03:10:28.560 |
triton attention and then we compare okay and then we can compare uh the result of our 03:10:38.000 |
implementation so this triton attention dot apply with the reference implementation which is this 03:10:43.440 |
one here and this should be uh we use the the function all close which basically compares 03:10:49.280 |
the elements of two tensors and make sure that their absolute difference is no more than this 03:10:54.560 |
one we are not using the relative distance we are just using the absolute distance between the two 03:10:59.440 |
elements which corresponding elements of two vectors this uh implementation that you have 03:11:04.720 |
that we will build will work with the causal attention and also with not causal attention 03:11:08.480 |
while the uh the one that we saw in the website of triton it only works with the uh the forward 03:11:15.120 |
pass actually works with the causal and non-causal while the backward pass only works in the case of 03:11:19.200 |
the causal attention um okay but it's highly optimized the one online so if you want to 03:11:25.120 |
learn a little more tricks on how to optimize triton kernels there is a lot of knowledge there 03:11:30.320 |
anyway guys now let's try to uh implement this triton attention at least the forward pass so 03:11:36.080 |
let's go to implement this triton attention class 03:11:39.600 |
okay here every time you want to introduce a new operation into torch you need to derive the um 03:11:50.560 |
you need to implement your operation by deriving from this autograd dot function class so every 03:11:56.400 |
operation in torch actually if it's the softmax or it's the um i don't know the the relu or the 03:12:03.360 |
zwiglu or whatever there is it is always implemented as a function is a class that 03:12:08.240 |
derives from this function and it should provide two methods one called the forward pass and one 03:12:13.200 |
called the backward pass the forward should produce the output of this operation and the 03:12:17.280 |
backward should compute the gradient um the gradient with the of the loss with respect to 03:12:22.960 |
that the the input of that function and later we will see how that works for now let's concentrate 03:12:28.480 |
on the forward pass to implement the forward pass we need to create a static method that 03:12:32.880 |
is called forward which takes as input one thing called the context so as you know in autograd in 03:12:41.600 |
when training neural networks we have the forward pass and the backward when computing the 03:12:46.240 |
backward pass we need to reuse the activations of each of the computation nodes during the forward 03:12:51.600 |
pass and this context basically allow us to save the information to uh for the necessary activations 03:12:57.520 |
that we will need during the backward pass and later we will see in the triton um in the flash 03:13:02.400 |
attention algorithm what information we need to save in order to compute the backward pass for 03:13:07.680 |
example what we will need to save during the backward pass we will need to recompute on the fly 03:13:13.040 |
the soft the query multiplied by the transport of the keys for each block but we don't want to 03:13:18.800 |
recompute the normalization factor or the maximum value for each row so we will save those two 03:13:23.520 |
values and actually we will not save two values we will save one value we do a trick called the log 03:13:28.720 |
sum exploit log sum exploit that we will see later anyway this context is just a kind of a storage 03:13:35.920 |
area where we can save some stuff that will be necessary for us to recompute the backward and 03:13:41.040 |
you can see whatever you like then we have the input of this operation which is the query key 03:13:47.360 |
and value which is a three tensors with the causal if we are going to compute the causal attention 03:13:52.720 |
and the softmax scale that we should apply based on the one over the square root of the 03:13:56.800 |
head dimension which we could also compute it on the fly actually by the way by by checking the 03:14:04.240 |
shape of this but okay it doesn't matter anyway so um the first thing that we are going to do 03:14:09.520 |
is to extract the shapes of these objects and make sure all the shapes are what we expect them to be 03:14:14.560 |
so the shape of the query key and value is a batch size by number of heads by sequence length 03:14:19.840 |
by head dimension we make sure that the head dimension matches for the query key and value 03:14:25.920 |
they should match because each vector should should be of the same size 03:14:30.560 |
and then we declare what we pre-allocate the output vector so where we should save our output 03:14:38.640 |
so as you remember the output in the attention mechanism has the same same shape as the query 03:14:44.720 |
key and value sequence where the query key and value sequence i want to remind you is not the 03:14:49.440 |
query key and value of the input of the attention which is a sequence of tokens but it's the output 03:14:54.880 |
already of the wqwk and wv because flash attention is not concerned with optimizing those metrics 03:15:01.040 |
multiplication but only the output of the wqwk and wv so we pre-allocate the output tensor where 03:15:08.640 |
we will store this output which has the same shape as the query key and sequence uh matrix 03:15:16.240 |
actually actually no not true actually it has the same shape as the query but it may not be the same 03:15:23.600 |
as the key and value why because there is this thing called cross attention where the query key 03:15:29.920 |
and value are transposition are different projection through wqwk wv not of the same 03:15:36.560 |
input sequence but of two sequences so cross attention happens when we have a query that 03:15:41.440 |
comes from one uh sequence and the key and value come from another sequence and they pass through 03:15:47.440 |
their own wk wv and they may not have the same sequence length so the shapes of the output of 03:15:53.280 |
the attention only depends on the shape of the query sequence not of the key and value sequence 03:15:58.880 |
this is happens during cross attention but usually in language models we always work 03:16:03.360 |
with the self-attention so that should not happen at least in the causal language models 03:16:08.480 |
then we have the stage and later we will see what is this stage basically the stage it's just a 03:16:16.400 |
number that tells if the operation that we are going to do later is for the causal attention 03:16:22.160 |
or for the not causal attention and then we need to define our launch grid the launch grid tells 03:16:28.400 |
us how many parallel process we need to be launched by triton actually they will be launched by cuda 03:16:34.800 |
but by we always work with the triton as an interface to cuda so by triton so in triton 03:16:41.840 |
as i said before we want to parallelize along the batch dimension so each batch each sequence 03:16:48.560 |
in the batch should work independently from each other not only each inside of each sequence in 03:16:54.160 |
the batch each head should work independently from each other so at least we have a batch size 03:16:58.720 |
multiplied by number of heads programs and for each of this program we have another 03:17:05.840 |
dimension called the we divide the query into blocks of queries so as you remember when talking 03:17:15.200 |
about a block matrix multiplication we don't work with the query as the original matrix query matrix 03:17:20.560 |
so where each query is one vector or one token we work with group of queries so each block of 03:17:27.840 |
queries is a group of tokens in the query sequence so we are saying that we want to launch at a 03:17:34.560 |
number of kernels or blocks of threads or a group of threads along two dimensions so just like the 03:17:44.480 |
cuda kernel can be launched along two dimension x and y here we are launching programs along two 03:17:50.000 |
dimensions one dimension that tells us which batch which head of which batch we are going 03:17:56.240 |
to work with so which head of which batch element are we going to work with and inside this we are 03:18:07.040 |
going to say okay this is a sequence which group of queries are we going to work with 03:18:16.400 |
are we going to going to work with so overall and the group of queries is what is the sequence 03:18:25.840 |
length divided by the number of queries that we want to group together so the block size cube 03:18:30.880 |
tells us how many queries are there in each block of queries so this cdiv is just the ceiling 03:18:37.200 |
division so it is equal to let me write it here this is equal to ceiling of sequence length 03:18:44.960 |
divided by the block size q this tells us how many blocks of q we have so let's rehearse we 03:18:57.840 |
have a tensor that is q that is batch size by number of heads and each flash attention algorithm 03:19:05.680 |
will work with the following the sequence length head dimension moreover we have seen that the 03:19:10.720 |
flash attention has two loops one is the outer loop among all the query blocks one is the inner 03:19:17.200 |
loop along all the key block we have seen that the query block can work independently from each 03:19:22.960 |
other so we can spawn as many programs in parallel as there are number of blocks of q because they 03:19:28.880 |
can work in parallel so this grid tells us how many programs there are that can work in parallel 03:19:34.080 |
then it will be the gpu that based on its resources will decide how many program actually to work in 03:19:39.360 |
parallel if it has enough resources to make them all work in parallel wonderful if it doesn't have 03:19:45.200 |
enough resources to make them work in parallel it will launch them sequentially one after another 03:19:48.880 |
and the last dimension is this is like the z dimension in the cuda in the cuda launch grid 03:19:57.760 |
and we don't want to use it because we don't want an additional level of parallelism all right this 03:20:05.040 |
is our launch grid so we will launch a number of programs that is this one a number of programs 03:20:12.240 |
of parallel programs or number of parallel kernels and each kernel in triton work is a group of 03:20:18.800 |
threads which is a batch size multiplied by number of heads multiplied by a number of blocks of q 03:20:30.000 |
so how many blocks we have we divided the q sequence into okay let's continue so then we 03:20:38.320 |
will see what is this one so this m is another matrix that we will need and it's the log sum 03:20:44.400 |
expo for the backward pass and we will see at the end of this video what is not at the end of this 03:20:49.280 |
video but at the end of the forward pass what it's needed for but basically this is you can think of 03:20:54.960 |
it as the maximum for each row um you we to to recompute the query multiplied by the key in the 03:21:02.560 |
backward pass we should also have if we don't want to recompute the maximum for each row and 03:21:07.120 |
the normalization factor of the softmax we should save two things one is the maximum for each row 03:21:12.000 |
and one is the the normalization factor however by using the log sum exp trick we can only save 03:21:19.680 |
one value which is the as you can see in the algorithm of flash attention it's this stuff here 03:21:26.800 |
which is let's see here it's this stuff here so this li which is the maximum for each row 03:21:36.480 |
plus the logarithm of the of the normalization factor and basically in when computing the 03:21:46.320 |
backward pass we need to recompute on the fly this block here so this query multiplied by the 03:21:51.360 |
transpose but to apply the softmax as you remember we need to have the maximum for each row and the 03:21:55.120 |
normalization factor so we don't we don't recompute them during the backward because we have already 03:22:01.440 |
computed them during the forward so we save this information but we don't need to save these two 03:22:05.600 |
information separately we can aggregate it into one single value called li and later we will see 03:22:10.720 |
how we can use it all right so we have defined also this one and we can proceed further so now 03:22:20.800 |
we launch our grid our kernel don't be scared it's going to be a little long so here so we are 03:22:29.920 |
launching the the kernel for the forward pass by defining what is the launch grid so how many of 03:22:36.720 |
this program should run in parallel at most we are passing the query we are passing the key we 03:22:42.160 |
are passing the values we are passing the softmax scale the m which is the information that we need 03:22:47.120 |
to save for the backward pass it's actually the l in the code of the pseudo code of the flash 03:22:53.760 |
attention algorithm here i call it m i think because also in the original code it was called m 03:23:00.640 |
the o where the our kernel should save its output and then as you remember 03:23:07.280 |
we don't get all the nice access by indexing tensor like we are used to in torch we only get 03:23:18.080 |
a pointer to the starting element of q a pointer to the starting element of k and to the starting 03:23:23.680 |
element of v and then we have to figure out all the index in the memory of the other elements 03:23:30.080 |
how to calculate the index we need the stride because the stride tells us how many elements 03:23:35.440 |
to skip to go from one dimension to the other and that's why we are passing the stride for 03:23:40.400 |
each dimension of each tensor actually in our case we are only working with q k and v that are 03:23:48.400 |
actually of the same d type and of the same shape so we should not need actually to pass all all the 03:23:55.360 |
strides for each of these tensors because they should have the same strides however in the 03:24:01.520 |
original code i believe they were passing it so i kept it so the stride allow will allow us to 03:24:07.440 |
index this pointer so to understand to access the elements of of this tensor just by using its 03:24:16.160 |
starting the pointer to the starting element and then the strides we will be able to index any 03:24:21.360 |
element we want in the tensor then we pass the information of these shapes so the batch size 03:24:28.320 |
the number of heads the sequence length and the head dimension and which is the same for all of 03:24:34.720 |
them and then the stage the stage indicates if we are going to compute the causal attention or not 03:24:39.680 |
causal attention so let's not implement it and let's continue writing this method so then then 03:24:46.560 |
we need to save some information that we will be needed for the backward pass which is this 03:24:50.400 |
context variable that i told you before so we save some information for the backward pass which is 03:24:55.840 |
the query key and value which are the tensor for which we want to compute the gradient during the 03:25:01.200 |
backward pass and then we need to store also this m tensor and this o tensor and then we can all we 03:25:13.280 |
need to also store the causal variable so because if we computed the causal attention during the 03:25:18.640 |
forward forward pass then during the backward pass we need to have this information because 03:25:25.440 |
we need to mask out the things that we don't want to contribute to the gradient but we will 03:25:31.040 |
see that later when computing the backward pass for now let's concentrate on this attention forward 03:25:35.600 |
so we need to implement this forward kernel that you can see so underscore attention underscore 03:25:41.600 |
forward method now a triton kernel is just a python method with a particular decorator called 03:25:48.400 |
triton.git so we copy and paste the signature so this is what makes a method become a triton kernel 03:25:56.800 |
and as you can see here we pass the query key and value matrix along with other information 03:26:03.360 |
the emmetrix so please don't confuse the emmetrix with the mask that we will apply 03:26:07.680 |
on the fly we will generate it on the fly because we are only concerned in this case with a causal 03:26:13.760 |
attention or not causal attention we do not accept custom masks here and then we pass the strides of 03:26:21.520 |
all these tensors the batch size the number number of heads the sequence length the head dimension 03:26:26.160 |
which is the shape of each of these tensors and the block size q and the block size kv the block 03:26:33.600 |
size q indicates how many queries we want to group together to make one block of the q matrix 03:26:39.920 |
and how the kv indicates how many keys and values we want to put together to make one block of the 03:26:46.160 |
k and v matrix which is what we do when we do block matrix multiplication this stage is a 03:26:52.720 |
number that indicates if it's a causal or not causal attention we are doing so it will be three 03:26:58.640 |
in case it's a causal and one in case it's not causal okay the first thing that we do is to 03:27:04.000 |
verify some information so we verify that the um the block size of the kv is less than or equal to 03:27:12.560 |
the head dimension to be honest i don't think we need it with my code because i removed most of the 03:27:18.640 |
constraints so this uh this check was also present in the original code so i kept it but it all 03:27:24.160 |
depends on how we are later we will see what is the auto tuning process and later we will see 03:27:29.680 |
what variables we are going to auto tune for and how many stages we will choose how many warps we 03:27:34.560 |
will choose etc etc so let's leave it for later you can comment it or you can keep it it shouldn't 03:27:39.600 |
matter um the first thing that we do as i said before we launch a grid so a grid is a series 03:27:46.880 |
of programs where we will have some identifiers like in the cuda we had an identifier for the 03:27:52.240 |
blocks on the x-axis and on the y-axis in triton we get this identifier for the programs we launched 03:27:59.600 |
um um uh sequence length divided by block size q number of programs along the zeroth axis and 03:28:09.600 |
the batch size multiplied by number of heads on along the first axis of the grid of the launch 03:28:16.880 |
grid which will help us identify which um part of the query we are going to work with in this program 03:28:25.200 |
in this kernel and also in which batch and on which head this program should work with so that's 03:28:33.040 |
what we are going to do now we are trying to understand what part of the input we should 03:28:37.680 |
work with based on the ids of the program which corresponds to the block id in cuda 03:28:43.680 |
so let me copy so the program id zero indicates it's this stuff here tells us which part of the 03:28:53.760 |
queries so which block of the queries we are going to work with why do we have a block on the query 03:28:59.520 |
because as we saw before the output can be computed independently for each block of the 03:29:03.760 |
queries while each block of the query has to iterate through all the key and values 03:29:08.960 |
so this is what will tell us what is the index of the block of the queries we are going to work 03:29:15.040 |
with in this particular program then we can understand also which index which batch and 03:29:20.800 |
which head with this program is associated with the program id number one is the product of the 03:29:26.880 |
batch size and the number of heads it means that we will have as many programs on the axis number 03:29:32.960 |
one as there are indicated by this product so this product lets us understand this product will tell 03:29:41.360 |
us which batch and which head this particular program is associated with so to get the id of 03:29:48.400 |
the batch we just divide this number by the number of heads and it will give us the head index and 03:29:54.160 |
to get the head index inside this batch we just do the this number here modulus the number of heads 03:30:00.960 |
okay the next thing that we need to do we need to okay first of all when we pass a tensor because 03:30:12.720 |
as you can see here the q parameter to this attention forward method is a tensor because 03:30:18.480 |
it's the input of this function forward function and this forward function is called here when we 03:30:24.320 |
do attention dot apply and it's this q stuff here and this q stuff here has been created as a tensor 03:30:31.040 |
so when we pass a tensor to a triton kernel it's not really a tensor it is a pointer to the first 03:30:36.720 |
element of that tensor in the memory now we need to understand because now we know which batch we 03:30:42.800 |
are going to work with and which head we are going to work with we need to index this tensor to 03:30:47.680 |
select the right batch and the right head inside of the right batch which means that basically 03:30:53.840 |
we have this q tensor so we need to do some some sort of like some stuff like this like q 03:31:01.440 |
of the index batch and the number of the number of heads indicates the which head we are going 03:31:11.200 |
to work with so it should be index of head and we need to select everything that is inside 03:31:17.040 |
these indices so we are we need to enter the tensor at the right location where the particular 03:31:23.840 |
sequence length and head dimension for this batch and for this head starts for that we need to 03:31:29.520 |
generate an offset in which we need to move this tensor from because this tensor this pointer sorry 03:31:36.160 |
this not answer this pointer from because this pointer is pointing at the beginning of the entire 03:31:41.040 |
tensor so we need to move in the batch size dimension and in the number of heads dimension 03:31:45.840 |
to do that we generate the following offset which will tell us where this uh where this 03:31:52.640 |
particular batch and where this particular head starts in this tensor and to do that we need to 03:31:58.640 |
do the strides we need to use the strides so what we are going to do is we are creating we're going 03:32:03.520 |
to create the qkv offset this should be the sequence length which will be the index batch 03:32:11.520 |
multiplied by the stride for the batch dimension which will tell us how many elements we need to 03:32:17.680 |
skip to get to the next batch and it's based and we multiply it by the index of the batch that we 03:32:23.600 |
want so for the zeroth batch we don't skip anything because we are already pointing to the first 03:32:27.760 |
element of that batch but if we are at the batch one we will skip that many number of elements 03:32:32.720 |
plus we also need to skip the some heads how many head we need to skip based on which head 03:32:38.320 |
we are going to work with and what tells us how how to go from one head to the next the stride 03:32:44.080 |
of the head dimension so we multiply the index head so the head that we should be working with 03:32:49.280 |
with the stride q head all right then we select now triton helps us with a new function that i 03:33:02.560 |
think it was quite recent that helps us index element inside of a tensor without having to 03:33:08.880 |
deal with all the complex indexing maths that can be confusing for beginners so i will be using a 03:33:16.960 |
few methods to help us with this with this indexing and this function is called make block 03:33:25.760 |
pointer and it's this following so basically this make block pointer takes as input a vector 03:33:33.600 |
and sorry a pointer not a vector it takes as input a pointer in this case we are saying 03:33:41.920 |
create a block that has the following shape that is sequence length by head dimension so let me 03:33:48.720 |
do it one by one actually i don't want to confuse you guys with all this stuff altogether okay so 03:33:54.400 |
take a start there is a pointer that is right now pointing at q plus q kv offset so right now it is 03:34:02.320 |
not pointing at the first batch but it's pointing exactly to our batch so the the batch that this 03:34:08.160 |
particular program should be working with and inside this batch to the particular head that 03:34:13.200 |
this program should be working with which is basically saying that we have we are pointing 03:34:19.760 |
to a tensor that is as follows so we are pointing to the following tensors which is the right head 03:34:26.480 |
the right sorry the right batch and the right head and then we are selecting everything inside 03:34:34.160 |
so it's pointing to the first element of this particular tensor this tensor particular tensor 03:34:41.040 |
because we have already selected the batch and the head it is a two-dimensional tensor with this the 03:34:46.560 |
following shape because the following dimensions are sequence length and head dim so we are saying 03:34:52.480 |
take this pointer which contains a tensor of the following shape sequence length and head dimension 03:35:00.640 |
and i'm also giving you the strides of these dimensions that are in this pointer so the the 03:35:07.280 |
two dimensions that are that we need are the sequence dimension and the head dim dimension 03:35:12.320 |
which is this one for the q tensor and um and in this um in this query tensor we want to select 03:35:27.040 |
a block of queries based on the query on the block of queries that this program should be 03:35:33.040 |
working with so i think i need to maybe probably use the ipad otherwise it can be very confusing 03:35:39.840 |
to visualize so uh let's do it actually let me see if i can create another here and let's use the ipad 03:35:53.840 |
all right okay so we have a q vector q tensor because this construct we will be using it for 03:36:02.960 |
all the other tests so if you understand it for one tensor you understand it for all the others 03:36:07.040 |
we have a q tensor that is a batch by number of heads 03:36:14.800 |
number of heads then the sequence length and then the head dimension 03:36:19.840 |
with the following line so the this line here so when we create a q plus qkv offset 03:36:32.400 |
we are already selecting the right batch dimension and already the right head dimension which means 03:36:40.000 |
that we have already forwarded our q to not point to the first batch and the first head but to point 03:36:48.800 |
to the exact batch that this program is working with and the exact head that this program is 03:36:53.040 |
working with which basically means that right now it is pointing at a tensor that is made up of these 03:36:58.960 |
two dimensions now inside of this tensor we also need to select the right block of query that this 03:37:07.840 |
program should work with and this dimension here so the sequence dimension is all the queries so 03:37:14.880 |
we need to select the right queries so we need to skip some queries how to skip some queries well 03:37:21.200 |
we say that we need to skip block index multiplied by block size q number of queries because they 03:37:28.560 |
will be processed by another by another program that will have this number here the program id 03:37:36.480 |
will be different so we are selecting with this line not only inside of the q the right index 03:37:44.880 |
and the head but also the right position in this dimension in the sequence length dimension that 03:37:50.000 |
will point to the exact to the starting point of the exact query block that this particular program 03:37:56.160 |
should be working with this is what is happening and we are also creating this block basically 03:38:02.080 |
later we will see how it can be used to to create a block of the shape we are telling what is the 03:38:09.840 |
the size of this tensor so this tensor has two dimensions because we are pointing to the 03:38:16.960 |
beginning of the right query sequence so it has only two dimensions the sequence dimension and 03:38:22.400 |
the head dim dimension so it's the last dimension and we are already pointing to the right beginning 03:38:29.280 |
of the sequence dimension because we have already skipped some queries why we are skipping some 03:38:35.440 |
queries because these queries will be handled by another program that will have a block index q to 03:38:40.000 |
some other values and this order actually i don't know what is this order you can try to put 0 1 03:38:48.400 |
and 1 2 i think it's some optimization that triton does i have read the online documentation and i 03:38:52.880 |
couldn't find anything about it so this is something that i will investigate but actually 03:38:58.160 |
even if you put 0 1 it doesn't matter so i think it's something that you tell triton if this you 03:39:05.600 |
want the transposed of this block or you want the not transposed version of this block and later we 03:39:10.080 |
will see actually how we can transpose the key block without doing any transpose operation 03:39:15.280 |
actually we will just change the strides like we have seen before so um now this make block pointer 03:39:23.280 |
is not something that is necessary but it makes our life easier when we will index this particular 03:39:30.000 |
pointer so we can treat this pointer nearly as nearly in the same way when we work with the 03:39:36.320 |
tensor in pytorch we will be able to skip one increase one index in one dimension without 03:39:43.600 |
having to do the computation of the strides later when doing the backward pass i will not use this 03:39:49.360 |
one and do all the pointer indexing by hand so you can check the differences of indexing a tensor 03:39:54.560 |
by using make block pointer and not by using it anyway to rehearse what are we creating we are 03:40:00.800 |
creating a pointer to the right index in the batch to the right index in the head dimension and we 03:40:07.920 |
are already skipping some queries based on the block index queue so this pointer is already 03:40:15.040 |
pointing to the right block of queries that this particular program should be working with 03:40:18.800 |
let's look instead at the v and the k block now so let's copy the v block now which is 03:40:28.960 |
similar to the query but we are not going inside we are only indexing by the index batch and the 03:40:35.360 |
index head so what this one actually let me write it here is already skipping so this amount of 03:40:43.200 |
queries this is what we are indexing with this make block pointer so we are in the right batch 03:40:50.400 |
in the right head and we are skipping some queries here we are just indexing by batch and by head so 03:40:58.880 |
we are doing v of index batch index head and we are not selecting we are not skipping anything 03:41:07.680 |
because you see this offset is equal to zero in the first dimension in the second dimension 03:41:12.000 |
so we are not skipping anything on the sequence length and we are not skipping anything in the 03:41:16.160 |
head dimension dimension head dimension dimension um all right so let's look at the k block pointer 03:41:25.520 |
and this is different because as you know when computing the flash attention algorithm we need 03:41:30.880 |
to have access to the block of queries and all the block of the key transposed so when accessing the 03:41:38.720 |
key we shouldn't access it like we are accessing q we should invert the two dimensions that we want 03:41:46.000 |
to transpose for and that's very simple with make block ptr and you can see it here we say that we 03:41:53.040 |
want to point to the right index and to the right head and the tensor inside of it so let's let me 03:42:00.640 |
write it here so later i can explain in line by line so what we are doing here is go to the k 03:42:05.440 |
tensor select the right batch select the right head select everything that is inside so it's a 03:42:13.120 |
tensor of two dimensions with the sequence length and the head dim because we you can see here 03:42:18.400 |
here sequence length and head dims etc but we don't want first sequence length and then head 03:42:27.680 |
dim we want first head dim and then sequence length so we want to transpose it how to transpose it 03:42:32.400 |
we just say that you need to read this tensor with the two strides transposed so we are saying 03:42:37.840 |
first use the stride of the dimension dimension and then use the stride of the sequence dimension 03:42:44.400 |
and the shape of this tensor is not sequence head dim it's a head dim sequence and it's a block 03:42:55.360 |
of kvs why we are not putting directly the sequence dimension here because we want to 03:43:04.000 |
skip block by block later so we are not selecting all the sequence length in the sequence dimension 03:43:09.200 |
we are just selecting a block of kvs and later we will use another method to go to the next block 03:43:14.960 |
so i hope that by showing you the indexing like this it's a little easier to follow the indexing 03:43:22.000 |
so for each tensor we are going in the right batch in the right head dimension and for the 03:43:27.280 |
query we are skipping some query blocks because each each program will work with a small different 03:43:33.120 |
query block but for the key and value each program needs to iterate through all the key and value 03:43:40.000 |
so we just point it to the first key and value block and then we will advance by one block by 03:43:45.600 |
we will advance one block by one during the for loop that we will do later 03:43:51.120 |
then in the output also we need we can make a tensor block tensor this basically creates a 03:44:00.480 |
pointer just like in the query key and value case in which we select the right index batch 03:44:06.080 |
so what we are doing is we are indexing by batch we are indexing by head and we are selecting 03:44:12.560 |
everything is that inside i know we are not selecting everything inside we are skipping 03:44:17.040 |
also in this case some blocks of queries because as i said before the output has the same shape 03:44:24.240 |
as the query so um this particular block this particular program that we that will have this 03:44:32.080 |
particular block index queue will only work with one block of the queries which will produce only 03:44:37.840 |
one block of the output matrix and we need to select exactly that one so we we can point this 03:44:44.320 |
pointer exactly to the point where we should start writing so let's skip also in this case 03:44:49.680 |
block index q multiplied by block size q rows so we select exactly the block that our 03:45:01.120 |
our program this particular program will produce when i speak about this particular program i mean 03:45:08.960 |
the program that is identified by this program id in the x0 axis and this program id in the 03:45:15.200 |
first axis because each of this program will run in parallel hopefully and each of them will have 03:45:20.480 |
a different value for the block index q and index batch head okay now we have pointed our pointers 03:45:28.560 |
to the right position where they should either read some information or they should either write 03:45:32.800 |
some information by using make block pointer these pointers can also be treated directly as tensors 03:45:39.040 |
so that's why we specify the shapes of this tensor because python triton right now provides 03:45:44.160 |
some methods to work directly with blocks of to work directly with the pointers like they are 03:45:52.880 |
we are accessing um tensors so we can index them like tensors all right so basically just try it 03:46:01.440 |
on doing some calculation for you based on the strides so you don't have to do it by hand but 03:46:05.760 |
later when we do the backward pass we will avoid using big block pointer and we will see the 03:46:10.400 |
indexing done by hand all right um as you know we are processing a single block of queries so let 03:46:20.800 |
let's go back to the algorithm otherwise we we lose the sight of what we are doing 03:46:26.640 |
so let's go here and let's show my ipad all right so as you know each program we will parallelize 03:46:36.720 |
along the query block dimension so each program will work with a different query block 03:46:40.560 |
and then we need to do a for loop on all the key and value blocks right now we just 03:46:49.200 |
moved our pointers to the right position to select the right query block that we should work with 03:46:54.320 |
and to the beginning of the keys and values block that we should work with based on which 03:46:59.840 |
index and which head this particular program should be working with all right now that we 03:47:06.240 |
have pointed our pointers to the right position in which our program should be working it inside 03:47:11.040 |
of the big pointers that are inside of the big tensors that are the that have the batch dimension 03:47:19.360 |
the number of heads dimension the sequence length dimension and the heading dimension we have because 03:47:25.920 |
we are pointing to the right batch and we are pointing to the right head these tensors have 03:47:29.840 |
become two-dimensional tensors so they only work on the they are only tensors on the sequence length 03:47:35.760 |
and on the head dimension now we need some more information that we will use later 03:47:42.320 |
the first information that we need is the offsets of each query inside of the current block of 03:47:50.640 |
queries that this particular program should be working with and that is given by the following 03:47:56.320 |
line so let me copy and paste which is this one so the offsets of the queries are the first of all 03:48:05.840 |
they are how many of them block size q because each block of queries is made up of a block size 03:48:12.320 |
q number of queries what is each query it's a token and it's on the head dimension is the dim 03:48:19.920 |
dimension is not the all the embedding of the token but a part of the embedding of each token 03:48:26.160 |
which part the part corresponding to the head that this particular program is going to work with 03:48:32.000 |
so we are generating the offsets that will load this particular number of 03:48:36.560 |
this particular queries from the big tensor that contains all queries and we know that 03:48:41.840 |
our queries start at the block index q multiplied by block size q position so if this is the program 03:48:48.880 |
number zero they will the imagine block size is equal to four they will be the query with index 03:48:55.120 |
zero one two and three but imagine we are the program number three which means that we need 03:49:00.880 |
to skip three multiplied by four so 12 so it will point to the query number 13 14 15 and 16 etc etc 03:49:10.400 |
etc all right and we do the same for the key and values initially the key and values is a range of 03:49:20.160 |
keys and values that we need at each iteration and at the beginning because our pointer for the k 03:49:26.800 |
and v is pointing to the beginning of the sequence of key and value for this particular batch and for 03:49:32.480 |
this particular head we are pointing to the first block of key and value so we are not skipping 03:49:37.760 |
anything in the query case we are skipping because our program will only work with one single block 03:49:42.560 |
of queries in this case we don't skip anything because we need to iterate through this key and 03:49:47.280 |
values so we are pointing to the first block of key values so imagine block size kv is equal to 03:49:53.600 |
four so this stuff here will be equal to zero one two and three all right now we need as you remember 03:50:02.400 |
inside of the flash attention algorithm we need to compute a block of query multiplied by the 03:50:07.760 |
transpose of the keys and to each of this block we need to apply the softmax star if you remember 03:50:13.200 |
what is the softmax star it is the softmax of without the normalization so while computing 03:50:20.000 |
the softmax star we also actually compute the normalization factor without applying it and we 03:50:25.040 |
apply the normalization factor at the end so for each block of query multiplied by transpose of 03:50:30.480 |
the keys we need to have the maximum for each row in this particular block and the normalization 03:50:36.720 |
factor for each row so that's why why we need these two following statistics which is this one 03:50:44.080 |
and this is basically a block it's a block of numbers how many based on how many queries we 03:50:52.000 |
have in our block of queries each one initialized with minus infinity just like in my algorithm 03:50:57.600 |
that i showed before so let me go back to the slides in case we forgot or actually you can 03:51:03.680 |
also check the flash attention algorithm we initialize it with minus infinities so so far 03:51:07.520 |
we are creating this stuff here so we are initializing the mi we are we will be initializing 03:51:11.680 |
the li and we will initializing the o and then we will show the inner loop here and this is exactly 03:51:19.440 |
the algorithm that we have seen before so we initialize m with minus infinities now we initialize 03:51:25.200 |
also the l's so let me go back to the code all right so the l's are initialized with this number 03:51:37.040 |
here so here in the o blocks as we can see from the flash attention algorithm they are in the o 03:51:44.240 |
block is initialized with zeros so that's why we initialize a block this is the output block that 03:51:49.840 |
this particular program will compute which is based on the position in the batch and the position in 03:51:55.520 |
the index so it is one block of the size of block size q so how many queries there are in this block 03:52:01.200 |
by head dimension which if you want to visualize it let's go back to the slides it is equal to 03:52:09.440 |
one block of this matrix here so it's one block of the output matrix so one row of blocks one 03:52:19.600 |
block of rows okay so let's go back to the code now all right so now we have initialized a little 03:52:29.520 |
stuff here so the output the mi and li where mi is the maximum for each row in this particular 03:52:36.480 |
query block and li is the normalization factor for each of the items in the query for each of the 03:52:44.720 |
rows in our query block now we need to do the for loop the inner loop in the flash attention 03:52:53.200 |
algorithm we will create a separate method that will run the inner loop so let's let me copy it 03:53:01.040 |
here and i am following the same structure of the code that you see in the tutorial of 03:53:06.560 |
the triton website so basically if we are running the causal attention or even if we are not running 03:53:14.480 |
the causal attention we make this for loop and then we will make another for loop and i will 03:53:20.160 |
show you why so let me first write it and then we will see so this function here will be the 03:53:27.280 |
inner loop this inner loop needs to go through all key and value blocks one by one and for each 03:53:35.440 |
query and value block it needs to fix the previous calculated block of the the previous softmax star 03:53:43.440 |
block so basically what we are doing here we will need to create a function as the following where 03:53:50.720 |
we are going to iterate on all the key value block we will need to compute the query multiplied by 03:53:56.000 |
the transpose of the keys using the query block that is fixed for this program and the key is 03:54:01.360 |
block is the one that we are iterating it through and for each of these queries we need to calculate 03:54:06.400 |
what is the maximum for each row we need to compute the softmax star so the softmax without 03:54:13.040 |
the normalization factor we need to keep this statistics l which is the normalization factor 03:54:18.400 |
that we will apply at the end of the iteration of the for loop and at the same time we need to 03:54:26.080 |
update the output so as you remember the output is p11 multiplied by v1 plus p12 multiplied by v2 03:54:34.080 |
but we need to fix the previous p11 so to fix that we need to every time we sum to 03:54:39.920 |
o to the output we need to fix the output of the previous iteration 03:54:44.960 |
and then we increase introduce the p and v block of the current iteration so here the author of the 03:54:57.040 |
code for the the one that you see on the triton website decided to split this for loop into two 03:55:03.520 |
steps why because in the causal attention we need to when we have a causal attention 03:55:10.560 |
we have a group of we we don't we don't want the query to attend the keys that come after it 03:55:17.360 |
while in the non-causal attention we let all the queries attend to all the keys which also means 03:55:23.920 |
that we will need to have some kind of if statement inside of this if in the side of 03:55:29.120 |
this for loop through all the key and values in which we need to check if this particular query 03:55:35.120 |
that we are working with is comes before or after the key and value in case we are doing the causal 03:55:41.360 |
attention so instead of iterating through all the key and values also in the case of the causal 03:55:48.160 |
attention by splitting it into two steps we are saying first let's iterate through all the key 03:55:56.000 |
and values for which the index is smaller than the current queries block and for this we need 03:56:02.960 |
to compute the attention in the case of the causal and non-causal case then for all the elements on 03:56:09.280 |
the right of this block so for which the key index is more than the q index in the case of 03:56:15.760 |
causal attention we don't need to compute anything because it will be masked out because in the soft 03:56:20.800 |
max it will become zeros so it will not contribute to the output so we don't even have to compute it 03:56:26.320 |
this is why we split this this for loop into two steps so first we iterate to all the parts that 03:56:34.400 |
are left to the diagonal of the query multiplied by the key matrix so for all the values for which 03:56:41.200 |
the query index is less than the key index then we and then we skip all the parts to the right 03:56:49.120 |
of this diagonal in case we are working with the causal mask but in case of the non-causal mask we 03:56:53.760 |
compute the left part and the right part of this diagonal all right don't worry when we record 03:56:59.600 |
this for loop it will be more clear so i just wanted to give a little introduction so let's go 03:57:04.800 |
uh code this inner loop what will this inner loop do it will work with this particular query block 03:57:10.400 |
that we have found so this q block it will uh right i don't see the q block because i didn't 03:57:17.520 |
load it well yeah let's load it so we need to load the query block actually we forgot to load it so 03:57:25.120 |
as you remember in triton we load data from the high bandwidth memory to the sram so to the shared 03:57:31.680 |
memory by using the load statement and we are telling load the query block that we should be 03:57:37.280 |
working with because this pointer q block ptr is already pointing to the right block that we should 03:57:43.360 |
be working with so it's already skipping all the blocks that other programs should be working with 03:57:48.560 |
and it will load a a tensor of size of block size q head dim so the right block of queries 03:57:57.840 |
and we pass it to this inner loop to which we pass the output so where it should write this output 03:58:05.200 |
the li and mi which are the statistics for the rows and for the maximum for each row of each 03:58:12.960 |
query and the li which is the normalization factor for each query and the query block this program 03:58:19.360 |
should be working with the beginning of the key and value block pointer because we need to iterate 03:58:25.200 |
through them so we just point it to the beginning and then inside the for inner for loop we will 03:58:30.000 |
iterate through them then the softmax scale that we should use when computing query multiplied by 03:58:34.880 |
the transpose of the keys the block size so how many queries we have in each block of q 03:58:40.560 |
and how many key and value we have in each block of kv this is a stage that tells us what if we are 03:58:48.480 |
on the left side of the diagonal or on the right side of the diagonal so it will tell us if we need 03:58:52.800 |
to apply the causal mask or not based on where we are and if we are need to apply the causal mask 03:59:00.160 |
the offset q and the offset kv are just the offsets of the query and key inside of each q 03:59:06.480 |
and kv block which is a list of indices that tells us how many queries we have 03:59:14.160 |
and then the sequence length the entire sequence length because in the for loop we need to iterate 03:59:20.400 |
to all the sequence length block by block so block of kv block of kv block of kv all right 03:59:26.800 |
let's write this let's write this method and later we actually need to continue this method again 03:59:40.960 |
so this method we have already seen the signature so it's just another kernel so it can be called 03:59:50.640 |
by the first kernel and this is something you can also do in cuda you can actually call 03:59:55.920 |
call one cuda kernel from another cuda kernel and then we based on the stage of this inner 04:00:03.680 |
loop we decide what we need to do so when we are using a causal attention so we only want to 04:00:12.560 |
apply the attention to the queries for which the index is less than or equal to the key so we all 04:00:20.000 |
want the query to know or attend to key and value that come after it then we pass the value three 04:00:28.240 |
for the stage parameter now when we in the causal case this will become four minus three it is equal 04:00:35.760 |
to one so what will happen is that we will only work with the range of keys and values that are 04:00:44.240 |
from zero up to the current block of q so all the keys that whose index is less than or less than 04:00:51.600 |
the the index of the queries we are working with so to the left part of the causal mask let me draw 04:00:58.720 |
it otherwise i think it's going to be very difficult to follow so let's do it actually 04:01:03.120 |
so let's open a new one and let's go here all right so we have been using this one before so 04:01:11.040 |
we can do it again clear page all right in this now i i want you to think of the following 04:01:19.840 |
matrix as a block matrix so let's draw it in pink because i have been drawing it all in pink 04:01:24.800 |
we know that in the rows of this query multiplied by the transpose of the keys we have a uh the 04:01:31.520 |
queries blocks of queries so we are not watching one single block we are watching all the blocks 04:01:37.120 |
right now so this is the query block one this is the query block two this is the query block 04:01:41.840 |
three this is the query block four each of this query block is made up of multiple tokens of 04:01:46.960 |
queries and then we have the key the key blocks let's do it like this very ugly but okay key one 04:01:57.520 |
key block two key block three key block four when apply calculating the attention when you calculate 04:02:04.000 |
the causal attention so like with the causal mask you want only the query to attend to keys that 04:02:12.480 |
come before it so when we apply the causal mask this stuff here will be made up of zeros this 04:02:18.240 |
stuff here will be made up of zeros this stuff here will be made up of zeros and this stuff here 04:02:22.480 |
and this stuff here and this stuff here all made up of zeros we never have to mask out anything 04:02:28.720 |
when we are in this case because well when we are in this particular scenario actually in this 04:02:34.160 |
particular scenario we don't need to mask out anything for sure why because all the key 04:02:39.520 |
keys in this block so in this block of keys will have an index that is smaller than the index of 04:02:46.800 |
the corresponding queries in case the the key the block size of the query and the key matches 04:02:52.880 |
so imagine each query is made up of three queries so each block of query is made up of 04:02:57.920 |
three queries so this is the query number 0 1 and 2 this is the query number 3 4 5 3 4 5 yeah 04:03:07.600 |
this will be the number 6 7 and 8 and this will be the query number 9 10 and 11 in total we have 04:03:15.200 |
12 queries we will have the same indices also for the keys in case we choose the same size for the 04:03:21.920 |
blocks so this key block here will be the key number 0 1 and 2 this will be the key number 3 04:03:32.400 |
4 5 this will be the 6 6 7 and 8 etc etc now what happens is that in this case as you can see the 04:03:42.720 |
key indices of the keys are always smaller than the indices of the queries so we don't need to 04:03:48.960 |
mask out anything even in the case of the causal mask because we are sure that in this case all 04:03:54.240 |
of these dot products will never be masked out also in this case all these dot products will 04:03:59.360 |
never be masked out and also in this case will never be masked out will never be masked out 04:04:02.960 |
and will never be masked out and in this case however along the diagonal some of the queries 04:04:09.920 |
will be more have will have an index that is bigger than that of the keys and some of them 04:04:16.000 |
will not be will not have an index that is bigger than that of the keys because these are blocks of 04:04:22.160 |
queries and blocks of keys some of them need to be masked out and some of them don't need to be 04:04:26.720 |
masked out so we are dividing our for loop into multiple steps the first step that we are doing 04:04:32.720 |
is all to the left of this diagonal in which we don't need to mask out anything then we will see 04:04:38.240 |
another step here in which we we need to mask out and then everything to the right of this will be 04:04:46.800 |
we will not even compute in the case of causal attention because we already know it's made up 04:04:50.400 |
of zero so it will not compute so the product query multiplied by the transpose of the keys 04:04:55.280 |
after the softmax will be made up of zeros so if you look at the flash attention algorithm so 04:04:59.840 |
this stuff here the contribution will be zero because we are multiplying zero with v it will 04:05:05.920 |
be zero so we don't need to change the output so why even compute this part of the matrix if we 04:05:11.200 |
already know it's not going to contribute to the output so we just skip all those iterations 04:05:15.280 |
and this is why we are splitting the for loop i hope now it's much more clear all right so let's 04:05:22.560 |
go back um okay so uh we are now to the left part of the diagonal in case of the stage number one 04:05:29.520 |
in the case of the stage number two it's the part in exactly on the diagonal so in which we need to 04:05:35.680 |
do some dot products and some other dot products we don't need to do and then for the non-causal 04:05:40.560 |
attention we just go through from zero to the sequence length without doing this multi-step 04:05:45.120 |
because we don't need to mask out anything so this is why we have this stage this tells us 04:05:51.920 |
what is the lower and higher index of the key block that this particular stage should be working with 04:05:59.440 |
all right um now this function here multiple of is just telling triton that this number here is a 04:06:07.920 |
multiple of this number so triton can make some optimizations so the stage one happens when when 04:06:14.400 |
we are doing a causal attention so stage number three in this function and four minus three will 04:06:21.120 |
become one so imagine we are in the causal attention we will go through the key and value 04:06:26.400 |
block that are to the left of the diagonal with respect to the query block that we are working 04:06:32.320 |
with um in the case we are doing not causal attention in this first call to the inner 04:06:40.320 |
function this the stage will be one so the four minus stage will be equal to three so we will 04:06:46.960 |
execute this part of the if statement so we will go to all the key and values in case for the 04:06:55.120 |
causal attention only as you can see here we will do another iteration here that will only be done 04:07:00.960 |
along the diagonal in which we need to mask out something and we don't need to mask out something 04:07:05.120 |
because inside of each blocks there will be some keys that have the index below the index of the 04:07:11.280 |
query and some that have above the index of the query so only in the causal attention we will 04:07:16.240 |
call this function twice the first time with the stage equal to one and the second time with the 04:07:22.160 |
stage equal to two and the second time we will only iterate through the group of key v blocks 04:07:29.360 |
that are exactly on the diagonal of the matrix query multiply by transpose of the keys the big 04:07:36.160 |
matrix that is made up of all the blocks all right now that this should be clear let's proceed 04:07:42.320 |
further so let's um because we need to do the for loop the inner for loop of the flash attention 04:07:49.120 |
let's go and load the first blocks of key and values which is exactly the one that the key 04:07:55.920 |
and v blocks are currently pointing at which is the 0 0 block so uh we we define the the pointers 04:08:05.200 |
basically um we we we point the key and value blocks to the first uh key and value block that 04:08:12.720 |
this um for loop should be working with which will be based on the stage so if it's the first call to 04:08:19.040 |
this function they will be pointing to the first block in the case of the causal and not causal 04:08:25.600 |
if it's the second call to this function which only happens in the case of the causal attention 04:08:30.800 |
they will be pointing exactly to the key and value block to the diagonal 04:08:34.400 |
all right then we need to make the for loop so let's loop over all the for loop so let's do it 04:08:44.640 |
so loop over the key and value and what we do is um okay we we let the compiler know that 04:08:52.960 |
this number here the start kv will always be a multiple of the block size kv because we will be 04:08:58.400 |
moving from one kv block to the next kv block block by block so we let the compiler know that 04:09:04.320 |
this number here start kv is a multiple of block size kv it doesn't change anything from a logic 04:09:08.880 |
point of view we are just telling giving some hint to the compiler so it can do some other 04:09:13.280 |
optimization that triton does now the first thing that we see in the flash attention algorithm is 04:09:20.880 |
we need to compute the product of the query so this is the particular block of the query that 04:09:25.600 |
we are working with with the current kv block in this iteration so let's do it so we compute k and 04:09:33.200 |
b so we load the the query have already been loaded by the caller of this function we have loaded it 04:09:40.480 |
here here we have already loaded the query but we need to load the current block of k 04:09:46.880 |
so we load the current block of k indicated by the k pointer and we multi we do the matrix 04:09:54.720 |
multiplication of the current block of query the the block of query with the current block of k 04:10:00.480 |
which is already transposed because when we loaded this k k when we defined the k block pointer we 04:10:08.640 |
defined it already with the stride changed so we are reading the tensor already transposed so we 04:10:14.800 |
are doing the query multiplied by the transpose of the keys basically okay now let's do here 04:10:22.560 |
this part here basically saying okay if the stage is two when the stage is two is when we are 04:10:30.560 |
exactly on the diagonal we know that some of the queries will have an index that is bigger than 04:10:35.760 |
that of the keys and some of them we have an index that is smaller than that of the keys 04:10:40.080 |
so we need to apply the causal mask only in this case so basically what we do is we define 04:10:46.960 |
the mask that we should be applying so the mask will mask out all the values for which this 04:10:53.600 |
mask is not true so when this mask is true when the index of the query is more than the index of 04:11:00.960 |
the k and v's and we okay we apply the softmax scale so as you remember we here we only computed 04:11:09.440 |
query multiplied by transpose of the keys but we also need to divide by the square root of 04:11:13.600 |
head dimension and we do it here and then we because we already computed the 04:11:22.960 |
the product we can calculate the maximum for each row and then we we we subtract because 04:11:31.440 |
when later in the flash attention algorithm we have another operation which is the 04:11:36.080 |
which i call the softmax star and as you remember the softmax star needs to do each row my each 04:11:46.560 |
element of the s matrix so the query multiplied by the transpose of the keys minus the maximum 04:11:51.600 |
for each row so we can already compute the maximum for each row and we can also before 04:11:58.640 |
computing the maximum for each row we need to mask out all the elements that will be masked out 04:12:04.000 |
in the stage number two which is along the diagonal and how to mask out we just replace 04:12:10.960 |
with minus infinity before applying the softmax all the values for which the mask is false 04:12:15.840 |
so right now we are we have computed what we have computed the query multiplied by 04:12:20.960 |
transpose of the keys we have masked out in case we need to mask and when we need to mask only when 04:12:25.840 |
we are along the diagonal in all the other cases we don't need to mask out anything we just multiply 04:12:30.480 |
by the softmax scale and then we we subtract the mij the mij is the maximum value for each row 04:12:39.120 |
because we need to compute the softmax star operation which is the softmax without the 04:12:43.280 |
normalization which in the flash attention algorithm is exactly this operation which 04:12:47.520 |
will produce the pij okay so let's go here so now we can compute the pij block which is this stuff 04:12:56.080 |
here which is the exponential of the query kv block variable here which have already subtracted 04:13:03.920 |
the m so we have already subtracted this mi at the previous instruction so now we can just 04:13:11.680 |
apply the exponential and this is what we are doing here okay then we need to compute the 04:13:20.000 |
sum of the the rows for the before the normalization factor so for the current block we will 04:13:29.760 |
have a list of we have we have the pij block for the current kv block to compute the normalization 04:13:38.960 |
factor for the softmax we need to keep summing up these exponentials and later we will fix the 04:13:45.200 |
exponentials the the normalization factor that we computed at the previous step but we will do that 04:13:52.080 |
later so now we just computed the normalization factor for the current block which is just the 04:13:56.560 |
sum of all the values on a single row which is the same as what we did before here as you can see 04:14:04.400 |
here when i show you the algorithm so for each block we do the row sum as you can see here 04:14:13.440 |
of the p matrix what is the p matrix is the exponential of the s minus m and for now we 04:14:21.280 |
didn't apply the the correction to the previous block that's it so we computed the lij for the 04:14:28.240 |
current k and v block and then we compute the correction factor for the previous block so the 04:14:34.400 |
correction factor for the previous block if you remember the formula from the paper is this one 04:14:39.200 |
is the exponential of the previous estimate of the maximum minus the current estimate of the maximum 04:14:44.160 |
which is exactly this one so the previous estimate of the maximum minus the current 04:14:47.760 |
estimates of the maximum we will see later why mi is the previous estimate of the maximum and 04:14:54.000 |
mij is the current estimate of the maximum because it is coming from the current block that we are 04:14:58.640 |
computing mi is the let's say the the one that it is the the one of the previous iteration because 04:15:06.480 |
later we will override mi with mij but i'm just following the flash attention algorithm so far 04:15:13.840 |
so i am computing the correction factor of the previous li which in the flash attention algorithm 04:15:19.600 |
is let me show you this stuff here so it is this stuff here this one here 04:15:27.440 |
okay and then we apply it so apply the correction factor so we apply it so we apply the previous li 04:15:39.440 |
with the correction factor plus the current li which is the one coming from the current p block 04:15:44.480 |
the one that will be computed with the current k and v with the current iteration and right now we 04:15:50.320 |
are doing this operation so li is equal to the previous li multiplied by the correction factor 04:15:55.520 |
all right and then what we need to do okay we need to as you remember the formula is 04:16:02.240 |
we calculate the p block and then we need to multiply by the v block 04:16:07.520 |
so we need to load the v block so let's load it 04:16:13.680 |
we load the v block based on the pointer of the v block to which this um to to which the 04:16:20.240 |
pointer v is is pointing to at the beginning of this iteration in case we are in stage number 04:16:26.400 |
three so in case we are doing for example not causal attention this will be pointing to the 04:16:30.400 |
first k v block v block and then okay here there is just a type conversion so we make sure this is 04:16:39.760 |
in floating point 16 and then we compute the output block so we are computing the following 04:16:48.160 |
so we just take v p multiplied by v and we add it to the output and this is what we are doing here 04:16:54.960 |
we take p we multiply it by v and add it to the o block let's go actually to this line one by one 04:17:02.800 |
so first of all we need to fix the previous output block with the correction factor correction factor 04:17:07.920 |
that we have here so we can fix the previous block with this alpha term here which is the 04:17:13.600 |
correction factor for the previous block and so we just fix the previous block for now but we 04:17:19.760 |
didn't add the new pv so to add the new pv we do the dot product of p and v and this third argument 04:17:26.160 |
tells the dot this not dot product it's actually the matrix multiplication tell this matrix 04:17:31.840 |
multiplication to use this element here as the accumulator so this is exactly the same as doing 04:17:38.080 |
p block multiplied by the v block o block plus equal to p block multiplied by the v block 04:17:47.920 |
this is just optimized because anyway this dot function here needs some place where to store 04:17:55.040 |
the intermediate results so why not just store it where it should actually go and because it 04:18:01.600 |
the dot the the matrix multiplication is just a dot product and the dot product is just a repeated 04:18:07.120 |
sum this accumulator will be will this dot will keep summing the result to this block here which 04:18:16.000 |
will exactly result in this instruction like we have done the matrix multiplication separately 04:18:22.960 |
and we added it to the o block so this is uh that's why this argument is called the accumulator 04:18:30.080 |
okay all right so we have also computed the output and then we save the new estimation of 04:18:37.440 |
the maximum for the current iteration and it becomes mi so at the next iteration we can use 04:18:43.440 |
it to calculate the correction factor and then we have finished for the current block and then we 04:18:49.760 |
can move on to the next block so we advance our k and v pointers by one block of k and v 04:18:57.840 |
we advance it differently because we know that the v block is a pointer to a tensor of shape 04:19:05.280 |
let me write it here this is a tensor of shape sequence length head dim so we need to increase 04:19:15.040 |
the sequence length by one kv the block size kv while the k block is actually the k transpose 04:19:23.200 |
block so we need to and it is a transpose because we have exchanged the strides and the shape 04:19:31.120 |
head dimension sequence length so we don't change the head dimension we just advance the sequence 04:19:39.360 |
length by sequence block size kv so basically we are just going to point to the next block of k and 04:19:45.920 |
to the next block of v i hope you were able to follow the algorithm of flash attention i try to 04:19:52.720 |
use the same names i try to use the more or less the same logic and always writing the formula that 04:19:57.280 |
i'm referring to so hopefully you didn't get lost i think the only difference that there is between 04:20:02.400 |
the flash attention algorithm as written on the paper and this code is probably this alpha which 04:20:06.800 |
is the correction factor but i hope it's easily understandable anyway then we just return the o 04:20:13.440 |
block so o block li which is the the normalization factor for each row in the current output block 04:20:24.480 |
which is also a q block because we are working with one q block independently from the other 04:20:30.000 |
programs and mi is the maximum value for each row which will be needed for the backward pass 04:20:37.280 |
because when in the backward pass we will compute the qquery multiplied by transpose of the key 04:20:42.480 |
block on the fly we need to also apply the softmax but instead of re-computing the stuff which we 04:20:47.680 |
already computed during the forward pass we just save them and reuse them during the backward pass 04:20:52.480 |
which will save us some computation now i know it's time to talk about the log sum x trick because 04:20:59.200 |
we are going to use it so let's go back to the old method so let's go here all right so we have 04:21:06.640 |
computed two calls of this function in case we are working with causal attention in case of the 04:21:12.880 |
we are computing causal attention we call this function once to work with all the query blocks 04:21:17.520 |
that are to the left side of the diagonal of the query key matrix then we do another call of this 04:21:23.280 |
function to work only with those blocks of keys that exactly lie on the diagonal of the query key 04:21:30.800 |
matrix because in this case some of the values need to be masked out and some of them do not 04:21:37.520 |
need to be masked out moreover by doing this we can avoid computing the dot products for all those 04:21:44.400 |
values in the causal math in the causal case for which the key is index of the key is higher than 04:21:51.760 |
the index of the query saving some computation because anyway they will be resulting after the 04:21:56.560 |
softmax in zeros and they will not contribute to the output so it should be faster okay now let's 04:22:03.520 |
go back to the this method here so calling method and there is one last thing that we need to do 04:22:08.480 |
which is we need to compute the log sum exp and now i will show you what is it so in order for 04:22:16.320 |
the backward pass to recompute the softmax without having to recalculate the normalization factor and 04:22:21.760 |
the maximum value for each row we should be actually saving two different stuff one is the 04:22:26.640 |
maximum for each row in the query block and one is the normalization factor for each query in the 04:22:32.400 |
query block however there is a trick and the trick is okay it's not really called log sum exp trick 04:22:38.960 |
because the log sum exp trick is used for another purpose but let's call it log sum exp trick number 04:22:44.720 |
two so the log sum exp trick number two is something like this so let me open the slides 04:22:52.640 |
so when we do query multiply that transpose of the keys we get a matrix that is made up of dot 04:22:59.680 |
products so something like this like this is one dot product so let's call it query one transpose 04:23:05.920 |
the key one query one transpose the key two this is a query two transpose the key one and this is 04:23:13.200 |
a query two transpose the key two then we need to apply the softmax right so the softmax is what 04:23:20.480 |
is the let's write the formula of the softmaxes for each of these vectors so this is a vector 04:23:25.040 |
and this is a vector because we applied it by rows for each of these vectors this will modify 04:23:30.080 |
element wise each element as follows so the softmax of x i is equal to the exponential 04:23:37.040 |
of x i minus oh my god i didn't leave enough space so let's move this stuff here back 04:23:44.000 |
and this stuff here please left all right it will be the softmax of the exponential of each element 04:23:52.960 |
minus the maximum for the current vector to which we are applying the softmax 04:23:57.600 |
divided by the normalization factor which is the summation over all possible j's 04:24:05.360 |
where n in this case is equal to 2 because we have each vector is made up of two elements 04:24:09.280 |
of the exponential of x i minus x max now imagine we already have x max and we already have this 04:24:21.200 |
summation in the flash attention algorithm in the forward pass this stuff here is called l i 04:24:26.240 |
and this stuff here is called m i what we are going to save in the code you can see here 04:24:32.880 |
we are saving actually not m i and l i separately we will be saving m i plus the logarithm of l i 04:24:41.680 |
so we are going to save m i plus the log of l i so what will happen is that when we will compute the 04:24:54.240 |
compute the backward pass we need to recreate this matrix here on the fly which means that 04:24:59.680 |
we need to recompute the query multiply by the transpose of the keys and we to and then we should 04:25:06.080 |
apply the softmax to apply the softmax we should need this stuff and this stuff here but we have 04:25:11.200 |
only this stuff here so this is the m i plus the logarithm of l i so when we're computing the 04:25:17.040 |
softmax we will compute the following so we will compute the softmax as follows we will define 04:25:23.600 |
let's call it a new softmax so let me use another color here we will apply the softmax as follows so 04:25:33.680 |
softmax of x i let's call it the softmax 2 because it's a i don't want to confuse softmax 04:25:46.880 |
is equal to the exponential of each element minus we will subtract this value here 04:25:54.400 |
the one corresponding to the current row to which we are applying the softmax 04:25:58.400 |
so it will be the exponential of x i minus m i minus the log of l i if we expand this expression 04:26:09.040 |
this will become the exponential of because the exponential the sum of two exponential of the sum 04:26:16.800 |
is equal to the product of the two exponentials we can also write it like this so it will be 04:26:21.040 |
the exponential of x i minus m i divided by the exponential of the log of l i which guess what 04:26:36.800 |
it is equal to the exponential of x i minus m i divided by l i which is exactly the normalization 04:26:45.440 |
factor and we also have m i so instead of saving two values we save only one value and when we 04:26:51.040 |
apply it the exponential's properties will take care of actually also normalizing each value to 04:26:56.080 |
which we apply it if you don't remember the properties of the exponential it is very simple 04:27:01.040 |
so the exponential of a plus b is equal to the exponential of a multiplied by the exponential 04:27:10.320 |
of b and the exponential of a not exponential it's the exponential 04:27:18.080 |
a minus b is equal to the exponential of a divided by the exponential of b 04:27:28.080 |
and this is the the trick that we're using so that's why we don't need to save two different 04:27:32.000 |
values we just need to save one value and then when we apply it it will automatically be taken 04:27:36.800 |
care will take care of normalizing because of the properties of the exponential all right let's move 04:27:42.560 |
forward so we have also created this value that we will use during the backward pass now as you 04:27:51.280 |
remember in the flash attrition algorithm we don't normalize each block while computing it we normalize 04:27:57.200 |
the output at the end and this is exactly what we are going to do here so we normalize the block at 04:28:04.000 |
the end after we have computed all the normalization factors that we need for all the rows that belong 04:28:08.560 |
to the current output block we save this m i so we save it this m i is what is the normalization 04:28:17.920 |
factor and the maximum for each row that we will need for the backward pass so we need to save it 04:28:24.640 |
in a tensor that we will use during the backward pass so we need to understand which tensor is this 04:28:29.920 |
and it's the tensor that we called m which is a tensor of a batch size num heads and sequence 04:28:35.440 |
length dimensions so we need to select the right point in this tensor to select to where we should 04:28:42.160 |
save this m i values so we need to select the right batch size index and the right number of head 04:28:49.120 |
index so we advance this pointer by the following offset which is m plus the index batch head 04:29:01.520 |
because each index okay the index batch head is what is the index of the current program that 04:29:09.520 |
includes information about which head we are working with and which batch we are working with 04:29:14.640 |
because each of this for each batch and for each head we have a sequence length we can skip 04:29:23.440 |
a number of sequence length based on which index is okay what we are doing is basically we are 04:29:32.240 |
skipping for each batch and for each head we will have a sequence length because each token in the 04:29:42.800 |
sequence has a maximum value and each token in the sequence will have normalization value 04:29:47.520 |
so based on the current combination of batch and head we can skip a number of sequence length that 04:29:54.240 |
other programs will process so because in this tensor we have the sequence length as the last 04:30:01.440 |
dimension and we have what is the combined index of the batch size and number of head size we can 04:30:08.160 |
skip a number of sequence of length based on the combined index which is given by the program index 04:30:13.760 |
number one which is the index batch head that we have here and this is why we skip here a sequence 04:30:20.240 |
length number multiplied by the index batch head this m is pointing to the first element of the 04:30:29.600 |
entire tensor so we are skipping the heads and the batch based on the combined index index batch head 04:30:37.920 |
that this particular program is working with and then we have off skew off skew is because each 04:30:43.920 |
of these kernels the attention forward method will work with one query block 04:30:52.240 |
each query block has some indices for the exact queries it includes and this is given by off skew 04:31:01.600 |
variable that you can see here which is how many blocks of queries we need to skip because they 04:31:06.560 |
will be processed by other programs plus the range of queries that this particular that not this that 04:31:12.800 |
a particular block of queries has so imagine this particular program is working with the queries 04:31:20.080 |
that go from i don't know from 12 to 16 then this will be 12 13 14 15 so the normalization factor 04:31:29.280 |
and the maximum value for each row we only have that for the disk for this indices of query queries 04:31:36.960 |
so 12 13 14 and 15 and that's why we need to also skip the number of queries that this particular 04:31:42.880 |
program works with which is already included in this offset of skew variable all right so now we 04:31:50.480 |
can store the mi so because we have the pointer to which where it should be saved and we can also 04:31:56.400 |
store the output which was computed of by our inner for loop and this guys is the forward step 04:32:03.520 |
of the attention flash attention now we should go forward which is we should compute the backward 04:32:11.600 |
pass we also have all the ingredients for computing the backward pass because we have 04:32:15.520 |
already seen this trick which is the log sum x trick so we already know what um how to use it 04:32:22.000 |
to compute the query key block during the backward pass on the fly what we miss to understand the 04:32:27.840 |
backward pass well we need to understand what is the first of all what is the backward pass why do 04:32:32.160 |
we even need a backward pass we need to understand what is the autograd of pytorch how does it work 04:32:37.680 |
how to compute the gradient what is the gradient how to compute do we need to what is the jacobian 04:32:43.200 |
when computing the gradient on the backward pass do we even need to compute that so we need to 04:32:47.520 |
derive all the formulas of the backward pass by hand so if you are in for the challenge let's 04:32:52.240 |
continue all right so now before looking at the flash attentions backward pass at the algorithm 04:32:59.920 |
we need to understand why we even need a backward pass and to understand why we even need a backward 04:33:04.320 |
pass so before looking at the autograd of pytorch we should be looking at what is what are derivatives 04:33:10.480 |
what are gradients what are jacobians so that when we talk about derivatives gradients and jacobians 04:33:14.480 |
we don't feel lost so i will do a very fast let's say rehearsal of what these topics are 04:33:22.240 |
now what is the derivative when you have a function that takes as input a real value and 04:33:27.840 |
outputs a real value we talk about derivatives which is defined as follows the derivative of 04:33:33.680 |
the function with respect to its variable x is defined as the limit for a step size that 04:33:39.520 |
goes to zero of the function evaluated at x plus h so x plus the step size minus f evaluated at the x 04:33:48.240 |
at x divided by the step size so intuitively we are saying is the ratio of how much the output 04:33:56.160 |
change for a small change for how much the input has changed in the function that this also gives 04:34:02.400 |
you the intuitive intuition of why the gradient is the derivative is also the tells you the 04:34:11.600 |
inclination of the tangent line of the to the function at the point in which it's evaluated 04:34:16.880 |
i will use also the following notation to denote the derivative so the derivative i am used to 04:34:24.960 |
write it as like this so f prime of x but it's also possible to write it as a d of f of x with 04:34:30.720 |
respect to the x or d of y where y is the output of the function with respect to x 04:34:35.600 |
and they are all equal to the same thing which is the definition above if we invert this formula 04:34:41.760 |
here and we take h to the left side we can also write the follows so if we want to evaluate the 04:34:49.200 |
function at the position x plus h we can also evaluate it as f prime of h so the derivative 04:35:00.000 |
of the function in the point x multiplied by h which is the step size plus f of x this is 04:35:07.280 |
actually also how we derive the Euler rule for computing the differential equations but that's 04:35:12.960 |
not the topic of today so this h we can also call it delta x so f of x plus delta x is more or less 04:35:21.200 |
because here we have a limit that says when this only happens when h is very very very small so 04:35:25.600 |
that's why we put this more or less approximately so f of x plus delta x is more or less equal to 04:35:33.280 |
f prime of x multiplied by delta x plus f of x this you can also read it as follows that if by 04:35:41.440 |
inverting this formula if x changes by a little amount and this little amount is delta x how much 04:35:52.320 |
y will change? y will change by this exact amount which is the derivative of y with respect to x so 04:36:00.240 |
dy with respect to dx multiplied by how much x has changed so this dy dx tells us how much 04:36:08.320 |
y will change with a small change of x if we multiply with the actual change of x it will 04:36:14.560 |
tell us how exactly y will be affected i don't want to use stay too much on this but i would 04:36:23.520 |
like to use this intuition to introduce the chain rule because imagine we have a function of a 04:36:31.600 |
function so imagine we have z is equal to f of g of x we can think of x being mapped into a variable 04:36:39.600 |
y through the function g and then y being mapped into a variable z through the function f if x 04:36:46.880 |
changes by a little bit and by a little bit i mean delta x how much y will change? well y will change 04:36:53.200 |
by delta y what is delta y? delta y is the derivative of y with respect to x multiplied 04:36:58.400 |
by the step size of x but if y changes it will also affect z because there is a direct mapping 04:37:05.920 |
between y and z so how much z will change for a small change in y? let's see so if y changes 04:37:13.360 |
from the old y by a small step delta y then z will also change by some delta z and this delta z 04:37:21.520 |
is the dz on dy multiplied by delta y if we replace this delta y with the delta y that we 04:37:29.840 |
have computed in the expression above we arrive to the chain rule it will tell us how z will be 04:37:36.800 |
affected so this is delta z what is the effect on z for a small change on x and it's the product 04:37:45.600 |
of the two derivatives the one with the of y with respect to s and one z with respect to y 04:37:52.240 |
and this is the chain rule that we study in high school so it is if you want to compute dz on dx 04:37:59.440 |
it is dz on dy multiplied by dy dx which is very intuitive if you think about the following 04:38:07.360 |
example so you can think of z as the price of cars and x as the price of the oil how much will 04:38:16.880 |
a small change in the price of oil affect the price of a car? well this small change in the 04:38:23.200 |
price of the oil will affect for example a variable y which could be the price of electricity 04:38:30.000 |
so if how much the price of electricity will affect the price of a car it's through the 04:38:37.360 |
derivative of the price of the electricity with respect to the the price of the car with respect 04:38:42.080 |
to the electricity so to get the effect of the price of oil on the price of the car we just 04:38:49.280 |
multiply the two effects and this is the intuition behind the chain rule anyway let's talk about 04:38:55.200 |
gradients so when we have a function that as input takes a vector and produces a scalar we talk not 04:39:02.320 |
anymore about derivatives we talk about gradients so imagine we have a function that takes as input 04:39:09.280 |
a vector made up of two dimensions but n dimension in general and it produces a scalar when do we 04:39:15.680 |
have to deal with this kind of function for example loss functions loss functions are something that 04:39:21.200 |
are always a scalar as output and as input they take tensors so for example imagine the cross 04:39:27.760 |
entropy loss it will take a sequence of tokens each tokens with its own logics and it will compute 04:39:34.800 |
one single number which is the loss so how to view the effect on the output with respect to the input 04:39:44.160 |
in this case well if x changes by a little amount and this little amount is not anymore a number but 04:39:50.400 |
it's a vector so if change the x the old x plus delta x is a vector sum then y will also be 04:40:00.400 |
affected by what y will be affected by dy on dx multiplied by delta x however this delta x is not 04:40:08.720 |
a number anymore it's a vector because x1 may change by a little bit x2 will change by a little 04:40:14.880 |
bit x3 will change by a little bit x4 until xn will change by a little bit so this is actually 04:40:20.880 |
a dot product of this vector multiplied by this vector why a dot product because y will be affected 04:40:28.880 |
by the change in x1 it will be affected by the change in x2 it will be changed affected by the 04:40:34.880 |
change in x3 up to xn and each of the contribution of the contribution of x1 will be the partial 04:40:42.080 |
derivative of y with respect to x1 multiplied by how much x1 has changed plus the contribution of 04:40:49.520 |
x2 will be the partial derivative of y with respect to x2 multiplied by how much x2 has 04:40:55.280 |
changed blah blah blah until the last contribution of xn so and the chain rule in this case also 04:41:02.320 |
applies in the same way as in the scalar case so the formula does not change also for the chain 04:41:06.960 |
rule here i just want you to to remember to remind you that in this case we are talking about a 04:41:13.280 |
gradient and the gradient is just a vector made up of all the partial derivatives of the output 04:41:20.160 |
with respect each of the input variables that are in the input vector when we talk about a function 04:41:27.760 |
that have as input a vector and produces a vector then we don't talk about gradient anymore we talk 04:41:34.000 |
about jacobians so if our input x the input x of this function changes by a little amount and this 04:41:41.840 |
delta x is a vector then the output y will also change and this output y will change by a delta y 04:41:49.520 |
that is not a number anymore it is a vector and this vector is the result of this quantity dy on 04:41:56.400 |
the x multiplied by delta x delta x is a vector so this one has to be a vector it has this one 04:42:04.560 |
here has to be a matrix and this matrix is called the jacobian it is a matrix that has as many rows 04:42:11.120 |
later we will talk about the denotations so it has as many rows as there are output variables 04:42:16.880 |
and as many columns as there are input variables the first row will be the partial derivative of 04:42:24.080 |
the first output variable with respect to all the input variables the second row will be the 04:42:29.440 |
partial derivative of the second output variable with respect to all the input variables and the 04:42:35.200 |
last row will be the partial derivatives of the last output variable with respect to all the input 04:42:41.200 |
variable in the input vector now let's talk about the notations the jacobian that i have written 04:42:48.640 |
here is a is written according to the numerator layout this is called the numerator layout 04:42:56.240 |
and there is another convention called the not layout sorry guys it's called the numerator 04:43:02.000 |
convention and there is another convention called denominator convention or notation 04:43:06.480 |
in which the rows are not the the number of rows is not the equivalent to the number of 04:43:15.440 |
output variables but equal to the number of input variables so the fact that i have we we choose to 04:43:22.640 |
write the jacobian as follows is based on a convention you can also write the the jacobian 04:43:29.840 |
according to the denominator convention just by transposing this jacobian here and also the formula 04:43:35.200 |
for the chain rule changes accordingly for now i want to keep the formula for the chain rule just 04:43:40.480 |
like the one for the scalar case so that's why i am using this notation here but later we can change 04:43:46.000 |
between one notation to the other just by doing a transposition okay now that we have reviewed what 04:43:52.000 |
is derivative what is a gradient and what is a jacobian let's talk about what happens when we 04:43:59.360 |
take derivatives with respect to tensors of a tensor with respect to another tensor in this case 04:44:04.400 |
we talk about the jacobian but it's called the generalized jacobian so if we have the function 04:44:10.400 |
that is at input takes a tensor of dx dimensions where the first shape this is kind of the shape 04:44:19.120 |
of the tensor so the first element of the shape is n1 the second element of the shape of the input 04:44:24.320 |
vector is n2 etc etc until n dx and it produces an output tensor that has this shape so m1 m2 04:44:34.000 |
mdy in this case the formula for the chain rule doesn't change and if x changes by a little amount 04:44:44.000 |
so by delta x which is a tensor y will also be affected by how much by dy on dx multiplied by 04:44:53.840 |
delta x and this is a tensor product it will be a jacobian this is called generalized jacobian 04:45:01.520 |
with the following shape so all the dimensions of the output multiplied by all the dimensions of 04:45:07.280 |
the input all right this is very abstract for now we will see actually a concrete case of this one 04:45:14.320 |
because we will be deriving the gradient of the output of a matrix multiplication 04:45:20.240 |
the gradient of the loss when computing backward pass with respect to each of the input of the 04:45:26.800 |
matrix multiplication operation and we will do it also for the softmax and we will do it also for 04:45:30.800 |
the attention so i don't want to jump to too many topics i just wanted us to get into the right 04:45:35.440 |
mindset so we know that derivatives when we have scalar functions gradients when the output is a 04:45:40.880 |
scalar input is a vector jacobian when the input and output are both vectors generalized jacobian 04:45:46.640 |
when the input and output are tensors the chain rule always works in the same way all right let's 04:45:54.560 |
talk about autogradient i will do the scalar case and then we will extend it to the tensor case 04:46:01.040 |
so imagine we have a very simple computation graph why we have computation graph because 04:46:05.120 |
we are talking about neural networks and neural networks are nothing more than computation graphs 04:46:09.680 |
where we have some input we have some parameters and we do some operations with this input and 04:46:13.440 |
parameters suppose that you have an input a and this input a is multiplied by a weight 04:46:19.440 |
it's a parameter weight it's just a scalar and it produces an output y1 this y1 is then summed up 04:46:26.800 |
with another number called b1 and it produces y2 this y2 is then raised to the power of 2 so this 04:46:33.040 |
is e to the power of 2 it's just the power of 2 of the input and it produces y3 and this y3 becomes 04:46:39.520 |
our loss function so it's a scalar now what we want to do to apply gradient descent is we want 04:46:47.040 |
to compute the gradient of the loss function with respect to each of the input of this computation 04:46:53.920 |
graph so each of the leaves of this computation graphs what are the leaves it's this node here 04:46:58.800 |
so the parameter nodes and input nodes and to do that there are two ways one is if you have access 04:47:07.440 |
to the expression that relates directly the input to the output so the to the loss then you can 04:47:15.760 |
directly compute the gradient the derivative in this case because it's not a gradient it's a 04:47:20.640 |
scalar versus color so in this case imagine you want to compute the derivative of the loss with 04:47:25.760 |
respect to w1 imagine we have access to the exact expression that relates the w1 to to the phi which 04:47:35.200 |
is our loss we can compute it as follows so we just derive this expression with respect to w1 04:47:40.960 |
which is two times because this is the power of two of a function so it is two multiplied by the 04:47:46.880 |
function multiplied by the derivative of the content of this function with respect to the 04:47:51.840 |
variable that we are deriving so it will become the following expression there is another way 04:47:57.600 |
which is by using the chain rule so we can use the derivative of phi with respect to yw1 is the 04:48:04.800 |
derivative of phi with respect to y3 which is the previous output of the previous node then the 04:48:11.760 |
derivative of phi3 with respect to the previous the output of the previous node so and then the 04:48:17.600 |
multiplied by the derivative of y2 with respect to the output of the previous node and then the 04:48:22.720 |
derivative of y1 with respect to w1 if we do all this chain of multiplication we will obtain the 04:48:29.040 |
same result and you can see that here this stuff here is exactly equal to this stuff here by doing 04:48:35.920 |
this procedure here we will note something that is i want to zoom out a little bit okay to compute 04:48:44.240 |
the derivative of phi with respect to w1 we are doing all this chain of multiplication but what 04:48:53.120 |
is each item in what is each factor in this sequence of multiplications well this stuff 04:49:01.120 |
here is nothing more than the derivative of phi with respect to y2 these multiplications here are 04:49:08.080 |
nothing more than the derivative of phi with respect to w to respect to y1 and all of them 04:49:14.960 |
combined are the derivative of phi with respect to w1 what pytorch will do it will do the following 04:49:22.560 |
pytorch will do the backward pass because pytorch knows what is the computation graph that relates 04:49:30.000 |
the output so the loss function in this case and the variable for which we want to compute the 04:49:36.080 |
gradient right now we are talking about derivatives so it's not gradient but the mechanism is exactly 04:49:42.080 |
the same so pytorch will say it will pytorch is like a person that knocks the door of this 04:49:52.400 |
operation and says hey operation exponential power of two if i give you the gradient of the 04:50:03.360 |
loss with respect to y3 which is one because the loss and y3 are actually the same can you give me 04:50:09.840 |
the gradient of the loss with respect to y2 because pytorch actually does not implement 04:50:15.440 |
an autograd system in the sense that it does not know the symbolic operations that led to the 04:50:21.280 |
output it just knows what are the functions that computed the output and each function has a 04:50:27.760 |
function each function is a class in python that implements two methods one is the forward step 04:50:33.760 |
and one is the backward step the forward step takes the input so in this case y2 and computes 04:50:38.560 |
the output y3 the backward step will take the gradient of the loss with respect to its output 04:50:45.920 |
and needs to compute the gradient of the loss with respect to its input how can we do that well 04:50:51.680 |
it's very simple because a pytorch will knock the door as let me copy it and this stuff here 04:50:58.560 |
otherwise it's not easy to go back and forth so okay and let's paste it here pytorch will knock 04:51:06.320 |
the door of this function here and we'll say hey if i give you the loss of the gradient of the loss 04:51:14.480 |
function with respect to your output can you give me the gradient of the loss function with respect 04:51:19.200 |
to your input yes the function can do it why because of the chain rule this operator here 04:51:25.120 |
this function here can just do take the loss the gradient of the loss function with respect to its 04:51:30.480 |
output multiply it by the jacobian or in this case the derivative of its output with respect to its 04:51:38.640 |
input and it will be equal to the gradient of the loss with respect to its input then pytorch will 04:51:44.080 |
take this one and knock the door at the next operator which is this one this summation 04:51:48.400 |
and we'll say hey if i give you the gradient of the loss with respect to your output can you give 04:51:54.720 |
me the gradient of the loss with respect to your input yes this operator can do it because this 04:51:59.920 |
operator just needs to apply the chain rule so it will take the gradient of the loss with respect 04:52:03.920 |
to y2 which is provided by pytorch and by multiplying it with the the jacobian in this 04:52:11.600 |
case it's the derivative the derivative of the its output with respect to its input it can compute 04:52:16.480 |
the the gradient of the loss with respect to its input then pytorch will take this output of this 04:52:23.360 |
backward pass and will knock the door of the next operator which is this product and will ask again 04:52:28.800 |
the same question hey if i give you the gradient of the loss with respect to your output can you 04:52:34.080 |
give me the gradient of the loss with respect to your input yes this will do the same exact job it 04:52:39.200 |
will take the gradient of the loss with respect to the output multiplied by the jacobian of the 04:52:44.400 |
output with respect to the input and obtain the gradient of the loss with respect to the 04:52:48.640 |
input and this is how pytorch runs the backward step it runs one operator at a time backwards 04:52:56.800 |
in the computation graph knocking the door of each operator and asking always the same question if i 04:53:03.040 |
give you the output the gradient of the loss with respect to your output can you give me the gradient 04:53:07.680 |
of the loss with respect to your input and each operator will just apply the chain rule to 04:53:12.240 |
to to get this to get this gradient to calculate this gradient that pytorch needs 04:53:17.520 |
why pytorch cannot do it by itself because pytorch does not do symbolic mathematics it does not have 04:53:25.200 |
access to the exact expression that each function is computing it just uses the function as a black 04:53:30.800 |
box that computes forward and backward however with the jacobian we have a problem and let's 04:53:37.280 |
see what is the problem all right so up to now we have been working with a computation graph that is 04:53:43.680 |
made up of scalars but the things that we have said they work in the scalar case but also in 04:53:48.640 |
the tensor case so let's go back see what is our computation graph we have seen that pytorch will 04:53:55.600 |
go operator by operator asking always the same question if i give you the gradient of the loss 04:54:00.400 |
with respect to your output can you compute me the gradient of the loss with respect to your input 04:54:05.120 |
and each operator can just apply the chain rule to compute that imagine now that all of these 04:54:11.840 |
operators are working not with scalars but are working with tensors which means that the derivative 04:54:17.600 |
of the output with respect to the input of each operator is not a derivative it will be a jacobian 04:54:23.600 |
because the output will be a tensor a generalized jacobian and input will be a tensor which means 04:54:30.000 |
also that this quantity here so the derivative of the loss with respect to the input in this case 04:54:35.120 |
will not be a derivative it will be a gradient because the output the loss is a number always 04:54:39.920 |
while the input in this case y1 will be a tensor so number output input is a tensor then we talk 04:54:49.120 |
about gradients so this will be a gradient and we will call it the downstream gradient that the 04:54:54.880 |
operator needs to compute this will be the upstream gradient that pytorch will give to the 04:54:59.920 |
each of these operators so the gradient of the loss with respect to the output of each operator 04:55:05.520 |
and each operator needs to come up with this downstream gradient by using the jacobian 04:55:11.760 |
however the jacobian has a problem let's see so imagine we are implementing a simple operation 04:55:18.640 |
that is the matrix multiplication and the matrix multiplication is takes as input a x tensor 04:55:26.000 |
it multiplies it by a w matrix made up of parameters and produces a y matrix as output 04:55:31.920 |
suppose that x is let's call it n by d matrix w is let's say d by m matrix and so y will be a 04:55:45.920 |
n by m matrix usually the input x is a sequence of tensor of let's say vectors each of each with 04:55:58.880 |
d dimensions so you can think of it as a sequence of tokens each token is a vector made up of d 04:56:04.880 |
dimensions usually we have many tokens so suppose that n usually is at least 1024 at least in the 04:56:13.440 |
most recent language models we even have millions of tokens actually so and d is also actually quite 04:56:20.880 |
big it usually it is at least 1024 also so also this one is 1024 and d and m m is also at least 04:56:31.600 |
1024 so we can actually become 2028 let's say so i i like the powers of two by the way so the problem 04:56:39.920 |
of the jacobian is this if we compute want to compute this downstream gradient by multiplying 04:56:45.040 |
the upstream gradient with the jacobian this jacobian matrix is huge because look at the 04:56:51.520 |
dimensions here this will be a matrix that is it will be well n by m multiplied so it will be a 04:57:02.400 |
generalized jacobian so it will be a tensor that has a shape n m and then the input is x so it is 04:57:10.800 |
n by d so how many elements it will have well it will have 1024 multiplied by m which is 2048 04:57:20.320 |
multiplied by 1024 multiplied by d which is 1024 so it is at least wow it's a billions more than 04:57:31.200 |
1 billion elements so it is impossible actually to materialize this matrix here in the memory 04:57:39.440 |
because in the ram of the gpu because it will be too big so but we need to compute this downstream 04:57:46.640 |
gradient because pytorch needs it to continue calculating the gradient of the loss function 04:57:51.760 |
with respect to each of the nodes in the computation graph so how can we proceed 04:57:56.640 |
the first thing that we should notice is that this this jacobian is actually a sparse matrix 04:58:04.080 |
and i want to show you why it is actually is a super super super sparse matrix because if you 04:58:10.240 |
look at the input what is the effect of the input on the output the input is a sequence of tokens 04:58:17.120 |
so this is the token number one it's a vector of some dimensions 1024 dimension then we have 04:58:24.320 |
another token as input then we have another tokens as input then we have another tokens 04:58:29.280 |
as input and we multiply by the w matrix which is made up of some columns some columns so this 04:58:37.680 |
one is n by d right yes and w is d by m so d by m this will produce a matrix that is n by m 04:58:50.080 |
so it will be also a sequence of tokens each made up of m dimensions so it will be a matrix like 04:58:56.480 |
this so this will be the first output token this will be the second output token this will be the 04:59:02.560 |
third output token and this will be the fourth output token now this output row here is the dot 04:59:11.200 |
product of this input row with all the columns so the derivative of each of these dimensions 04:59:19.280 |
with respect to the dimensions of all the other tokens will be zero because they do not contribute 04:59:24.320 |
to this output so the jacobian will have zeros every time the we are calculating the derivative 04:59:32.960 |
of this first dimension with respect to any other element of other tokens that's why we always can 04:59:40.960 |
come up with a better formula for computing this downstream gradient that does not involve the 04:59:46.000 |
materialization of the jacobian because the matter the jacobian itself is sparse so let's see how we 04:59:51.920 |
can optimize this computation without materializing the jacobian in the case of matrix multiplication 04:59:57.280 |
because we need it for flash attention all right guys so before proceeding to the backward watch 05:00:04.800 |
the formulas of the backward path of the flash attention let's look at how to compute the gradient 05:00:10.640 |
of the matrix multiplication operation with respect to its input so imagine we are creating 05:00:15.920 |
okay pytorch already have actually how to compute the gradient of the inputs of the 05:00:23.040 |
matrix multiplication with the gradient of the loss with respect to the input of the matrix 05:00:26.720 |
multiplication operation but in flash attention we are creating a custom kernel which means that 05:00:32.480 |
the custom kernel is fusing multiple operations into one operation so when pytorch will knock 05:00:39.360 |
the door of our operator it will ask the our operator which is the triton attention operator 05:00:45.280 |
that we have built what is the gradient of the loss function with respect to q k and v because 05:00:50.080 |
that's the input of our function so if we look at the code that we have built so far you can see 05:00:55.280 |
that our triton attention will be a node in the computation graph that takes as input q k and v 05:01:02.560 |
and produces an output then pytorch will give us the gradient of the loss with respect to that 05:01:09.040 |
output so it will give us a d o so the derivative of the loss with the gradient of the loss with 05:01:14.160 |
respect to o and then we'll ask this class here so triton attention to compute the gradient 05:01:21.440 |
of the loss with respect to q k and b because we are fusing multiple operations together so we are 05:01:26.720 |
computing on the fly the softmax of query multiply by the transpose of the key and then multiplying 05:01:32.000 |
doing the softmax and multiplying it by v to compute the output we need to compute this 05:01:38.640 |
gradient internally to compute this the gradient of the inputs so because in this operation that 05:01:45.280 |
we are doing fusing together there is a matrix multiplication we need to derive by hand the 05:01:50.240 |
matrix multiplication the gradient of the of the loss function with respect to the input of the 05:01:56.480 |
matrix multiplication operation so that we can provide it to pytorch that's why we need to derive 05:02:03.760 |
this formula i will derive it in the simple in a very simple way and and then we will do it for 05:02:12.000 |
the softmax as well because these are the two things that we need to derive by hand to derive 05:02:15.840 |
the formula of the flash attention's backward pass so let's start imagine we have a computation 05:02:23.680 |
graph a node in the computation graph called the matrix multiplication and this node in the 05:02:29.200 |
computation graph is doing a matrix multiplication so it is computing the following operation 05:02:33.360 |
y is equal to x multiplied by w now what pytorch will give us as input when computing the backward 05:02:43.600 |
pass of this node pytorch will give us the gradient of the loss so it will give us d phi 05:02:49.840 |
with respect to dy so the output of this node and will ask us to compute the gradient of the 05:02:58.400 |
loss function so the gradient of the loss function with respect to dx and the gradient of the loss 05:03:04.400 |
function with respect to dw the easiest one to work with and the one that i will be showing and 05:03:10.560 |
the other one i will not show in the video but i will attach the pdf slide on how it is computed 05:03:14.960 |
because they are very similar in the way they are computed so i don't want to make the video 05:03:19.440 |
too long for unnecessary reasons let's compute the gradient of the loss function with respect 05:03:28.240 |
to the input so with respect to x all right so how to do that by hand without materializing the 05:03:37.120 |
jacobian because as we have seen we cannot just use the chain rule by materializing the jacobian 05:03:42.160 |
which would be the easiest way because the jacobian is very big matrix that cannot even 05:03:46.960 |
fit in the memory of the gpu so we need to find a smarter way we exploit the fact that the jacobian 05:03:53.120 |
is sparse so hopefully we will get formula that does not involve the materialization of a very 05:03:58.560 |
big sparse jacobian let's see so uh let's see um let's when dealing with these kind of derivations 05:04:07.360 |
i always recommend to make some example tensors so suppose that that x is a tensor of size let's say 05:04:15.920 |
n by d and where n let's say n is equal to one and d is equal to let's say three and 05:04:25.200 |
w is a tensor also or a matrix with the shape let's say d by m 05:04:36.800 |
where m is equal to let's say four and y will have as a consequence the shape n by m 05:04:47.440 |
so it will have the shape well one by four what pytorch will give us and pytorch will give us 05:04:58.320 |
the following quantity so it will give us this stuff here so the gradient of the loss function 05:05:03.840 |
with respect to the output of this operator which is y so it will give us a vector or a tensor 05:05:10.240 |
actually with the following dimension which is n by m and we need to compute the gradient of the 05:05:19.280 |
loss function with respect to x which should be a tensor of shape n by d because when dealing 05:05:26.240 |
with the gradient it always has the shape of the input variable because it's the output which is a 05:05:33.200 |
scalar with respect to each element in the input so it has the same shape as the denominator 05:05:38.160 |
all right so when dealing with this kind of problems i always recommend to create example 05:05:44.160 |
matrices and then work out what happens to the output and then try to work out the 05:05:49.280 |
the gradient matrix so let's do it so let's see that what is how is the output computed well 05:05:57.440 |
the output will be a matrix that is a one by four computed as follows it will be the input 05:06:05.920 |
so one by three so let's call the input x one one x one two x one three it will be multiplied 05:06:16.160 |
by another matrix w that it has dimension three by four so it will be three rows by four columns 05:06:25.280 |
so it will be w 1 1 w 1 2 w 1 3 w 1 4 then w 2 1 w 2 2 w 2 3 w 2 4 w 3 1 w 3 2 w 3 3 w 3 4 05:06:48.560 |
if we do this matrix multiplication it will be well it will produce the following matrix 05:06:55.040 |
that is okay this is one row by three columns this is three column three rows by four columns 05:07:01.040 |
so the output will be a matrix that is one by four so one row by four columns so it will be 05:07:08.560 |
let me write it with a smaller because otherwise it will never fit here so 05:07:15.520 |
let's do it like this it will be x 1 1 multiplied by w 1 1 plus x 1 2 multiplied by w 2 1 plus x 05:07:29.280 |
1 3 multiplied by w 3 1 and this will be the first element of the output the second element 05:07:37.120 |
of the output will be x 1 1 with w 1 2 x 1 1 with w 1 2 plus x 1 2 with 1 2 with w 2 2 05:07:53.360 |
plus x 1 3 with w 3 2 this will be the second element of the output matrix the third element 05:08:04.800 |
of the output matrix will be let me move this stuff on the left otherwise it will never fit 05:08:08.720 |
so okay i think now it can fit this will be x i need to watch this one so x 1 1 with w 1 3 x 1 05:08:20.720 |
x 1 1 with w 1 3 plus x 1 2 with w 2 3 plus x 1 3 with w 3 3 and then we multiply the same 05:08:35.440 |
row with the last column so it will be x 1 1 w 1 4 plus x 1 2 w 2 4 plus x 1 3 w 3 4 05:08:49.760 |
this will be the output y if we do the matrix multiplication what pytorch will give us 05:08:56.880 |
it will give us the gradient of the loss so it will give us delta phi with respect to delta y 05:09:04.000 |
because it's a gradient it has the same shape as the denominator so it has a shape that is 1 by 4 05:09:10.640 |
let's call it because we don't know what this value will be they will be provided to us by 05:09:15.600 |
pytorch let's just give them generic name like d y 1 1 d y 1 2 d y 1 3 and d y 1 4 like this 05:09:28.880 |
now to compute the the downstream gradient that we need to provide to pytorch we should 05:09:37.360 |
be computing the we should be materializing the jacobian which is which is okay let's write the 05:09:46.560 |
chain the chain rule formula so we need to provide delta phi to with respect to delta x which is 05:09:54.880 |
equal to delta phi with respect to delta y this is provided by pytorch multiplied by the jacobian 05:10:02.800 |
which is delta y with respect to delta x now instead of materializing this jacobian let's try 05:10:10.320 |
to do this let's materialize it now and let's do the multiplication of these two quantities 05:10:16.640 |
to see if something simplifies so this stuff here will be dy with respect to dx which means the 05:10:24.400 |
derivative of every output y with respect to every input x how many output we have we have 05:10:32.080 |
four elements as the output which is this stuff here and we have three element as input in the x 05:10:39.280 |
matrix so it will be as follows i don't know how to let me copy it because my screen is 05:10:47.600 |
not big enough and i remember that x is x 1 1 and x x 2 so delta y with respect to delta x 05:10:58.400 |
will have the following entries so the y1 with respect to x 1 1 and as you can see y1 only has 05:11:09.760 |
one x 1 1 appearing as multiplied by w 1 1 so the derivative with respect to x 1 1 will be w 1 1 05:11:16.560 |
then y 1 1 so this stuff with respect to x 1 2 it will be w 2 1 05:11:27.360 |
then x y 1 1 with respect to x 1 3 will be w 3 1 the second row of this matrix will be the 05:11:37.920 |
derivative of the partial derivative of the second output so w y 2 with respect to all the x 05:11:45.680 |
inputs which will be the derivative partial derivatives of this stuff here with respect to 05:11:51.520 |
every x which is w 1 2 w 2 2 i guess and w 3 2 now let me check if it's what i'm doing is correct 05:12:05.200 |
yes because i've already done it so i can always double check and then we have w the partial 05:12:13.840 |
derivatives of this stuff here with respect to all the x which is w 1 3 w 2 3 and w 3 3 05:12:25.280 |
then the partial derivatives of the last output so y 4 with respect to all the x which will be 05:12:33.600 |
w 1 w 1 4 w 2 4 and w 3 4 we obtain the following jacobian if um but this jacobian as you can see 05:12:50.160 |
is just equal to w transposed so we don't need to materialize the jacobian we can just do the 05:12:57.520 |
multiplication of whatever gradient pytorch is giving us multiply it by w transposed and 05:13:05.120 |
we will get the downstream gradient so let me rewrite so we know have we know what we are 05:13:10.720 |
doing so d phi on d dx is equal to d phi with respect to y multiplied by dy on dx but we have 05:13:23.760 |
seen that dy on dx is just equal to w transposed so this is equal to d phi on dx dy multiplied by 05:13:33.840 |
w transposed and this gives us the downstream gradient so in order to provide the downstream 05:13:38.960 |
gradient that pytorch need we just need to take whatever gradient pytorch will give us multiplied 05:13:43.280 |
by w transposed and it will give us the gradient of the loss function with respect to the input 05:13:48.880 |
x of the matrix multiplication in the same way we can also write the formula for the gradient 05:13:55.520 |
of the loss function with respect to w and it is equal to x transposed multiplied by d phi 05:14:03.360 |
with respect to dw dy how to remember these formulas these are there is a mnemonic rule 05:14:14.720 |
which is these are the only possible ways for this to have the shape of x and this to have the 05:14:24.480 |
shape of w because this one's this stuff here will have the same shape of y so it will be 05:14:30.880 |
n by m this stuff here will have shape of w transposed w is d by m so w transpose should be 05:14:42.400 |
m by d and the resulting operation of this matrix multiplication or tensor 05:14:48.960 |
multiplication will be n by d which is exactly the same shape as x 05:14:56.320 |
in this case we will have that xt is the transposed of t and it is n by d so it's d by n 05:15:09.360 |
multiplied by d phi with respect to dy which is a gradient so it has the same shape as the 05:15:15.040 |
denominator so it has n by m and the output will have shape d by m which is exactly the 05:15:29.520 |
the shape of w so if you if to remember them this is the only way this shape work out 05:15:37.280 |
otherwise they don't work out so this is a mnemonic formula on how to remember how to 05:15:40.880 |
compute the gradient of the inputs of a matrix multiplication given the gradient of the loss 05:15:46.160 |
with respect to the output of the matrix multiplication and the inputs to the 05:15:49.600 |
metric multiplication are the input matrix and the parameter matrix w now we need to derive the 05:15:55.600 |
gradient of the output of the softmax with respect to the input of the softmax because that's another 05:16:01.360 |
operation that we do in our fused attention because we are fusing many operations together 05:16:05.280 |
which is matrix multiplication and the softmax so this is the second ingredient that we need 05:16:09.680 |
to understand the backward pass of flash attention so let's do it i will use to make 05:16:15.440 |
this derivation i will use the same notation as in the flash attention paper so first of all 05:16:21.040 |
let's write the title of this stuff which is the gradient through the softmax 05:16:33.840 |
the first operation that we do in during computation of the attention is we compute 05:16:44.400 |
the product of the query multiplied by the transpose of the keys we do in a blockwise 05:16:48.320 |
way it means that we do it block by block but it doesn't matter because the end result is the same 05:16:52.800 |
so we can also we can write s equal to q multiplied by the transpose of the keys 05:16:58.240 |
and then we apply the softmax to this operation to the result of this operation and we call this 05:17:05.040 |
output p which is the softmax of s and after the uh we have applied the softmax we take the output 05:17:15.440 |
of the softmax we multiply it by v to obtain the output so the output is equal to p multiplied by v 05:17:21.520 |
now we need to understand how to because as i said before pytorch autograd works in the 05:17:31.280 |
following way pytorch will treat our attention computation as a black box so we will have a 05:17:37.040 |
computation graph like the following we will have a query input a key input and a value input which 05:17:44.560 |
are sequences of tokens each one with some embedding dimension these are fed to some 05:17:50.640 |
black box called the attention which is our implementation of the attention which is the 05:17:56.800 |
function that we started coding before this will be fed as input to this node in the computation 05:18:03.200 |
graph and the computation graph will output a an output tensor o what pytorch will give us 05:18:10.480 |
pytorch will give us the gradient of the loss with respect to the output so as you remember 05:18:17.680 |
pytorch knocks the door knocks the door at each operator and says if i give you the gradient 05:18:23.920 |
of the loss with respect to your output can you give me the gradient of the loss with respect to 05:18:28.560 |
your inputs and this is what we need to figure out so given the gradient of the loss with respect 05:18:34.400 |
to the output we need to understand how to compute the gradient of the loss with respect to the 05:18:40.400 |
wq the gradient of the loss with respect to wk the gradient of the loss with respect to wb however 05:18:48.480 |
there is no direct connection between q and o or k and o because there are two intermediate 05:18:54.560 |
operations so one there is a first a matrix multiplication then there is a softmax then 05:18:59.120 |
there is an additional matrix multiplication however we have tools that allow us to understand 05:19:04.080 |
how the gradient propagates through multiple operations when they are applied in sequence 05:19:08.880 |
and that's called the chain rule however we have seen that applying the chain rule in its 05:19:14.000 |
naive way by materializing the jacobian is infeasible so we need to understand 05:19:19.120 |
how to apply the chain rule without materializing the jacobian and that's what we are going to 05:19:24.480 |
figure out for one of the operations inside of this attention computation which is the softmax 05:19:29.840 |
and that's why we are going to do this derivation which i promise is the last one that we will do 05:19:34.320 |
and then we will finally go to code the backward pass of flash attention we cannot proceed directly 05:19:40.000 |
to coding the backward pass of the flash attention because if we look at the formulas on how it is 05:19:43.680 |
computed we will not understand how the the derivation comes out okay now we can start 05:19:51.440 |
so let me delete this stuff delete and imagine for simplicity now we apply the softmax to a 05:20:01.680 |
row wise to this s matrix so each row is softmaxed independently from the others 05:20:08.640 |
so let's see what happens to one single row of this matrix and for simplicity i will call it s 05:20:17.040 |
so s is a single row of the s matrix i could also call it s of i but if i do it like this we will 05:20:26.160 |
have to carry over the index okay guys just just do it we will carry over the index all right so 05:20:33.280 |
let's call si one row of the s matrix so si is equal to let's say it's the in tensor notation 05:20:40.720 |
pytorch tensor notation it will be like this so from the matrix s from the tensor s we take the 05:20:47.840 |
ith row and all the columns this is the definition of si i know it's very ugly notation but it helps 05:20:53.440 |
you understand and this is a vector of size and dimensions we apply the softmax to this 05:21:02.000 |
vector and we will obtain an output vector and we call it pi pi is equal to the softmax softmax 05:21:12.800 |
of si so as we have seen the softmax operation does not change the shape of the input it just 05:21:20.240 |
changed element wise each number so the output will also be a vector of size r to the power of n 05:21:28.240 |
now what is the softmax so the softmax is defined as follows the softmax 05:21:38.480 |
of well p i j so the jth element of the p ith vector is equal to the exponential of the jth 05:21:51.760 |
element of the s ith vector divided by a normalization factor that is computed as follows 05:22:01.600 |
with let's say not j let's use k in this case not even k let's use l 05:22:08.160 |
is equal to one up to n of e to the power of s i l all right so first of all you may be wondering 05:22:21.840 |
the softmax that we are that we apply during the forward pass of the computation of the attention 05:22:28.080 |
is not really this softmax because in if you remember what we applied before we were applying 05:22:33.520 |
the softmax where each of the argument of the exponential is reduced by the maximum element 05:22:39.760 |
in the vector to which we apply the softmax so it was more or less like this so s i j minus s i max 05:22:48.240 |
so the maximum element in the s i j s i vector and also the argument of the denominator was 05:22:55.760 |
reduced by s i max however we also proved that this stuff here is equivalent to the standard 05:23:05.680 |
softmax without this reduction in the argument because this reduction in the argument is only 05:23:11.760 |
added because we want to make it numerically safe to compute but there is it's equivalent 05:23:18.080 |
to do it without but from a mathematical point of view on the computer of course it will become 05:23:23.920 |
numerically unstable but from a mathematical point of view it is the same thing which also 05:23:29.600 |
means that doesn't doesn't matter how you compute the forward pass if it's equivalent to another 05:23:36.320 |
mathematical definition you can always use the other mathematical definition to compute the 05:23:40.160 |
backward pass it will result in the same value if you didn't understand what i said let me give you 05:23:46.080 |
a more simple example which is imagine you have a do you remember the formula from high school 05:23:54.320 |
this one so cosine cosine of squared of x plus sine squared of x is equal to one now imagine 05:24:04.480 |
we compute an output y is equal to cosine squared of x and then we need to compute the derivative 05:24:14.000 |
of y with respect to x it doesn't matter if you compute it as the derivative of cosine squared 05:24:24.880 |
of x with respect to x or if you compute it as the derivative of one minus sine squared of x 05:24:37.440 |
with respect to x because they will result in exactly the same result because the two definitions 05:24:44.320 |
are equivalent and this is why we don't need to add this this factor in the exponential 05:24:49.360 |
because the two definitions are equivalent mathematically we just use the numerically 05:24:54.560 |
safe one because when computed on the on the computer we need something that is numerically 05:25:00.080 |
stable that will not overflow all right now what do we want to obtain 05:25:07.280 |
so we want to obtain the gradient of the loss with respect to the input vector of the softmax 05:25:16.080 |
which is the s_i vector given the gradient of the loss with respect to the output of the softmax 05:25:24.160 |
which is the p_i vector and we can obtain that with the chain rule multiply that by the jacobian 05:25:33.200 |
p_i with respect to s_i now we the chain rule is always valid let's see what does this jacobian 05:25:47.520 |
look like all right so this jacobian will be the p_i with respect to delta s_i well we need to do it 05:26:00.800 |
let's look at what each element in this jacobian will look like so the jth element with respect to 05:26:08.240 |
the let's say the kth element so we are um we are computing the the we are looking at what each 05:26:21.840 |
element in this jacobian will look like which is what is the jacobian it's each element in the out 05:26:28.320 |
in the numerator of the jacobian derived with respect to each element in the denominator of the 05:26:33.840 |
jacobian in this fraction here so we are saying for each element in the output vector derived 05:26:42.480 |
with respect to each element in the input vector this is what we are writing here so what is how 05:26:48.880 |
is the output vector obtained well p_ij we know that it is equal to by the definition of the 05:26:55.040 |
softmax is obtained as follows so e to the power of s_ij divided by the normalization factor let's 05:27:07.120 |
call it l is equal to one to n e to the power of s_il all derived with respect to s_ik 05:27:24.160 |
i k so what we are trying to do is we know that the p vector is suppose it's a vector with the 05:27:32.400 |
three elements so this is p_1 this is well p_11 p_12 and p_13 the s vector will be a vector 05:27:45.440 |
also with the three elements so it will be the s_11 s_12 and s_13 what we are trying to do is 05:27:54.400 |
the calculate what the jacobian will be the derivative of this one with respect to all 05:27:58.480 |
the input vector then then the second row of the jacobian will be the derivative of this one with 05:28:03.600 |
respect to each of this input element then the third row of the jacobian will be this stuff here 05:28:09.200 |
with respect to the derived with respect to each of the input element of the s vector we are trying 05:28:14.720 |
to understand what does the generic element in this jacobian look like based on the j date element 05:28:20.800 |
of the output vector so this j index refers to the output vector and the kth element in the input 05:28:26.000 |
vector all right so what can happen when we do this jacobian is that we have a this one here is the 05:28:37.200 |
derivative of a fraction of two functions and we know from high school that the derivative of the 05:28:44.400 |
fraction of two functions is as follows so the derivative of the derivative let me write like 05:28:52.160 |
this of f of x with respect to g of x prime is equal to with respect to x by the way is equal to 05:29:04.240 |
f prime oops of x multiplied by g of x minus g prime of x f of x 05:29:18.880 |
all divided by the g of x to the power of two like this now let's apply it here so this will 05:29:28.160 |
become here we will have two cases either the variable that we are deriving with respect to 05:29:33.840 |
so this sik has the same index as the variable being derived so either we are doing a p11 05:29:41.840 |
with respect to s11 or we are doing a p11 with respect to something else that has not the same 05:29:47.440 |
index so like p11 with respect to s12 or s13 so there are two cases that we need to consider 05:29:53.280 |
suppose that we are deriving p11 with respect to s11 or we are deriving p12 with respect to s12 05:29:59.360 |
or we are deriving p13 with respect to s13 so we are deriving the element of the output with 05:30:05.360 |
respect to the same element in the input with the same index so in this case the this this 05:30:14.880 |
derivative will look like the following so it's the derivative of f so the numerator 05:30:21.920 |
with respect to the denominator that has the same index so we are saying that in this case 05:30:27.360 |
j is equal to k so the numerator with respect to sij with respect to e to the power of sij 05:30:40.720 |
with respect to sij will be e to the power of sij so because e to the power of x1 with respect to 05:30:48.960 |
x1 will be e to the power of x1 so this is equal to i am reducing the size now 05:30:54.480 |
e to the power of sij then we need to multiply that by the denominator of the fraction 05:31:05.360 |
which is this summation here so the summation over all possible l of e to the power of sil 05:31:14.320 |
minus the derivative of the denominator with respect to the variable being derived so this 05:31:24.000 |
denominator is the sum of all the exponentials of all the input elements if we derive it with 05:31:30.400 |
respect to one particular input element there will be at least one term that contains that 05:31:35.040 |
input element and so the all the other terms will result in zero so the only 05:31:40.400 |
derivative that will survive will be the e to the power of sik with respect to sik 05:31:57.280 |
multiplied by the numerator which is e to the power of sij 05:32:00.480 |
all this divided by the denominator to the power of two which is this summation here so l equal to 05:32:12.080 |
one up to n e to the power of sil all to the power of two and this stuff here will be equal to well 05:32:22.400 |
we can see that they this two term this one and this one have a one term factor in common which 05:32:28.480 |
is e to the power of sij so we can collect that so e to the power of sij multiplied by the summation 05:32:51.040 |
all this divided by the denominator which is the power of two of this stuff here so let me just 05:32:57.440 |
copy and paste it which is let me rotate it also because i don't know why i always write little 05:33:02.000 |
little yeah all right and this stuff here is equal to well we can separate the two terms so 05:33:14.000 |
we can separate this term here and this term here because the denominator is to the power of two 05:33:20.560 |
so we can write it also as e to the power of sij divided by the denominator so which is 05:33:28.960 |
summation of l equal one to n e to the power of sil multiplied by this stuff here so this stuff 05:33:40.960 |
here divided by the same denominator so there's summation of l equal one up to n 05:33:49.520 |
e to the power of sil minus e to the power of sik i am sik divided by the same denominator 05:34:08.400 |
sil now this one can be written as this stuff here is nothing more than the output element pij 05:34:18.480 |
because this one is just the softmax applied to the sij element which we know that the output of 05:34:25.120 |
the softmax applied to the sij element is called pij because it's one element of the output vector 05:34:30.480 |
which we call the p so this stuff here is equal to pij multiplied by this stuff here will be equal 05:34:39.440 |
to one minus this stuff here what is this stuff here is the output of the softmax applied to the 05:34:46.560 |
sik element so it will be pik so it is equal to one minus pik okay and this is in the case the 05:34:59.360 |
variable with respect to which we derive has the same index as the numerator in this fraction here 05:35:09.520 |
in this derivative here the other case is when the two variables so the output the index of the 05:35:17.680 |
output with respect to the index of the input are not the same in this case we will have another 05:35:23.680 |
case so we will have that j let me write it again so this stuff here i hope i can copy it all without 05:35:35.840 |
in the other case in which s is not equal to j 05:35:39.120 |
uh yes it's j not equal to k so j is not equal to k what happens 05:35:47.840 |
in this case it will be well the derivative of the numerator because we need to apply again 05:35:55.760 |
this formula here so derivative of the numerator with respect to something that is not the same 05:36:00.640 |
variable it will be zero because it's like computing the derivative e to the power of x1 05:36:06.400 |
with respect to x2 it will be zero so it will be zero so all the first term here will become zero 05:36:14.000 |
no matter what is g of x minus the derivative of the denominator of this fraction here with respect 05:36:21.680 |
to the variable sik g prime of sik so this is all the variable in the input and we are deriving it 05:36:31.920 |
with respect to one particular variable of the input so only one item in the summation will 05:36:37.360 |
survive so it will be the item sik so it will be e to the power of sik multiplied by f of x which 05:36:50.160 |
is the numerator in this fraction which is e to the power oh we forgot a minus e to the power of sij 05:36:56.240 |
let me see if i forgot something all divided by the denominator 05:37:02.800 |
of this fraction here to the power of two so it is equal to the summation 05:37:11.040 |
l equal one up to n of e to the power of sil all to the power of two 05:37:18.160 |
i believe i didn't forget anything so let's continue so here also we can see that this 05:37:26.640 |
one here is because uh okay let's separate it minus e to the power of sik divided by 05:37:36.400 |
the summation l equal one up to n of e to the power of sil multiplied by 05:37:43.840 |
e to the power of sij divided by the summation l equal one up to n of e to the power of sil 05:37:56.160 |
this stuff here is nothing more than the softmax applied to the kth element of the si vector 05:38:02.880 |
this one here is nothing more than the softmax applied to the jth element of the si vector 05:38:08.480 |
so we know what these are we know that we call them p minus pik pij 05:38:17.920 |
so in the end we have two cases one is the derivative of this stuff here looks like the following 05:38:30.160 |
each item in the jacobian looks like the following when the numerator and the denominator have the 05:38:35.280 |
same index so j equal to k this stuff here is equal to now this notation here is wrong so i 05:38:43.760 |
shouldn't be writing it with the equal sign but doesn't matter guys it's we are doing a little 05:38:50.560 |
okay so pij pij multiplied by one minus pik let me check yes the other case is when the j 05:39:02.480 |
is not equal to k then this stuff here let me write it like this will be equal to 05:39:09.920 |
minus pik multiplied pij now that we know what the two typical cases of this jacobian look like 05:39:21.760 |
let's actually look at what this jacobian look like in the matrix form so this jacobian will 05:39:28.880 |
look like the following it will be a matrix that is more or less like the following 05:39:36.320 |
it will be an n by n matrix where n is the size of the input vector and the output vector 05:39:42.080 |
at here the first element of the jacobian as you saw as you remember the first row of the jacobian 05:39:50.640 |
in the numerator convention is the derivative of the first output with respect to all the input 05:39:58.160 |
so this first term here will be the derivative of p11 with respect to s11 so in this case j and k 05:40:08.080 |
match so we know that it will be equal to p11 multiplied by 1 minus p11 the second element 05:40:16.400 |
to the right of this one so the element one two will be the derivative of p12 with respect to 05:40:23.600 |
sorry the p11 with respect to s12 the j and k do not match so we will be in this case here so it 05:40:30.240 |
will be minus p11 p12 the third element you can check it by yourself it will be minus p11 p13 05:40:42.000 |
blah blah blah until the end which will be minus p11 p1n the second row of this jacobian will be 05:40:52.000 |
will look like this so it will be the derivative of p12 with respect to s11 the j and k do not 05:40:58.880 |
match so we are in this case here so it will be minus p12 p11 then the next element it will be 05:41:09.440 |
the derivative of p12 with respect to s12 so j and k match so we are in the first case so it will be 05:41:16.880 |
p12 multiplied by 1 minus p12 then this stuff here will be equal to then the third element 05:41:25.840 |
will be minus p12 with respect to p13 blah blah blah and until we arrive to the last one which 05:41:33.680 |
is minus p12 with respect to p1n not with respect to multiplied by b1 and all the elements like 05:41:41.440 |
this until the last row the last row will be the the first element of the last row will be the 05:41:46.160 |
derivative of the last output element with respect to the first input element so it will be the 05:41:53.520 |
derivative of p1n with respect to s11 so the two indices do not match so we are in the second case 05:42:02.880 |
so it will be minus p1n p11 this will be minus p1n p12 etc etc etc now let me do also the 05:42:16.000 |
third element since we are here so minus p1n p13 etc etc etc until the last element of the 05:42:24.560 |
last row which will be minus p1n p1n i guess oh oh no that's wrong guys because the two indices 05:42:35.440 |
match so it should be p1n multiplied by 1 minus p1n this is what the jacobian will look like 05:42:45.360 |
let's see if we can find a better how to generate this jacobian with some pattern recognition 05:42:51.920 |
let's write it in a different way first of all the thing first thing that we can notice is that 05:42:57.600 |
this jacobian is symmetric so you can see that this element is equal to this element if you 05:43:02.000 |
expand the third row you will see that it's equal to this element this one on the top right corner 05:43:07.040 |
is equal to the one in the top bottom left corner so this matrix is symmetric 05:43:14.720 |
the second thing that we can notice is that only the element in the diagonal are different 05:43:20.880 |
they have an additional term because you can look at this element here so let me write this element 05:43:28.320 |
here can also be written as p11 minus p11 multiplied by p11 the second element here 05:43:37.600 |
in the second row so the second diagonal element of this matrix is p12 minus p12 multiplied by p12 05:43:48.320 |
so this element on the diagonal actually look like just like the other elements they just have 05:43:53.920 |
an additional term which is p11 in the first diagonal element p12 in the second diagonal 05:44:02.560 |
element so we can also say that this matrix here is the product of all the possible combinations of 05:44:11.200 |
p_ij with p_ik which we can obtain with an outer product or even with the product of one column 05:44:20.320 |
with the transpose of the same column so if you do one column vector for example imagine p is a column 05:44:27.840 |
vector and you do p multiplied by p_t you obtain all the possible combinations of products of these 05:44:34.640 |
two vectors because this will be one i can do a simple case so p11 p1 let's call it p2 p3 05:44:43.040 |
multiplied by the row vector p1 p2 p3 this will generate all the possible combinations of products 05:44:53.440 |
between p1 and the p the first vector and the second vector because this will be a three by one 05:44:59.680 |
this is one by three so it will be generated three by three vector and it will be equal to p1 p1 05:45:06.960 |
p1 p2 p1 p2 p1 p3 etc etc etc moreover we can see that in the diagonal 05:45:18.480 |
of the matrix we have this additional term this additional term p1 in the first diagonal element 05:45:26.880 |
p1 p12 in the second diagonal element p13 in the third diagonal element i actually call it p1 it's 05:45:34.720 |
wrong because i should call it p_i that's why i didn't want to bring the i indices so it's not 05:45:40.000 |
really p1 it should be p_i p_i p_i p_i because we are doing it for the generic height p_i vector 05:45:49.040 |
so let me fix the indices p_i_n p_i_3 this is one 05:46:05.440 |
p_i and p_i okay so this is p_i p_i p_i p_i p_i p_i p_i okay we can obtain so we can write the this 05:46:21.440 |
this jacobian here also as the diagonal matrix that in the diagonal has all the 05:46:30.240 |
element of the p_i vector minus the p vector multiplied by the transpose of itself so with 05:46:38.800 |
itself but transposed because we need all the elements to be kind of a combination of one 05:46:44.720 |
element of p with itself with another element of p plus only on the diagonal we need some this 05:46:50.320 |
additional term which are the elements of p and all the elements of the the the output of this p 05:46:56.800 |
multiplied by p transposed are negated that's why we need this minus sign so if you look at the 05:47:01.840 |
flash attention paper they give you this formula here they say that if y is equal to the softmax 05:47:08.160 |
of x then the jacobian will look like the following will be diagonal 05:47:21.360 |
of y minus y y transposed where y is the is a column vector 05:47:32.080 |
all right guys i know this has been long so let's take a pause and we are going to now 05:47:38.320 |
code finally first of all let's check the mathematics of the backward path of flash 05:47:44.160 |
attention we will see it briefly i will not do any more derivation but i will explain it 05:47:49.920 |
and then we finally switch to coding it so let's go all right guys now finally we can see the 05:47:57.120 |
the backward path of the flash attention so we will be looking at the algorithm and if you look 05:48:03.760 |
at the the the appendix of the flash attention paper you will see this part b.2 where they derive 05:48:09.680 |
the backward path step by step now i don't want to do all the same all the steps of this derivation 05:48:17.120 |
because it's going to be too long but i want to give you all the tools necessary to understand it 05:48:21.920 |
now let's start from what kind of what say conventions they are using 05:48:30.240 |
notations they are using in this paper so the first thing that we need to rehearse is the 05:48:36.640 |
naming of what is what is the name of each matrix as you know in the forward attention in the 05:48:44.080 |
forward pass we do the query multiply by the transpose of the key and the output of this we 05:48:48.400 |
call it s then we apply the softmax to this s matrix and it becomes the p matrix the softmax 05:48:55.840 |
is applied by rows then we talk take this p matrix and we multiply by a v matrix to obtain the output 05:49:02.560 |
of the attention let's look at for example how the computation of the height row of the output 05:49:12.080 |
is computed based on the p matrix and the v matrix so we can understand this kind of notation that 05:49:17.120 |
they are using here in the paper because the way i read this formula here is the height 05:49:22.320 |
row of the output which is a column vector because in when we write in in mathematics in 05:49:28.560 |
linear algebra whenever we write the name of a vector it is always by convention a column vector 05:49:35.440 |
but the origin of this particular vector is actually a row of the output matrix let's try 05:49:42.320 |
to understand what is the output row of a matrix in a matrix multiplication now so that we can 05:49:50.960 |
understand how to go from here to here so let's write a generic matrix multiplication for example 05:49:58.400 |
an a matrix let's say that it is the following and we only write one row actually let me zoom again 05:50:07.440 |
and i want to write smaller so we have enough space so we make a matrix that has a row let's 05:50:14.000 |
call it a 1 a 2 a 3 and then we multiply this will be a matrix with many rows like the this one 05:50:23.760 |
because we want to study the effect only of one row and we multiply it by another matrix let's 05:50:30.000 |
call it this one is the matrix a and it has i don't know let's say n rows by three columns then 05:50:39.200 |
we should have another matrix b with three columns and some number of three rows and some number of 05:50:47.280 |
column let's say four columns so we call the first row let's call it let me zoom a more so it's b11 05:50:58.640 |
b12 b13 b14 then this one should be b21 b22 b23 b24 this should be b31 b32 05:51:17.120 |
b33 b34 etc i know i am not very rigorous in my notation i should have called all these elements 05:51:26.960 |
with the capital letter a and the capital letter b so this is the notation that you use when 05:51:31.920 |
referring to single item of a matrix but please forgive me for this so the output of this matrix 05:51:40.080 |
multiplication will be another matrix that is n by 4 so it will be n by 4 so we will have four 05:51:48.640 |
columns for each row of the output i want to write the output in a different way so i want to write 05:51:58.400 |
it as follows as a vector only so the first output row as a vector and want to understand what is 05:52:06.560 |
each dimension of this vector so because otherwise i don't have enough space to write it here 05:52:11.680 |
so the first let's write it so let's call it o i want to write what is o of one which is the first 05:52:22.720 |
row of the output but written as a column vector so o of one will be here we should use the small 05:52:32.400 |
letter o of one should be a vector where the first dimension is the dot product of this stuff here so 05:52:40.320 |
the first row of the a matrix with the first column of the b matrix so the first let's say 05:52:47.760 |
dimension will be a1 with b11 i should also call this one a11 a12 actually 05:53:00.160 |
and a13 so a13 because we have many rows in the a matrix so let me use the correct naming so this 05:53:10.240 |
will be a11 with b11 a11 b11 plus a12 multiplied by b21 plus a13 with b31 and this will be the 05:53:28.000 |
first dimension of the first row of the output matrix o the second dimension of the first row 05:53:36.000 |
of the output matrix o will be the dot product of this row of the a matrix with the second column 05:53:43.040 |
of the b matrix and let me write here b so it will be a11 b12 plus a12 b22 plus a13 b32 05:54:03.040 |
the third dimension will be a11 b13 plus a12 b23 plus a13 b33 the fourth dimension will be a11 05:54:24.080 |
b14 plus a12 b24 plus a13 b34 now this is the output the first output row of the o matrix and 05:54:40.160 |
it's a vector called o1 and these are this is the first dimension of this vector this 05:54:45.120 |
is the second this was the third and this is the fourth dimension and each of this stuff here is 05:54:49.280 |
one scalar um so the output o1 which is the first row of the output matrix can also be written as 05:55:01.520 |
the first element as you can see in is a sum of many vectors where the first element is a11 05:55:13.040 |
multiplied let me use a smaller this one but i want to use a smaller i can't change the size here 05:55:19.760 |
okay it doesn't matter so as you can see here there is a1 multiplying a different b number 05:55:26.880 |
every time so this is a b11 b12 b13 b14 what is b11 b12 b13 b14 it is the first row of the b 05:55:36.400 |
matrix so it is equal to b1 and all the dimensions of the first row then plus then we have the 05:55:46.720 |
element a12 multiplied by b21 b22 b23 etc etc and this is the second row of the b matrix so we use 05:55:57.440 |
the tensor notation of pytorch to describe this row which is a b2 and all the dimensions of b2 05:56:07.280 |
so it looks this is a vector scalar product and plus a13 multiplied by b3 05:56:22.320 |
and all the dimensions of b3 this one can also be written as the summation 05:56:29.280 |
over all possible i that go from 1 to 3 where 1 to 3 is how many columns there are in the a matrix 05:56:42.400 |
of a ij well a1 let's call let's call this one j actually sorry let's call it j 05:56:53.360 |
equal to 1 and let's call this the generic ith row of the output matrix will be a i1 a i2 and a i3 05:57:07.840 |
each one multiplied by the corresponding row in the b matrix so we can write it as a i j 05:57:16.560 |
multiplied by b j where b j is the a row of b we can also write it like this to indicate that this 05:57:29.920 |
is a vector and this is exactly what they do here so the output in the output matrix when we do the 05:57:37.360 |
multiplication p multiplied by v the ith row of the output matrix we call it o i which is a vector 05:57:46.320 |
but by notation it is a column vector where the elements of this column vector are actually the 05:57:52.160 |
elements of the ith row of o this is only by notation guys is equal to the ith row of p so 05:58:02.160 |
the ith row of the matrix that is on the left in the matrix multiplication multiplied by all the 05:58:07.440 |
columns of the v matrix which can also be written as the summation over all the elements of the 05:58:13.680 |
ith row of p so all the elements of the ith row of the first matrix the one on the left in the 05:58:19.280 |
matrix multiplication multiplied by each vector in the v matrix where the jth matrix here in v 05:58:27.200 |
is each row of the v matrix so and p i j can also be written as p i j is what is the the output of 05:58:38.480 |
the softmax so as you know the output of the softmax is e to the power of l the element input 05:58:44.640 |
of the softmax what is the element input of the softmax is the query multiplied by the transpose 05:58:49.280 |
of the keys so it's a dot product between one query and one key and that's why you have this 05:58:54.080 |
stuff here in the exponential so this is the first step in understanding this derivation another 05:59:00.160 |
thing that we have studied so far is how to derive the backward path of the matrix multiplication 05:59:06.560 |
and of the softmax so now let's use it in the matrix multiplication let's rehearse the formula 05:59:13.600 |
so if given a matrix multiplication that is y equal to x multiplied by w we know that given 05:59:21.600 |
the gradient of the loss function with respect to y so the output of this operation we know how 05:59:27.600 |
to derive the gradient of the loss with respect to one of the input of this function which is the 05:59:33.200 |
x or w to get the gradient with respect to x we need to take the upstream gradient so the 05:59:40.080 |
the gradient with respect to the output multiplied by the transpose of w t and to get the gradient 05:59:47.520 |
with respect to w we need to do the xt so the input transposed multiplied by the upstream gradient 05:59:55.600 |
this one is the formula that we didn't derive and this one is the formula that 05:59:59.440 |
we derived but how to derive them is exactly the same procedure 06:00:03.040 |
in attention we are doing the last product that we are doing is o equal to p multiplied by v 06:00:12.880 |
what pytorch will give us as input during the backward path is the gradient of the loss with 06:00:18.880 |
respect to the output and we need to use this gradient of the loss with respect to the output 06:00:23.680 |
of the attention to derive the gradient of the loss with respect to q with respect to k and with 06:00:29.600 |
respect to v so that it can then be used by the operators in the backward path in the in the 06:00:35.920 |
computation graph in the operations before okay so but in order to arrive to the gradient with 06:00:43.040 |
respect to query key and value we need to derive the gradient with respect to each intermediate 06:00:48.480 |
operation so the last operation that we do is o equal to p multiplied by v so the gradient with 06:00:55.840 |
respect to o of the loss with respect to v given the gradient of the loss with respect to o 06:01:04.160 |
it is exactly like computing the gradient of the of the loss with respect to x in a matrix 06:01:10.400 |
multiplication and we know that it is equal to pt so just by analogy guys so this is our 06:01:17.280 |
reference point and i am just changing the names here and you should understand what is the analogy 06:01:21.920 |
here so the gradient of the loss with respect to v which is the matrix on the right which is like 06:01:28.960 |
computing it with respect to w it is equal to just like this formula here so the transpose of 06:01:34.960 |
the matrix on the left multiplied by the upstream gradient which in the paper they write it as this 06:01:40.960 |
so dv is equal to pt multiplied by do and it's the formula that you said you can see here the other 06:01:48.080 |
derivation is how to derive the gradient with respect to dp dp is just like deriving the 06:01:54.400 |
gradient of the loss with respect to the matrix that is on the left side of the matrix multiplication 06:01:59.120 |
so it is just like deriving the gradient of the loss with respect to x in the reference 06:02:04.400 |
formulas which is equal to the upstream gradient multiplied by the transpose of the other matrix 06:02:11.840 |
which in the notation of the paper they write it as dp is equal to do multiplied by v transposed 06:02:18.000 |
and it's this formula here how they compute this stuff here is exactly as above so as this 06:02:25.520 |
derivation here they call vj the jth row of the v matrix and they write it as pij multiplied by do 06:02:37.680 |
how to arrive to this formula here well let's do it so let me write let's see okay 06:02:45.680 |
theoretically we know that from this derivation here so from this derivation here or from this 06:02:51.440 |
derivation here we know that the i-th row of the output in a matrix multiplication 06:02:57.360 |
first of all let's simplify our life every time you see a transpose and you don't like work with 06:03:02.000 |
the transpose in a matrix multiplication just give it a different name and then work with the 06:03:06.640 |
different name and after when you have derived the formula you resubstitute the transpose 06:03:12.480 |
operation in this case we are doing dv is equal to p transpose multiplied by do let's call p 06:03:19.920 |
transposed let's give it a name that we are we didn't use so far so let's call it f i always 06:03:24.800 |
use f when it's available so we call dv is equal to f do we know from above here from this derivation 06:03:38.720 |
here or this derivation here is equivalent that the output of a matrix multiplication so the out 06:03:45.920 |
i-th row of the let's know the j-th row let's call it the j-th row dvj is equal to a summation 06:03:55.680 |
of each element of the j-th row of f of the first matrix so we do the let's see here we do the sum 06:04:06.720 |
by i so let's do it by i it's the sum over all possible i of the i-th element in the j-th 06:04:16.880 |
row of the first matrix so fji multiplied dot product not dot product this is a scalar vector 06:04:32.800 |
multiplication multiplied by a vector that is let me check what was the formula so it was the j-th 06:04:40.000 |
row of the other matrix so in this case it should be the i-th row of the other matrix 06:04:45.760 |
o of i where i this is the i-th row of i this is the j-th row of the v matrix 06:05:00.480 |
and but also we don't we know that f is not a matrix that we have it's actually the transpose 06:05:06.400 |
of p which means that fji will be equal to pij because in a matrix transposition you invert the 06:05:13.440 |
two indices so this is the summation over all possible i's of p not ji but ij multiplied by o i 06:05:24.320 |
and this should be equal to the same formula that you see on the right here this allows you to 06:05:28.960 |
compute one output row in the v matrix okay and we know that pij is just the output of the softmax 06:05:39.440 |
the output of the softmax is the input of the softmax to the exponential of the input of the 06:05:45.200 |
softmax divided by the normalization factor associated with that row so because we are 06:05:52.960 |
iterating through the row of i it will be the i-th the normalization factor associated with that 06:05:59.600 |
row of of o i so we know that the formula for the p is equal to the softmax of s now the i-th 06:06:11.840 |
row of p will be the softmax of the i-th row of s and this is what is written here we know from 06:06:18.560 |
our derivation that the jacobian with respect to the softmax operation so if we have an input x and 06:06:26.640 |
the output is y of the softmax operation the jacobian of this of the y with respect to the x 06:06:33.440 |
is equal to the diagonal y it's a diagonal matrix of the element of the factor y minus y multiplied 06:06:41.120 |
by y transposed and we have also seen before that this matrix is symmetric however you may not 06:06:48.560 |
understand this formula here because we have seen from our in the chain rule we always write it like 06:06:55.200 |
this we always write that the downstream gradient so the d phi of let's say t x should be equal to 06:07:07.200 |
the upstream gradient so d phi with respect to d y multiplied by d y and with respect to d x 06:07:17.440 |
this only works if you make this matrix here as a in the numerator convention the numerator 06:07:26.160 |
convention is one of the two convention in which you can create a jacobian we so far we have always 06:07:31.760 |
written it as the numerator convention if you use the numerator convention this is a row vector and 06:07:38.480 |
this is a row vector however if you want to treat this stuff here as a column vector then you need 06:07:45.920 |
to take the transposed or you need to make the jacobian in the denominator convention 06:07:50.720 |
how to get this formula here because this formula here is basically doing the 06:07:55.520 |
jacobian multiplied by the upstream gradient not the gradient upstream gradient multiplied by the 06:08:02.160 |
jacobian and it's only because here we treat it as a column vector and when you do the you want to 06:08:08.240 |
transform a row vector into a column vector you take the transpose of both sides of the equation 06:08:12.000 |
and let's do it actually so we apply the transpose to the both side of the equation 06:08:20.560 |
okay in a matrix multiplication if you do a b transposed it become b transposed multiplied 06:08:30.480 |
by a transposed so the transposed is applied independently to each input of the matrix 06:08:35.840 |
multiplication but we invert the matrix multiplication and if you remember the 06:08:39.440 |
matrix multiplication is not commutative so what we do here is that we say okay it will be the 06:08:48.960 |
here they call it dsi so it will basically just become d phi on dx if you treat this 06:09:01.200 |
one as a column vector so this one as a column vector will be equal to dy on dx as a column 06:09:09.120 |
vector as a jacobian in the denominator layout in this case multiplied by d phi on dy as a column 06:09:19.760 |
vector this one is a column vector this is a column vector and this is what you see here that's 06:09:24.000 |
why the jacobian is on the left side of the upstream gradient what else we need well i i 06:09:31.120 |
know that there is a lot of things here in this derivation but i prefer actually going directly 06:09:35.840 |
to the code otherwise i think it's going to be too boring um so let's go to the code and while 06:09:41.920 |
writing the code i go back to the formulas in which we can find the association of what we are 06:09:47.360 |
doing and the formula in the paper i think this is the best way so let's proceed further all right 06:09:54.640 |
guys now we can finally code the backward pass before we code the backward pass let's look at 06:09:59.040 |
the algorithm of the backward pass as written in the paper this is the paper flash attention 06:10:03.920 |
one and i will be because we will follow the structure of the code that is present on the 06:10:10.560 |
triton website so it's not my idea to split it like this but i simplified it in such i simplified 06:10:17.840 |
it so it's different than the one that you can find online because mine is a simplified version 06:10:22.960 |
and mine works with the causal and non-causal attention um so first if you look at this 06:10:29.360 |
algorithm you need to you can see that we have an outer loop through all the k and v blocks 06:10:36.240 |
and an inner loop through all the query blocks however as you can see to compute the dq which is 06:10:44.080 |
the downstream gradient of the the loss with respect to the q matrix we need to have an 06:10:51.920 |
iteration through all the k's and to compute each dk block we need to have an iteration through all 06:10:58.960 |
the queues so if we follow the loop like it is it would involve writing to the high bandwidth memory 06:11:06.880 |
so to the dram of the gpu at every inner iteration and that could be also that is not so efficient 06:11:13.120 |
and also if we don't want to write it would require some sort of 06:11:18.080 |
some sort of synchronization between blocks which is also not very efficient 06:11:24.080 |
so we split we will split this four into two parts because we can see that each dq depends 06:11:30.800 |
on a loop over the k's and each dk depends on a loop over all the queues so to compute dk we will 06:11:41.200 |
fix the kth block and iterate through all the q blocks then we will do another iteration in which 06:11:46.640 |
we fix the q block and iterate through all the kv blocks to compute the dq this is what we are 06:11:52.320 |
going to follow and this is an idea that i took from the original implementation that is present 06:11:56.560 |
on triton website another thing that we can notice here is um where where is it here to compute the 06:12:05.360 |
dq and dk so a dq vector and the dk vector we need this element this information here called the di 06:12:15.680 |
di and it's shared between the two so we can pre-compute it and then we can reuse it for the 06:12:21.840 |
qi vector to compute the qi vector and the dk vector what is this di di is um is uh introduced 06:12:32.320 |
here and it's the dot product of a vector that is the doi vector multiplied by o vector so the first 06:12:41.040 |
thing that we will do is do a loop over all the vectors in o and do and do their dot products 06:12:48.000 |
to compute this di element then we will use this di element and actually uh let me see yeah and 06:12:56.160 |
then we will use this di element to update to to compute dq and dk and we will also have another 06:13:02.960 |
two loops one in which we fix the q and we iterate through all the keys and one in we fix the keys 06:13:08.400 |
and iterate to all the queues so let's start so now that we know more or less the structure of 06:13:14.800 |
the code that we're with all right so we start by writing this backward function here 06:13:23.920 |
uh let me check yeah okay so do you remember this is saved tensor these are all the information that 06:13:35.360 |
we save during the forward pass uh to compute the backward pass now to to optimize the memory 06:13:43.760 |
utilization in flash attention we don't save the query multiplied by the transpose of the 06:13:50.160 |
key matrix because that would be a sequence by sequence matrix that is too big to save into the 06:13:55.680 |
hbm in the dram during the forward pass and then i re get it back from the hbm into the local memory 06:14:02.480 |
because i want to remind you that in triton uh compared to cuda in triton what we do is we load 06:14:09.040 |
stuff from the high bandwidth memory in the shared memory so the sram we do all the operations there 06:14:15.920 |
and then after when we call the store method we save the element from the shared memory into the 06:14:22.000 |
high bandwidth memory so in order to not materialize this s matrix in its entirety save it to the hbm 06:14:30.160 |
and then reget it back which could be very slow and secondly actually it is very expensive because 06:14:35.840 |
usually right now we are computing attention on thousands and thousands of tokens so imagine 06:14:41.200 |
saving a matrix that is 5000 by 5000 that's a big matrix to save for each batch uh for b each batch 06:14:49.280 |
and for each head so that would be really too expensive to save so the idea in flash attention 06:14:56.480 |
is to recompute what we can compute on the fly during the backward pass because any way if we 06:15:01.760 |
were to load it it would be memory i/o bound so it's faster to recompute than to save it and restore 06:15:09.200 |
it from the memory this is the idea of flash attention okay so we saved some stuff during the 06:15:16.080 |
forward pass and now we can access it back during the backward pass and this stuff is saved in the 06:15:21.440 |
context and this it's a it's a kind of a dictionary that is made available by by pytorch all right so 06:15:29.600 |
we get back the query key and values and as you know pytorch during the autograd will just give 06:15:35.600 |
us the gradient of the loss with respect to the output of our implementation of the attention 06:15:41.440 |
of our attention so this is triton attention and then we need to compute dq dk and dv by using only 06:15:49.200 |
the gradient of the output with respect to the the loss with respect to the output um we do for 06:15:54.800 |
some checks so here i know i could optimize this code and make it even smaller by for example 06:16:02.080 |
checking that here the stride that i am using i actually inside of the code i always uh pretend 06:16:07.840 |
that the stride is the same but uh doesn't matter i just take the code from triton and uh try to 06:16:14.000 |
simplify it my goal was to simplify it not optimize it so all right we create the um the 06:16:21.520 |
vectors the tensors in which we will store the result of this backward pass which is the dq dk 06:16:28.320 |
and dv and as you know from what we have seen of the definition of the gradient the size of the 06:16:34.960 |
output of the gradient vector is the size of the vector with respect to which we calculate the 06:16:41.760 |
gradient because in the numerator is always a scalar and we compute the gradient with respect 06:16:46.080 |
to all the elements in the input vector so the output the gradient itself is a vector of the 06:16:51.200 |
same size of the element by which we compute the gradient with respect to so uh we get some 06:16:58.480 |
information on the bed size blah blah blah and later we will see what is this number of warps 06:17:04.400 |
and the number of stages i will not explain it now it's how pytorch number of parts warps is an 06:17:11.360 |
indication on how many threads we want to launch in our grid and number of stages is next to the 06:17:15.600 |
number of stages that has used in software pipelining we will see later what is software 06:17:19.360 |
pipelining when we talk about the auto tuning then we define some uh blocks uh in the original um 06:17:29.120 |
in the original code i think they call it a block kv1 kv2 q1 and q2 i think it was confusing i call 06:17:37.920 |
it a block macro and block micro because the thing that we will fix and the things that we will 06:17:42.480 |
iterate from will be once it's the query so we fix the query block and we iterate through all the 06:17:48.800 |
keys and then we will fix the keys and values block and we iterate through the queries the one 06:17:55.200 |
that we iterate on is the micro one and the one that we fix is the macro one this is my uh the 06:18:01.360 |
naming that i am using um then we as i said before we need to pre-compute the di elements that we saw 06:18:09.440 |
in the paper before so that's the first kernel that we are going to launch and this kernel will 06:18:14.800 |
have its own launch grid because later we want to optimize the the tuning of this kernel later 06:18:21.520 |
we will talk about the tuning with respect to its own parameters so let me see what are we going to 06:18:29.280 |
do so here so the first kernel that we are going to launch is this pre-process kernel this pre- 06:18:35.680 |
process kernel will pre-compute all the di elements that we need to compute i remember dk and dv if i 06:18:43.120 |
know dq and dk and this di element depends only on o and do um so let's do it and let's create 06:18:55.360 |
another function called the backward preprocessor what is the process preprocess grid this is the 06:19:01.440 |
launch grid of this function of this kernel and this will be launched on a independently for each 06:19:10.000 |
batch and for each head and moreover it will be work with a block size of vectors of o what is 06:19:17.440 |
this block what is this number of vectors of o it will be the block size macro so on 128 vectors of 06:19:25.760 |
o so uh let me copy the signature of this function this is here so let's write it here i think it's 06:19:34.960 |
fine yeah okay this function takes a the matrix o so it's a pointer to the matrix o it's a pointer 06:19:43.280 |
to the d o and it's a pointer to the matrix d where we will store this di elements and we have 06:19:49.600 |
one for each vector in the output that's why the shape of this d is a batch size number head 06:19:57.840 |
sequence length it means it's one for each of the output element in the output of the attention 06:20:03.040 |
this di where is it actually it's not this one it's this one yeah like m so it has the same shape 06:20:11.760 |
as m which is as you can see it is this size here so batch size number heads and sequence length m 06:20:18.320 |
if you remember is the matrix that we saved during the forward pass which includes the 06:20:22.160 |
normalization factor of the softmax and also the maximum element but in log sum exp format so that 06:20:29.920 |
when we apply it will automatically apply the maximum element for each row and also normalize 06:20:34.560 |
at the same time which i think i proved previously so let me do it so we write it like this so we 06:20:44.560 |
extract the the index of this program so this program has two index like identifier this is 06:20:55.600 |
equivalent to the cuda identifier and this is along the axis 0 so let's see what we what we 06:21:01.520 |
what did we launch on the axis 0 so on the axis 0 of this launch grid we defined what is the block 06:21:08.400 |
of vectors of the o that this particular will program will work with and the second axis is 06:21:17.280 |
which batch and which head inside of each batch this particular program will work with so this 06:21:23.760 |
identifies the block index of q so which group of vectors in the o matrix this particular program 06:21:30.400 |
will work with here is called q i believe because i copied it from the original code where they call 06:21:36.160 |
it q but i could have eventually also call it o um so we define uh so basically this means that we 06:21:45.440 |
are for this program we need to skip some query vectors that have been already or that will be or 06:21:52.160 |
have been already processed by other programs in parallel so we will only block with a number of 06:21:58.560 |
query vectors inside of o that have the following indices so imagine that the query block size is 06:22:06.480 |
i think it's 128 the way we have defined it but suppose it's a 4 for simplicity so this one will 06:22:14.480 |
be and the query vectors are how many are sequence length number of query vectors we have so some of 06:22:22.960 |
imagine the query vectors are in total they are i don't know let's say uh 64 and 32 will be managed 06:22:30.960 |
by other programs so this particular of skew will be equal to 33 34 35 and 36 this tells me which 06:22:40.560 |
query vectors or which vectors in the output o matrix among all the vectors in the o matrix this 06:22:46.880 |
particular program is going to work with okay so then we extract also the index of the batch which 06:22:54.800 |
tells us which batch and which head in each batch this particular program is going to work with 06:23:01.040 |
which is the dimension one of our launch grid and then we define the offset of the dimension 06:23:07.440 |
because we need to load all the dimensions of each vector so these are the it's a vector that tells 06:23:14.320 |
which dimensions we need to load from each vector and we will load all of them so we don't divide on 06:23:18.400 |
the head dimension dimension we just divide on the sequence length dimension the the load among 06:23:26.080 |
multiple programs um you will see in this part of the the video so when we are writing the backward 06:23:33.280 |
pass that we will not be using the make block pointer like we did during the forward pass 06:23:39.040 |
so this function here we will work with directly with indexing by using the strides so let's do it 06:23:47.200 |
so let's load a single block of rows of o which i want to remind you has the same shape as q and 06:23:57.040 |
that's why we can call it block size q so the o block that we are loading is o so uh the load 06:24:03.840 |
function accepts a pointer to what it should load actually not a pointer it accepts a array of 06:24:11.360 |
pointers or a multi-dimensional array of pointer in case you want to load a multi-dimensional data 06:24:16.720 |
so actually load also allows you to load two-dimensional data in this case we are going 06:24:23.360 |
to load two-dimensional data which is a block of rows of o which should be a block a tensor of the 06:24:31.520 |
shape block size q in this case multiplied by the other dimension being head dimension 06:24:39.120 |
but we don't we need to tell it where in this o matrix it needs to find this one first of all we 06:24:46.160 |
need to skip some batches and some heads based on what the head and the batch that will be processed 06:24:52.800 |
by other programs so based on the index that this um program will process of the batch and the head 06:24:59.920 |
we need to skip all the other batches and heads let's write the shape of this tensor so the o 06:25:07.840 |
tensor has a shape block size not block size batch size a number of heads then sequence length 06:25:17.360 |
and then head dimension each block and each head will have a sequence length multiplied by dim 06:25:25.360 |
head dim number of items so based on our index we skip how many items our index multiplied by head 06:25:32.960 |
dimension multiplied by sequence length so what i mean is this the batch zero and the head zero 06:25:40.480 |
will have a sequence length multiplied by head dimension items the batch zero and the head one 06:25:47.360 |
will also have the same number of items and the batch zero and head two will also have the same 06:25:53.120 |
number of items so how many items sequence length multiplied by head dimension do we need to skip 06:25:58.160 |
from the starting of the o tensor it is equal to the index of the current batch and head indicator 06:26:04.320 |
so because this index indicates both the head in the batch and the head inside of each batch because 06:26:11.360 |
it's already the product of the head and the batch so how many we skip indicated by the this index 06:26:20.640 |
and after we point to this starting point of the current batch and the current head 06:26:25.760 |
we need to select a two-dimensional tensor where the offsets are indicated for the rows by off skew 06:26:33.840 |
and that's why we have this one um the i don't know what this is called this is uh 06:26:41.600 |
the the index uh semi-colon index that tells all the all these vectors in off skew will with an 06:26:50.880 |
additional dimension for the columns and these columns will be the off dim so basically this 06:26:56.160 |
will select a tensor of the following shape inside of this big tensor that includes head size and 06:27:03.920 |
number of heads this is what we are doing so we are saying select a tensor of this size 06:27:11.280 |
inside of one that is made up of four dimensions by skipping the elements of all the batch and 06:27:18.080 |
heads that will be processed by other programs i always talk in terms of programs because in 06:27:24.400 |
triton these are called programs in coda you would refer to them as kernels 06:27:28.160 |
all right so this one is done i hope it is decently clear um all right so then we also load 06:27:38.560 |
a single block of d o in the same way because we are going to load a group of vectors from all the 06:27:46.880 |
sequence length also from d o and the d o has the same shape as o which has the same shape as q and 06:27:54.320 |
that's why we can use the um the the block index we call it q because it's equivalent because they 06:28:00.080 |
have the same shape okay and how to compute this d i element well it's written in the paper so if 06:28:08.160 |
we go in the in the what is it man if we go here it shows you how to compute the d i of each given 06:28:17.280 |
a block of d o and a block of o it tells you how to compute d i which is the row sum which means 06:28:25.200 |
the sum of by rows for each row we will have one sum for each vector in the o matrix we will have 06:28:32.080 |
some of the element wise product so this stuff here is the element wise product of d o i multiplied 06:28:39.360 |
by o i so it's not a matrix multiplication it's element wise product which means each element of 06:28:46.480 |
one matrix with the corresponding element of the second matrix and the output shape it will be the 06:28:51.280 |
same as the two matrices which must have the same shape okay so we compute this d i block 06:29:00.960 |
which will have shape block size q because we will have one sum for each vector 06:29:06.560 |
then well we need to store it somewhere so we need to calculate where to store it 06:29:12.800 |
inside of the d matrix well the d matrix is i remember correctly has the same shape as m so 06:29:20.240 |
it should be batch size a number of heads and sequence length so we need to select the right 06:29:29.120 |
batch and the right head and also the right position inside of the sequence length based 06:29:33.440 |
on the block index q that we have okay so let me index okay 06:29:43.120 |
all right because we already um so we skip um again just like before we know that the d is 06:29:53.680 |
of this size each batch and each head will have sequence length number of elements so how many 06:29:59.520 |
number of elements we need to skip from the starting of the tensor is sequence length 06:30:04.960 |
multiplied by the combined index batch size head number and plus we need to also skip some 06:30:12.640 |
queries based on our block index q and it's already this skipping is already done inside 06:30:18.480 |
of off skew so we add off skew and then once we have computed the index where we should store this 06:30:24.720 |
d i block why did i even call it d block let's store it so let me i didn't call it d block i 06:30:34.320 |
think it was already in the original code but this is d i and this big matrix d is actually 06:30:40.240 |
the matrix that includes all the d i for one for each token in the sequence length 06:30:45.600 |
all right so the pre-processing has been done now we need to do prepare the two for loops as you 06:30:54.480 |
remember i said before we will be doing two for loops one in which we fix the query and we iterate 06:31:00.880 |
through all the keys and values and one in which we fix the key and value block and we iterate 06:31:05.120 |
through all the queries and while coding it i will always show you the formula from the paper so 06:31:10.720 |
don't worry let's start with the next iteration so first we create the launch grid for the next 06:31:16.720 |
iteration as the launch grid is always the same so we first because we we need to keep one block 06:31:24.000 |
fixed and iterate through all the other blocks the block that we keep fixed will define how many 06:31:29.600 |
programs we have that run in parallel and the block that is fixed has a block size macro number 06:31:35.760 |
of elements that's why we create a sequence length divided by block size macro number of blocks 06:31:41.120 |
thread blocks or programs in this axis the axis two in this grid is i could have used 06:31:50.240 |
also the axis one indifferently i think it was already done here in the original code 06:31:55.520 |
it's we will indicate which batch and which head inside of each batch we are going to work with 06:32:01.920 |
so and just like the forward pass we will also use a variable called the stage that if the 06:32:10.560 |
attention that we are computing is causal it will be equal to three and if we are computing 06:32:14.880 |
a non-causal attention then it will be equal to one the first iteration we will fix k and v blocks 06:32:22.960 |
and we will iterate through all the q blocks in size of block size micro number of query vectors 06:32:32.000 |
so let's look at the signature so we pass we launch it as a launch grid because 06:32:38.800 |
and we have defined how many programs we have so we have how many kv blocks we will have 06:32:47.040 |
it's a sequence length divided by the block size macro because that's the the block that we will 06:32:51.920 |
keep fixed in this uh for loop in this function and then we go through all the query blocks in 06:32:59.200 |
size of block size micro which i defined it as 32 and later we will talk about auto tuning and 06:33:05.600 |
how to tune these values all right so i pass the query vector the key vector and the v vector uh 06:33:12.560 |
sorry not vector tensors now the query tensor k tensor and v tensor and they are pointing to the 06:33:18.800 |
beginning of the tensor which means that they are beginning to the first batch and the first 06:33:23.520 |
head and the first token and the first dimension of the tensors then we pass the softmax scale we 06:33:30.400 |
pass do dq dk and db m is the one that is needed to compute as you remember from what we said 06:33:39.040 |
before we did not see the p matrix in the hbm because we want to recompute it on the fly doing 06:33:45.760 |
the backward pass so the query multiplied by transpose of the keys it's a very big matrix 06:33:49.760 |
to save in the hbm and restore it so we want to compute it on the fly but we don't need to 06:33:54.560 |
recompute the normalization factor and the maximum element for each row to apply the softmax that was 06:34:00.480 |
already computed during the forward pass and saved into this matrix m which includes the log sum exp 06:34:06.800 |
of the maximum of each row plus the logarithm of the normalization factor with the log sum 06:34:14.080 |
x to 3 we can just apply it and it will also normalize each value then we have the d vector 06:34:21.120 |
tensor that we computed here with all the di values one for each vector in the o 06:34:27.120 |
tensor then we need to pass some the number of heads the sequence length the block size that 06:34:34.320 |
we want to use for the kv which is the macro block size and the micros block size is always 06:34:38.480 |
the one that we iterate on i think using this name it should be easier to understand which one we are 06:34:43.440 |
iterating and which we want to keep fixed so the fixed one is macro and the iterating one is the 06:34:48.160 |
micro head dimension later we will see why we use a different block size to iterate from because 06:34:56.320 |
this is related to the number of stages that triton can divide your for loop into thanks to 06:35:03.200 |
software pipelining then we have head dimension the stage indicates if the attention that we 06:35:10.000 |
computed in the forward pass was causal or not causal the number of warps and the number of 06:35:15.840 |
stages which we defined as fixed but later we will talk about auto tuning so sometimes i repeat the 06:35:22.480 |
same stuff over and over so i should change that okay let's write the signature of this function 06:35:33.840 |
let's put it here so we already described what is the signature of this function let's go directly 06:35:41.680 |
to the meat so the first thing that we need to do is understand the offset by which we need to move 06:35:47.200 |
this query key and value and the offset is given by the first wall we need to enter the right batch 06:35:53.840 |
and the right head inside of each batch we compute the index of the batch just like during the 06:35:58.720 |
forward pass by dividing the program the program index which is a multiplication of the index of 06:36:05.040 |
the head and of the the batch we divided by the number of heads to get which batch this program 06:36:11.920 |
is working with and to get the head we just do the modulus just like in the for loop for one person 06:36:17.040 |
the offset batch head indicates let me check what is it for okay it enters the right batch and the 06:36:24.640 |
right head so what is the stride if you remember correctly the stride tells us how many items you 06:36:30.000 |
need to skip in that dimension to arrive to the next index in the same dimension so if we want to 06:36:34.480 |
skip index number of batch we multiply it by the stride batch which is how many elements you need 06:36:41.200 |
to skip to arrive to the next batch plus we also need to enter the right head so we multiply the 06:36:48.160 |
index of the head multiplied by the stride of the head to enter exactly in that head in the tensor 06:36:54.560 |
for each of the q k and v matrices plus we also have this is will be used for if i remember for 06:37:02.640 |
m and d because m and d only don't have the um the head dimension head dimension so they are only 06:37:10.480 |
batch size number of heads sequence length so we just use the index batch multiplied by sequence 06:37:15.760 |
length because for each batch and on each head we will have sequence length number of items so you 06:37:19.680 |
can think of it at the stride to move from one batch head to the next batch head uh or to the 06:37:33.040 |
and this was so we move the pointer q k and v by the offset batch head because we want to enter 06:37:43.360 |
the right um batch and the right head inside of these big tensors and we do it also for d o d q 06:37:50.960 |
d k and d v because they have the same shape as a q k and v and d o also has the same shape as 06:37:56.720 |
q so they have the same shape so we move by the same uh by the same offset all right so 06:38:03.680 |
then we move m and d to move them to the right starting point on which the sequence of the 06:38:11.440 |
current head and the current batch and the current head starts so they are pointing to the first 06:38:18.000 |
vector of the sequence dedicated to the current batch and the current head 06:38:22.400 |
and the same is true for q k and v and the d o d q d k and v okay then we load some other stuff 06:38:32.400 |
because here we fix in this iteration in this method we are going to do a for loop in which 06:38:40.480 |
we fix k v and we iterate through q so we first need to load this deeps block of k v 06:38:47.040 |
and we do it as follows as follows so we know we need to load a 2d tensor so we need to define 06:38:56.160 |
what are the ranges in the second dimension of each vector k and v that we need to load 06:39:04.320 |
and it's defined by this by this vector then we want to understand which kv block this particular 06:39:14.960 |
program is going to work with so this particular program is going to skip some kvs that will 06:39:20.720 |
already be managed by other programs that may be running in parallel and how to understand 06:39:25.440 |
what this program should be working with in based on the index of the program zero which is defined 06:39:32.560 |
on the sequence length divided by the block size macro and if you remember block size macro is the 06:39:37.920 |
thing that we fix so it's telling us this program id zero will tell us how many block size macro 06:39:45.680 |
kv are already being managed by other programs so we shouldn't care about them so we skip them 06:39:52.080 |
so let's go back here and this is the number of vectors that we need to skip 06:39:57.280 |
so our kv start from start kv and how many we need to load them well depends on what is the 06:40:04.160 |
block kv this block kv is equal to block size macro so it will be 128 vectors so we define 06:40:14.480 |
our tensors two-dimensional tensors that we will store in the sram because in triton every time 06:40:22.080 |
you load something you load it from the hbm into the sram so we define where they should be saved 06:40:28.080 |
in the sram and they are initially zeros and now we load them so we load them as follows 06:40:34.000 |
we say that okay in the k in the k tensor pointer which is already pointing to the right 06:40:46.320 |
index to the right batch and to the right head because that's something that we did here 06:40:52.160 |
we say we should need we need to load the right sequence of keys which should start from 06:41:00.320 |
offski because this already includes how many we should skip in the sequence length dimension 06:41:06.000 |
and for each of these vectors we need to load all the dimensions in the 06:41:12.640 |
in the head dimension dimension because the k if i want to remind you is batch number of heads 06:41:21.120 |
sequence length and head dim now by using this line we are skipping to the right b and to the 06:41:32.800 |
right head so it's like we already indexed here and here we already selected an index so right 06:41:38.720 |
now this k is pointing to the beginning of a tensor of two dimension and we tell okay we don't 06:41:45.120 |
want all the sequence we want some part of this sequence which part the one that is indicated by 06:41:50.960 |
this start kv and how many of in the sequence length we want well we want uh all right i think 06:42:00.320 |
it's easy to write it like this so we can write it that from start kv to start kv plus block kv 06:42:08.480 |
uh so we want this number of tensor exactly at this location and for head dimension what do 06:42:16.480 |
we want to select we want to select all the dimensions so we say that we want from zero 06:42:21.280 |
to head dimension which is exactly this offskdim okay uh we do it for the k block and we do it for 06:42:31.760 |
the v block here i think i didn't change the comment this should be block kv and this should 06:42:40.320 |
be block kv before it was called the block kv1 right like in the original code i simplified 06:42:46.960 |
a little bit the naming i think this one is better easier to follow because in the original code they 06:42:51.360 |
also do for two for loops but in the second for loop they will do it backward just to not change 06:42:56.400 |
the structure of the loops but i think mine is more verbose but easier to understand and probably 06:43:03.120 |
less efficient mine is much less efficient um then we have offsq because we need to understand for 06:43:10.240 |
each block of queries how many vectors we need to load and it's indicated by this offsq and how many 06:43:18.000 |
are them it's a block q block q in the color of this method was block size micro so it is 32 06:43:26.080 |
vectors okay um now we need to access q vectors and o vectors trans uh no q vectors but already 06:43:38.640 |
transposed and the o vectors also we need to access them because we are going to iterate 06:43:46.000 |
through queries and o vectors actually also why because let's look at here let's look at the 06:43:54.720 |
formulas in the paper to compute vj so to compute the dvj that's what we are trying to compute here 06:44:03.360 |
we need to iterate through all the do vectors and to compute dk we need to iterate through all the 06:44:10.240 |
qi vectors because the qi is a block of vectors so that's why we need um and why do we need to 06:44:21.840 |
access a q as a transposed because we need to compute let me show you here pij transposed to 06:44:30.000 |
compute pij transposed we need to we need the q transposed because the pij would be the softmax of 06:44:35.920 |
the query multiplied by the transpose of the keys after we apply the softmax it becomes p but if you 06:44:41.840 |
want the transposed of p then you need to do query transposed k multiplied by query transposed so 06:44:49.680 |
that's why we access the query transposed instead of queries and the way we access the query 06:44:55.600 |
transposed is just by playing with the stride so let's do it like this and i have also written 06:45:02.880 |
the comment on why we can do it so this is equivalent to accessing the query uh how many 06:45:12.080 |
first okay what is this um what is this operation uh what is this operation here 06:45:18.160 |
this is saying go to the query starting point starting um pointer to the query which is 06:45:26.560 |
already pointing to the right batch and to the right head for which this particular program 06:45:31.680 |
should work with and select a two-dimensional vector where you repeat the query starting point 06:45:38.720 |
along the in this case along the columns but we should be repeating it along the rows because we 06:45:44.160 |
want to select rows of queries however if we want to select the query transposed we just invert the 06:45:51.280 |
two dimensions so this is a let me actually show you without doing the query transposed so let's 06:45:56.720 |
do it simplified like this so to access the query um the query pointers without transposition we can 06:46:06.000 |
just do like this go to the query tensor and create a 2d tensor where in the rows you put 06:46:14.720 |
the starting point of each query that you want to get and and replicate each of these points 06:46:22.240 |
also on the column that's the meaning of adding this dimension none this is equivalent to when 06:46:28.000 |
you do in pytorch the unsqueeze like you are calling off q multiplied not unsqueeze i think 06:46:40.160 |
one so this is equivalent to adding the column dimension to this tensor and repeating all the 06:46:47.360 |
values that are on the row on all the um on the columns how many columns will be there 06:46:53.840 |
it will be broadcasted when we sum it with this tensor here this is a combination of 06:47:00.720 |
unsqueezing and broadcasting so we are taking the query vectors indicated by off skew 06:47:06.800 |
and then we are for each query vector we are selecting all the head dimensions indicated by 06:47:16.720 |
dim if you invert this broadcasting it will create the transposed of the the the query vector that 06:47:24.480 |
you are trying to access so this stuff here is equivalent to the these two lines so accessing 06:47:31.200 |
query and then transposing and uh it's something that you can do uh i could write down what is 06:47:39.680 |
happening at the pointer level so basically you need to think of off skew as being a vector of 06:47:46.640 |
pointers we multiplied by the sequence stride which tells us how many element we need to skip 06:47:54.880 |
to go from one query vector to the next because each stride q will be the stride will will be 06:48:02.240 |
equal to in the case the head dimension is 128 the stride of the sequence dimension will be 128 06:48:09.040 |
it means that to go from one query vector to the next you need to um you need to uh go forward by 06:48:17.040 |
128 elements because i want to remind you that in the memory the tensors are always stored like 06:48:23.440 |
flattened like each dimension is flattened with the next dimension so imagine you have three rows 06:48:31.040 |
and four columns but the first you will have the first three rows then the sorry the first row so 06:48:36.320 |
the first four columns then the next four columns then the next four columns row after row 06:48:41.040 |
it's difficult to visualize until you write it down so how to write it down take um create a 06:48:50.640 |
vector of off skew so what is off skew at the beginning it's is a range that is from here from 06:48:58.720 |
0 to 100 no 0 to 32 0 1 2 3 4 5 6 7 etc etc we are multiplying each one with the stride of the 06:49:11.280 |
sequence so this will not skip any element this will skip exactly 128 elements this will skip 06:49:18.000 |
exactly implying that the head dimension is 128 this will skip two times 128 elements this will 06:49:25.760 |
skip three times 128 elements and then we are adding also another dimension to this vector 06:49:33.920 |
so this will be a vector then you broadcast it on head dimension number of columns and to each of 06:49:40.400 |
them you add one number so it will become a vector like for okay let me just do it guys otherwise i 06:49:46.960 |
think it's too confusing okay so we have a vector that is as follows so zero then we have 128 then 06:49:56.000 |
we have two times 128 then we have three times 128 etc etc we are adding how many columns 06:50:03.120 |
indicated by off dim so off dim has how many columns so it has a head dim number of columns 06:50:08.480 |
please for simplicity let's pretend it's not 128 dimensions let's pretend it's four dimensions so 06:50:14.800 |
this will be four this will be two times four this will be three times four we are adding another 06:50:22.320 |
dimension that is the dim dimension each one multiplied by the stride of dim which will be 06:50:28.720 |
one because it's the last dimension stride dim so we are adding how many columns four 06:50:38.480 |
so we are adding um one zero one two three i guess zero one two three right also to this one 06:50:45.920 |
we are adding oh my god zero one two three and also to this one we are adding zero one two three 06:50:56.720 |
okay and then also to this one we are adding zero one two three so what this will select 06:51:05.280 |
this will select from the starting point of the pointer q it will select the element zero 06:51:12.160 |
then the element one then the element two and then the element three which is exactly the 06:51:18.240 |
head dimension of the first vector that we should be selecting then it will select 06:51:23.280 |
the element four from the starting point of the vector the element uh sorry this one let me write 06:51:29.760 |
the result of this operation so this one will be zero one two three then it will select the element 06:51:34.720 |
four five six seven then it will select the element um eight i guess nine ten eleven 06:51:41.600 |
and then it will select the element 12 13 14 15 so from the starting point of where this q is 06:51:52.480 |
pointing it will select the first element right after this q the second element right after this 06:51:57.840 |
q the third element right after this q etc etc and this will be the you can see that this will be the 06:52:03.600 |
first query vector this will be the second query vector this will be the third query vector this is 06:52:07.920 |
the fourth query vector because in the memory they are stored one after another they are flattened 06:52:13.200 |
so in the memory they are stored like this they are stored like the following they are stored 06:52:19.600 |
like this one after another so it will select all of them and it also create a virtual tensor with 06:52:25.920 |
the right shape that we want to visualize it into so as you saw as we saw before when you work with 06:52:32.480 |
a tensor layout in memory you can always view it as whatever shape you like based on the shape that 06:52:38.400 |
you want and the reshaping is always free doesn't involve changing the arrangement of the elements 06:52:44.960 |
in the memory i hope now it's more clear so now we can proceed further oh my god it was quite 06:52:51.200 |
complicated so whenever i get stuck i just draw things and i think you should do it too because 06:52:56.320 |
that's the only way to learn if you try to imagine everything in your head it's always difficult 06:53:02.080 |
and we do the same job for the o vectors so in the o vectors we don't access it as 06:53:08.320 |
access it as a transpose because we don't need it in transpose only the q we need it in transposed 06:53:14.240 |
okay it traced through the sequence dimension of the query so we start from the query number zero 06:53:21.280 |
in the current um well in the query we need to go through the all the sequence length dimension 06:53:27.520 |
because only the key we select the right key that we want to work with so i want to remind you here 06:53:32.720 |
we fix the key and we go through all the queries but the query we need to start from zero until 06:53:38.080 |
sequence length so the number of steps of this for loop will be sequence length divided by block q 06:53:44.800 |
so if we have a 1000 elements in the sequence and block q is 32 it will be 1000 divided by 32 06:53:54.320 |
a bad choice of 1000 should be 1024 otherwise it's not divisible so then we go through each 06:54:01.440 |
block in this for loop and we load a block of q the first one indicated by our pointer 06:54:07.360 |
and at the end of the iteration we will move it to the next to the next block of q 06:54:12.160 |
okay we'll add also the log sum exp values that are stored in the m matrix 06:54:20.800 |
because we want to compute on the fly pt pt is the transposed of the softmax of query multiplied 06:54:29.520 |
by the keys but we want to not take a query multiply by the transpose of the key and then 06:54:34.160 |
do the transpose we just already access q as a transposed so we can already compute the pt instead 06:54:40.400 |
of computing p and then transposing it um so we load the offsets of the elements that we need 06:54:49.200 |
from this log sum exp matrix which is the m matrix that we computed during the forward pass 06:54:57.040 |
and we access a block of q at a time the one we are currently working with in the iteration 06:55:05.440 |
then we access a query key transposed already so we do the if you want to get the pt 06:55:17.440 |
p should be um this is actually not p because we didn't do the softmax it's actually s t but okay 06:55:24.800 |
if you want to get the pt you need to get the softmax of st the softmax of st is what it's a 06:55:34.560 |
it's transposed of s what is s is a query multiplied by transposed of the key so to 06:55:39.440 |
get st you need to do um key transposed no key multiplied by query transposed so as you remember 06:55:45.200 |
in the matrix multiplication if you transpose the matrix multiplication you need to also invert the 06:55:51.520 |
two element in the matrix multiplication so that's why we are doing a key multiplied by query 06:55:56.720 |
transposed this will give us s transposed we are also scaling it with the softmax scale 06:56:03.440 |
before we apply the to apply the softmax we just need to do the exponential of each element minus 06:56:10.640 |
its maximum divide by the normalization value but with the log sum extract we just need to 06:56:16.000 |
each element subtracted by the m value which already includes the normalization factor 06:56:22.640 |
i think i already did the derivation of this so we don't need to go through that again 06:56:28.000 |
okay so now we have the pt block actually so in this formula i should have written st actually 06:56:35.920 |
okay then when doing the causal attention we also need to mask out some values 06:56:47.760 |
so as you can see here so in this case the causal mask is applied after the softmax has 06:56:56.480 |
been computed because during this one is you are used to compute the apply the soft the 06:57:03.200 |
causal mask before computing the softmax attention but this is actually during the forward pass 06:57:08.960 |
because you don't want the normalization factor to be affected by the element that should be zero 06:57:14.240 |
but we already computed the normalization factor so it cannot be affected anymore so we can compute 06:57:20.880 |
we can mask out after applying the software because the normalization factor has already 06:57:25.200 |
been calculated based on the fact that we applied the mask and that's why we we can apply it after 06:57:30.880 |
applying the softmax so the mask is always the same so if the query is more than the 06:57:38.000 |
index of the query so the mask is true in this case for all the values that do not need to be 06:57:44.800 |
masked so all the values that do not need to be masked are these ones here and all the other 06:57:50.640 |
value will be replaced with the zeros all right so after we have the pt block already masked 06:58:00.640 |
we can calculate dv dv i will write i will point to the right formula in the paper so we load a 06:58:07.680 |
block of do why do we not load a block of do let's look at the paper so how to compute the dv block 06:58:14.480 |
so the dv block is computed as the old dv plus so a repeated sum as you can see as you can see 06:58:23.520 |
it's here plus equal the old dv plus pt so here pt dropped indicates the pij after applying the 06:58:34.960 |
dropout in this implementation we don't support the dropout and also very few models actually 06:58:40.320 |
use the dropout in the attention so pt multiplied by doi so a block of doi and doi is the same 06:58:50.720 |
block that should be also doi and ki qi are referring to always the same block of rows in 06:58:58.960 |
the respective tensors that's why because this inner iteration i indicates a block of q and a 06:59:08.320 |
block of o but we are always referring to the same positions in the tensors because do has the same 06:59:15.520 |
shape as dq so we go through the blocks of query and the do simultaneously because one is needed 06:59:23.920 |
for dv so for dv we need do and for dk we need q and that's why we compute the dv as follows 06:59:32.960 |
just like from the paper so pt block multiplied by do as you can see it's a p transpose multiplied 06:59:38.800 |
by the o block so we have computed computed the do block then we need to load the di element that 06:59:47.600 |
we computed pre-computed initially with the first call to the function called the attention 06:59:56.000 |
backward pre-process because we will need it for dk so let's see and how many of them we are loading 07:00:06.800 |
exactly the same number of query that we load because they are we load always the same number 07:00:13.200 |
of block size micro number of vectors okay i will copy some stuff and explain it step by step so 07:00:21.520 |
the next operation that we need to do is to compute this dk to compute the dk we need the dst 07:00:30.240 |
to compute the st we need to to get a dpt so let's go one by one let's go from the back from 07:00:39.200 |
the end to the beginning of this formulas so we don't understand where everything is used to where 07:00:47.280 |
everything is created so let's start from dk if you look at the paper dk is equal to the old dk 07:00:55.120 |
plus ds transposed multiplied by a block of q and this is what is written here so it is 07:01:03.360 |
plus equal means basically just the old plus the new some it's an incremental addition so 07:01:10.240 |
increment the old k with some new stuff which is this stuff here so the softmax scale multiplied 07:01:17.200 |
because also there is a softmax scale this tau here multiplied by the matrix multiplication 07:01:23.440 |
between dst block and the transposed of um and and q and q you can see here this q but we don't 07:01:34.960 |
have a q we have a q transpose so we take the transpose of q transpose and it becomes back q 07:01:40.560 |
now let's look at this dst block dst is calculated as follows so in the formula of the paper we have 07:01:48.720 |
ds ds is here it is equal yeah it is here and it is equal to a block pij multiplied element wise 07:01:59.280 |
with dpi minus di now um we don't need ds we need ds transposed so to compute ds transposed this is 07:02:10.000 |
an element wise multiplication not a matrix multiplication which means that when you take 07:02:14.800 |
the transport of this operation you don't need to invert anything you just need to take the 07:02:18.480 |
transpose of the two operands so to compute the st we take the transposed of p which is the pt and 07:02:25.920 |
we already have that and then the transpose of everything that is inside of the parentheses so 07:02:31.280 |
this dpt minus di where we inverted the rows with the columns so this dpt is what well in the paper 07:02:40.480 |
we know the formula for dp dp is here and it is equal to d wait dp here and it is equal to do 07:02:51.280 |
multiplied by b transposed so but we don't need the dp we need the dpt and in this case it's not 07:02:58.080 |
an element wise multiplication it is a matrix multiplication so um in order to get not a dp 07:03:05.760 |
dp but dpt we need to take the transpose of these two operands of this matrix multiplication and 07:03:10.960 |
in the matrix multiplication when you take the transpose you need to also invert the order of 07:03:15.280 |
the two operands so we need to take the vt transposed which becomes v so the v block 07:03:22.400 |
matrix multiplied by the other operand so doi transposed and that's why we are doing the 07:03:29.120 |
transpose of do right now i'm not going through all the single pointers because i already told 07:03:36.240 |
you how to check what a pointer is pointing to and what an offset is referring to i hope that 07:03:42.400 |
now you have a better understanding on how these pointers work in triton which is also the same way 07:03:48.480 |
in the in which they work in cuda because in the gpu we only get a pointer to the starting point 07:03:55.280 |
to the starting address of the tensor and then we need to work out all these indices 07:03:59.040 |
we have computed the dk block so we now go to the next query to the next block of queries and 07:04:07.440 |
so the next block of queries because we are fixing k and v blocks and we are iterating 07:04:16.960 |
through all the queries so we need to move the query transpose the pointers 07:04:22.320 |
by stride sequence which means that how can we go from one query to the next 07:04:27.200 |
and we multiply it with the current block q which is a vector which indicates the pointers 07:04:34.080 |
to the current element in q that we are accessing and we do it also for do and we use the block q 07:04:39.840 |
as element and the stride q because do and q all have the same shape okay after we have run the 07:04:48.480 |
for loop of all the queries we can store this dk and dv block so we write it back as follows 07:04:56.080 |
and this is the end of our function guys so we save the dv block exactly in the position 07:05:05.520 |
inside of the current okay dv is already i believe pointing to the right 07:05:10.640 |
batch and to the right head because we incremented it here and also in the case of dk 07:05:16.400 |
then we need to tell it in the sequence dimension where they should save this 07:05:20.320 |
block of k and v and this is indicated by this one we say and we create the the pointers just 07:05:28.240 |
like before guys don't make me do it again it's a really easy if you write it down like you write 07:05:35.120 |
this vector of key and values pointers which is not pointers actually they are a range of the 07:05:45.920 |
of key and value that you need to take from the sequence dimension you add another dimension that 07:05:52.320 |
is the column so you repeat each value in the columns and then you add the dimension here 07:05:58.080 |
for the head dimension anyway after we have calculated the pointers where we should store 07:06:03.280 |
the dk and the dv we store them in the the pointers of we store them in the dv 07:06:12.000 |
i mean we store them in the dv tensor and the dk tensor what do we save we save the dv block and the 07:06:18.480 |
dk block which is the one that we were incrementally changing in the for loop that we have written 07:06:25.840 |
okay now that we finished this one we can go to the next function that will do the other for loop 07:06:32.400 |
so let's do it okay so now we do the second part of the iteration which is this one so let me just 07:06:40.160 |
copy it and then we we describe it uh let's write it here okay we use the same launch grid as before 07:06:49.600 |
of course we need to declare this function and again we um we because the grid is defined for 07:06:55.200 |
the block size macro for what is the thing that we keep fixed and then we in the side of the 07:07:01.520 |
for iteration we do um steps of block size micro in this case we are fixing q and we are iterating 07:07:09.840 |
through k and v because we need to compute dq right now we have computed dk and dv 07:07:16.320 |
okay the i believe the arguments are the same as before so and actually this is also the reason 07:07:23.680 |
why in the original implementation on the triton website the author decided to um to use the same 07:07:30.880 |
for loop but with different arguments and i believe it was a little confusing so that's why i just 07:07:36.720 |
separated them i just repeat the code twice it's the goal of this video is to be as easy 07:07:42.240 |
to understand as possible not to be as efficient as possible so uh let's go uh here so let me copy 07:07:50.400 |
the signature again let me define this function here okay so uh again we need to first move the 07:08:02.720 |
query key and value uh to the right pointer so which will point to the exact batch and the exact 07:08:08.640 |
head that we are working with in this program so um let's do it let me check where is the code here 07:08:17.520 |
and the first part is exactly the same as the other for loop that we have written 07:08:21.840 |
so let's go here and really is i just copied so it's exactly the same so we check what is the 07:08:30.720 |
index batch head we move the query key value pointers to the right place the d o d q d k d v 07:08:36.480 |
point to the right place the m and d to the right place exactly like before so i don't think i need 07:08:41.840 |
to explain that again and then we load a block of q the one that we will keep fixed 07:08:58.800 |
okay we define the offset that we will need to load the blocks of k and p in the head dimension 07:09:05.840 |
because we are going to iterate in the k and v we will access them as transposed blocks so instead 07:09:12.880 |
of accessing them directly as a k and v we access access them as a kt and pt and you know that that's 07:09:20.240 |
possible just by changing the strides in this case because we are treating them as a 2d vectors we 07:09:26.960 |
treat the offs kv when you want to access k as just not transposed but k you treat this offs kv 07:09:34.880 |
as a row vector sorry a column vector so you repeat on the rows each k offset that you want to 07:09:43.680 |
access in this case we are repeating it as a we are treating it as a row vector so it will be 07:09:49.280 |
repeated on the rows um sorry it will be broadcasted on the column dimension and that's how you can 07:09:57.520 |
access the transposed version of k and how you can access the transposed version of v another thing 07:10:03.760 |
that we are doing is we are loading the q vector which vector well based on offs q which is the 07:10:09.920 |
q vector which vector well based on offs q which is based on the start q which is based on the 07:10:16.240 |
exact starting point in which this particular program should be working with because this 07:10:20.640 |
particular program works as two dimensions the first dimension indicate which batch and which 07:10:26.720 |
head this program should be working with and the second dimension which is the program index number 07:10:31.440 |
zero indicates which among all the sequence length which query this particular program is going to 07:10:37.120 |
work with this is indicated by the index block this should be actually q in this case i forgot 07:10:43.840 |
to change the name so actually let me change it so it's index q because we start we skip some q how 07:10:50.000 |
many q we skip based on the index of the current program multiplied by how many blocks have already 07:10:57.600 |
been processed by the previous programs this will tell us inside of the sequence length what are the 07:11:04.400 |
queries that this one needs to select so that's why we use the start query plus the range that 07:11:10.320 |
is block q so imagine the starting query for this program among all the sequence length is 100 07:11:16.320 |
then this will load the query row 100 101 102 until 100 plus block q minus 1 07:11:24.640 |
this is the range that we of the query vectors that we will load in this program 07:11:32.720 |
we load the block of q by using a q plus the offset repeated on the columns so we treat it as a 07:11:41.520 |
column vector but we repeat broadcast it on the rows vector where each column will be one head 07:11:51.200 |
dimension multiplied by the stride in this case we actually can also not multiply by the stride 07:11:55.680 |
because the stride in the dimension dimension so the last dimension of the batch is one 07:12:01.680 |
because to go from one um actually the stride um how it is defined the stride of the last dimension 07:12:10.000 |
is always one because to go one element to the next element you should move to move it to by one 07:12:16.720 |
element um so we load the dq which is the stuff that we are going to compute in this um iteration 07:12:27.040 |
and then we have a do that we need to load and the do we use the same offset as q because the do and 07:12:33.680 |
dq have the same shape and they work in the same way so we load a block of q and we load the 07:12:40.080 |
corresponding block of o of do in this case and the do has the same shape as o which has the same 07:12:47.440 |
shape as q plus we need to load also the m normalization factors which are in the m matrix 07:12:54.400 |
which one the chorus the one corresponding to this particular group of queries that we are 07:12:58.800 |
going to work with in this particular program we start with the offsets are the as you can see the 07:13:06.800 |
offsets are the first block of kv starting from the zero position so because we will iterate through 07:13:13.680 |
all the kvs and we start from the zero kv so the key key vector zero and the v vector zero and then 07:13:21.040 |
we move by block kv number of vectors forward at each iteration i hope i didn't go too fast because 07:13:30.080 |
most of the things that are written here are very similar to what we have already done 07:13:34.240 |
in the the other for loop so i don't want to be you know repeat myself too much um what did matter 07:13:41.840 |
is actually the formulas that we will use which is exactly the one in the paper so uh we go through 07:13:48.880 |
these blocks of kv we load the first block of k transposed and v transposed which is loaded like 07:13:59.200 |
this as usual you tell it what pointers the elements you want to load and what are the 07:14:04.320 |
pointers of another element that you want to know that it will load the the block that you are 07:14:09.360 |
asking triton to load inside of the sram so this stuff all reside in the sram and also q 07:14:15.520 |
resides in the sram and also do reside in the sram um then we compute the query multiplied by the 07:14:23.840 |
transpose of the keys because we need to compute the p block so the query the qk block is just the 07:14:31.520 |
query in the current query block with the k transposed in the current query block and the 07:14:36.880 |
current key block um why but we access the query the keys already as a transposed so we don't need 07:14:45.360 |
to transpose it and anyway even if we did if we need to transpose it it's just um it's not it 07:14:53.440 |
doesn't require any computation to transpose matrix we just access it in a different way 07:14:58.480 |
because in the memory layout it's always stored kind of as a flattened array 07:15:06.160 |
then we compute the p block which is the output of the softmax so each of the query key 07:15:10.880 |
we subtract the log sum exp value for the this block of queries that's why for loading the m 07:15:19.680 |
block we use the offsets of the queries that we are loading and as you remember the m block already 07:15:26.400 |
includes also the normalization factor because each m is actually the maximum value for each 07:15:32.160 |
row plus the logarithm of the normalization factor that when you apply with the properties 07:15:36.800 |
of the exponential it goes into the denominator okay and then we apply again the autoregressive 07:15:44.240 |
masking oops what did i do let me go back to the code here so we have the stage this one 07:15:59.520 |
so when we launch the backward pass stage three indicates that it's a in the forward pass we 07:16:05.760 |
computed the causal attention and the one indicates that we computed the non-causal attention so if we 07:16:13.520 |
computed the causal attention in the forward pass we also need to mask out these elements in the 07:16:19.120 |
backward pass so we check um we create the mask which tells us which index this mask is true for 07:16:28.720 |
only for the elements for which the query index is more than the key index and if this is true 07:16:36.400 |
then we uh we don't mask otherwise we mask um let's compute the next operation which is to compute dp 07:16:45.600 |
and ds actually i let's compute directly dk and then we explain it like before so we start from 07:16:51.280 |
the end and we go to where this stuff what is needed to compute it so if you look at the formula 07:16:58.160 |
uh let me check this one we don't need i think okay let's go here to the ipad okay what we are 07:17:13.360 |
trying to compute here is dq so dq as you can see in the paper is a dq is equal to the old dq 07:17:23.920 |
plus tau which is the softmax scale which is this stuff here multiplied by 07:17:31.680 |
the matrix multiplication between the ds and the k block so the ds block is here and the k 07:17:41.040 |
block is the transpose of the kt block because we are accessing k already as a transpose block 07:17:47.840 |
we could also access a k directly as not transposed block by inverting if you don't 07:17:54.000 |
want to access it as a transpose block just do like this like here none this will treat it as a 07:18:01.920 |
row vector and broadcast along the columns otherwise and also this one you need to change 07:18:08.240 |
so this one you shouldn't need to change because this one you need to treat it as 07:18:14.720 |
a column vector the dimensions but if you want to access it as a k transpose then you just 07:18:20.080 |
invert these two operations i hope i didn't mess up anything so let's move forward um so okay we 07:18:27.840 |
know that the formula for the dq is exactly the same as the as the paper one but what is this ds 07:18:35.760 |
block let's look at the paper this ds block is coming from this stuff here so this i believe 07:18:43.360 |
this stuff here ds which is a pi the p block element wise multiplication with the dpi minus 07:18:53.040 |
di which is dpi minus di now what is the this p block the p block is exactly the output of 07:19:02.160 |
the softmax which we already have what is the dp block well the dp block is exactly 07:19:08.000 |
do multiplied by v transposed which is a do which we already loaded and it's here 07:19:14.800 |
multiplied by the transpose of the v which we already load as transposed and this is how we 07:19:20.080 |
compute the dq let's include then of course we need to move to the next block of keys so we 07:19:29.280 |
increment the pointers just like before so we move to the next block of keys and values 07:19:37.760 |
and also remove the pointers um just like before and then we need to store the result of dq and 07:19:45.680 |
this way we only need to do one write to the hbm by dividing the for loop like the following 07:19:52.400 |
so if you look at the original algorithm uh i i don't know if the original algorithm actually 07:19:58.560 |
corresponds in to to the implementation that they did in cuda but i don't think so because it would 07:20:04.400 |
not be so optimized but in the original algorithm in the paper they say that you need to go through 07:20:10.800 |
all the keys and then while going through the keys you need to go to all the queues and for 07:20:14.960 |
each queue that you visit then you need to write back the queue while you are updating it which is 07:20:19.760 |
not optimized that's why we needed to do two for loops one in which we fix the query and we update 07:20:26.160 |
the keys because each key is updated depends only on a particular block of queue on all the blocks 07:20:32.720 |
of queue sorry and then we fix the queries and we iterate through all the keys because one block of 07:20:37.680 |
queue depends on all the blocks of case and this is why we split and this is the second loop that 07:20:43.680 |
we have written now we have written everything that we needed to for flash attention um the 07:20:50.800 |
forward pass and the backward pass so uh we should be ready to uh launch the uh the kernel i hope i 07:20:58.480 |
didn't make any mistake in copying the code so i don't think i will try to launch it and if there 07:21:04.480 |
is any error i will just use my reference code which i have already written that i used as a 07:21:08.720 |
copy the only difference up to now between my reference code and the one that we have written 07:21:14.640 |
is the auto tuning which i didn't explain so let's talk about the auto tuning so the auto tuning is 07:21:20.240 |
also something that was already present in the original paper and i kept it as is uh i removed 07:21:25.120 |
the auto tuning for the backward pass but in the forward pass you if you check there is this code 07:21:32.160 |
here that indicates the auto tuning configuration for triton so triton basically cannot know 07:21:40.560 |
beforehand what is the best block size or what is the best block size for the query or what is the 07:21:45.840 |
best block size for the key and values or what is the best block size for another dimension that we 07:21:51.520 |
have we need to try based on the hardware that we are running on based on the availability on the 07:21:57.840 |
sram based on the thread coarsening that triton can apply so i didn't talk also about thread 07:22:03.920 |
coarsening basically in cuda you can choose if each thread does one atomic operation for example 07:22:10.000 |
in a matrix addition each thread is doing one addition of one particular element of the output 07:22:15.840 |
matrix or it's managing multiple elements this is called thread coarsening and i think i didn't 07:22:21.120 |
check the documentation but i believe triton does it for you based on the block size that 07:22:26.320 |
you give it and the number of warps that you want the number of warps is what is a block of threads 07:22:33.040 |
of 32 threads that work cooperatively running the same instruction always at the same time 07:22:40.560 |
the number of stages is more interesting it's an optimization that triton does 07:22:48.480 |
basically it is not loop unrolling so actually let's talk about uh let's talk about software 07:22:55.360 |
pipelining because this is the last part that we need to understand from this code which is 07:22:58.640 |
the auto tuning so i believe that the most interesting part here is not choosing the 07:23:02.720 |
block size q and the block size k because that is just a kind of you try whatever whatever 07:23:07.920 |
configuration works best based on the timing through cuda triton will actually run all these 07:23:14.000 |
configurations for you every time the sequence length or the head dimension changes and for 07:23:18.160 |
every pair of head dimension and sequence length it will choose the best configuration that runs in 07:23:23.360 |
the least amount of time that gives you the best throughput actually so let's look at this 07:23:30.240 |
numstages what is it and how it works so let's do it okay so software pipelining is 07:23:38.880 |
it's used when you have a kind of a for loop so you have a sequential operation 07:23:44.800 |
in which each iteration does not depend on the previous iteration so the operations that 07:23:49.280 |
you're doing in one iteration are independent from what you have done in the previous iteration 07:23:52.480 |
which is more or less what we have done before in our for loops actually there i believe there are 07:23:58.400 |
how to say conditions in which this doesn't have to be true so like 07:24:03.920 |
the operations can depend on each other and you still can do software pipelining 07:24:09.600 |
so for example imagine you have the following for loop for loop that rose from one to n and 07:24:16.880 |
first you load some data then you load some other data then you do a matrix multiplication and then 07:24:21.280 |
you store some data so here you are reading data here you are reading data here you are 07:24:26.240 |
computing some stuff and here you are writing data if we look at what happens at each iteration 07:24:33.280 |
we will see the following picture imagine our gpu is made up of a compute unit and a unit that is 07:24:40.160 |
dedicated to loading stuff so reading from the memory or writing to the memory what we will see 07:24:46.080 |
in the time scale is that at the first iteration first we are reading some data and the compute 07:24:52.000 |
unit is idle because we need this data then we are reading some more data and the compute unit 07:24:56.480 |
is idle because we need this data then finally we have enough data and then we can compute this 07:25:00.960 |
operation and the reading unit is idle and then we are writing some data back to the memory and 07:25:06.160 |
the compute unit is again idle and then it will be idle for another two time steps until it has 07:25:12.000 |
enough data to run the computation so as you can see this is not very efficient because at any 07:25:18.560 |
time a point in time there is only one unit working and the other is sitting idle 07:25:24.480 |
so one way to optimize this for loop is to do software pipelining and you can tell triton to 07:25:30.480 |
do it for your for loops by telling it how many stages you want so let's see how it works so to 07:25:36.320 |
pipeline a for loop means that first of all you need to convert all these operations into async 07:25:42.480 |
operations and in cuda at least in the gpu of nvidia there are the async loading from the 07:25:48.560 |
memory and the async load writing to the memory which basically means that i spawn a load operation 07:25:55.120 |
and after and when i only i check if it has completed when i actually need it so i will 07:26:02.480 |
spawn this operation and this instruction will return immediately and move to the next instruction 07:26:07.840 |
here i will spawn a load iteration and this will return immediately and move to the next instruction 07:26:13.280 |
and then i can compute but before computing i just check if these two operations have completed so i 07:26:19.040 |
can spawn immediately two reads and then i just check if this they have completed so with the 07:26:26.560 |
software pipelining what we are doing is we are pipelining operations of different iterations 07:26:32.640 |
into a single iterations so first basically what we will do is we will do the read the first 07:26:38.720 |
matrix that we need for computing this matrix multiplication then at this next iteration we read 07:26:45.200 |
the the we we read the first matrix of the second iteration and also read the second matrix of the 07:26:55.040 |
first iteration so i call it read a and read b which indicates read the first matrix of the 07:27:02.080 |
that we need and the b means the read the second matrix that we need all these operations are 07:27:08.880 |
asynchronous then i launch another asynchronous operation at the third iteration that says 07:27:14.160 |
read the the first matrix of the third iteration and then read the second matrix of the 07:27:23.040 |
of the second iteration and then compute the matrix multiplication because at the third 07:27:31.040 |
iteration this one and this one should have completed but while computing the matrix 07:27:37.200 |
multiplication i don't keep the loading unit idle because they are still computing this this 07:27:43.600 |
and this load this can only work if you can spawn async operations so at the third iteration i can 07:27:51.600 |
compute this matrix multiplication by using this one and this one because they should have finished 07:27:56.720 |
but while i'm computing the matrix multiplication i already spawned some async operations to load 07:28:02.080 |
the data necessary for the second iteration and the third iteration so at the fourth iteration 07:28:08.720 |
i will spawn the loading of the data for the fourth iteration loading the data for the third 07:28:15.040 |
iteration while computing the matrix multiplication of the second iteration because they should have 07:28:19.040 |
already completed by now actually it's not like we expect them to have been completed there are 07:28:25.360 |
primitives in the language in the CUDA language to check if the operation has completed so actually 07:28:31.760 |
before doing the multiplication we will actually check if the async operation has finished so it's 07:28:37.520 |
not like we just expect it we have finished it with respect to time this is like in javascript 07:28:43.760 |
you have these things called promise i remember and you can wait for the promise to be finished 07:28:50.080 |
before you actually need them but you can spawn as many promise as you want in C# i think they 07:28:55.120 |
are called tasks so you spawn as many tasks as you want and then when you need it then you just wait 07:29:01.680 |
for them only the one that you needed while the other are still running in the background 07:29:06.080 |
asynchronously this is the whole idea of software pipelining software pipelining as you can see only 07:29:13.840 |
works when you have async operations and also it increases the memory requirement for your program 07:29:19.200 |
because when matrix multiplication one is going to run we may have enough data for the first 07:29:28.720 |
two iterations plus half data for the third iteration so we increase the memory requirement 07:29:34.320 |
for the SRAM okay and the Triton will do this software pipelining for you it will convert all 07:29:42.720 |
the load all the stores and maybe also the matrix multiplication into async operations and do this 07:29:48.240 |
pipelining for you if you are confused by how it works there is another easy solution to explain 07:29:53.360 |
you how it works because it's already something that we do in model training it is called pipeline 07:29:58.720 |
parallelism so in pipeline parallelism it works as follows we have a very big neural network that 07:30:06.240 |
does not fit in a single gpu so imagine this neural network is made up of three layers layer 07:30:11.040 |
one layer two and layer three but this is so big it does not fit entirely in one single gpu so one 07:30:17.360 |
way would be to put this each layer into one gpu so we put for example layer one into gpu one 07:30:25.520 |
layer two into gpu two layer three into gpu number three so imagine we have an input for this neural 07:30:33.680 |
network so we put it to the first gpu the gpu one will process the layer one and generate some 07:30:39.840 |
output which will be transferred to the gpu two which will calculate its own output and transfer 07:30:45.200 |
it to the gpu three which will compute its own output and finally we will have the output of 07:30:49.280 |
the neural network the problem is when you send the output of the gpu one to the gpu two for the 07:30:55.600 |
gpu two to do its own thing the gpu one now is free so it is a waste of resources we could always 07:31:02.160 |
should keep the gpus busy so what one thing that we can do is instead of sending all the the mega 07:31:08.880 |
batch to the gpu one we send many smaller batches how does it work imagine that we send the batch 07:31:15.760 |
number zero so batch zero uh to the gpu one the gpu one will compute its output and send it to 07:31:23.920 |
the gpu two so now the gpu two is computing the batch number zero so now the batch zero is not 07:31:30.560 |
here anymore but now the gpu one is free so we send another micro batch called the batch one 07:31:39.360 |
then the gpu two will finish processing the batch zero and we'll send it to the batch to the gpu 07:31:46.160 |
number three so now the gpu three has the batch number zero and the gpu two now is free so we 07:31:52.240 |
transferred and hopefully also gpu one has finished so we transfer the batch number one 07:31:56.960 |
from gpu one to gpu two the batches and then the gpu one will be free so so we transfer 07:32:05.760 |
here becomes one and now this one is free so because it's gpu one is free we can introduce 07:32:10.720 |
another batch so batch number two etc etc etc so we always introduce when while moving one 07:32:18.160 |
batch from one gpu to the other we introduce a new batch at the beginning of the pipeline and 07:32:23.440 |
they shift by one position at every iteration this will keep the gpus always busy there is only one 07:32:30.080 |
problem of the pipeline parallelism which is the this bubbling effect because to create this 07:32:35.440 |
pipeline you at the beginning of this um okay actually in the pipeline parallelism you also 07:32:40.240 |
have the problem of the backward step so the backward step has to run exactly in reverse 07:32:44.880 |
in the order in which you receive the micro batches while in triton when doing software 07:32:51.520 |
pipelining you have the problem of the prologue and the epilogue because you need to create this 07:32:57.440 |
pipeline and and to start the pipelining and at the end of the pipeline you need to 07:33:04.320 |
use all the stuff that is currently in the pipeline so only in the beginning step and in 07:33:10.400 |
the last step of this for loop your um all the units of this gpu may not be working simultaneously 07:33:18.000 |
which what does it mean it means that in order to use pipelining you want the number of iterations 07:33:23.840 |
of your for loop to be much more bigger than the number of stages in which your iteration is 07:33:28.400 |
divided into in this case we have four stages these are called stages so you want the number 07:33:33.520 |
of iterations to be much more to be much larger than the number of stages all right guys finally 07:33:39.760 |
i have completed the video um i hope that you learned a lot from this video i believe that we 07:33:45.520 |
can run the triton code so let's run it actually uh let's see i copied everything i believe we also 07:33:52.240 |
put the code to test it but we didn't call uh put the uh main method which we can copy right now i 07:34:00.400 |
hope there is no error so i really hope there is no error i really hope so um let me check if i am 07:34:09.440 |
in the right machine i am so let's just run program pray if there is an error i will just copy my own 07:34:20.880 |
reference implementation but i hope it works because otherwise i forgot something so i'm 07:34:27.280 |
running my code on an h100 because my company has h100 if you have a smaller gpu what you can do is 07:34:34.720 |
you can reduce the sequence length you can reduce the batch size i think it's already one when we 07:34:41.600 |
call it uh oh no the batch size you can reduce the batch size the number of heads the sequence 07:34:45.600 |
length you can even put head dimension equal to 8 and sequence length equal to 16 let's check 07:34:51.680 |
run backward triton backward returned an incorrect number of gradient expected 5 got 1 07:34:59.600 |
we probably forgot some return statement i believe yes so i forgot the return statement 07:35:08.800 |
here so in the backward pass after running the last for loop we need to return the stuff that 07:35:14.160 |
we have computed cross finger again okay passed so the backward pass that is computed by torch 07:35:24.080 |
it is equivalent to our backward patch up to 10 to the power of minus 2 error absolute error 07:35:31.040 |
so when you as you can see this backward that we run here is different than the backward that 07:35:37.360 |
we run here because when you apply triton attention it will introduce a new computation 07:35:42.400 |
graph in the computation graph of our tensors that will include this triton attention operator 07:35:47.920 |
and when pytorch want to compute the backward pass it will just call the backward function 07:35:52.080 |
of this triton attention to compute it and it will populate the grad value of all the tensors 07:35:58.400 |
that are the input to this triton attention and this is how pytorch autograd works guys 07:36:04.960 |
thank you for watching my video guys it has been super super super demanding 07:36:11.120 |
i spent many months first of all to learn myself about the triton about cuda about flash attention 07:36:17.360 |
etc also i have a full-time job so it is really hard to make videos like this like i need to 07:36:23.440 |
dedicate you know the nights the mornings the weekends i spent three days just to record this 07:36:27.920 |
video because sometimes i don't like how i explain something sometimes i make mistake or sometimes i 07:36:32.800 |
need to restart because what i'm doing is wrong etc etc i believe there should be no big errors 07:36:39.760 |
in what i have done so far but for sure my notation is completely bad like because all 07:36:45.280 |
the mathematics i know has been self-taught by i i learned it by myself so because i didn't learn it 07:36:51.760 |
in academia i have bad habits and i'm trying to get rid of them so i use the very bad notation 07:36:57.040 |
sometimes i calls with the capital letter sometimes with this lowercase sometimes i just 07:37:01.760 |
forget the index etc so i'm trying to solve these problems um i believe i have explained everything 07:37:08.240 |
so i should be you should have all the knowledge to derive all the formulas that you see in the 07:37:14.320 |
paper of the flash attention and you should also have an internal image on how the back the the 07:37:21.200 |
attention calculation is working block by blocks i know that i could have spent 20 hours explaining 07:37:27.360 |
things better but i also have a life and i also have a wife so i i i cannot make a 100 hours videos 07:37:35.920 |
also there were some interruptions making these videos i i removed some wisdom teeth so it took me 07:37:41.120 |
at least one more than one week to to to recover because it was so painful so thank you guys for 07:37:48.640 |
watching my video i hope you learned a lot also this time i as you can see triton is something 07:37:53.600 |
new there is not much documentation so something that i have said about triton may not be totally 07:37:58.480 |
correct because really there is very little documentation so all the triton that i have 07:38:03.040 |
learned is by looking at the code written by others and try to understand it um and i think 07:38:12.400 |
that's it guys so i wish you a wonderful day and see you next time on my channel