Back to Index

Flash Attention derived and coded from first principles with Triton (Python)


Chapters

0:0 Introduction
3:10 Multi-Head Attention
9:6 Why Flash Attention
12:50 Safe Softmax
27:3 Online Softmax
39:44 Online Softmax (Proof)
47:26 Block Matrix Multiplication
88:38 Flash Attention forward (by hand)
104:1 Flash Attention forward (paper)
110:53 Intro to CUDA with examples
146:28 Tensor Layouts
160:48 Intro to Triton with examples
174:26 Flash Attention forward (coding)
262:11 LogSumExp trick in Flash Attention 2
272:53 Derivatives, gradients, Jacobians
285:54 Autograd
300:0 Jacobian of the MatMul operation
316:14 Jacobian through the Softmax
347:33 Flash Attention backwards (paper)
373:11 Flash Attention backwards (coding)
441:10 Triton Autotuning
443:29 Triton tricks: software pipelining
453:38 Running the code

Transcript

Hello guys, welcome back to my channel. Today we are going to explore FlashAttention. Now, we are going to explore FlashAttention from first principle which means that not only we will code FlashAttention, we will actually derive it. So we pretend that the paper, FlashAttention paper, never existed and we look at the attention computation and we look at the problem it has and we try to solve it step by step pretending that FlashAttention never existed.

This will give us a deep understanding of how it works and also we will combine theory with practice because we will code it. Now, in order to code FlashAttention we will need to write a kernel for our GPU and in our specific case I will be using an NVIDIA GPU so a CUDA kernel but instead of writing C++ code we will use a Triton which is a way of converting Python directly into CUDA kernels that can run directly on the GPU and Triton you can think of it as a compiler that takes in Python and converts it into something that can run on the GPU.

So let's look at the topics for today. First of all I will give an introduction to multi-head attention because we need to look at what is attention and how it's computed and what are the problems in computing this attention. Then we will look at actually the most critical part of the attention computation is this Softmax and how it impacts the computation and complexity.

We will look at what is online Softmax. Then we will explore what is the GPU because we are going to write a kernel that will run on the GPU so we need to understand what is the difference for example the CPU and the GPU and what is a kernel and how it differs from a normal program that you write for the CPU.

We will look at how tensors are laid out in memory so row major layout, column major layout or etc strides. We are going to look at block matrix multiplication, Triton, software pipeline and all the optimization that Triton does to our code. Finally we will be able to code the FlashAttention forward pass but of course we are not satisfied only by coding the forward pass.

We also want to code the backward pass but in order to code the backward pass we also need to understand how autograd works and the gradient descent works in the case of custom operations so we need to understand what are derivatives, what are gradients, what are Jacobians and then we calculate the gradient of the common operations that we use in FlashAttention and finally we will have enough knowledge to code the backward pass.

For this reason this video is going to be super long but I hope you don't mind because we are going to learn a lot. Of course you may be wondering all of this requires a lot of knowledge that you may not have but that's not a problem because that's my problem because in this video I will make sure that if you only have high school calculus so you know what are derivatives you have basics of linear algebra like you know what is matrix multiplication or what is the transpose of a matrix and you have a basic knowledge of attention mechanism so like for example you have watched my previous video on the attention is all you need paper and you have a lot of patience that should be enough to understand all of this video because all the topics that I will introduce I will always introduce them in such a way that I pretend that you don't know anything about the topic so we try to derive everything from first principle everything from scratch.

Okay now that we have seen the introduction let's go see the first part of the video which is the multi-head attention. All right let's talk about multi-head attention. Now I am using the slides from my previous video attention is all you need so we can look at very fast at what multi-head attention is and how it works.

I hope you remember the formula softmax of the query multiplied by the transpose of the key divided by dk all multiplied by b because we will be using that a lot throughout the video. Now multi-head attention starts from an input sequence or two input sequence in case we are talking about cross attention.

In the simple case of self-attention we have one input sequence which is a sequence of in the case of language model a sequence of tokens where we have sec number of tokens and each token is represented by an embedding so a vector with d model dimensions. The first thing that we do is we convert this input sequence into query key and values through three linear projections one called wq one called wk one called wv which in pytorch are represented through linear layers and these linear layers are of d model by d model so they do not change the shape of the input tensor and then after we do this job of projecting them they become three different sequences one called query one called key and one called value so here i'm calling them q prime k prime and v prime then we divide them into smaller embeddings so each of this token which is made up of d model dimensions we divide it into smaller tokens each one suppose we have four heads each one will have d model divided by four dimensions so this one is a sequence of tokens where each token is not the entire token but a part of the embedding of each token and this one is a another part of the embedding of the tokens and this one is another part of the embedding of the token etc and we do this job for the query key and value sequence then we compute the attention as follows so the softmax of the query multiplied by the transpose of the key divided by the the square root of dk where dk is the dimension of each head so how many dimensions each head is working with and then we do the multiplication with v and this will give us the output of the attention mechanism for each head and this job is done independently for each head this should be clear to you if it's not please watch my previous video on the attention mechanism because we will be working with this scenario a lot now then we take this o the output of each head and then we concatenate it back in order to get the representation of each token as a full embedding so before we split this embedding into smaller embeddings this one here is called the q1 q2 q3 q4 then after we compute the attention we get back the output of each head and we concatenate it together to get back the full embedding dimension which is this edge here we run it through another linear projection called wo which will be the output of the multi head attention now flash attention is not concerned with all of these operations actually flash attention is only concerned with the operation that require optimization and the operations that require optimizations are this one so the softmax of the query multiplied by the transpose of the key divided by the square root of dk multiplied by v which means that the projection of the input sequence through wq wk and wv is not something that flash attention is concerned about because that's a matrix multiplication so when you use a linear layer it's just a matrix multiplication of the input with the weight matrix of the linear layer and this kind of operation so the matrix multiplication is one of the most um optimized operation that we have in the gpu because the manufacturer of the gpu usually also releases um the necessary library for computing the the matrix multiplication so actually these are quite fast and they do not require any optimization so flash attention will pretend that the query is has already been passed through by wq and the key has already passed through wk and the v has already passed from wb moreover flash attention will not be concerned with the projection with wo because that's also a matrix multiplication because the wo is always represented in pytorch as a linear layer so it's a matrix multiplication and matrix multiplication as we have seen are very optimized so there is nothing to optimize there but what we need to optimize in terms of speed is this operation here softmax of the query multiplied by the transpose of the keys divided by the square root of vk multiplied by v all right guys so now we have rehearsed what is multi-head attention i also want to give you a lot of visualization which is basically here in the paper of the multi-head attention we can see that we have the input that is v k and q so q k and v each of them runs through a linear layer which is the w q w k and w v then we do the scaled dot product attention which is done independently for each head so each head will do query multiplied by the transpose of the key divided by the square root of dk where each query and each key is not the full embedding of each token but a part of the embedding of the token because we split them into smaller embeddings and eventually we take all the output of each of this head which are computed in parallel so that's why you see this dimension edge in the depth we concatenate them and then we run them through w o what are we concerned with we are concerned with optimizing this particular block here the scaled dot product attention so let's start our journey one thing that is very important to understand is why do we even need a better implementation of the attention mechanism and if you look at the flash attention paper you will notice the following part this is the paper flash attention one and in the flash attention one paper they describe the attention implement implementation as it's done naively when using pytorch so first we do the multiplication of the query multiplied by the transpose of the keys then we apply the softmax to the output of this operation and finally we multiply the output of the softmax with the v matrix to obtain the output of the attention the way this implementation is done by pytorch without any optimization is as follows so we load the first of all these tensors are residing in the gpu the gpu is made up of two main memories one is called the hbm which is the dram which is the the ram of the gpu which is the 40 gigabyte of the a100 for example so it's the biggest memory that we have in the gpu and then there are there is the shared memory so the problem of the gpu is that accessing this hbm so the global it's also called the global memory it's very very slow compared to the shared memory however the shared memory it's much much smaller compared to the hbm and what they claim in the flash attention paper is that the operation of the attention is i/o bound meaning that if we keep accessing the global memory the overall operation of computing the attention is not because computing all these operations it's slow but because we keep accessing the global memory which is slow so we call this kind of operations i/o bound so the only way to improve this situation is to compute the attention inside the shared memory of the gpu which is much smaller which is much closer to the cores that actually do the computation so we will need to kind of also split the attention computation into smaller blocks that can reside in the shared memory and we will see later in how this is possible through block matrix multiplication and this is in the paper here they call it the tiling and it's a very how to say use the technique when doing when writing kernels for the gpu which are usually involve some kind of matrix multiplication so now we know what problem the flash attention is trying to solve it's trying to make sure that we do not need to access the hbm so the high bandwidth memory when computing the attention but copying only a part of each matrix inside the local memory so the shared memory of the gpu that is closer to the cores and computing a part of the output matrix there then copying that part to the output in that is residing in the hbm and keep doing it for all the blocks in which we can divide this query key and value matrices and later we will see how this blocked computation is done but also we will see that the biggest problem in computing this block computation is the softmax because the softmax needs to access all the row of the s matrix to apply the softmax because the the softmax needs to have a normalization factor which is the sum of all the exponentials of all the values to which it is applied row wise and we will see later how we will solve this problem so let's move on all right guys um okay when i say guys i mean guys and girls because i don't know in my usually i just say guys too you know but please girls don't feel excluded so we saw that first of all flash attention is only concerned in optimizing this softmax of the transpose of softmax of the query multiplied by three divided by the square root of dk multiplied by b and we need to introduce a little bit of notation so that we don't get lost in the future slides first of all this is the formulas i took from the flash attention paper but for now we let's pretend flash attention never existed so we are trying to solve the problem step by step now um we should treat this q as something that has as the sequence that is the output of the input sequence that has already passed through wq the k as something that has already passed through wk and v as something that has already passed through wv because we don't want to optimize the matrix multiplication because it's already fast enough another thing is let's talk about what are the dimensions of these matrices so we can then understand what is the the dimensions of the output of this operation so we will see treat q as a sequence of tokens with n tokens so n tokens where each token is d has d dimensions so lowercase d dimensions why because usually we take the queries and then we split them into multiple heads so we have we pretend we have already done this splitting so we pretend we are ready to cover input sequence we already run it through wq and then we have already split it into multiple heads and each of this head will do the following operation so the the one we already saw and so the usual formula query multiply the transpose of the keys and each of this head will work with these dimensions for the query for the key and for the value sequence so now let's look at the the dimensions of the output so the first operation that we will do is the query multiplied by the transpose of the keys where the transpose of the keys is a matrix that originally is n by d but become but with the transpose will be d by n so d by n and the result will be a matrix that is n by n because in the matrix multiplication the outer dimensions become the dimension of the output matrix what do what is the next operation that we do we take the output of this operation so the query multiply by transpose of the keys and we run it through a softmax operation and we will see what is the softmax operation which preserves the shape of the input so it doesn't change the shape of the input matrix it just changes the values of it and then we take the output of the softmax and we multiply it by v which will change the which will change the of course the shape because the p matrix is n by n so this one is n by n and v is n by d so this one the output will be n by d the outer dimensions of this matrix multiplication now let's look at the details of each of these operations so when we do query multiply by transpose of the keys we will get a matrix that is n by n where each value in this matrix is a dot product of a row of q and a column of k in particular the first element of this matrix will be the dot product of the first query with the first key vector the second element will be the dot product of the first query with the second key vector and the third element will be the first query with the third key etc etc and the let's say the the last row of this matrix will be the dot product of the last query with the first key then the last query with the second key the last query with the third key etc etc until the last query with the last key you may also notice that here i have written query transpose the key because when we what is q1 first of all q1 is the first row of the query matrix so a little bit of background on matrix multiplication so we know that when we do matrix multiplication each output element is one row of the first matrix with one column of the second matrix but we are doing the product of the first matrix with the transpose of the second so it will be the dot product of the one row of the query matrix with one row of the key matrix because we are doing the multiplication with key k transposed when you take a vector from a matrix the usual notation so the in in as in how to say in in in mathematics in a linear algebra we always pretend that a vector is a column vector so we cannot just write q multiplied by k because that would be mean that would mean we are doing the dot product of we are doing the kind of the matrix multiplication of one column matrix with one column matrix that is not possible because the shapes do not match so as a notation we write that we do the dot product of the first matrix the transpose which is a column vector but we transpose it so it becomes a row vector with the second vector this is just because of notation guys so you just need to pretend that this is the first query with the first key then the first query with the second key the first query with the third key etc etc etc so we are doing dot products of vectors then we apply this softmax operation the softmax operation what it will do it will transform each of these dot products which are scalars so the output of a dot product is a scalar and it will transform each of these numbers in such a way that they become kind of a probability distribution row wise which means that each of these numbers is between 0 and 1 and when we sum up these numbers together they are sum up to 1 and this condition this property will be valid for each row so this row also will sum up to 1 this row will sum up to 1 and this row will sum up to 1 etc etc etc let's see what is the softmax operation now given a vector so let's call it x which is made up of n dimensions the softmax is defined as follows so it is the the softmax basically transforms this transforms this vector into another vector with the same dimension where each item of the output vector is calculated as follows so the height element of the output vector is the exponential of the element input element divided by the summation of all the exponentials of all the dimensions of the vector basically this is called the normalization factor to make it all these numbers between 0 and 1 we usually normalize that's why it's called the normalization factor and we use the softmax because we want each of these numbers to be positive we don't want the stuff the output of this operation to be negative so that's why we use the exponential but there is a problem the problem is imagine our input vector is made up of many numbers that are maybe large so for example let's say x1 is equal to 100 x2 is equal to 200 x3 is equal to 300 which is can happen if we do the exponential of these numbers so the exponential of 100 that is going to be a huge number it's going to very close to infinity at least compared to what we can store in a computer so the output of exponential of 100 may not fit into a floating point 32 or a floating point 16 number or even an integer of 32 bit so we cannot compute it because it will overflow our our variable our integer that is storing this value this output so we talk in this case about numerical instability so every time you hear the term numerical instability in computer science it means that the number cannot be represented within a fixed representation with the bits we have available which are usually 32 bit or 16 bit we have also 64 bit but that will be too expensive to use so let's try to find a solution to make this stuff here computable and numerically stable in order to make this softmax operation numerically stable which means that we want these numbers to not explode or to become too small that they are not representable we need to find a solution and luckily it's quite easy so the softmax as we have seen before it is the following formula so each number is exponentiated and then we divide it by this normalization factor which is just the sum of the exponential of each input dimension of the input vector if we multiply the numerator and the denominator of a fraction with a constant with a number then the fraction will not change so that's what we are going to do we are multiplying the numerator and the denominator with this factor c as long as c is not equal to zero of course then we can take this c and by using the distributive property of the product with respect to the sum we can bring this c inside of the summation as you can see here then we can also write every number as the exponential of the log of itself because the exponential and the log will cancel out and then we can by using the properties of the exponentials we know that the product of two exponential is equal to the sum of the is equal to the exponential of the sum of the arguments of each exponential and we do it on the numerator and in the denominator then we just call this quantity minus log c equal to k or k is equal to minus k is equal to log c so we can replace this quantity with k we can do that because this is a constant that we have chosen and we just are assigning it to another constant so basically by doing this derivation we can see that we can sneak in a value inside of this exponential that if chosen carefully can reduce the argument of this exponential and we will choose this k equal to the maximum element inside of the input vector that we are applying the softmax to so that each of this argument will be either zero in case xi is equal to the maximum element that we are processing of the vector or it will be less than zero and we know that the exponential when it's equal to zero will be equal to the output of the exponential will be one so the argument when it's zero it will be equal to one and when it's smaller than zero so it's in the negative range it will be between zero and one so which is easily representable with floating point 32 for example so this exponential will not explode anymore so basically to apply the softmax to a vector in a numerically safe way we need to find a k constant which is the maximum value of this vector and when we apply it we need to subtract each element minus this constant that we have chosen so let's look at the algorithm to compute the softmax so first of all given a vector or given an n by n matrix because we want to apply the softmax to this matrix here which is n by n we need to go through each row of this matrix and for each row we need to find the maximum value among the elements which takes time complexity linear with respect to the size of the vector to the size of the row to which we are applying the softmax then we need to compute the normalization factor which is this stuff here and we we cannot compute it before the step number one because we need to have the maximum element to compute this summation here and after we have calculated the normalization factor we can then divide each element's exponential by the normalization factor and we cannot do the step number three before calculating the normalization factor because we need to divide each number by the normalization factor so if you like pseudocode for algorithms this is an algorithm for computing the softmax that we have seen right now so first we find the maximum of the row to which we are applying the softmax then we compute the normalization factor and then we apply the softmax to each element which means that we calculate compute the exponential of each element minus the maximum value of the vector divided by the normalization factor now this pseudocode is an algorithm that is quite slow because look at a practical example imagine we have this vector here first we need to do step one find the maximum value in this vector which is number five and this takes linear time computation then we need to calculate the normalization constant which is the sum of the exponential of each element minus the maximum value so e to the power of 3 minus 5 plus e to the power of 2 minus 5 etc etc this we will call it l and then each we need to go again through this vector again and take the exponential of each element minus the maximum divided by the normalization factor so to apply the softmax to an n by n matrix we need to go through each element of this matrix three times and these operations must be done sequentially so we cannot start operation two until we have done operation one and we cannot start operation three until we have done one and two so this is quite slow only to apply an operation that doesn't even change the shape of the matrix it's just uh normal uh normalizing the values so there must be a better way that that does not involve three sequential operations in which we need to go through this matrix three times let's see all right guys let's rehearse what is the problem that we are trying to solve the problem statement is the following can we find a better way to compute the softmax that does not involve going through the vector three times because let's look at the pseudocode of the algorithm for computing the local the softmax that we have found so far imagine we have a vector made up of four elements the first thing that we need to do is to compute the maximum element in this vector which means going through this for loop here that allow us to compute the maximum element in this vector which means that we start from the left side of the vector and iteratively go to the right side so we start from the first element arrive to the end and we compare the previously found maximum with the current element to find the global maximum basically this means that uh i i know that this is very simple uh i'm probably sure that you don't need to this example but making this example will help us understand what we will do next so please bear with me even if it's super simple what i'm doing okay we at the beginning m0 is equal to minus infinity m1 is basically the for loop at the iteration number one which means that we are m1 will be equal to the maximum of the previous estimate of the m which is minus infinity with the current element which is three so it will become equal to three then m2 will be equal to the maximum of the previously computed maximum so m1 so three with the current element which is two so it will be equal to three m3 will be equal to the maximum of the previously computed maximum so three with the current three with the current element which is five so it will be equal to five and m4 will be equal to the maximum of the previously computed maximum and the current element so it will be equal to five so this allow us to compute the maximum element so at the fourth iteration we will have the maximum the global maximum independently of what is the input array um delete okay after we have computed the maximum which we know is five we can compute the normalization factor so let's start with the l0 l0 is equal to zero l1 will be equal to the exponential of l0 so actually sorry it will be l0 plus the exponential of the current element so three minus the maximum element we have found in the previous for loop so five then l2 will be equal to l1 plus the exponential of the the current element so it's two minus the maximum then l3 will be equal to l2 plus the exponential of the current element five minus five then l4 will be equal to the l3 plus exponential of one minus five if you expand this l this will be basically equal to e to the power of three minus five plus e to the power of two minus five plus e to the power of five minus five plus e to the power of one minus one minus five after we have computed this normalization factor we can use it to normalize the each element in the input vector which means that the x new x1 so x1 prime let's see will be equal to e to the power of what's the first element three minus five divided by l that we computed in the previous for loop so the l at the fourth iteration the new x2 so x2 prime will be equal to the e to the power of two minus five divided by l4 and x3 prime will be equal to the e to the power of five minus five divided by l4 etc etc for all the elements i know this is super simple but it will help us later so in this for loop we have that we need to go through the vector three times because first we need to compute this for loop then we need to compute this for loop and then we need to compute another for loop we cannot do them not in this sequence because in order to compute this for loop we need to have the maximum element because we need it here and we cannot compute this for loop until we have computed the previous one because we need to have the normalization factor however we are stubborn and let's try to fuse these two operations into one for loop which means that we go through the array and simultaneously compute mi and in the same iteration we also try to compute lj of course we will not be able to compute lj because we don't have the global maximum because we didn't go through the old array yet however let's try to use the locally and whatever estimate we have of the maximum so far so let's try to use instead of mn let's try to use mi so the local maximum that we have computed so far so if we apply the softmax in this way in this fused way to this vector we will have the following iterations so this is our array or vector and the first step is mi so m1 will be equal to the previous maximum which is minus infinity with the current element so the maximum minus infinity and the current element is equal to 3 and l1 will be equal to the previous l so l0 which is starts from 0 plus e to the power of the current element minus we should be using the global maximum but we don't have the global maximum so let's use the whatever maximum we have so far so we can use 3 now at the second iteration we are at this element of the vector and we compute the maximum so far so the maximum so far is the previous maximum and the current element so the maximum of the previous maximum and the current element which is the maximum between 3 and 2 which is 3 and the normalization factor is the previous normalization factor plus exponential of 2 minus 3 which is the current element minus whatever maximum we have so far now if our array were made only of these two elements so 3 and 2 then whatever we have computed is actually correct because the maximum that we have found is a 3 and it's actually the global maximum and the normalization factor that we have computed is actually correct because each of the exponential has been computed with the global maximum because the first element was computed using 3 as the with the argument minus 3 and also the second element was computed with the argument with the argument having minus 3 in the in the argument which is the global maximum of the vector however when we arrive at the third iteration so let me delete this vector so let me arrive here at the third iteration the maximum will change which will also cause our normalization factor to get to to be wrong because we arrive at the element number 3 so the number 5 here and we compute the maximum so the maximum is the comparison of the previous maximum and the current element so the new maximum becomes 5 and the normalization factor is the previous normalization factor so l2 plus the exponential of the current element minus the current estimate of the maximum which is 5 however if you look at this l3 this is wrong why because l3 is equal to if you expand this summation it will be equal to e to the power of 3 minus 3 plus e to the power of 2 minus 3 plus e to the power of 5 minus 5 this exponential here is using 5 as the global maximum this exponential here is using 3 as the global maximum and this one is using 3 as the global maximum so the first two elements have been computed thinking that the global maximum is 3 but actually we later we found a better global maximum which is 5 so which makes this normalization factor wrong however can we fix at the third iteration whatever normalization we have computed so far up to the second iteration actually we can because if we expand this so as we have here we have expanded it what we need here is here to have a minus 5 because that's actually the global maximum that we have found so far not the minus 3 that we had at the previous iteration so and here we also need to fix this replace this minus 3 with minus 5 how can we do that well if we multiply this one here and this one here with a correction factor that will sneak in a new maximum inside of this exponential then we solve the problem and actually this correction factor is very easy to calculate because at the third iteration if we multiply l2 so the previously computed normalization factor with this factor here which is the exponential of the previous estimate of the maximum minus the current estimate of the maximum so 5 we will see that e by the properties of the exponentials this one here will become e to the power of 3 minus 3 plus 3 minus 5 so this minus 3 will cancel out with this 3 and also the second factor will have this 3 will cancel out with this minus 3 will cancel out with this 3 and they will become e to the power of 3 minus 5 and 2 to the power of e to the power of 2 minus 5 which is actually correct because at the third iteration we should be actually happy we should be using minus 5 as the maximum of the array so far so basically what we have found is a way to fix whatever normalization factor we have computed so far while iterating through the array when we found we when we find a better maximum compared to what we have so far and when we don't need to fix anything then the formula still stands because what we did here as a multiplication as a correction factor so this is the correction factor this correction factor is nothing more than the previous maximum so the previous estimate of the maximum minus the current estimates of the maximum at the current iteration so the current max so this is basically m of i minus 1 and this is m of i so the current maximum at the current iteration and let me delete it otherwise it remains forever in my slides so basically when we arrive to the last element we will see that the maximum doesn't change because we compare the previous maximum with the current element which is less than the previous maximum so the maximum doesn't change and we don't need to fix anything because the the the previous l3 so the previously computed normalization factor is correct because they have all been using the minus 5 so when we don't need to fix anything we just multiply by e to the power of the previous maximum minus the current maximum which is e to the power of zero in this case so it's not fixing anything so we have found a way to fix the previously computed normalization factor while going through the array even if at the current iteration we don't have the global maximum yet so that every time the maximum changes we can fix and every time it doesn't change we just multiply with e to the power of zero which is like multiplying with one so the new algorithm that we have found for the softmax is the following so we start with m0 equal to minus infinity we start with l0 equal to zero we go through the array we compute the locally the local maximum so up so the maximum so far from the zeroth element to the ith element so to the element at which we are doing the iteration and the previously computed li can be fixed by using this correction factor which is e to the power of the previous maximum minus the current maximum plus the exponential of the current element minus the current estimate of the maximum in this way we go through the array only once and we obtain two values the global maximum at at the end at the same time the normalization factor and then we can use it to compute the softmax so we made three transformed three passes through the array into two passes through the array and this is very important and we will see how we actually use it to derive flash attention the example that i have given you so far is not really a proof that our algorithm will work in every case because we made a very simple example by using a vector made up of four elements but does our new algorithm work in every single case with whatever the numbers are we need to prove that so we will prove that by induction so what first of all what are we trying to prove we have fused the first two for loops into one for loop as you can see here what we expect is that at the end of this for loop this mn so the m at the last iteration will be actually the global maximum in the vector and this ln so the l at the last iteration will be equal to the sum of all the exponential of all the elements minus the maximum element of the vector so the global maximum of the vector and we need to prove that because what i did before was an example and that was not really a rigorous proof and the way we will prove it is by induction which is a typical way of proving this kind of theorems now proof by induction basically works in the following way we need to prove that our algorithm works for a base case for example with n equal to one and then we pretend we assume that the algorithm works on n and we need to prove that it also works for n plus one if this holds then we have proven our algorithm for every possible n because it will work for the base case so for example n equal to one and then by using the induction step we say so this if it works for n and then it also works for n plus one then it means that it will also work for two but then if it works for two then it should also work for three because of the induction step that we will prove and if it works for three then it will also work for four etc etc up to infinity so let's prove it for the base case which is n equal to one it's very simple so at n equal to one this for loop will only have one iteration so m m1 and l1 m1 will be the maximum of the previous m which is minus infinity because we initialize m0 equal to minus infinity so it will be equal to the maximum of the previous m and the current element which is x1 so it will be equal to x1 whatever x1 we is uh x1 usually will never be equal it cannot be equal to minus infinity um because it's a number in a fixed representation so it cannot be minus infinity um so the the x the m1 at the end so it will because we have only one element n equal to one this is m1 is also the last um m of this it of this for loop it will be equal to the global maximum of the vector made up of only one element and l1 will be equal to the previous l which we start from zero so l0 multiplied by a correction factor which will be in this case e to the power of minus infinity because the correction factor is the previous estimate of the max of the max minus the current estimate of the max but the previous estimate of the max is minus infinity minus x1 it is equal to minus infinity so this one will be this will be cancelled out and then plus e to the power of x1 minus the current maximum which is x1 so m1 and if this one will be equal to the sum of all the elements of the vector which is made up of only one element minus the maximum element in the array which is x1 so we have proven that it works for n equal to one now we assume that it works for n does it also work for an array of vector or with a vector of size n plus one so let's see what happens at the n plus one iteration at the n plus one iteration we will be doing the maximum of the previous estimate of m which is the m at the nth iteration and the current element so xn of plus one this by the properties of the max function it will be actually equal to the maximum of the global vector up to n plus one because the maximum will choose whatever is the maximum between the previous estimate and the current estimate and ln plus one which is the normalization factor at the n plus one iteration will be equal to the ln so the previous estimate not previous estimate but the previous normalization factor at the nth iteration multiplied by the correction factor which is the previous maximum minus the current maximum plus the exponential of x the current element minus the current estimate of the maximum but ln we have we assume that this property so this algorithm works up to n so ln is for sure equal to the sum of all the exponentials of the previous of the vector up to n minus the local maximum of the vector up to the nth element which is mn we multiply by the correction factor if there is something to correct which will be the previous maximum minus the current maximum plus the exponential of the current element minus the current estimate of the maximum now by the properties of the exponentials so we can bring this one inside of the summation and we will see that this mn and this mn will cancel out because it will be exponential of xj minus mn plus mn minus mn plus one so this mn and this mn will cancel out and we obtain this one plus this factor here that remains unchanged however you can see that this stuff here is exactly the argument of this summation for the at the iteration n plus one so it is this one is e to the power of xj where j is going from one to n minus mn plus one plus e to the power of xn plus one minus mn plus one so the j only appears here and it's equal maximum to n and this is similar to being a j with n plus one so we can increase the index of this summation by one and it will be the same and it will result in the same summation so we have proven that also at the n plus one iteration we will have that the l will be equal to the sum of all the elements of the array the exponential of all the elements of the array up to the n plus one element minus the maximum up to the n plus one element so we have proven that if it works and then it also works for n plus one this is enough to prove that it works for all size of arrays don't worry if you didn't get the proof by induction it is if it's the first time you are seeing this kind of proof it may take a little bit to to get it if you want to learn a little bit more about proof by induction i recommend watching some other proof it's very simple it's just you need to get into the right mindset anyway let's move forward all right let's talk about block matrix multiplication i know that you want to jump to the code immediately and we will go there we just need a little more theory actually so imagine we are doing a matrix multiplication so we have a matrix a we want to multiply it with a matrix b and it will produce an output matrix c imagine the dimensions of the first matrix are m by k the second matrix is a k by n it will produce an output matrix that is m by n now imagine we want to parallelize the computation of this output matrix i know that i didn't talk about gpus yet so we will not talk about gpus we will talk about parallelization in the case of a multi-core cpu with which you are very probably familiar with because right now in nowadays when you buy a computer you have a cpu and usually you can buy a single core cpu or multi-core like a two core four core eight core etc etc each of the these cores are actually kind of small cpus inside your cpu that can execute operations in parallel how to parallelize the matrix multiplication imagine you have this matrix multiplication to parallelize each of the output element in this c matrix is a dot product of a row of the a matrix with a column of the b matrix for example this element on the top left is the dot product of the first row of a and the first column of b this element on the top right of c is the dot product of the first row of a and the last column of b this element on the bottom left is the dot product of the last row of a and the first column of b etc etc for all the other elements now to parallelize this computation we need as many cores as is as there are elements in c if we want to parallelize it so if m and n are very small then maybe we have enough cores but imagine m and n are quite big we imagine like 100 by 100 we don't have 10 000 cores right now in the cpus so how can we parallelize a matrix operation by using less cores than there are elements in the matrix itself that's when we talk about block matrix multiplication basically block matrix multiplication means that you can divide the original matrix into smaller blocks of elements and then the operations of matrix multiplication can be computed between these blocks for example imagine we have a matrix that is 8 by 4 it means that it has 8 rows and 4 columns which means that it has 32 elements and then we are multiplying it with another matrix that is 4 by 8 so it has 4 rows and 8 columns so it also has 32 elements the output matrix will should have 64 elements we don't have 64 cores so how can we parallelize it imagine we only have 8 cores now with 8 cores we can divide this original matrix a into 4 blocks where the first block is this top left block of 2 by no 4 by 2 elements so um let's say um 8 elements on the top left and then 8 elements on the top right of this matrix then 8 elements on the bottom left and 8 elements in the bottom right of this matrix these are 4 blocks then we divide also the b matrix into um 8 blocks where each block is made up of 4 elements so this b11 is the top left 4 elements in the original matrix this b4 is the top right 4 elements in the original matrix this b21 is the um bottom left 4 elements in the original matrix etc etc etc how do we do this block matrix multiplication we can watch these matrices as made only by their blocks so we can view this matrix here as made up only by its blocks we can view this matrix here as made up only by its blocks and the output of this multiplication will be a matrices that is computed in the same way as the original matrix but where the output of each dot product will not be a single element of the output matrix but it will be a block of elements of the output matrix for example the top left block here is the dot product of the first row of this matrix with the first column of this matrix and it will be computed as follows so it will be a11 multiplied by b11 plus a12 multiplied by b21 and this output will not be a single scalar but it will be uh well let me count it should be eight elements so it should be four um made up it should be a block of four elements or eight elements let me let me count actually so because we have eight blocks and it should be made up of eight elements let's we can see that here um how to find the dimensions of this output block well we can check what is a11 a11 is four by two so it's eight elements in a smaller matrix made up of eight elements where the elements are distributed in four rows and two columns we are multiplying it by b11 which is a smaller matrix compared to the original made up of two by two elements so four elements so when we multiply four by two multiplied by two by two it will produce a four by two output block matrix so block so if we do this computation here block by block it will produce a block of output elements of the original matrix so not not a single scalar but a block of outputs which makes it very easy to parallelize because if we have only eight cores we can assign each output block to one core and each core will not produce one output element of the original matrix but it will produce eight elements of the original matrix as a four by two matrix so basically block matrix allow us to to do the matrix multiplication either by element by element so like in the original matrix so each row with each column or blocks by blocks in the same way like we do normal matrix multiplication because the the matrix multiplication that we are doing between blocks is the same way as we do matrix multiplication with the original matrix and it will produce not a scalar but a block and now let's see why this is very important for us so why should we care about block matrix multiplication because we are trying to compute the following operation so the query multiplied by the transpose of the keys and then we will should apply the softmax of this operation and then we should multiply the output of the softmax with v for now let's ignore the softmax let's pretend that we are not going to apply any softmax so we take the output of the query multiplied by the transpose of the keys and we just multiply it by v to obtain the output of that edge which is wrong of course but it simplifies our tractation of what we are going to do next so for for this moment let's pretend that we are not going to apply any softmax so we just do the query multiplied by transpose of the keys and directly we multiply the result of this operation with v this will result in a matrix that is n by d so n tokens each made up of an embedding of d dimensions so lowercase d dimensions and we know that query key and values are themselves matrices of n by d dimensions so the n tokens which made up of an embedding of d dimensions so imagine we have a query matrix and the key and the value matrix that are 8 by 128 so we have 8 tokens each token is made up of 128 dimensions we can divide as we have seen each when we compute a matrix multiplication we can divide our matrix into blocks how we choose the blocks is up to us as long as the operating the shapes of the blocks match when doing the matrix multiplication so for example in the previous case we divided our matrix a into blocks such that the the shape of the block matrix so the matrix that is made up only of the blocks is compatible with the block matrix b so that this operation is possible so this is the only requirement that we need to be aware when doing the block matrix multiplication the shapes of the blocked matrix so the matrix that is made only of the blocks should match in the matrix multiplication for the rest it doesn't matter how we divide it so imagine that we choose to divide this query matrix into blocks of rows and we can do that we don't have to necessarily divide also the columns we can just divide the rows so that each q is not a single row but it's a group of two rows so q1 is a group of the first two rows of the q matrix of the q sequence q2 is the group of the second two rows of the q sequence etc etc and we do the same also for v for k we don't do it because we are actually going to multiply with k transposed so we do the subdivision directly on k transposed so we so we have the q which has been divided into groups of rows and then we have a k transposed which is a matrix that is 108 by 8 because it's the transpose of the keys which is 8 by 108 and we decide to divide each of the column group of columns of k into a single block so the k1 is the first two columns of k transposed k2 is the second group of two columns in k transposed etc etc until k4 which is the last two columns in k transposed the first operation that we do is the multiplication query multiplied by the transpose of the keys which basically means that we need to multiply each query with all the keys then the second query with all the keys etc etc now each query is not a single row of the q sequence it's a group of two rows of this q sequence and each k is not a single column of k transposed it's a group of two columns of k transposed but doesn't matter because we have seen that the matrix multiplication if we write the matrices as made up of blocks we just compute it in the same way when we do a normal matrix multiplication so we are multiplying this matrix by this matrix and for what we know this matrix here is made up of four rows with some dimensions which is 128 dimensions and this one here is made up of how many rows 128 rows and four columns i didn't draw the columns because it's too many to draw here but you need to pretend it's a lot of dimensions one for each 128 for each vector and here you need to pretend that this is 128 rows when we do the matrix multiplication we apply the normal matrix multiplication procedure which is each output element so this first of all the output shape of this matrix of this matrix multiplication will be four by four because it's the outer dimensions of the two metrics that you are multiplying the first element of the output will be the dot product of this vector here with this vector here the second element so this one here will be the dot product of this vector here with this vector here however this is not vector and this is not a vector so it's actually a matrix multiplication in this case this element here is not a scalar it is a group of elements of the output matrix because we are doing block matrix multiplication and how many elements it will be well we know that the original q1 is a 2 by 128 the k1 is 108 by 2 so it will be a group of 2 by 2 elements of the output matrix so we are doing the matrix multiplication of the q1 with k1 then q1 with k2 then q1 with k3 q1 with k4 etc etc for the first row and then the second row will be q2 with all the k's and the q3 with all the k's and q4 with all the k's so as you can see when we do matrix multiplication we don't even care if what is underlying is a block or a vector or a scalar we just apply the same procedure first row of the black block matrix multiplication with the first column of the matrix of the second matrix and then the first row with the second column the first row with the third column etc etc let's then multiply because the formula says that we need to multiply query with the transpose of the keys and then multiply by b all of these are block matrices now as you can see from my using of colors every time i refer to the original matrix i use the blue color and every time i refer to the block matrix i use the pink color so we need to multiply the output of the query multiply by the transpose of the key then by v because we are skipping for now the softmax and later we will see why so if we want to do this multiplication we need to do the following so it will be uh this matrix is made up of blocks and block matrix multiplication just ignores this fact and just does the matrix multiplication like it is a normal matrix multiplication so we do the first row with the first column then the first row with the second column then the third row the first row with the third column etc etc so the first block of row how is going to be calculated this output in the output matrix of this matrix multiplication well it will be the first row so the dot product of the first row the dot product because it's not really a dot product it's the actually the matrix multiplication of the first row but in a dot product way let's say with the first column which is made up of v1 v2 v3 and v4 so it will be this element with v1 plus this element with v2 plus this element with v3 plus this element with v4 and this will be the first output element the second output block will be this row with this column which will be this element with v1 this element plus this element with v2 plus this element with v3 plus this element with v4 and this will produce the second output block etc etc also for the third and the fourth block output let's look at what is each block made up of so each block is made up of the um the first element so query one multiplied by key one because um it's the result of the query multiplied by the keys with the v1 of the second matrix plus the this element with this one plus this element with this one plus this element with this one so the pseudocode for generating this output of this attention mechanism which is not really attention mechanism because we skip the softmax but i just want you to get into the habit of thinking in terms of blocks is the following so we take each query block we go through each query and as you can see let's look at actually what this output is made up of it is made up of the query one multiplied by key one and the result multiplied by v1 then the query one with k2 then the result multiplied by b2 then the query one with k3 and the result multiplied by v3 plus the query one with the k4 and result multiplied by v4 this is basically what we are doing is the dot product of this row with this column made up of blocks so the the pseudocode for generating this first row is the query is then query number one and then we iterate through the keys and the values from one to four and we sum iteratively so for each block basically to generate this output matrix and if you for each row we will see that it's a different query with all the keys and values and then this will be the the query number three with all the keys and values and this will be the query four with all the keys and values so to generate this output matrix we need to do we iterate through the queries and this will be one row of this output matrix and then we need to do this iterative sum of the query i that we are iterating through multiplied by the jth k and v and we keep summing them iteratively and that would that will produce the output matrix or you can see here i know that what i have done so far is not useless not useful for flash attention but it's useful for us to get into the mindset of computing this product by blocks because later we will use it also with the softmax all right guys i i know that we have computed what we have computed so far is not really the softmax operation it's not sorry they're really the attention mechanism because we have skipped the softmax so somehow we need to restore it and the the following few i think 10 20 minutes we are going to be really really challenging because i am going to do a lot of operations that will involve a lot of different blocks and a lot of different matrix multiplication and the variants of the softmax so it may be difficult to follow however don't give up you can watch this part twice three times and every time you it will have a better understanding i also recommend watch it until we reach the flash attention algorithm before we start restarting from to to go back to re-watch it because you watch it we reach the flash attention algorithm and it will give you a better understanding of what has happened so far and then you can re-watch it to deepen your understanding another thing that i recommend is take pen and paper and write exactly the operations that you are seeing and write the shapes of each of these blocks of these elements that are made in the that are part in this matrix multiplications so that you better understand what is happening and you better remember what when i refer to a particular element or a particular block okay after giving this small motivational speech let's start so what we have done so far was query multiplied by the transpose of the keys however each query is not a single row of the query sequence but it's a block of queries it's a block of rows in our particular case this q1 is not one row of the query sequence it's two rows of the query sequence because we have chosen as a block size a group of two rows and this k transposed one is not one column of the k transposed matrix is two columns of the k transposed matrix because we have chosen it like this and if you don't remember let's go back to see it here we have chosen k1 is two columns and q1 is two rows of the query original matrix and every time i use the blue color i am referring to the original shape and every time i'm using the pink or violet whatever it is i am referring to the block matrix so it's a block of elements of the original matrix okay now the first thing that we have done was a query multiplied by the transpose of the keys and this produces a block matrix as output that we will call s where each element sij so the s11 element of this matrix will be the query one with the k transposed one this s12 will be query one with k transpose the two s13 will be query one with k transpose the three etc etc for all the rows and for all the columns then we should be applying the softmax because if you remember the formula is softmax of the query multiplied by the transpose of the keys however i want to restore the softmax operation but with a twist which means that we will apply the simplified version of the softmax and we will call it softmax star which is just the softmax without the normalization so let me write it for you what it means let's do it with the same color that i chose for the softmax which is orange so the softmax if you remember correctly if we remember it's the softmax of a vector we apply it element wise so each element is modified according to the following formula so the ith element of the output vector to which we are applying the softmax is equal to the exponential of the ith element of the input vector minus the maximum element in the input vector divided by a normalization factor that is calculated according to this summation that is going from j equal to 1 up to n of the exponential of xi minus x max so basically we are doing the exponential of each element minus this x max and why are if you remember correctly why are we subtracting this x max to make this exponential numerically stable computable because otherwise it will explode and because we are applying it to the numerator we also need to apply to the denominator okay the softmax star operation is exactly like the softmax but without the normalization part which means that it's just the numerator of the softmax so we will modify each element of the vector to which we apply the softmax star according to this formula let me move it more aligned like this so we just do element element wise operation that is the exponential of each element minus the maximum of the vector to which we are applying softmax star okay now why did i introduce this softmax star operation because we will be applying it to the matrix that we have computed so far which is this s matrix so we apply the softmax star to each element of this s matrix but each element of this s matrix is itself a matrix because it's a block matrix and each element of this s matrix so for example the element s11 is a two by two matrix because it is coming from the product of two matrices which are a group of rows and a group of columns from the q and the k so for example this s11 is what is let's draw it actually this s11 will be for example made up of four elements let's call it i don't know a of s11 uh let's let's choose better naming let's call it i don't know a b c and d just the generic elements when we apply the softmax star to this s11 it will result so let's apply the softmax star softmax star it will result in a matrix that is each element the exponential of each element minus the maximum for each row now we don't know which is the maximum so let's choose one suppose that the maximum for this row is a and the maximum for this row is d the first element of the output of this softmax star applied to this block s11 will be the exponential of a minus a because that's what we chose as the maximum for this row the second element will be the exponential of b minus a because it's the maximum for that row then in the bottom row it will be the exponential of c minus d because that's the maximum for the bottom row and this will be the exponential of d minus t and that's the exponential that's how the softmax star will modify each block in this block matrix let me delete this stuff otherwise it will remain in my slides forever and later i want to share the slides with you guys so you can use my same slides so delete delete okay after we have applied the softmax to each of the elements in this s matrix we will call it the p matrix and each element p11 will again be a block of two by two elements so p11 will be the softmax so p11 will be the softmax star applied to s11 where s11 is what is a query one k transposed one and the p12 will be the softmax star applied to s12 where s12 is what is a query one multiplied by k transposed two etc etc etc for all the elements of s okay now that we have applied this softmax star operation the next operation that we should be doing according to the formula of the attention is the softmax of the query multiplied by the transpose of the keys then the result of the softmax multiplied by v i know that we didn't apply the real softmax we apply the softmax star which is softmax without the normalization later we will see how to compensate this lack of normalization because we will do it at the end and it's something that we can do okay so we take this p matrix which is the result of the softmax star applied to this s matrix and we multiply it by v what how do we do it well it's a block or it's a matrix made up of blocks of matrices so p11 is actually not a scalar but it's a matrix of two by two elements and we need to multiply it by v but we don't multiply with the original sequence v but with the blocked sequence v just like before where each v is not one row of v but it's a group of rows of v and how many rows is it is it is two rows of v for now please ignore completely whatever i have written here because we will use it later so we need to do this product of this matrix here which is made up of blocks remember with this matrix here which is made up of blocks it is made up of four rows where each row is not really a row it is a block of rows and this one it is made up of four by four elements where each element is not really a scalar but it's a matrix so as you remember in the block matrix multiplication when the algorithm for computing the matrix multiplication is the same as the normal matrix multiplication except that we use blocks so what i am doing is guys the following operation so let's write it somewhere let's say o is equal to p multiplied by v okay so the first output row a row because it's not really a row but it's a block row will be computed as follows the first row of this block matrix with the first with the first column of this v matrix and we are treating it like a block matrix so it will be p11 multiplied by v1 plus p12 multiplied by v2 plus p13 multiplied by v3 plus p14 multiplied by v4 this will produce the first output row of o but it's not really a row because it's a made up of two rows so this stuff here is not one row it is two row and we can prove that because what is p11 p11 is let's write it somewhere so p11 is a 2x2 matrix yeah 2x2 and we are multiplying it with v1 which is a block of two rows of v so it is a two rows by 128 dimensions so it is equal to 2 by 128 so this stuff here is 2 by 128 so this block here the output block that we're computing is a block of two rows of the output matrix that we are computing i know this is really difficult to follow because we are involving blocks so we need to visualize at the same time matrix as blocks and as the original matrix that's why i highly recommend you to pause the video think it through write down whatever you need to write down because it's not easy to follow it just by memorizing the shape so you you actually need to write down things anyway we are computing the first output block of the output o matrix now if we if you remember the output the output this output here should be the output of the output of the softmax multiplied by v now this softmax has not been applied to the entire row of this matrix here as matrix here basically to compute this softmax star what we did was to compute the softmax star at each block independently from the other blocks which means that the maximum that we are using to compute each softmax star is not the global maximum for the row of this s matrix but the local maximum of each block and this is wrong actually because when we compute the softmax we apply the softmax we should be using the global row i want to give you an example without using blocks because otherwise i think it's not easy to follow so when we do the normal attention so we have a query multiplied by the transpose of the keys this produces a matrix that is n by n so sequence by sequence where each element of this matrix so let's say three four five i don't know how many is one two three four five six yeah six two three four and five six should be one two three four five six okay this one here should be the dot product of the first query with the first um let me use because query one transpose the key one uh this is because as i said before when we do the product of two vectors we always treat them as column vectors so when you want to write the dot product you cannot multiply two column vectors you need to multiply one row vector with one column vector that's why we transpose this one if it confuses you you can also write q1 k1 that's totally fine it's just uh wrong from a notation point of view anyway the first one will be the dot product of the query one with the k1 the second element will be the dot product of the query one with the k2 the third will be the query one with the k3 etc etc etc um so this is a q1 with k1 q1 with the k2 k2 and q1 with the k3 q1 with the k4 um anyway when we do the softmax we actually calculate the maximum on this entire row however what we are doing is we are actually doing a block matrix multiplication and as you remember um when we do by blocks we are grouping together rows of queries and rows of keys and in this particular case we are grouping two queries together to create one uh one group of queries and two keys together to create one block of keys so we need another row of this one so it's the let me choose a query one k or query 2k1 this should be query 2k1 query 2k2 query 2k3 query 2k4 query 2k5 and query 2k6 um when we each of this each of this block here is computing this block here is computing two by two elements of the original matrix if we had never applied the blocks so it is computing these two four elements here and if we apply the softmax star to each of these blocks we are not using the maximum element in this row we are only using the maximum element in each block which means that when we will use it in the downstream product with vmatrix we will be summing values that are wrong because each of these values here will be based on a maximum that is not the global maximum for this row it is the local maximum of this block here and um and this block here will have the global the low it will use the local maximum of this block here and this block here will use the local maximum of this block here etc etc etc so what i'm trying to say is that when you sum p11 with v1 p11 may have some maximum local maximum that is different than from the local maximum of p12 and p13 may have a different maximum local maximum that of p1 p11 and p12 so we need to find a way to fix the maximum that was used to compute the exponential here with the maximum found here in case the maximum here is higher than the one local to p11 so if we have found for example here a maximum that is higher than the maximum used here here then we need to fix this one and this one because that maximum in the softmax should be the maximum for all the row not the one belonging to the each block and this leads to our next step how to fix this first of all let me introduce a little pseudo code for computing this output matrix here which is an output block matrix and later we will use this pseudo code to adjust the error that we have made in some blocks in case the future blocks so the p13 has a better maximum than p11 or p12 so to compute this output matrix o we go through so for example to compute the first row we choose well p11 is what is is let's go back p11 is let me delete also this one it's not needed anymore p11 is the softmax star of q1 k1 p12 is the softmax star of q1 k2 p13 is the softmax star of q1 k3 p14 is the softmax star of q1 k4 which means that to compute this block here here we first need to compute the p11 what is p11 well p11 is the softmax star of a block of q and another block of k which in the case of the first row of the output matrix means that it is the query 1 with the softmax star of the query 1 with q1 the softmax star of the query 1 with k2 the softmax star of the query 1 with k3 the softmax star of the query 1 with k4 which means that we need to go we need to make a for loop through all the keys while keeping the query fixed so to compute the first output row we need to do the softmax star to produce p11 we need to do the softmax star of query 1 k1 and we sum it initially to zeros because we don't we need to initialize our output somehow and we initialize it with zeros then we sum the next p12 which is the query 1 with the k2 and then we sum the next p13 which is a query 1 with the k3 etc etc that's why we have this inner loop here all right so however this output that we are computing is wrong because i told you we have computed the softmax star using statistics the maximum value that is belonging to each block and not the one that is the overall row of the original matrix how to fix that we have a tool actually we have computed before an algorithm called the online softmax i don't know if i referred to it before as the online softmax but it's called the online softmax that allows to fix previous iterations when we are computing the current iteration based how well let's review the online softmax we start imagine we are working with one single vector so we are a vector made up of n elements what we do is we do a for loop where we compute iteratively the maximum up to the height element and we fix the normalization factor computed in previous iteration in case we found a better maximum at the current element if this is not clear guys go back and watch the online software because this is very important because this is what we are going to use to fix this p11 p12 blocks in case we found better maximum in p13 or p14 etc so let's see how to apply this online softmax to this case here so that we can compute so you may be wondering why are we going through all these troubles i mean why the real reason is when first of all why did we introduce block matrix multiplication because we want to compute matrix multiplication in parallel so you can think that each of this p11 because they are independent from each other and because each of them are using the maximum belonging to each block they can be computed independently from each other then however we need to somehow aggregate their value and to aggregate the value we need to fix the values that have been calculated independently because we didn't when computing values independently we don't have a global view we have a local view so we compute local blocks p11 p12 p13 etc etc and then when we aggregate these values we need to fix them so that's why we are trying to come up with this system of fixing values that have been calculated independently so how to fix this let's look at the following algorithm first of all this o block here as i said before it is a block of two rows where each row is made up of 128 dimensions and we have seen that before by checking the dimensions of p11 and v1 the result of p11 v1 which means that for each output block we need to take care of two maximums and two normalization factors so up to now i didn't use the normalization factor we said that we are applying the softmax star which is the softmax without the normalization but eventually we will need to compute this normalization so we want to create an algorithm that fixes the maximum used to compute each of this p11 and also computes simultaneously the normalization factor and at the end we will apply this normalization factor and the way we will do it is as follows we start with initializing the maximum to minus infinity one for each row that we are computing so our output block is made up of two rows so we need one maximum for the top row and one maximum for the bottom row and also the normalization factor which we initialize with zero because we didn't sum anything for now and the output we initialize it with all zeros because we didn't sum anything to this output for now we compute the we uh to compute the output row so this output block here so this output block here we need to go through all the keys uh to produce this p11 p12 p13 p14 while the query is the query number one the query block number one so the first step that we do is we compute the maximum of the first block p11 which is the row max so the maximum for each row of the block q1 k1 this is not p11 it's s1 sorry guys this is s11 so we compute the maximum of this one and we call it actually s1 as you can see here and then we can calculate p11 which is the softmax star which is the exponential of the query multiple query one k1 so s1 minus the maximum in the local group s1 and we add it to our output for now the output is initialized with zero so for now ignore this part here i will explain it later so for now all one should be equal only to p11 v1 now at the step number two we may find in the local group s12 so this one is s12 we may find a better maximum for the top row and the bottom row and this maximum is the m2 which may be better than the previous maximum for each of these two row but may also not be so we need to find a way to fix in case it's better and to not fix anything in case it's not better and the way we do it is this so we compute the new maximum of the current local row query two we calculate the p12 which is the softmax star of s2 which is s2 minus m2 which is the local maximum and then we need to add it to the output however in this case we may have found a better maximum so how to fix the o1 which only used the maximum that was local to s1 well we know that we can fix that by using exponentials because each of this element of o1 is just an exponential without the normalization because we are applying softmax star so how to fix an exponential with another exponential so basically we are saying that we multiply o1 which is a matrix so let me show you what is this matrix so o1 is a matrix made up of two rows so as you can see here i have the shape of o1 it's a 2x128 matrix so this is the top row so o11 o12 blah blah until o1 128 then o21 o22 blah blah and o2128 we need to fix this value how we basically just using the exponential that we have used in the online softmax that we have seen before so if we multiply this matrix here by a diagonal matrix that is made as follows it's a diagonal matrix made up of two elements because the exponential of m1 minus m2 will be a vector of two elements and exponential of a element wise exponential is another vector of two elements and this basically means that diagonal matrix where in the diagonal we have the elements of the vector to which we are applying this diag operation which means that this value here will be the exponential of the first element of m1 so let me show you how to write it exponential of m1 minus m2 minus m2 so the first element so let's call it one here here is a zero here will be zero and let's delete this one and we write another one here exponential m1 minus m2 but the second element of this vector so basically the diag this notation here diag means basically just take the vector and distribute it over a n by n matrix where n is the size of the vector to which is applied and all the other elements of this matrix should be zeros this is what this diag means if we do this operation here we will see that the output of this multiplication will fix each element of the top row using this exponential and the bottom row with this exponential which will basically cancel out this m1 that was computed in the previous iteration and introduce the m2 that we have computed in the current iteration in each of these elements in this o block matrix okay so this output will be this element will multiply by this one so it will fix o11 with this factor here and o21 will not be fixed by will be multiplied by zero so it will not contribute to this first output element so this element here will only depend on o11 fixed by the exponential of m1 minus m2 but the first element of this vector and then o12 will also be fixed by um o12 will be fixed by this exponential here but not by this one and all the dimensions of the first row will be fixed by this exponential and all the dimensions of the second row here will be fixed by this exponential here this this scalar here which is the second element of the vector exp of m1 minus m2 okay it was really challenging this one so so what we are doing is we compute p12 and we fix all the elements in p1 by multiplying by this matrix here by multiplying by this factor here matrix factor here and when we will compute step 3 we will fix step 2 etc etc etc now let's talk about the normalization factor because for now we have been ignoring it the normalization factor is something that we can compute while computing these maximums because it is provided in the pseudocode of the online algorithm that we have seen before for the softmax so while computing the maximum we can actually compute the normalization factor by fixing the normalization factor of the previous iteration and this is exactly what we are doing here so at the first iteration we compute the normalization factor using the local maximum and at the second iteration so you can for now ignore uh this one because we are not fixing l0 with anything because l0 will be 0 so we are just basically um we are just computing this summation here so l0 will be 0 so this factor here will be 0 um and when computing l2 so the normalization step at the second iteration we will fix l1 with an exponential which guess what it's exactly the same exponential that fixes the maximum uh the p11 so it is the previous estimation of the maximum minus the current estimation of the maximum plus the new uh normalization factor using the local maximum and we keep doing this job at the end we will obtain a correct output for this uh matrix for for this block here but without the normalization how to apply the normalization well the normalization is something that is we need to divide each element of this o by the normalization factor but because we are keeping while iterating through these four loops we also calculate the normalization factor we keep accumulating it until we reach the end of the iteration and then we apply the normalization factor so we take the last output and we just divide it by l4 which is the normalization factor calculated as the fourth iteration and that will fix the softmax all right guys so now that we have derived the algorithm of how to compute this output of the attention blockwise while also fixing the softmax which is done independently in each single block we know that the normalization is done at the end i want to also prove it so what we done when we introduced this algorithm that computes the softmax in an online way we proved by induction that this algorithm is correct so at the end of this algorithm this l of the last iteration will actually be the normalization factor that we can apply to get the softmax so we don't apply the normalization while computing this output in an online way iteratively way by multiplying the query with all the blocks of keys we apply it at the end of this four iteration and at the end of this four iteration we will have the last output and we also know that the last l will contain exact normalization factor that we need to apply to each row because this o of four is a block of output rows which is if you remember from the attention mechanism each output the output of the attention has the same shape as the input query vector which is a sequence of tokens so this o is a sequence of tokens that we need to apply the normalization to and we know that the correct factor is l4 so let's prove this simple formula l4 is a vector one for that contains as many elements as there are rows in o4 so in this o block of rows suppose that it contains two rows like in the algorithm that i have described so far in which we pretend that we are grouping two rows of queries with two columns of keys together so the output o the block o will contain two rows of the output so we will have two normalization factor in this l4 vector here what we are doing with this formula is we are taking this l4 vector and we are creating a diagonal matrix with it and then we are computing the inverse of this diagonal matrix so l4 is a vector that contains two normalization factors so it's l i don't know let's call it l l4 element 1 and l4 element 2 this is our l4 vector then we have o4 o4 is a matrix as you can see from the shape is a 2 by 128 matrix so o is let's copy it actually oh no it's not copied o4 is a matrix that is two rows with 128 elements so the first row with 128 dimensions and the second row with 128 dimensions the first thing that we are doing with this l4 is we are converting it into a diagonal matrix which will be a diagonal matrix 2 by 2 because it contains two elements so it will become something like this so it will be l4 the first element of l4 0 and then 0 l4 the second element of this vector then we are computing the inverse of this matrix the inverse of a diagonal matrix is just the diagonal matrix with each element on the diagonal that becomes its reciprocal this is from linear algebra it's not i'm making it i'm making this up so the inverse of this matrix here is equal to the same diagonal matrix but where each element is 1 over l4 the first element of l4 0 0 and 1 over l4 the second element of l4 and then we are multiplying this stuff here so let me delete some stuff so this stuff here is getting multiplied by o which is a matrix that is a 2 by 128 so we are doing this multiplication now multiply now the output of this so this is a 2 let me write it 2 by 2 multiplied by 2 by 128 will be a matrix that is 2 by 128 where the first dimension of the first row of the output of this operation will be the dot product of this call this row here with the first column so basically we are dividing this element here by l4 the first element of l4 the second output element here will be the dot product of this row with the second column so we are only multiplying we are dividing the the second element here of this input vector here by l4 the first element of l4 because the all the elements of the second row will be multiplied by 0 so they will not contribute to this output row while the second output row will be the dot this element here will be the dot product of this row with the first column the first element here is multiplied by 0 so it will not contribute to this output so it's only the second element the first row of the second this first element of the second row of the input matrix here will be divided by l4 2 so basically this will be applied will divide all the elements in the second row and this will divide all the element in the first row in producing this one here which is exactly what we need to do when we want to normalize we need to apply this normalization factor and this should help you better visualize why this operation is normalizing the vectors of the output at the end and still obtaining the same result now let's proceed further all right guys finally we are ready to see the flash attention forward pass by also comparing it with what we have derived so far so if you look at the flash attention paper first of all this is the flash attention 2 forward pass and later i will explain what are the differences between the flash attention 1 and the flash attention 2 i didn't want to jump directly to this forward pass because i believe that even if the derivation like the derivation was a little uh difficult to follow i believe that it gave you some intuition into what is happening so even if you understand 50 percent of it that's enough because later we will also code it and you should reach like a 90 percent of understanding so every time we introduce some new information it should improve your your understanding so basically in flash attention what we are flash attention 2 especially we take our as input we have our query key and values which are a sequence of tokens each token is made up of a vector of d dimensions and the d lowercase d dimensions and we divide this query guess what into blocks in how many blocks well depending on this parameter br which is the size of the query block that we want to choose so how many rows of query we want to group together into one block and we also do it with the k and v and we divided that into blocks of depending on this parameter bc then we also initialize the output which is the output that we want to produce so what is the flash attention computing well the flash attention is computing the following so it's computing the softmax softmax of the query multiplied by the transpose of the keys divided by the some normalization factor multiply that by b and so that's what it's going to compute and it's going to compute it this way first of all there is an outer loop through the queries which corresponds to the same pseudo code that we have seen before because we want to compute each block of the output matrix in parallel with the with respect to the others so basically we want to compute this output block and this block output block independently this output block here depends on the query one and all the keys this output block here depends on the k query two and all the keys this output block here depends on the query three and all the keys where query one is not the first query but it's the first group of queries or first block of queries query two is not the first query two is not the second row of the very metric but it's the second block of the query matrix etc etc and so that's why we have this outer iteration among all the blocks because we want to compute all those blocks of the output matrix in parallel but to compute each of this output block we need to go to an iteration among all the keys that's why we have an inner loop on the keys and we do exactly the same operation that we have done so far by hand so first we compute the s matrix which is what the each block of query with the corresponding block of the keys then we compute the local maximum to the current s block this is the local maximum and we compare it with the maximum of the previous iteration because that's what we do in the online softmax then we compute the p block which is the softmax star of the s block minus the local maximum of the s block then we compute the normalization factor what is the normalization factor it is the summation of all the exponential of the softmax star but by fixing the normalization factor of the previous step and we know how to fix the normalization factor because we just multiply by an exponential which is the previous maximum minus the current maximum that's what this factor is and then we compute the output exactly using the same correction factor that we have seen before which is the diagonal matrix made up of the diagonal where on the diagonal you have the elements of this vector here which is the exponential of the previous maximum minus the current maximum multiplied by the output of the previous step because we want to fix the previous step because it was based on the previous p which was using the maximum of the local previous p plus the current p v which is based on the current local maximum and it will be fixed by the next iteration okay and at the end after we have gone through all the case so we have computed all the output block but we didn't apply the normalization factor and it's applied at the end because while going through each key we are calculating the l normalization factor for the softmax because inside of this for loop we are just computing the softmax star so we are not normalizing each value so at the end someone has to normalize it and it will be this instruction here which is use the normalization factor that we have computed over all the iterations and apply it to each element of O because the difference between the softmax star and the actual softmax is just the division by the normalization factor and this instruction here is actually dividing each O with the corresponding normalization factor one for each row of the block each row in the output block that we are computing later we will see also what do we do what is what does it what is this SRAM what is the HBM for now i just want you to concentrate on the operations that we are doing and they are exactly the same operations that we have done so far later we will see also why do we need to save this stuff here and etc etc but for now you should have enough knowledge to be able to follow what is written in the flash attention paper for with respect to the forward pass algorithm and what we are doing basically is just block matrix multiplication and while computing this block we fix the previous block by using tricks of the exponential all right now that we have seen forward pass of the flash attention before we can implement it we still lack a little bit of knowledge because we don't know anything about the GPUs and we don't know anything about CUDA and we don't know anything about Triton so that's what we are going to see next all right guys it's time for us to explore finally the GPU and the CUDA programming model well let's start by comparing the CPU and the GPU and this will let us understand how CUDA works then so first of all what is the CUDA and what is the GPU the GPU is the hardware unit that we are that we buy and CUDA is a software stack made by made by NVIDIA to write software for this GPU that they sell AMD has its own software stack and other manufacturer have their own in this particular video we will be seeing example of CUDA kernels but the knowledge that you will get can apply also to other GPUs now the first difference between a CPU and the GPU is its purpose the your computer is right now running on a CPU and your operating system is interfacing with the CPU in using the the so-called scheduler so right now probably you are running a browser you are also running some other software on your computers on your computer and the scheduler is tasked with switching between them very fast on your CPU in such a way that it looks like to you that the processes are running concurrently this actually is a fake kind of parallelism unless your CPU also has multiple cores which nowadays CPUs do have so a CPU usually has one or multiple cores but not so many of them so usually have a dual core or quad core or eight core CPU and each of these cores can execute instructions in parallel the CPU is tasked the the main purpose of the CPU is to execute many different tasks and switching between them very fast so maybe you have a browser that is running a small game and then you have another movie player but then you have a word processor and then you maybe have some utility to manage your to download files etc so most of these programs actually are not compute intensive are actually I/O bound meaning that most of the time they are either waiting for the network or they are waiting for the disk and they are very different from each other in the purpose so the browser is completely different from a movie player and it's completely different from a word processor so the job of the CPU is to actually reduce the latencies of processing all these operations and it's highly optimized to process to optimize each of these execution units called the cores which means that each core has a part that is tasked to understand first of all what is the next instruction to run or to predict the branch of how the what the next operation may be based on the conditions that you are running for example if you have a if condition the branch predictor can understand what is the more most likely next instruction and can do some optimizations also the CPU is has a lot of caches to reduce the latencies in the loading data from all the devices it can interface with it can interface with the the RAM for sure but it can also interface with the disk it can also interface with some peripherals like the printer like the mouse like the keyboard etc etc on the other hand the GPU is not tasked to do many different things at the same time but it's tasked to do one thing or a few things but on a massive amount of data so the operations that we do on the GPU are requires a lot of computation and for that for this reason most of the area so the physical area of the GPU is dedicated to compute units so this green stuff that you can see here and these are called cores and you can see that the part that is dedicated to the control area so the part that is tasked with understanding what is the next instruction to run or to do some optimization in this the program is very little you may be thinking well does it make it does it make the GPU less fast compared to the GPU to the CPU well not really because we have many more cores that can compensate for these higher latencies okay i can give you a lot of knowledge about the GPU from a theoretical point of view i think the best way to understand the CUDA programming model is just to jump into the code so we don't get bored okay imagine we have a very simple task and we have a vector we have two vectors a and b and we want to calculate the sum of these two vectors into and save the result into another vector c where each item is the element wise sum of the corresponding item of a and b how would you proceed with this task on the CPU well you would do a for loop for example so for example you would make a for loop that starts from the first index so the index is zero and c of zero is equal to a of zero plus b of zero then c of one is equal to a of one plus b of one etc and you do a for loop on all the elements of this vector in the GPU we want to do the same operation but in parallel because we have a lot of compute units called cores and we want all of them to work in parallel so the first thing that we need to understand is how to divide the work that we are going to do into sub units of work and dedicate each core to one subunit one simple subdivision would be okay the first core should do this summation the second core should do this summation the third core should do the summation etc etc so imagine we have a eight element vector we need eight cores to do this element wise summation we will call the course threads because it should also remind you of the multi-threading that we already use in operating system so multiple threads work concurrently on the same on the same or similar job in the GPU let's look at the code now the code that i am going to show you is a CUDA kernel and it's written in c but you don't have to understand c and you don't have to understand this code what i want you to understand is the intuition behind it because later we will need this knowledge and convert it into triton which is python and you should already be familiar with python so let's go to the code and i have a very simple vector addition we can see it here okay first of all how to do a vector summation usually the gpu is interfaced with a cpu and the cpu has to first of all tell the gpu what is the data it is going to work with so the cpu needs to have these vectors it needs to transfer them to the gpu then the gpu needs to do this vector summation then the cpu has to copy back the information from the output from the gpu to the cpu and then make it available to the program this is what we are going to do here so we are going to allocate a three vectors of size n one called a one called b and one is the output vector we initialize their items randomly so the a of i is a random number between 0 and 100 excluded then we allocate memory on the gpu to hold these vectors and then we copy them to the gpu so we copy the a vector to the gpu and the b vector to the gpu of course we don't copy the result because that's what we want the gpu to populate with the output so we just allocate it on the gpu what we don't copy our output vector on the gpu because it's it's made of random values then what we do is we launch the kernel the launching the kernel means that we launch a program that the gpu should execute in parallel on multiple threads or multiple cores each of these threads should do a unit of operation a unit of work that is independent from the others actually they can be dependent on the others and but we will not be talking about synchronization so we launch this kernel and what we are saying in this line is launch one block of threads and later we will see what are blocks but you can think of you can ignore this one for now what we are saying here is launch n threads so n parallel operations on with the following arguments so the output where we want to save data the input array a and the b input b and the number of elements let's see what happens inside of this method this method is following a particular syntax that is um how to say CUDA specific so this global is actually added it's a like a superset of the c language where we have some additional keywords that belong to CUDA so it's not really c it's CUDA c so it's a very simple method as you can see and the first thing that we need to do is CUDA cannot know what each thread should do it's we should tell each thread what to do so the mapping between the data and the what each thread should do it's up to us as software engineer CUDA what we'll do is when we ask it to launch n threads in parallel it will allocate n threads and assign a unique identifier to each of these threads in our simple case we can see it like this so it will assign the first thread the index zero so we are asking for example imagine we have a vector of eight elements it will assign the first thread index zero here i call it one but it's it's wrong but we can write another number here so this will be actually thread zero this will be thread one this will be thread two thread three thread four thread five thread six and thread seven so let me delete this one so we don't get confused and what we are saying is that the item that each thread should process is equal to its thread index so this is the thread zero so it should process the item with index zero this is the thread one and it should process the item with index one this is the thread number two and it should process the item with index two and this is what we are doing in this line of code we are saying which item each thread should process which is exactly the thread identifier so the thread id later we will see why why do we have this dot x but that's for later next thing that you should see is okay we are doing the output of the height position is equal to the a vector as the height position plus the b vector as the height position so it's a very simple summation element wise you may have noticed this if statement why do we need an if statement if we already know that we are going to launch eight threads and of course i will be between um we already know that we are going to launch n threads so i should of course be less than n because each thread id will be between zero and n minus one so why do we need this if condition this is needed because when you CUDA when it launches a number of threads this number of threads is always a multiple of a unit which is a 32 in the case of the CUDA so if we have like 34 elements in a vector and we ask CUDA to launch 34 threads CUDA will not launch 34 exactly it will launch 64 threads so multiple of 32 which is the warp size by the way um and uh what we need to do is we need to ask these threads to only work for we only need to ask the threads that have a corresponding element to work and all the others that don't have a corresponding element because the the vector is not large enough for all of them to not do anything so do not enter this uh if statement there is another thing that we should learn which is actually the threads actually when we have a group of threads in in a CUDA programming model but i believe also in other GPUs a group of threads of 32 threads is called a warp and this 32 threads will share the same control unit so let's go back to the slide so as you so as you can see here we have this yellow unit here in the GPU and a group of threads will share the same control unit which means that what is this control unit it's a part of the hardware of the GPU that is tasked with understanding what is the next instruction to run now if the group of threads is sharing the same unit it means that this group of thread will always execute the same statement at any time they will always work in synchrony will always work on the same instruction they it's it cannot be like this thread is working on one instruction and this one is working on another instruction what does this mean on a programming level it means that if when we launch a group of threads of course CUDA will spawn more threads than we need if the if the number of elements of our vector is not a multiple of 32 this means that when we did this thread they will first execute this operation and each of them will have its own value of this thread id so they will execute the same instruction but the data at each instruction may be different because each of them have their own registers which means that they will always they will for example reach this statement here and the first thread will have i equal to zero the second thread will have i equal to one etc etc even if they are executing the same instruction this programming model is called the single instruction multiple data CUDA likes to call it a single instruction multiple thread doesn't matter for us it just means that they will always execute the same instruction but the value of the variables may be different then after executing this statement they will reach this statement here the if statement and of course some of them will evaluate this statement to true and some of them will execute the statement to false which also means that some of them should enter this if statement and some of them should not enter this if statement however because the control unit is the same for all of them they will be forced to enter this if statement even if they should not so how CUDA manages this control divergence it will basically make work like this all the threads for which this if statement is equal to true will enter this if and will execute the instructions inside of this if and all the threads that have this statement equal to false so the condition of this if equal to false they will enter the for loop because they cannot not enter it because they should be always executing the same instruction at any time but they will just not do any operations inside of this for loop they will just sit idle this is um called the control divergence and it can reduce the um the the throughput of your program so you want to minimize it but you may be wondering why doesn't the gpu dedicate a control unit to each core so that they can work independently from each other because the control unit is expensive to add in the chip area of the gpu it's much more efficient to add more workers instead of adding a control area control units for each worker so this is a design choice of the gpu and it works fine okay now that we have seen how a kernel works let's move forward to another example all right the next example that we are going to see is the following is the same as the as before so we are going to do a vector addition but imagine that we have a very large vector so imagine that we have a vector with 1 million elements of course we could do like before so we launch a kernel with 1 million threads the problem is CUDA will reject it because it's a i don't have 1 million threads to run in parallel so how can we proceed in this case because usually we are working with very big matrices or very big vectors so we need to process a massive amount of data so how to manage a parallel um let's say parallel computation when we do not have enough uh computation cores one way is to divide the input vector into blocks of elements for example we may decide for example imagine our gpu only has 32 cores in total we may divide our input vector into blocks of size 32 such that the first 32 element are the first block the next 32 element are the second block the third 32 element the third block and the last 32 element are the last block in this way we can ask the gpu to work on one block at a time so we can say okay work on the first block and after it has processed the first block it can work on the second block and then the third block and the fourth block this also allows the gpu itself to manage a subunit of work because imagine now we have blocks of 32 elements but we have a gpu of 64 cores the gpu we can also schedule two blocks at the same time because it has enough cores so we need to give some granularity uh we need to reduce the ground increase the granularity of our data to let the gpu decide how many blocks to schedule this is the reason we introduce blocks inside of CUDA so let me make a concrete example but with a very simple assumption imagine our gpu only has two cores or let's say four cores actually so we have n is equal to eight elements eight and we have four cores in total so what we can do for example is to is divide this vector into groups of either four cores or even less let's say two two elements at a time so this is the block number one this is the block number two this is the block number three and this is the block number four we can ask CUDA to launch a kernel that is made up of four blocks and where each block is made up of two threads so when we launch the CUDA kernel we can show the code now we ask the CUDA where is the instruction this first instruction tells CUDA how many blocks we have and the second part of this in this symbols tells how many threads we have for each block in our case we want n divided by the block size number of blocks where the block size in my picture is two so how many blocks we will have we will have a number of blocks so the number of blocks is n divided by two where two is the block size so this is the block size and this will be equal to four blocks each of size equal to two and this is what we are doing here so we are saying that the number of blocks is okay the ceiling because it may not be a multiple of the block size n of n divided by the block size and this tells how many blocks we have and this is will be the this will define our grid it means the grid is basically telling how many blocks we have and then each block is made up of block size number of threads then the problem is how do we assign the work to do to each of these threads when we launch a kernel like this with this configuration so the number of blocks and the number of threads per block CUDA will do the following job it will assign this block each block a index called the block id where the block id of the first block is zero so let me write here so this will have the first block will have a block id equal to zero and in each block it will assign a thread id and the thread id of the first thread of each block will be the thread zero and the second thread will be the thread number one the second block will have a block id block id equal to one and the first thread of this block will be the thread number zero and the second thread of this block will be the thread number one the third block will have a block id block id equal to two and the first thread will be the thread number zero and the second thread will be thread number one etc until the last block which will be equal to three this will be thread number zero and thread number one the problem is now based only on the index of the block and the index of the thread how can we map it to what element of the vector each thread should work with one simple assignment would be to just do well you can see that in this case we need the this vector this thread here to work with element zero this one should work with element one this one should work with element number two this one to the element number three this one four this one five six and seven this five is so ugly so let me write it again how can we find the mapping given only the block id and the thread id how can we find which element it should correspond to well it's very simple formula so you can see that the element let's call it the element id which in the code i call it i is equal to the block id multiplied by the size of each block which is a block size let's call it block size yeah i have it block size plus the thread id because in the case of the first thread this will be equal to zero multiplied by two plus zero which is zero in this case it will be equal to zero multiplied by two which is zero plus one and it will be equal to one in this case it will be equal to one because block id is equal to one one multiplied by two is equal to two plus zero is equal to two etc etc and you can see that this formula works for all the threads so the mapping when we launch a CUDA kernel we are telling the gpu how many blocks we want and how many threads there are in each block but CUDA has no notion of how to map each CUDA has no way of knowing how to map each thread into the element it should work with that's up to us and that's what we are doing here when we are creating this kernel here so we are telling that each element each thread should work with the ith element of the vector where i is calculated as follows the block id to which this thread belongs multiplied by the block size so how many threads there are in each block plus the thread id and this will tell the ith element this particular thread should work with by giving in let's go back to the slides by choosing the block size equal to two and having four cores the gpu can choose to run one block or two block concurrently if it has enough free cores so that's why we want to work with by block by block because it allows the gpu to choose how it want to parallelize the operations if it has enough cores and we don't need to have n cores for n element vector we can divide it into smaller blocks and let the gpu manage the scheduling let's see one last example and then we move on to triton imagine now we want to do a matrix addition instead of doing a vector addition now in a matrix addition we have data that we can see on two axes one is the rows and one is the columns it's usually we represent the vertical axis as the y-axis and the horizontal axis as the x-axis by using the same blocked intuition that we used before so dividing the data input data into blocks this is how we can divide the labor of our matrix addition into blocks for example we can divide our rows into blocks and call this one the block zero and this one in the block one and this one is the block two the same we can do on the x-axis so we can choose this one as the block zero this one as the block one and this one as the block two on the x-axis with x is the column axis and the y is the row axis we don't even have to choose the same block size for the rows and the columns we can even choose the to group together three columns and two rows instead of doing two and two in this case we need to find because as we said before when we launch a CUDA kernel CUDA will just assign ids to the blocks and the threads in each block then it's up to us understanding what to how to map the id of the block and its corresponding thread id into the data element that this particular thread should work it should work with so in the case of matrix addition we could say that each thread should work with one output element of the output matrix c so it will become the sum of the a element plus the b element and it should map it to the c matrix output matrix so how to do it imagine we have six rows and we have six columns one easy way would be to divide these rows into three blocks each made up of two rows and each column into three blocks each block made up of two columns CUDA will launch as many blocks as there are the combinations of the rows and column blocks so in this case we have three blocks for the columns and three blocks for the rows so it will launch nine blocks so this is the block number 00 because it's a CUDA will identify the dimensions of the block based on the axis in which we have divided it so we will call this the x dimension the columns and the rows we will call it the y dimension so it will launch as many blocks as there are combinations of x and y in this case we have nine so this will be the block 00 this will be the block 01 this will be the block 02 this one will be the block 10 11 and 12 etc etc inside of each block we will also divide the threads into x threads and y threads along the two dimensions so this will be the thread 0 and the thread 1 along the x axis in the x block and this will be the thread 0 and the thread 1 in the y in the in the block 0 of the y axis and each block will have two threads and they will be identified as thread 0 and the thread 1 so let's look at how the launch grid works in this case so imagine we have a matrix with number num rows number of rows and num columns num calls number of columns and we want to divide each row the rows into block size number of rows and the calls block size number of columns we define basically the number of blocks that we need is this one so this is just a fancy way of writing the ceiling of the num rows divided by the rows block size and this is just a fancy way of writing the ceiling of the number of columns divided by the calls block size this tells us how many blocks we will have on the rows and how many we will have on the columns the grid you can see here which tells us how many blocks we have is a tuple that accepts three values which tells how many blocks we want on the x dimension how many we want on the y dimension and how many we want on the z dimension we are not going to use the z dimension because we only have a matrix then inside of each block how many threads we want for the x dimension and for the y dimension as the x dimension we have chosen the columns so we are saying how many blocks we want the columns and how many blocks we want for the rows and then inside of each block how many threads we want for the column block and how many threads we want for the row block this will define our launch grid and what CUDA will do it will just launch this following configuration so it will launch as many blocks as there are combinations of x and y's and inside of each x and y it will assign a thread id in such a way that the thread zero on the x-axis so there will be two threads on the x-axis and the two threads on the y-axis of each block now let's try to understand how to map just based on the block id on the x-axis just based on the block id on the y-axis and the thread id on the x and y-axis how to map it to the one element of the output matrix let's look at the code so first we can use the following formula to identify which row this element should work with which the which because each element of a matrix is identified by two indices one is the row identifier and one is the column identifier the row identifier we can look at it like the block id multiplied by the block size plus the thread id let's see why it makes sense so in this case for example this thread will work with the row zero because the block id is on the y-axis is zero and the thread id zero so it's a block id multiplied by the block size so zero plus zero it will be zero so this element will be working with the row number zero and which column it will be working with well it will be working with the block id zero multiplied by the block size on the column which is again zero i mean this block size is two but multiplied by zero it will be zero plus the thread zero so it will be zero this element here on the here it will be the block id of the y-axis multiplied by the block size plus the thread so it will be the element zero on the row and for the columns it will be the element one let's see another one for example here uh for example this element here so this um how this thread will uh which element it will work with well it will be the block size on the y-axis multiplied by the the block id on the y-axis multiplied by the block size so it will be one multiplied by two so that will be our row so the row number two uh which makes sense because it's the um this is the row zero this is the row one and this is the row two and the column will be the block id on the x-axis which in this case is equal to one multiplied by the block size which is equal to two so two plus one is equal to three so this thread here will work with the element number two three and this formula now makes sense so this is how we use the block id and the thread id inside of each block to map it to which element this particular thread should work with so as i said before cuda has no notion of knowing which element this particular thread should work with this is up to us just based on the block id and the thread id that cuda assigns then we make sure that the row index is less than the number of row and the column index is less the number of columns why because as i said before when we launch um blocks and threads cuda will round up that number to a multiple of 32 in the case of the threads so which means that some of this thread should not work with any data so we make sure that all the threads that should not have the corresponding element to work with they should be just sit idle inside of this if statement but the one that have it they should go enter and do some job so we calculate the index of the element of the matrix that this particular thread should work with as follows which is the row index multiplied by the number of columns plus the column index this is just another way of writing a or for example this is just another way of writing a of row index call index but the way we allocate arrays in c or c++ is a flattened array where all the rows are one after another so we need to identify the element inside of the array based on its row index and column index and this is the formula that we use to identify it if you have never worked with um arrays in c++ or c then it doesn't matter because later we will see tensor layouts and this will be much more clear but if you have already worked with then you already know how to index an element inside of a multi-dimensional array in c++ and then we compute the output as as usual so i know that this has been a lot of information so what should we should we remember from this the first thing that we should remember is that we decide how to divide the work on whatever matrix we are working with or whatever thread we are working whatever vector we are working with we tell cuda how many blocks we want and we tell cuda how many threads we want in each block based on the identifier of the block id and the thread id we should come up with a strategy on how to map it to a subunit of work so which part of the matrix or which part of the vector that particular thread should work with um now the next step for us is to understand the tensor layouts because we are going to work with the tensors and we need to understand how the tensors are layout in the memory of the gpu or in the cpu as well actually so we need to understand what is the row column row major layout and the column major layout what is the stride etc and convert all the knowledge that we have about cuda into triton so that we can then code with triton our kernel so let's go all right guys finally it's time for us to explore tensor layouts now why do we need to explore tensor layouts because before we we have seen some examples of cuda kernels and when you give a matrix to cuda or to a cuda kernel or a vector to cuda kernel cuda will not give you will not give you the entire matrix like like in python where you can access each element by its index cuda will just give you a pointer a pointer to the starting element of that particular matrix or the starting element of that particular vector then it's up to you to calculate the memory address of all the remaining elements so suppose that we have a simple vector in pytorch this simple vector could be the following which is a vector of shape 7 because it's a tensor with only one dimension with shape 7 which is the number of elements in the first dimension for now ignore this property called the stride and later i will explain it what is it how this tensor will be saved in the memory of the cpu or in the gpu it will be saved as follows suppose that the starting address of the first element is the address 100 and suppose that each element is made up of a floating point of 16 bit so it means that each element will occupy two bytes so the start address of the second element will be the address 102 and the third element will be 104 and the fourth element will be 106 etc etc etc so this is exactly what you get when you in c you get you allocate a vector or a matrix with malloc so when you allocate in c a vector or a memory with malloc c or the memory allocator will just allocate enough memory to store all the elements and it will give you a pointer to the start address of this memory then it's up to you to understand where each of these elements is stored in that block of memory and this is to to do this we introduce a property called the stride the stride tells us how many elements we need to skip to arrive to the next element in the particular dimension in this case for example in the case of a vector we only have one dimension which is the x dimension or the columns dimension you can think of it so this is the first column this is the second the third the fourth fifth etc etc um so in order to arrive from one element to the next we just need to skip one element so to go from here we need to just increase our pointer by one element and then to go here we need to increase again pointer by one element etc this allow us to do a for loop on this tensor let's look at a more complicated case like the matrix so the matrix is a two dimensional and suppose we have the following matrix which is made up of six elements with two rows and three columns so the shape of this tensor will be two by three because if we have two rows and three columns how this matrix will be saved in the memory in the memory it will be just a flattened matrix it means and this is called the row major layout but there is also another one called column major layout that we will not be discussing so how it will be stored in the memory is as follows it will be the first elements of the first row so the elements of the first row followed immediately by the elements of the second row so that the memory address imagine with this the memory address of the first element is 62 to go to the next element we need to increase the memory address by the number of bytes that each element occupies which is two bytes so the the address of the second element will be 64 the third element will be 66 and the next row will start immediately after the end of the first row let's introduce this property stride so the stride is what the stride tells us how many elements you need to skip in each dimension to arrive to the next element of that dimension for example imagine we want to address we want to get the element so all the elements of the first row so let's call this tensor here let's call it t so t of zero and this basically this indexing here says give me all the elements of the first row so in the first row select the all only the first row and give me all the elements of that row how to how does this indexing work well by starting from the pointer to the first element it will select only the first row and then it will move the index here one element after another so it will select the first one the second one the third one how does it know that it needs to move one element by one element because in this dimension the stride is one so the stride tells us how many elements you need to skip to arrive to the next element in that dimension imagine now that we want to get the t of let's say zero and one well in this case let's say t of one actually and all the elements of the first row it will first of all it needs to skip some elements from the first dimension it needs to skip the element zero because we don't we are not selecting it we only want to select the element one of the first dimension which basically means the row with index one so because it will start from the first pointer to the first element it will it needs to know how many elements to skip and how many element to skip is given by the stride so the stride tells us how many elements you need to skip to arrive to the next element of the first dimension so in this case it will take the pointer to the first element skip three elements and it will be starting with the second row and then inside this row it will go through the second in the the index of the second dimension in which the stride is one so it will just go one after another and it will return only this part of the memory so to rehearse the stride is just a a number that tells us how many elements you need to skip in each dimension to arrive to the next index in that dimension so it means that to go from one row to the other we need to skip three elements to go from one column to the other we need to skip one element why is the stride useful well the stride is useful because it allows us to reshape tensors very easily and without doing any computation let's see okay imagine we want to reshape a matrix imagine initially the shape of this matrix is a two by three so we have a two row by three columns and we have a stride calculated as follows means that to go from one row to the other you need to skip three elements and to go from one column one row to the other you need to skip three elements and to go from one column to the next you need to skip one element so you need to jump by one element if we want to reshape it into this shape so three by two basically we want to have three rows and two columns we can reshape it without actually changing its memory layout just by changing the stride because look at this physical configuration of the tensor and we can access this same tensor as this shape or as this shape exactly by using the same physical view because to go from one row to the next here the stride is a three so we need to skip three elements it means that the starting address the starting element of the second row is given by the start pointer plus three elements so exactly here the second row will start and each element of the second row is one after another because the stride of the second dimension is one so you can see that to get the second row we can just start from here and then go one after another and get all these elements which is exactly the second row suppose we want to obtain the second row of this view here of this shape of this reshaped matrix how to do that let's look at the stride the stride now is a two in the row it means that to go from one row to the next we need to skip two elements so if we want to select this row here we go from the starting point of the memory so this start pointer we skip the first two elements because the stride says that to go from one row to the next you need to skip two elements so we arrive here and then we select exactly two elements which are one after another because the stride in the second dimension is one so the stride allow us to reshape the tensor without changing the physical layout on how it is stored in the memory moreover the stride also allow us to get the transpose of a matrix without changing the shape of how it is stored in the memory so without changing the arrangement of the elements in the memory and this is very cool because we can view the same matrix as without the transpose and also the transpose version of the matrix without changing anything in the memory so it comes for free just by working with the index and the stride so to transpose the matrix along two dimensions we just need to swap the stride along these two dimensions that we want to transpose so in this case for example imagine we want to get the transpose of this matrix we just need to swap the strides so if we want to get the second row of the transpose matrix how to get that well you we always have the pointer to the first element where the tensor is stored so at the beginning of where the tensor is stored in the memory and it says that in order to go to from one row to the next we need to skip one element which is correct because as you can see the second element is exactly the second element also in the memory so we just skip by one and we get the starting point of the second row and then to go from one element to the next in within the same row we need to skip three elements so the second element of the second row will be after three elements after the first element of the second row so after two we need to skip three elements so we skip this one we skip this one and we arrive to this one eight which is exactly the second column of the second of the second row so basically the the stride as you can see allow us to do two things one is it allow us to reshape the tensor without having to reallocate it in another configuration in the memory secondly it allow us to transpose a matrix without having to rearrange the elements in the memory which is great because moving memory around is expensive and rearranging the memory is expensive so that it's great that this this stuff comes for free basically another thing okay for example if you try to you know that in pytorch there are two methods to reshape a tensor one is called the reshape method and one is called the view method the after transposing a matrix by swiping the by swiping the stride of the two dimensions that you want to transpose you cannot reshape for free the tensor anymore because um the tensor basically what is the stride the stride how it is computed the stride is just the uh let me show you with a concrete example the stride is just the product of all the shape uh after um in the future dimension so the stride of the zeroth dimension is just the product of the elements in the shape of the future dimension so the stride of zero is just the product of all the shape starting from the index number one uh it's not easy to see with the 2d matrix because we don't have enough elements so let's do it with a 3d matrix so this is a tensor with the three dimensions so it is a shape of two four three which means that we have two matrices each matrix is made up of four rows and each made and three columns the stride is calculated as follows so the zeroth dimension stride is just the product of four by three and this three here comes the with the product of just a three with its with one because we don't have any future dimension of the three so when we transpose this stride property is lost and we cannot um after transposing this matrix by swapping the strides we cannot do further reshaping operations so basically the the tensor is not log contiguous so this is a very advanced okay property if you it doesn't matter if you know it or not but if you are curious basically in pytorch you cannot um view a tensor after it has been transposed because the pytorch to transpose a tensor will just swap the two strides but it loses the stride property which is basically the stride will not be anymore the product of the future shapes so this is not anymore two this should be two for example and this should be one but after transposing this property is lost so you need to actually reallocate the tensor if you want to reshape it after it has been transposed it doesn't matter if you remember this it's just a curiosity anyway so what is the transposed what is the stride used for is the stride for the stride is used for two things first of all it is used to understand how to index this tensor so just by having a pointer to the first to the starting address of this tensor we can index this tensor however we like so we can access any row any column moreover it allow us to reshape this tensor for free so without rearranging the elements inside the memory and third it allow us to transpose the tensor however we like just by swapping the strides of two uh the two dimensions that we want to transpose now that we have seen also how the tensor is stored in the memory we can finally go to see triton um and see some examples all right guys now that we have seen how uh tensors work tensor layout works how CUDA works now we can see some examples of triton kernels to see how triton differs from CUDA now if you go on the triton website you will find some tutorials like in this section here and let's do let's work one tutorial together to understand how triton is different from CUDA so if you go to the tutorial there are many examples so first of all the code that i will be coding for flash attention is based on this tutorial here fused attention that you can see here but with some modifications because i simplified the code a lot i removed for example the fp8 implementation i also for example um this code here on the fused attention only works in the backward pass only for the causal attention while my code will work for the causal and non-causal attention uh the second another modification i did is instead of using the exponential tool that they use here to make things faster drawing because the exponential tool is implemented with a faster unit i i use the the original implementation of flash attention which use the exponential with the base e etc so i simplified my code as much as possible to make it simple to follow instead of making it optimized so for sure my code will be slower than the the fused attention that you see here but mine should be more comprehensible more easy to follow anyway let's go to the vector addition tutorial and if you go to the vector addition tutorial there are some examples on how to do a vector addition with triton this should allow you to get into the mindset of how to write kernels with triton instead of writing first the kernel and then calling it let's do the opposite so let's see how to call this kernel and let's explore how it works so i have already copied the tutorial vector addition from the website so let's look at first of all what we want to achieve we have an input vector x and an input vector y and we want to compute the vector addition which means that with the torch we want to do the following operation and also we want to do the same operation also with the triton by calling this method add and then we want to compare the two vectors output and they should be equal or at least the difference should be very very small because of course there is always some rounding error in case you are working with floating point numbers the size of this vector is 98 000 elements and we want to work in a blocked way so as you remember before with the cuda you can do vector addition by spawning a lot of number of threads each doing one operation but when the number of threads that you have is not enough then you need to divide the input vector into blocks and this is what we are going to do here so let's look at this add method so this add method basically will first of all allocate the necessary memory for the output vector then it will compute the launch grid the launch grid tells triton just like in cuda how many kernels we want to how many blocks we want to launch how many blocks of threads we want to launch if you remember in the cuda kernel we specify how many blocks we want and then how many threads we want for each block in the case of triton we tell how many blocks we want and then we don't force how many threads to launch it will be triton that will choose how many threads to launch we just tell what each group of threads should do so in this case for example we divide our number of elements so n so which is 98 000 into blocks of size block size which is initialized as 1024 this is basically saying take them to calculate the grid size you do the ceiling division so basically this means ceiling of seal of n elements divided by block size this is the meaning of this one so how many blocks we want now what each block should do is inside of the kernel so let's go to the kernel and when we launch the the kernel we we can specify the launch grid in this square parentheses and then in the round parentheses we specify the arguments of this kernel so let's go to the kernel we see that python triton will not give us access to the tensor x it will give us a pointer to the first element of this tensor and this takes us back to the tensor layouts so the reason we studied the tensor layouts and the strides and all the stuff is because triton this code this add kernel will run on the gpu and the gpu cannot um does not index tensors like pytorch by using all the dimension and with the broadcasting and all this fancy stuff the gpu will just give you the pointer to the first element of this tensor in the memory and then it's up to you to compute all the indexes of all the elements that you want to access so this x ptr is the pointer to the first element of the x vector this y pointer is the first the pointer to the first element of the y vector then we have the pointer to the output vector where we want to store the result of this matrix addition we specify how many elements our vectors have and what is the block size so how many items each block should process which may not correspond to how many threads each each kernel will have you may be confused because okay in triton in coda we specified how many threads each block should have so the granularity that we manage is the thread level here we are saying it's a group of thread that should work with this quantity of data then it's up to triton to optimize the number of threads that it will actually use actually there are tricks there are ways to say how many threads we actually want by specifying the number of words but we will see that later for now just remember that this thread this kernel here will process a number of elements in the input vectors how many number how many elements block size number of elements first of all we need to identify which block we are we are in coda we use the the variable called the block id.x to identify the identifier of the block which tells us which group of elements we should be working with in triton you do the same by using program id and in coda the block id can be along the x y and z axis in triton these are called the dimension 0 1 and 2 here we have one dimensional data so we only use one axis to specify the block index so we get the block index which is the p id in this day in triton this is called the program id it's more intuitive to think of it as the program like this is a kind of a program that is running in parallel with other programs that will have different program id and based on the program id we can understand what is the starting element this program should work with so this blue block of threads should work with them and together that is just the p id multiplied by the block size so the p id 0 should be working with the element that starts from the element 0 the p id 1 should start with the element 1024 and the p id 2 should start from the element 2048 so it should skip the first 2048 elements and start with the element with index 2048 next we define how to load these elements based on the pointer in which of the x and the y vector to do that we specify a list of offsets with respect to the starting address that we want to load so because each program in triton works with a group of data so not one single element but a block of elements we mean we need to understand which elements to load so the offset of these elements in the case of the program id 0 it will load the block start so 0 plus the elements from index 0 to 1024 excluded with the program element 1 this basically will result in a vector that is well the program start with p id equal to 1 will be 1024 then 1025 1026 1027 etc etc until 2047 with the program number let's say 2 this this offset will be the elements 2048 2049 blah blah blah until 3000 and something now we also as you remember when we create when we launch a grid the number of threads is not always based on the number of elements in the block or the number of elements in your vector it is always a multiple of a base number which is usually 32 which means that the grid this program may have more threads that it needs so some threads should not be doing anything so should not be loading any data and should not be computing any summation so what we this is what we why we need this mask this means that if all these offsets that we are loading it should be at most up to n elements because imagine you have not 1000 2000 imagine you have a vector of 2060 elements which means that this offset for the the third program of this kernel will load the offset that go from 2048 2049 blah blah 2060 and then also 2061 2062 etc etc but we said that we only have a 2060 elements so all the elements of 2061 62 etc until 3000 and something they don't exist so we need to tell somehow that all the threads that are working with these elements should not load anything that's why we need this mask this mask tells load among all the offsets that this block should work with only those elements that actually exist for which this mask is true then we load the elements of this current program which is a group of elements defined by these offsets and only the one that for which this mask is true so only the one that actually exists all the others should be ignored and we can also specify what it should load in case this the mask is false with another parameter but we will not see that here we also load the group of elements of the y vector and then we compute the output x plus y so if you remember previously in CUDA we we did something like this like the output of i is equal to the x of i plus the y of i so we did it one element at a time because each thread was working with one index here we are working with a group of elements so this x is a group of elements is a block of elements at most of size block size actually of size block size and it's this y is a group of elements from the y vector and we are computing the output group by group so this this is summing a group of elements of x with the corresponding group in y and writing it in output then we need to restore this output we need to store it in the output tensor output ptr that you can see here which is a pointer to the first element of the output vector and we say that where should we store this output vector which is of size shape of this vector here is block size where should we save it well in the same offset to where which we loaded x so if this program worked with the index 2048 2049 etc etc then all this output should be written in the same offset 2048 2049 etc up to 3000 and something using the mask as well because we don't want to write all the values of this block size because maybe we don't have enough elements so only write the one that are actually present in the vector so the reason we need the mask is because CUDA will launch a number of thread that is always a multiple of a base unit that may not be a multiple of the vector size that we are working with so we need to find a way to tell some threads to not do anything for those that the data is not available so let's rehearse what you have seen so far in CUDA the program that we write is at the thread level so each thread what it should do in triton it's this block of data we work with a block of threads what data this block of thread should work with all right guys the final finally the moment has come so we are going to code the flash attention for our pass right now in triton but let's rehearse the algorithm so the goal of the attention mechanism in specifically in triton in flash attention is to compute the attention output which is we want to compute the output of the following formula so the query multiplied by the transpose of the key divided by the square root of the head dimension all multiply we apply the softmax and then all multiply by b now we in this video we will be coding the forward pass and also the backward pass but before coding the backward pass we need to understand how the autograd works we need to understand what is the gradient what is the jacobian how to derive the gradient of the softmax operation how to derive the gradient of the matrix multiplication operation etc etc so that is going to be another part of the video for now let's concentrate on the forward pass right now we have some tools so we know that we have this thing called the gpu that can parallelize operation among multiple cores we know that in cuda we can parallelize operations by telling by writing a program that is the definition of what each thread should do or we can follow the triton programming mode which is telling in python what each group of threads should do the mapping between the what each thread should do and the which element that should try to work with is up to us to the programmers and the same happens in triton we tell we how many blocks of threads we want how much data each thread should block of thread should process so that's the block size that we saw in the vector addition but then the mapping between the elements of the vector and the the identity of each group of threads so the program id that we saw is up to us and the same will happen when we record flash attention let's see what can we parallelize in this flash attention so first of all this code that you see in the forward pass of the flash attention is takes as input query key and value that is a vector that is a matrices of n by d however usually in a transformer network we don't have only one sequence made up of d dimensions we have many sequences made up of d dimensions and this d is the lowercase d which is the the number of dimensions dedicated for each head but we don't have only one head we have multiple head so the algorithm that you see here is what each head should work so each head of each batch should do moreover we have seen before when talking about block matrix multiplication that we can parallelize the computation of the output because this output block here depends on the query one and all the keys this one here depends on the query group block of query two with all the keys and this one here is the query tree with all the keys etc so because this one only depends on query the group the block query one and this one only depends on the block query two they can work independently from each other by sharing of course work the keys another thing that we need to understand about triton is the shared memory so um the in the gpu we have the high bandwidth memory and which is the kind of the ram so the when you buy an a100 they tell you that it has a 40 gigabyte that's the amount of memory in the high bandwidth memory so the dram so let's look at actually the structure of the gpu which is here we have this dram which is the big memory that we that the gpu has and then each streaming multiprocessor so it's a let's call it a block of threads actually also have a shared memory so inside of the gpu actually we have we have these streaming multiprocessors and these streaming multiprocessors have a part of memory called the shared memory which is much smaller than the dram like much much much smaller what changes between these two memories the access to the dram is very slow and the access to the shared memory is very very very fast so one thing that is different between cuda and triton is that whenever you load some information in cuda you are loading that information directly from the global memory because when we launch a cuda kernel first of all as you remember in my c++ code we first copy the tensors from or the vectors from the cpu to the gpu and they reside in the global memory of the gpu then we load these elements directly from the global memory but the access to the global memory usually it's much much much slower so what happens with the flash attention is that the flash attention computation in its the attention computation in its naive version the one that we can do with the torch is very slow because the access to the global memory is very slow so we want to use as much as possible the shared memory so we want to reuse the elements loaded from the global memory into the shared memory so that we don't need to access the global memory every time to load elements from the vectors or the matrices and this is what happens also in triton so in triton whenever you load some data you are copying the information from the global memory to the shared memory then whatever operations that you are doing is done on the shared memory and then when you store the information you are copying the data from the shared memory to the global memory this makes it much faster so we always work with the elements that have been loaded in the shared memory and this shared memory basically it's shared for all the threads that belong to the same block in triton we have an abstraction level that doesn't make us work directly with the threads so we always work with a group of threads that belong to the same block that share this shared memory so in triton we are copying information from the global memory to the shared memory we do some operation with it and then we store back to the global memory and this is what we are going to do with flash attention now let's review the algorithm of flash attention so in flash attention we have to go an outer for loop that is among all the between all the keys and then an inner loop that is sorry between all the query blocks and then an inner loop that is through all the key block in the original flash attention algorithm the flash attention one the outer block was on the keys and inner block was on the queries this made it less parallelizable why because the outer loop is on the queries and we have seen before that the the output of this attention can be computed independently for each block of queries so it's much easier to parallelize so this outer for loop actually we don't have to run a for loop we just spawn many kernels each working with one iteration of this outer for loop so each working with a different query block of this outer for loop and the inner for loop is something that we have to iterate through so each triton kernel will work with one query block and then iterate through all the key blocks and inside of this key block we have already seen the operations that we are going to do which the we explored before and at the end of this for loop we need to store back the output in the high bandwidth memory and this is how it's gonna we are going to work another thing that we should notice is that this query key value are n by d so as i said before but usually in in a transformer model we don't have only one sequence we have many sequences so we can also parallelize on the number of sequences that we have in the batch because each batch can work independently from each other and inside each and each head each sequence has multiple heads so each head also can work independently from each other because that we know from the attention is all unit paper that's what's the meaning of head that's what's the meaning of multi-head attention so that each head can compute the attention independently from each other so we will also parallelize along the head dimension and moreover if you look at this definition of the query block we can also split the query into blocks and each query block can work independently from the other query blocks by in producing one output block this is how we are going to parallelize so we are going to parallelize each sequence in the batch but inside of each sequence we are going to parallelize each head and inside of each head we are going to parallelize each query block so how many programs we we will have working in parallel at most it will be the sequence the number of batches so the batch the number of sequences in the batch so the batch size it will be the batch size multiplied by the number of heads multiplied by the number of blocks that we will divide the query sequence into so let's call it the i don't know block size q the block size q all right now that we have seen this one let's go actually code it so i have already introduced a little bit the differences between my implementation of the flash attention and the one that you can find on the triton documentation which is first of all i don't work with fp8 because i believe this is unnecessary for our explanation it's of course much faster because the recent gpus also support fp8 second difference is that in the um in the flash attention on the triton website the backward pass is only implemented for the causal attention but in my case i implement it for the causal and the non-causal attention even if it's slower and later i actually i want to give you an exercise on how to improve it and the third difference main difference is that i made make explicit use of the softmax scale so i actually use the scale when needed another difference is that in the online triton computation of the flash attention is this x is not really e to the power of x but it's 2 to the power of x and then they compensate it with by by using the logarithm however because probably the implementation of 2 to the power of x is faster than the e to the power of x but in my case i retain the original exponential because i want to follow the original algorithm to make it simpler to visualize the code along with the algorithm as in the flash attention paper so i know i have created a lot of hype so let's do it let's start by creating a new file let's call it a program.py just like before when i introduced triton i will start by coding first the code that will use our kernel and then we code the kernel and we will only be coding the forward pass of the kernel so let's start by importing what we need to import which is just the torch and the triton and secondly let's start by let me check okay the copilot is already off so i don't have to worry about that let's start to implement the code that will test our implementation of the triton and compare it with the naive implementation of the attention mechanism so we create our query key and value sequence for testing which is if you remember it's a query is the batch size and it has the dimension batch size because we have multiple sequences each sequence has a number of heads and it's made up of sql and tokens and each token is identified by a head dim number of dimensions if you and then this is because we have already split each token into smaller tokens each each with its own head dimension if you remove the num heads dimension then you put back you concatenate all the dimensions of this head dim we initialize the query key and the value sequence by using a normal distribution this code i already took from the tutorial of triton so it's nothing different and we require the gradient because we want to compute the gradient with respect to query key and value and we will see later why because because we want to implement the back we want to test also the backward pass even though we will not be coding it now so the first thing that we do is we define our softmax scale which is as you remember the formula is a query multiplied by the transpose of the keys and then divided by the square root of head dimension so dk or dd head sometimes it's called and then we need to so we need to compute this one we can already compute it it's this this is the one over the square root of the head dimension and then we also define do and later we will see what is this but this is basically we will be needed needed for the backward pass um don't worry if you don't understand what is do later we will see it let's do the naive implementation of the attention which is very simple which is first we define the mask and we use this mask only if the attention we are computing is causal so as you can see we pass this parameter called the causal that tells if we want to compute the causal attention or the not causal attention and the d type which is a float 16 because we want to work directly with 16 bit floating point numbers we will not be working with fp8 just uh because we don't we don't want to implement my implementation is actually not as fast as the one in the tutorial of the triton website but i believe it's much more easier to comprehend so we define the mask we compute the the product the query multiplied by the transpose of the key divided by the square root of the head dimension so that's why we are multiplying by softmax scale if the attention we are computing is causal then we use this mask that we have computed so we replace all the points all the dot products where this mask is equal to zero with minus infinities and then the softmax will replace this minus infinities with zeros because then we are applying the softmax and the softmax is applied by rows just like the normal attention we compute okay the second thing that we do is we want to um so the output is the product of the output of the softmax with the v so this is the reference output on the naive implementation of um flash of the attention mechanism then we want to compute we want to also derive the gradients of the output with respect to the um inputs and in this case it's the the the v the k and the q later we will see what are we doing here then we want also to we want to compare this reference implementation with our triton implementation so let's do it so our triton implementation will be implemented as a class called triton attention that we will call using this method called apply and later we will see what is this method in which we pass the query key and value if we want to compute the causal attention the softmax scale that it should be using and it should produce some output which is the output of the output of the softmax multiplied by v then we can run also the backward and this backward will be the the same backward that we will compute with the triton attention and then we compare okay and then we can compare uh the result of our implementation so this triton attention dot apply with the reference implementation which is this one here and this should be uh we use the the function all close which basically compares the elements of two tensors and make sure that their absolute difference is no more than this one we are not using the relative distance we are just using the absolute distance between the two elements which corresponding elements of two vectors this uh implementation that you have that we will build will work with the causal attention and also with not causal attention while the uh the one that we saw in the website of triton it only works with the uh the forward pass actually works with the causal and non-causal while the backward pass only works in the case of the causal attention um okay but it's highly optimized the one online so if you want to learn a little more tricks on how to optimize triton kernels there is a lot of knowledge there anyway guys now let's try to uh implement this triton attention at least the forward pass so let's go to implement this triton attention class okay here every time you want to introduce a new operation into torch you need to derive the um you need to implement your operation by deriving from this autograd dot function class so every operation in torch actually if it's the softmax or it's the um i don't know the the relu or the zwiglu or whatever there is it is always implemented as a function is a class that derives from this function and it should provide two methods one called the forward pass and one called the backward pass the forward should produce the output of this operation and the backward should compute the gradient um the gradient with the of the loss with respect to that the the input of that function and later we will see how that works for now let's concentrate on the forward pass to implement the forward pass we need to create a static method that is called forward which takes as input one thing called the context so as you know in autograd in when training neural networks we have the forward pass and the backward when computing the backward pass we need to reuse the activations of each of the computation nodes during the forward pass and this context basically allow us to save the information to uh for the necessary activations that we will need during the backward pass and later we will see in the triton um in the flash attention algorithm what information we need to save in order to compute the backward pass for example what we will need to save during the backward pass we will need to recompute on the fly the soft the query multiplied by the transport of the keys for each block but we don't want to recompute the normalization factor or the maximum value for each row so we will save those two values and actually we will not save two values we will save one value we do a trick called the log sum exploit log sum exploit that we will see later anyway this context is just a kind of a storage area where we can save some stuff that will be necessary for us to recompute the backward and you can see whatever you like then we have the input of this operation which is the query key and value which is a three tensors with the causal if we are going to compute the causal attention and the softmax scale that we should apply based on the one over the square root of the head dimension which we could also compute it on the fly actually by the way by by checking the shape of this but okay it doesn't matter anyway so um the first thing that we are going to do is to extract the shapes of these objects and make sure all the shapes are what we expect them to be so the shape of the query key and value is a batch size by number of heads by sequence length by head dimension we make sure that the head dimension matches for the query key and value they should match because each vector should should be of the same size and then we declare what we pre-allocate the output vector so where we should save our output so as you remember the output in the attention mechanism has the same same shape as the query key and value sequence where the query key and value sequence i want to remind you is not the query key and value of the input of the attention which is a sequence of tokens but it's the output already of the wqwk and wv because flash attention is not concerned with optimizing those metrics multiplication but only the output of the wqwk and wv so we pre-allocate the output tensor where we will store this output which has the same shape as the query key and sequence uh matrix actually actually no not true actually it has the same shape as the query but it may not be the same as the key and value why because there is this thing called cross attention where the query key and value are transposition are different projection through wqwk wv not of the same input sequence but of two sequences so cross attention happens when we have a query that comes from one uh sequence and the key and value come from another sequence and they pass through their own wk wv and they may not have the same sequence length so the shapes of the output of the attention only depends on the shape of the query sequence not of the key and value sequence this is happens during cross attention but usually in language models we always work with the self-attention so that should not happen at least in the causal language models then we have the stage and later we will see what is this stage basically the stage it's just a number that tells if the operation that we are going to do later is for the causal attention or for the not causal attention and then we need to define our launch grid the launch grid tells us how many parallel process we need to be launched by triton actually they will be launched by cuda but by we always work with the triton as an interface to cuda so by triton so in triton as i said before we want to parallelize along the batch dimension so each batch each sequence in the batch should work independently from each other not only each inside of each sequence in the batch each head should work independently from each other so at least we have a batch size multiplied by number of heads programs and for each of this program we have another dimension called the we divide the query into blocks of queries so as you remember when talking about a block matrix multiplication we don't work with the query as the original matrix query matrix so where each query is one vector or one token we work with group of queries so each block of queries is a group of tokens in the query sequence so we are saying that we want to launch at a number of kernels or blocks of threads or a group of threads along two dimensions so just like the cuda kernel can be launched along two dimension x and y here we are launching programs along two dimensions one dimension that tells us which batch which head of which batch we are going to work with so which head of which batch element are we going to work with and inside this we are going to say okay this is a sequence which group of queries are we going to work with are we going to going to work with so overall and the group of queries is what is the sequence length divided by the number of queries that we want to group together so the block size cube tells us how many queries are there in each block of queries so this cdiv is just the ceiling division so it is equal to let me write it here this is equal to ceiling of sequence length divided by the block size q this tells us how many blocks of q we have so let's rehearse we have a tensor that is q that is batch size by number of heads and each flash attention algorithm will work with the following the sequence length head dimension moreover we have seen that the flash attention has two loops one is the outer loop among all the query blocks one is the inner loop along all the key block we have seen that the query block can work independently from each other so we can spawn as many programs in parallel as there are number of blocks of q because they can work in parallel so this grid tells us how many programs there are that can work in parallel then it will be the gpu that based on its resources will decide how many program actually to work in parallel if it has enough resources to make them all work in parallel wonderful if it doesn't have enough resources to make them work in parallel it will launch them sequentially one after another and the last dimension is this is like the z dimension in the cuda in the cuda launch grid and we don't want to use it because we don't want an additional level of parallelism all right this is our launch grid so we will launch a number of programs that is this one a number of programs of parallel programs or number of parallel kernels and each kernel in triton work is a group of threads which is a batch size multiplied by number of heads multiplied by a number of blocks of q so how many blocks we have we divided the q sequence into okay let's continue so then we will see what is this one so this m is another matrix that we will need and it's the log sum expo for the backward pass and we will see at the end of this video what is not at the end of this video but at the end of the forward pass what it's needed for but basically this is you can think of it as the maximum for each row um you we to to recompute the query multiplied by the key in the backward pass we should also have if we don't want to recompute the maximum for each row and the normalization factor of the softmax we should save two things one is the maximum for each row and one is the the normalization factor however by using the log sum exp trick we can only save one value which is the as you can see in the algorithm of flash attention it's this stuff here which is let's see here it's this stuff here so this li which is the maximum for each row plus the logarithm of the of the normalization factor and basically in when computing the backward pass we need to recompute on the fly this block here so this query multiplied by the transpose but to apply the softmax as you remember we need to have the maximum for each row and the normalization factor so we don't we don't recompute them during the backward because we have already computed them during the forward so we save this information but we don't need to save these two information separately we can aggregate it into one single value called li and later we will see how we can use it all right so we have defined also this one and we can proceed further so now we launch our grid our kernel don't be scared it's going to be a little long so here so we are launching the the kernel for the forward pass by defining what is the launch grid so how many of this program should run in parallel at most we are passing the query we are passing the key we are passing the values we are passing the softmax scale the m which is the information that we need to save for the backward pass it's actually the l in the code of the pseudo code of the flash attention algorithm here i call it m i think because also in the original code it was called m the o where the our kernel should save its output and then as you remember we don't get all the nice access by indexing tensor like we are used to in torch we only get a pointer to the starting element of q a pointer to the starting element of k and to the starting element of v and then we have to figure out all the index in the memory of the other elements how to calculate the index we need the stride because the stride tells us how many elements to skip to go from one dimension to the other and that's why we are passing the stride for each dimension of each tensor actually in our case we are only working with q k and v that are actually of the same d type and of the same shape so we should not need actually to pass all all the strides for each of these tensors because they should have the same strides however in the original code i believe they were passing it so i kept it so the stride allow will allow us to index this pointer so to understand to access the elements of of this tensor just by using its starting the pointer to the starting element and then the strides we will be able to index any element we want in the tensor then we pass the information of these shapes so the batch size the number of heads the sequence length and the head dimension and which is the same for all of them and then the stage the stage indicates if we are going to compute the causal attention or not causal attention so let's not implement it and let's continue writing this method so then then we need to save some information that we will be needed for the backward pass which is this context variable that i told you before so we save some information for the backward pass which is the query key and value which are the tensor for which we want to compute the gradient during the backward pass and then we need to store also this m tensor and this o tensor and then we can all we need to also store the causal variable so because if we computed the causal attention during the forward forward pass then during the backward pass we need to have this information because we need to mask out the things that we don't want to contribute to the gradient but we will see that later when computing the backward pass for now let's concentrate on this attention forward so we need to implement this forward kernel that you can see so underscore attention underscore forward method now a triton kernel is just a python method with a particular decorator called triton.git so we copy and paste the signature so this is what makes a method become a triton kernel and as you can see here we pass the query key and value matrix along with other information the emmetrix so please don't confuse the emmetrix with the mask that we will apply on the fly we will generate it on the fly because we are only concerned in this case with a causal attention or not causal attention we do not accept custom masks here and then we pass the strides of all these tensors the batch size the number number of heads the sequence length the head dimension which is the shape of each of these tensors and the block size q and the block size kv the block size q indicates how many queries we want to group together to make one block of the q matrix and how the kv indicates how many keys and values we want to put together to make one block of the k and v matrix which is what we do when we do block matrix multiplication this stage is a number that indicates if it's a causal or not causal attention we are doing so it will be three in case it's a causal and one in case it's not causal okay the first thing that we do is to verify some information so we verify that the um the block size of the kv is less than or equal to the head dimension to be honest i don't think we need it with my code because i removed most of the constraints so this uh this check was also present in the original code so i kept it but it all depends on how we are later we will see what is the auto tuning process and later we will see what variables we are going to auto tune for and how many stages we will choose how many warps we will choose etc etc so let's leave it for later you can comment it or you can keep it it shouldn't matter um the first thing that we do as i said before we launch a grid so a grid is a series of programs where we will have some identifiers like in the cuda we had an identifier for the blocks on the x-axis and on the y-axis in triton we get this identifier for the programs we launched um um uh sequence length divided by block size q number of programs along the zeroth axis and the batch size multiplied by number of heads on along the first axis of the grid of the launch grid which will help us identify which um part of the query we are going to work with in this program in this kernel and also in which batch and on which head this program should work with so that's what we are going to do now we are trying to understand what part of the input we should work with based on the ids of the program which corresponds to the block id in cuda so let me copy so the program id zero indicates it's this stuff here tells us which part of the queries so which block of the queries we are going to work with why do we have a block on the query because as we saw before the output can be computed independently for each block of the queries while each block of the query has to iterate through all the key and values so this is what will tell us what is the index of the block of the queries we are going to work with in this particular program then we can understand also which index which batch and which head with this program is associated with the program id number one is the product of the batch size and the number of heads it means that we will have as many programs on the axis number one as there are indicated by this product so this product lets us understand this product will tell us which batch and which head this particular program is associated with so to get the id of the batch we just divide this number by the number of heads and it will give us the head index and to get the head index inside this batch we just do the this number here modulus the number of heads okay the next thing that we need to do we need to okay first of all when we pass a tensor because as you can see here the q parameter to this attention forward method is a tensor because it's the input of this function forward function and this forward function is called here when we do attention dot apply and it's this q stuff here and this q stuff here has been created as a tensor so when we pass a tensor to a triton kernel it's not really a tensor it is a pointer to the first element of that tensor in the memory now we need to understand because now we know which batch we are going to work with and which head we are going to work with we need to index this tensor to select the right batch and the right head inside of the right batch which means that basically we have this q tensor so we need to do some some sort of like some stuff like this like q of the index batch and the number of the number of heads indicates the which head we are going to work with so it should be index of head and we need to select everything that is inside these indices so we are we need to enter the tensor at the right location where the particular sequence length and head dimension for this batch and for this head starts for that we need to generate an offset in which we need to move this tensor from because this tensor this pointer sorry this not answer this pointer from because this pointer is pointing at the beginning of the entire tensor so we need to move in the batch size dimension and in the number of heads dimension to do that we generate the following offset which will tell us where this uh where this particular batch and where this particular head starts in this tensor and to do that we need to do the strides we need to use the strides so what we are going to do is we are creating we're going to create the qkv offset this should be the sequence length which will be the index batch multiplied by the stride for the batch dimension which will tell us how many elements we need to skip to get to the next batch and it's based and we multiply it by the index of the batch that we want so for the zeroth batch we don't skip anything because we are already pointing to the first element of that batch but if we are at the batch one we will skip that many number of elements plus we also need to skip the some heads how many head we need to skip based on which head we are going to work with and what tells us how how to go from one head to the next the stride of the head dimension so we multiply the index head so the head that we should be working with with the stride q head all right then we select now triton helps us with a new function that i think it was quite recent that helps us index element inside of a tensor without having to deal with all the complex indexing maths that can be confusing for beginners so i will be using a few methods to help us with this with this indexing and this function is called make block pointer and it's this following so basically this make block pointer takes as input a vector and sorry a pointer not a vector it takes as input a pointer in this case we are saying create a block that has the following shape that is sequence length by head dimension so let me do it one by one actually i don't want to confuse you guys with all this stuff altogether okay so take a start there is a pointer that is right now pointing at q plus q kv offset so right now it is not pointing at the first batch but it's pointing exactly to our batch so the the batch that this particular program should be working with and inside this batch to the particular head that this program should be working with which is basically saying that we have we are pointing to a tensor that is as follows so we are pointing to the following tensors which is the right head the right sorry the right batch and the right head and then we are selecting everything inside so it's pointing to the first element of this particular tensor this tensor particular tensor because we have already selected the batch and the head it is a two-dimensional tensor with this the following shape because the following dimensions are sequence length and head dim so we are saying take this pointer which contains a tensor of the following shape sequence length and head dimension and i'm also giving you the strides of these dimensions that are in this pointer so the the two dimensions that are that we need are the sequence dimension and the head dim dimension which is this one for the q tensor and um and in this um in this query tensor we want to select a block of queries based on the query on the block of queries that this program should be working with so i think i need to maybe probably use the ipad otherwise it can be very confusing to visualize so uh let's do it actually let me see if i can create another here and let's use the ipad all right okay so we have a q vector q tensor because this construct we will be using it for all the other tests so if you understand it for one tensor you understand it for all the others we have a q tensor that is a batch by number of heads number of heads then the sequence length and then the head dimension with the following line so the this line here so when we create a q plus qkv offset we are already selecting the right batch dimension and already the right head dimension which means that we have already forwarded our q to not point to the first batch and the first head but to point to the exact batch that this program is working with and the exact head that this program is working with which basically means that right now it is pointing at a tensor that is made up of these two dimensions now inside of this tensor we also need to select the right block of query that this program should work with and this dimension here so the sequence dimension is all the queries so we need to select the right queries so we need to skip some queries how to skip some queries well we say that we need to skip block index multiplied by block size q number of queries because they will be processed by another by another program that will have this number here the program id will be different so we are selecting with this line not only inside of the q the right index and the head but also the right position in this dimension in the sequence length dimension that will point to the exact to the starting point of the exact query block that this particular program should be working with this is what is happening and we are also creating this block basically later we will see how it can be used to to create a block of the shape we are telling what is the the size of this tensor so this tensor has two dimensions because we are pointing to the beginning of the right query sequence so it has only two dimensions the sequence dimension and the head dim dimension so it's the last dimension and we are already pointing to the right beginning of the sequence dimension because we have already skipped some queries why we are skipping some queries because these queries will be handled by another program that will have a block index q to some other values and this order actually i don't know what is this order you can try to put 0 1 and 1 2 i think it's some optimization that triton does i have read the online documentation and i couldn't find anything about it so this is something that i will investigate but actually even if you put 0 1 it doesn't matter so i think it's something that you tell triton if this you want the transposed of this block or you want the not transposed version of this block and later we will see actually how we can transpose the key block without doing any transpose operation actually we will just change the strides like we have seen before so um now this make block pointer is not something that is necessary but it makes our life easier when we will index this particular pointer so we can treat this pointer nearly as nearly in the same way when we work with the tensor in pytorch we will be able to skip one increase one index in one dimension without having to do the computation of the strides later when doing the backward pass i will not use this one and do all the pointer indexing by hand so you can check the differences of indexing a tensor by using make block pointer and not by using it anyway to rehearse what are we creating we are creating a pointer to the right index in the batch to the right index in the head dimension and we are already skipping some queries based on the block index queue so this pointer is already pointing to the right block of queries that this particular program should be working with let's look instead at the v and the k block now so let's copy the v block now which is similar to the query but we are not going inside we are only indexing by the index batch and the index head so what this one actually let me write it here is already skipping so this amount of queries this is what we are indexing with this make block pointer so we are in the right batch in the right head and we are skipping some queries here we are just indexing by batch and by head so we are doing v of index batch index head and we are not selecting we are not skipping anything because you see this offset is equal to zero in the first dimension in the second dimension so we are not skipping anything on the sequence length and we are not skipping anything in the head dimension dimension head dimension dimension um all right so let's look at the k block pointer and this is different because as you know when computing the flash attention algorithm we need to have access to the block of queries and all the block of the key transposed so when accessing the key we shouldn't access it like we are accessing q we should invert the two dimensions that we want to transpose for and that's very simple with make block ptr and you can see it here we say that we want to point to the right index and to the right head and the tensor inside of it so let's let me write it here so later i can explain in line by line so what we are doing here is go to the k tensor select the right batch select the right head select everything that is inside so it's a tensor of two dimensions with the sequence length and the head dim because we you can see here here sequence length and head dims etc but we don't want first sequence length and then head dim we want first head dim and then sequence length so we want to transpose it how to transpose it we just say that you need to read this tensor with the two strides transposed so we are saying first use the stride of the dimension dimension and then use the stride of the sequence dimension and the shape of this tensor is not sequence head dim it's a head dim sequence and it's a block of kvs why we are not putting directly the sequence dimension here because we want to skip block by block later so we are not selecting all the sequence length in the sequence dimension we are just selecting a block of kvs and later we will use another method to go to the next block so i hope that by showing you the indexing like this it's a little easier to follow the indexing so for each tensor we are going in the right batch in the right head dimension and for the query we are skipping some query blocks because each each program will work with a small different query block but for the key and value each program needs to iterate through all the key and value so we just point it to the first key and value block and then we will advance by one block by we will advance one block by one during the for loop that we will do later then in the output also we need we can make a tensor block tensor this basically creates a pointer just like in the query key and value case in which we select the right index batch so what we are doing is we are indexing by batch we are indexing by head and we are selecting everything is that inside i know we are not selecting everything inside we are skipping also in this case some blocks of queries because as i said before the output has the same shape as the query so um this particular block this particular program that we that will have this particular block index queue will only work with one block of the queries which will produce only one block of the output matrix and we need to select exactly that one so we we can point this pointer exactly to the point where we should start writing so let's skip also in this case block index q multiplied by block size q rows so we select exactly the block that our our program this particular program will produce when i speak about this particular program i mean the program that is identified by this program id in the x0 axis and this program id in the first axis because each of this program will run in parallel hopefully and each of them will have a different value for the block index q and index batch head okay now we have pointed our pointers to the right position where they should either read some information or they should either write some information by using make block pointer these pointers can also be treated directly as tensors so that's why we specify the shapes of this tensor because python triton right now provides some methods to work directly with blocks of to work directly with the pointers like they are we are accessing um tensors so we can index them like tensors all right so basically just try it on doing some calculation for you based on the strides so you don't have to do it by hand but later when we do the backward pass we will avoid using big block pointer and we will see the indexing done by hand all right um as you know we are processing a single block of queries so let let's go back to the algorithm otherwise we we lose the sight of what we are doing so let's go here and let's show my ipad all right so as you know each program we will parallelize along the query block dimension so each program will work with a different query block and then we need to do a for loop on all the key and value blocks right now we just moved our pointers to the right position to select the right query block that we should work with and to the beginning of the keys and values block that we should work with based on which index and which head this particular program should be working with all right now that we have pointed our pointers to the right position in which our program should be working it inside of the big pointers that are inside of the big tensors that are the that have the batch dimension the number of heads dimension the sequence length dimension and the heading dimension we have because we are pointing to the right batch and we are pointing to the right head these tensors have become two-dimensional tensors so they only work on the they are only tensors on the sequence length and on the head dimension now we need some more information that we will use later the first information that we need is the offsets of each query inside of the current block of queries that this particular program should be working with and that is given by the following line so let me copy and paste which is this one so the offsets of the queries are the first of all they are how many of them block size q because each block of queries is made up of a block size q number of queries what is each query it's a token and it's on the head dimension is the dim dimension is not the all the embedding of the token but a part of the embedding of each token which part the part corresponding to the head that this particular program is going to work with so we are generating the offsets that will load this particular number of this particular queries from the big tensor that contains all queries and we know that our queries start at the block index q multiplied by block size q position so if this is the program number zero they will the imagine block size is equal to four they will be the query with index zero one two and three but imagine we are the program number three which means that we need to skip three multiplied by four so 12 so it will point to the query number 13 14 15 and 16 etc etc etc all right and we do the same for the key and values initially the key and values is a range of keys and values that we need at each iteration and at the beginning because our pointer for the k and v is pointing to the beginning of the sequence of key and value for this particular batch and for this particular head we are pointing to the first block of key and value so we are not skipping anything in the query case we are skipping because our program will only work with one single block of queries in this case we don't skip anything because we need to iterate through this key and values so we are pointing to the first block of key values so imagine block size kv is equal to four so this stuff here will be equal to zero one two and three all right now we need as you remember inside of the flash attention algorithm we need to compute a block of query multiplied by the transpose of the keys and to each of this block we need to apply the softmax star if you remember what is the softmax star it is the softmax of without the normalization so while computing the softmax star we also actually compute the normalization factor without applying it and we apply the normalization factor at the end so for each block of query multiplied by transpose of the keys we need to have the maximum for each row in this particular block and the normalization factor for each row so that's why why we need these two following statistics which is this one and this is basically a block it's a block of numbers how many based on how many queries we have in our block of queries each one initialized with minus infinity just like in my algorithm that i showed before so let me go back to the slides in case we forgot or actually you can also check the flash attention algorithm we initialize it with minus infinities so so far we are creating this stuff here so we are initializing the mi we are we will be initializing the li and we will initializing the o and then we will show the inner loop here and this is exactly the algorithm that we have seen before so we initialize m with minus infinities now we initialize also the l's so let me go back to the code all right so the l's are initialized with this number here so here in the o blocks as we can see from the flash attention algorithm they are in the o block is initialized with zeros so that's why we initialize a block this is the output block that this particular program will compute which is based on the position in the batch and the position in the index so it is one block of the size of block size q so how many queries there are in this block by head dimension which if you want to visualize it let's go back to the slides it is equal to one block of this matrix here so it's one block of the output matrix so one row of blocks one block of rows okay so let's go back to the code now all right so now we have initialized a little stuff here so the output the mi and li where mi is the maximum for each row in this particular query block and li is the normalization factor for each of the items in the query for each of the rows in our query block now we need to do the for loop the inner loop in the flash attention algorithm we will create a separate method that will run the inner loop so let's let me copy it here and i am following the same structure of the code that you see in the tutorial of the triton website so basically if we are running the causal attention or even if we are not running the causal attention we make this for loop and then we will make another for loop and i will show you why so let me first write it and then we will see so this function here will be the inner loop this inner loop needs to go through all key and value blocks one by one and for each query and value block it needs to fix the previous calculated block of the the previous softmax star block so basically what we are doing here we will need to create a function as the following where we are going to iterate on all the key value block we will need to compute the query multiplied by the transpose of the keys using the query block that is fixed for this program and the key is block is the one that we are iterating it through and for each of these queries we need to calculate what is the maximum for each row we need to compute the softmax star so the softmax without the normalization factor we need to keep this statistics l which is the normalization factor that we will apply at the end of the iteration of the for loop and at the same time we need to update the output so as you remember the output is p11 multiplied by v1 plus p12 multiplied by v2 but we need to fix the previous p11 so to fix that we need to every time we sum to o to the output we need to fix the output of the previous iteration and then we increase introduce the p and v block of the current iteration so here the author of the code for the the one that you see on the triton website decided to split this for loop into two steps why because in the causal attention we need to when we have a causal attention we have a group of we we don't we don't want the query to attend the keys that come after it while in the non-causal attention we let all the queries attend to all the keys which also means that we will need to have some kind of if statement inside of this if in the side of this for loop through all the key and values in which we need to check if this particular query that we are working with is comes before or after the key and value in case we are doing the causal attention so instead of iterating through all the key and values also in the case of the causal attention by splitting it into two steps we are saying first let's iterate through all the key and values for which the index is smaller than the current queries block and for this we need to compute the attention in the case of the causal and non-causal case then for all the elements on the right of this block so for which the key index is more than the q index in the case of causal attention we don't need to compute anything because it will be masked out because in the soft max it will become zeros so it will not contribute to the output so we don't even have to compute it this is why we split this this for loop into two steps so first we iterate to all the parts that are left to the diagonal of the query multiplied by the key matrix so for all the values for which the query index is less than the key index then we and then we skip all the parts to the right of this diagonal in case we are working with the causal mask but in case of the non-causal mask we compute the left part and the right part of this diagonal all right don't worry when we record this for loop it will be more clear so i just wanted to give a little introduction so let's go uh code this inner loop what will this inner loop do it will work with this particular query block that we have found so this q block it will uh right i don't see the q block because i didn't load it well yeah let's load it so we need to load the query block actually we forgot to load it so as you remember in triton we load data from the high bandwidth memory to the sram so to the shared memory by using the load statement and we are telling load the query block that we should be working with because this pointer q block ptr is already pointing to the right block that we should be working with so it's already skipping all the blocks that other programs should be working with and it will load a a tensor of size of block size q head dim so the right block of queries and we pass it to this inner loop to which we pass the output so where it should write this output the li and mi which are the statistics for the rows and for the maximum for each row of each query and the li which is the normalization factor for each query and the query block this program should be working with the beginning of the key and value block pointer because we need to iterate through them so we just point it to the beginning and then inside the for inner for loop we will iterate through them then the softmax scale that we should use when computing query multiplied by the transpose of the keys the block size so how many queries we have in each block of q and how many key and value we have in each block of kv this is a stage that tells us what if we are on the left side of the diagonal or on the right side of the diagonal so it will tell us if we need to apply the causal mask or not based on where we are and if we are need to apply the causal mask the offset q and the offset kv are just the offsets of the query and key inside of each q and kv block which is a list of indices that tells us how many queries we have and then the sequence length the entire sequence length because in the for loop we need to iterate to all the sequence length block by block so block of kv block of kv block of kv all right let's write this let's write this method and later we actually need to continue this method again so let's go and let me go here all right so this method we have already seen the signature so it's just another kernel so it can be called by the first kernel and this is something you can also do in cuda you can actually call call one cuda kernel from another cuda kernel and then we based on the stage of this inner loop we decide what we need to do so when we are using a causal attention so we only want to apply the attention to the queries for which the index is less than or equal to the key so we all want the query to know or attend to key and value that come after it then we pass the value three for the stage parameter now when we in the causal case this will become four minus three it is equal to one so what will happen is that we will only work with the range of keys and values that are from zero up to the current block of q so all the keys that whose index is less than or less than the the index of the queries we are working with so to the left part of the causal mask let me draw it otherwise i think it's going to be very difficult to follow so let's do it actually so let's open a new one and let's go here all right so we have been using this one before so we can do it again clear page all right in this now i i want you to think of the following matrix as a block matrix so let's draw it in pink because i have been drawing it all in pink we know that in the rows of this query multiplied by the transpose of the keys we have a uh the queries blocks of queries so we are not watching one single block we are watching all the blocks right now so this is the query block one this is the query block two this is the query block three this is the query block four each of this query block is made up of multiple tokens of queries and then we have the key the key blocks let's do it like this very ugly but okay key one key block two key block three key block four when apply calculating the attention when you calculate the causal attention so like with the causal mask you want only the query to attend to keys that come before it so when we apply the causal mask this stuff here will be made up of zeros this stuff here will be made up of zeros this stuff here will be made up of zeros and this stuff here and this stuff here and this stuff here all made up of zeros we never have to mask out anything when we are in this case because well when we are in this particular scenario actually in this particular scenario we don't need to mask out anything for sure why because all the key keys in this block so in this block of keys will have an index that is smaller than the index of the corresponding queries in case the the key the block size of the query and the key matches so imagine each query is made up of three queries so each block of query is made up of three queries so this is the query number 0 1 and 2 this is the query number 3 4 5 3 4 5 yeah this will be the number 6 7 and 8 and this will be the query number 9 10 and 11 in total we have 12 queries we will have the same indices also for the keys in case we choose the same size for the blocks so this key block here will be the key number 0 1 and 2 this will be the key number 3 4 5 this will be the 6 6 7 and 8 etc etc now what happens is that in this case as you can see the key indices of the keys are always smaller than the indices of the queries so we don't need to mask out anything even in the case of the causal mask because we are sure that in this case all of these dot products will never be masked out also in this case all these dot products will never be masked out and also in this case will never be masked out will never be masked out and will never be masked out and in this case however along the diagonal some of the queries will be more have will have an index that is bigger than that of the keys and some of them will not be will not have an index that is bigger than that of the keys because these are blocks of queries and blocks of keys some of them need to be masked out and some of them don't need to be masked out so we are dividing our for loop into multiple steps the first step that we are doing is all to the left of this diagonal in which we don't need to mask out anything then we will see another step here in which we we need to mask out and then everything to the right of this will be we will not even compute in the case of causal attention because we already know it's made up of zero so it will not compute so the product query multiplied by the transpose of the keys after the softmax will be made up of zeros so if you look at the flash attention algorithm so this stuff here the contribution will be zero because we are multiplying zero with v it will be zero so we don't need to change the output so why even compute this part of the matrix if we already know it's not going to contribute to the output so we just skip all those iterations and this is why we are splitting the for loop i hope now it's much more clear all right so let's go back um okay so uh we are now to the left part of the diagonal in case of the stage number one in the case of the stage number two it's the part in exactly on the diagonal so in which we need to do some dot products and some other dot products we don't need to do and then for the non-causal attention we just go through from zero to the sequence length without doing this multi-step because we don't need to mask out anything so this is why we have this stage this tells us what is the lower and higher index of the key block that this particular stage should be working with all right um now this function here multiple of is just telling triton that this number here is a multiple of this number so triton can make some optimizations so the stage one happens when when we are doing a causal attention so stage number three in this function and four minus three will become one so imagine we are in the causal attention we will go through the key and value block that are to the left of the diagonal with respect to the query block that we are working with um in the case we are doing not causal attention in this first call to the inner function this the stage will be one so the four minus stage will be equal to three so we will execute this part of the if statement so we will go to all the key and values in case for the causal attention only as you can see here we will do another iteration here that will only be done along the diagonal in which we need to mask out something and we don't need to mask out something because inside of each blocks there will be some keys that have the index below the index of the query and some that have above the index of the query so only in the causal attention we will call this function twice the first time with the stage equal to one and the second time with the stage equal to two and the second time we will only iterate through the group of key v blocks that are exactly on the diagonal of the matrix query multiply by transpose of the keys the big matrix that is made up of all the blocks all right now that this should be clear let's proceed further so let's um because we need to do the for loop the inner for loop of the flash attention let's go and load the first blocks of key and values which is exactly the one that the key and v blocks are currently pointing at which is the 0 0 block so uh we we define the the pointers basically um we we we point the key and value blocks to the first uh key and value block that this um for loop should be working with which will be based on the stage so if it's the first call to this function they will be pointing to the first block in the case of the causal and not causal if it's the second call to this function which only happens in the case of the causal attention they will be pointing exactly to the key and value block to the diagonal all right then we need to make the for loop so let's loop over all the for loop so let's do it so loop over the key and value and what we do is um okay we we let the compiler know that this number here the start kv will always be a multiple of the block size kv because we will be moving from one kv block to the next kv block block by block so we let the compiler know that this number here start kv is a multiple of block size kv it doesn't change anything from a logic point of view we are just telling giving some hint to the compiler so it can do some other optimization that triton does now the first thing that we see in the flash attention algorithm is we need to compute the product of the query so this is the particular block of the query that we are working with with the current kv block in this iteration so let's do it so we compute k and b so we load the the query have already been loaded by the caller of this function we have loaded it here here we have already loaded the query but we need to load the current block of k so we load the current block of k indicated by the k pointer and we multi we do the matrix multiplication of the current block of query the the block of query with the current block of k which is already transposed because when we loaded this k k when we defined the k block pointer we defined it already with the stride changed so we are reading the tensor already transposed so we are doing the query multiplied by the transpose of the keys basically okay now let's do here this part here basically saying okay if the stage is two when the stage is two is when we are exactly on the diagonal we know that some of the queries will have an index that is bigger than that of the keys and some of them we have an index that is smaller than that of the keys so we need to apply the causal mask only in this case so basically what we do is we define the mask that we should be applying so the mask will mask out all the values for which this mask is not true so when this mask is true when the index of the query is more than the index of the k and v's and we okay we apply the softmax scale so as you remember we here we only computed query multiplied by transpose of the keys but we also need to divide by the square root of head dimension and we do it here and then we because we already computed the the product we can calculate the maximum for each row and then we we we subtract because when later in the flash attention algorithm we have another operation which is the which i call the softmax star and as you remember the softmax star needs to do each row my each element of the s matrix so the query multiplied by the transpose of the keys minus the maximum for each row so we can already compute the maximum for each row and we can also before computing the maximum for each row we need to mask out all the elements that will be masked out in the stage number two which is along the diagonal and how to mask out we just replace with minus infinity before applying the softmax all the values for which the mask is false so right now we are we have computed what we have computed the query multiplied by transpose of the keys we have masked out in case we need to mask and when we need to mask only when we are along the diagonal in all the other cases we don't need to mask out anything we just multiply by the softmax scale and then we we subtract the mij the mij is the maximum value for each row because we need to compute the softmax star operation which is the softmax without the normalization which in the flash attention algorithm is exactly this operation which will produce the pij okay so let's go here so now we can compute the pij block which is this stuff here which is the exponential of the query kv block variable here which have already subtracted the m so we have already subtracted this mi at the previous instruction so now we can just apply the exponential and this is what we are doing here okay then we need to compute the sum of the the rows for the before the normalization factor so for the current block we will have a list of we have we have the pij block for the current kv block to compute the normalization factor for the softmax we need to keep summing up these exponentials and later we will fix the exponentials the the normalization factor that we computed at the previous step but we will do that later so now we just computed the normalization factor for the current block which is just the sum of all the values on a single row which is the same as what we did before here as you can see here when i show you the algorithm so for each block we do the row sum as you can see here of the p matrix what is the p matrix is the exponential of the s minus m and for now we didn't apply the the correction to the previous block that's it so we computed the lij for the current k and v block and then we compute the correction factor for the previous block so the correction factor for the previous block if you remember the formula from the paper is this one is the exponential of the previous estimate of the maximum minus the current estimate of the maximum which is exactly this one so the previous estimate of the maximum minus the current estimates of the maximum we will see later why mi is the previous estimate of the maximum and mij is the current estimate of the maximum because it is coming from the current block that we are computing mi is the let's say the the one that it is the the one of the previous iteration because later we will override mi with mij but i'm just following the flash attention algorithm so far so i am computing the correction factor of the previous li which in the flash attention algorithm is let me show you this stuff here so it is this stuff here this one here okay and then we apply it so apply the correction factor so we apply it so we apply the previous li with the correction factor plus the current li which is the one coming from the current p block the one that will be computed with the current k and v with the current iteration and right now we are doing this operation so li is equal to the previous li multiplied by the correction factor all right and then what we need to do okay we need to as you remember the formula is we calculate the p block and then we need to multiply by the v block so we need to load the v block so let's load it we load the v block based on the pointer of the v block to which this um to to which the pointer v is is pointing to at the beginning of this iteration in case we are in stage number three so in case we are doing for example not causal attention this will be pointing to the first k v block v block and then okay here there is just a type conversion so we make sure this is in floating point 16 and then we compute the output block so we are computing the following so we just take v p multiplied by v and we add it to the output and this is what we are doing here we take p we multiply it by v and add it to the o block let's go actually to this line one by one so first of all we need to fix the previous output block with the correction factor correction factor that we have here so we can fix the previous block with this alpha term here which is the correction factor for the previous block and so we just fix the previous block for now but we didn't add the new pv so to add the new pv we do the dot product of p and v and this third argument tells the dot this not dot product it's actually the matrix multiplication tell this matrix multiplication to use this element here as the accumulator so this is exactly the same as doing p block multiplied by the v block o block plus equal to p block multiplied by the v block this is just optimized because anyway this dot function here needs some place where to store the intermediate results so why not just store it where it should actually go and because it the dot the the matrix multiplication is just a dot product and the dot product is just a repeated sum this accumulator will be will this dot will keep summing the result to this block here which will exactly result in this instruction like we have done the matrix multiplication separately and we added it to the o block so this is uh that's why this argument is called the accumulator okay all right so we have also computed the output and then we save the new estimation of the maximum for the current iteration and it becomes mi so at the next iteration we can use it to calculate the correction factor and then we have finished for the current block and then we can move on to the next block so we advance our k and v pointers by one block of k and v we advance it differently because we know that the v block is a pointer to a tensor of shape let me write it here this is a tensor of shape sequence length head dim so we need to increase the sequence length by one kv the block size kv while the k block is actually the k transpose block so we need to and it is a transpose because we have exchanged the strides and the shape so it is head dimension head dimension sequence length so we don't change the head dimension we just advance the sequence length by sequence block size kv so basically we are just going to point to the next block of k and to the next block of v i hope you were able to follow the algorithm of flash attention i try to use the same names i try to use the more or less the same logic and always writing the formula that i'm referring to so hopefully you didn't get lost i think the only difference that there is between the flash attention algorithm as written on the paper and this code is probably this alpha which is the correction factor but i hope it's easily understandable anyway then we just return the o block so o block li which is the the normalization factor for each row in the current output block which is also a q block because we are working with one q block independently from the other programs and mi is the maximum value for each row which will be needed for the backward pass because when in the backward pass we will compute the qquery multiplied by transpose of the key block on the fly we need to also apply the softmax but instead of re-computing the stuff which we already computed during the forward pass we just save them and reuse them during the backward pass which will save us some computation now i know it's time to talk about the log sum x trick because we are going to use it so let's go back to the old method so let's go here all right so we have computed two calls of this function in case we are working with causal attention in case of the we are computing causal attention we call this function once to work with all the query blocks that are to the left side of the diagonal of the query key matrix then we do another call of this function to work only with those blocks of keys that exactly lie on the diagonal of the query key matrix because in this case some of the values need to be masked out and some of them do not need to be masked out moreover by doing this we can avoid computing the dot products for all those values in the causal math in the causal case for which the key is index of the key is higher than the index of the query saving some computation because anyway they will be resulting after the softmax in zeros and they will not contribute to the output so it should be faster okay now let's go back to the this method here so calling method and there is one last thing that we need to do which is we need to compute the log sum exp and now i will show you what is it so in order for the backward pass to recompute the softmax without having to recalculate the normalization factor and the maximum value for each row we should be actually saving two different stuff one is the maximum for each row in the query block and one is the normalization factor for each query in the query block however there is a trick and the trick is okay it's not really called log sum exp trick because the log sum exp trick is used for another purpose but let's call it log sum exp trick number two so the log sum exp trick number two is something like this so let me open the slides so when we do query multiply that transpose of the keys we get a matrix that is made up of dot products so something like this like this is one dot product so let's call it query one transpose the key one query one transpose the key two this is a query two transpose the key one and this is a query two transpose the key two then we need to apply the softmax right so the softmax is what is the let's write the formula of the softmaxes for each of these vectors so this is a vector and this is a vector because we applied it by rows for each of these vectors this will modify element wise each element as follows so the softmax of x i is equal to the exponential of x i minus oh my god i didn't leave enough space so let's move this stuff here back and this stuff here please left all right it will be the softmax of the exponential of each element minus the maximum for the current vector to which we are applying the softmax divided by the normalization factor which is the summation over all possible j's where n in this case is equal to 2 because we have each vector is made up of two elements of the exponential of x i minus x max now imagine we already have x max and we already have this summation in the flash attention algorithm in the forward pass this stuff here is called l i and this stuff here is called m i what we are going to save in the code you can see here we are saving actually not m i and l i separately we will be saving m i plus the logarithm of l i so we are going to save m i plus the log of l i so what will happen is that when we will compute the compute the backward pass we need to recreate this matrix here on the fly which means that we need to recompute the query multiply by the transpose of the keys and we to and then we should apply the softmax to apply the softmax we should need this stuff and this stuff here but we have only this stuff here so this is the m i plus the logarithm of l i so when we're computing the softmax we will compute the following so we will compute the softmax as follows we will define let's call it a new softmax so let me use another color here we will apply the softmax as follows so softmax of x i let's call it the softmax 2 because it's a i don't want to confuse softmax is equal to the exponential of each element minus we will subtract this value here the one corresponding to the current row to which we are applying the softmax so it will be the exponential of x i minus m i minus the log of l i if we expand this expression this will become the exponential of because the exponential the sum of two exponential of the sum is equal to the product of the two exponentials we can also write it like this so it will be the exponential of x i minus m i divided by the exponential of the log of l i which guess what it is equal to the exponential of x i minus m i divided by l i which is exactly the normalization factor and we also have m i so instead of saving two values we save only one value and when we apply it the exponential's properties will take care of actually also normalizing each value to which we apply it if you don't remember the properties of the exponential it is very simple so the exponential of a plus b is equal to the exponential of a multiplied by the exponential of b and the exponential of a not exponential it's the exponential a minus b is equal to the exponential of a divided by the exponential of b and this is the the trick that we're using so that's why we don't need to save two different values we just need to save one value and then when we apply it it will automatically be taken care will take care of normalizing because of the properties of the exponential all right let's move forward so we have also created this value that we will use during the backward pass now as you remember in the flash attrition algorithm we don't normalize each block while computing it we normalize the output at the end and this is exactly what we are going to do here so we normalize the block at the end after we have computed all the normalization factors that we need for all the rows that belong to the current output block we save this m i so we save it this m i is what is the normalization factor and the maximum for each row that we will need for the backward pass so we need to save it in a tensor that we will use during the backward pass so we need to understand which tensor is this and it's the tensor that we called m which is a tensor of a batch size num heads and sequence length dimensions so we need to select the right point in this tensor to select to where we should save this m i values so we need to select the right batch size index and the right number of head index so we advance this pointer by the following offset which is m plus the index batch head because each index okay the index batch head is what is the index of the current program that includes information about which head we are working with and which batch we are working with because each of this for each batch and for each head we have a sequence length we can skip a number of sequence length based on which index is okay what we are doing is basically we are skipping for each batch and for each head we will have a sequence length because each token in the sequence has a maximum value and each token in the sequence will have normalization value so based on the current combination of batch and head we can skip a number of sequence length that other programs will process so because in this tensor we have the sequence length as the last dimension and we have what is the combined index of the batch size and number of head size we can skip a number of sequence of length based on the combined index which is given by the program index number one which is the index batch head that we have here and this is why we skip here a sequence length number multiplied by the index batch head this m is pointing to the first element of the entire tensor so we are skipping the heads and the batch based on the combined index index batch head that this particular program is working with and then we have off skew off skew is because each of these kernels the attention forward method will work with one query block each query block has some indices for the exact queries it includes and this is given by off skew variable that you can see here which is how many blocks of queries we need to skip because they will be processed by other programs plus the range of queries that this particular that not this that a particular block of queries has so imagine this particular program is working with the queries that go from i don't know from 12 to 16 then this will be 12 13 14 15 so the normalization factor and the maximum value for each row we only have that for the disk for this indices of query queries so 12 13 14 and 15 and that's why we need to also skip the number of queries that this particular program works with which is already included in this offset of skew variable all right so now we can store the mi so because we have the pointer to which where it should be saved and we can also store the output which was computed of by our inner for loop and this guys is the forward step of the attention flash attention now we should go forward which is we should compute the backward pass we also have all the ingredients for computing the backward pass because we have already seen this trick which is the log sum x trick so we already know what um how to use it to compute the query key block during the backward pass on the fly what we miss to understand the backward pass well we need to understand what is the first of all what is the backward pass why do we even need a backward pass we need to understand what is the autograd of pytorch how does it work how to compute the gradient what is the gradient how to compute do we need to what is the jacobian when computing the gradient on the backward pass do we even need to compute that so we need to derive all the formulas of the backward pass by hand so if you are in for the challenge let's continue all right so now before looking at the flash attentions backward pass at the algorithm we need to understand why we even need a backward pass and to understand why we even need a backward pass so before looking at the autograd of pytorch we should be looking at what is what are derivatives what are gradients what are jacobians so that when we talk about derivatives gradients and jacobians we don't feel lost so i will do a very fast let's say rehearsal of what these topics are now what is the derivative when you have a function that takes as input a real value and outputs a real value we talk about derivatives which is defined as follows the derivative of the function with respect to its variable x is defined as the limit for a step size that goes to zero of the function evaluated at x plus h so x plus the step size minus f evaluated at the x at x divided by the step size so intuitively we are saying is the ratio of how much the output change for a small change for how much the input has changed in the function that this also gives you the intuitive intuition of why the gradient is the derivative is also the tells you the inclination of the tangent line of the to the function at the point in which it's evaluated i will use also the following notation to denote the derivative so the derivative i am used to write it as like this so f prime of x but it's also possible to write it as a d of f of x with respect to the x or d of y where y is the output of the function with respect to x and they are all equal to the same thing which is the definition above if we invert this formula here and we take h to the left side we can also write the follows so if we want to evaluate the function at the position x plus h we can also evaluate it as f prime of h so the derivative of the function in the point x multiplied by h which is the step size plus f of x this is actually also how we derive the Euler rule for computing the differential equations but that's not the topic of today so this h we can also call it delta x so f of x plus delta x is more or less because here we have a limit that says when this only happens when h is very very very small so that's why we put this more or less approximately so f of x plus delta x is more or less equal to f prime of x multiplied by delta x plus f of x this you can also read it as follows that if by inverting this formula if x changes by a little amount and this little amount is delta x how much y will change?

y will change by this exact amount which is the derivative of y with respect to x so dy with respect to dx multiplied by how much x has changed so this dy dx tells us how much y will change with a small change of x if we multiply with the actual change of x it will tell us how exactly y will be affected i don't want to use stay too much on this but i would like to use this intuition to introduce the chain rule because imagine we have a function of a function so imagine we have z is equal to f of g of x we can think of x being mapped into a variable y through the function g and then y being mapped into a variable z through the function f if x changes by a little bit and by a little bit i mean delta x how much y will change?

well y will change by delta y what is delta y? delta y is the derivative of y with respect to x multiplied by the step size of x but if y changes it will also affect z because there is a direct mapping between y and z so how much z will change for a small change in y?

let's see so if y changes from the old y by a small step delta y then z will also change by some delta z and this delta z is the dz on dy multiplied by delta y if we replace this delta y with the delta y that we have computed in the expression above we arrive to the chain rule it will tell us how z will be affected so this is delta z what is the effect on z for a small change on x and it's the product of the two derivatives the one with the of y with respect to s and one z with respect to y and this is the chain rule that we study in high school so it is if you want to compute dz on dx it is dz on dy multiplied by dy dx which is very intuitive if you think about the following example so you can think of z as the price of cars and x as the price of the oil how much will a small change in the price of oil affect the price of a car?

well this small change in the price of the oil will affect for example a variable y which could be the price of electricity so if how much the price of electricity will affect the price of a car it's through the derivative of the price of the electricity with respect to the the price of the car with respect to the electricity so to get the effect of the price of oil on the price of the car we just multiply the two effects and this is the intuition behind the chain rule anyway let's talk about gradients so when we have a function that as input takes a vector and produces a scalar we talk not anymore about derivatives we talk about gradients so imagine we have a function that takes as input a vector made up of two dimensions but n dimension in general and it produces a scalar when do we have to deal with this kind of function for example loss functions loss functions are something that are always a scalar as output and as input they take tensors so for example imagine the cross entropy loss it will take a sequence of tokens each tokens with its own logics and it will compute one single number which is the loss so how to view the effect on the output with respect to the input in this case well if x changes by a little amount and this little amount is not anymore a number but it's a vector so if change the x the old x plus delta x is a vector sum then y will also be affected by what y will be affected by dy on dx multiplied by delta x however this delta x is not a number anymore it's a vector because x1 may change by a little bit x2 will change by a little bit x3 will change by a little bit x4 until xn will change by a little bit so this is actually a dot product of this vector multiplied by this vector why a dot product because y will be affected by the change in x1 it will be affected by the change in x2 it will be changed affected by the change in x3 up to xn and each of the contribution of the contribution of x1 will be the partial derivative of y with respect to x1 multiplied by how much x1 has changed plus the contribution of x2 will be the partial derivative of y with respect to x2 multiplied by how much x2 has changed blah blah blah until the last contribution of xn so and the chain rule in this case also applies in the same way as in the scalar case so the formula does not change also for the chain rule here i just want you to to remember to remind you that in this case we are talking about a gradient and the gradient is just a vector made up of all the partial derivatives of the output with respect each of the input variables that are in the input vector when we talk about a function that have as input a vector and produces a vector then we don't talk about gradient anymore we talk about jacobians so if our input x the input x of this function changes by a little amount and this delta x is a vector then the output y will also change and this output y will change by a delta y that is not a number anymore it is a vector and this vector is the result of this quantity dy on the x multiplied by delta x delta x is a vector so this one has to be a vector it has this one here has to be a matrix and this matrix is called the jacobian it is a matrix that has as many rows later we will talk about the denotations so it has as many rows as there are output variables and as many columns as there are input variables the first row will be the partial derivative of the first output variable with respect to all the input variables the second row will be the partial derivative of the second output variable with respect to all the input variables and the last row will be the partial derivatives of the last output variable with respect to all the input variable in the input vector now let's talk about the notations the jacobian that i have written here is a is written according to the numerator layout this is called the numerator layout and there is another convention called the not layout sorry guys it's called the numerator convention and there is another convention called denominator convention or notation in which the rows are not the the number of rows is not the equivalent to the number of output variables but equal to the number of input variables so the fact that i have we we choose to write the jacobian as follows is based on a convention you can also write the the jacobian according to the denominator convention just by transposing this jacobian here and also the formula for the chain rule changes accordingly for now i want to keep the formula for the chain rule just like the one for the scalar case so that's why i am using this notation here but later we can change between one notation to the other just by doing a transposition okay now that we have reviewed what is derivative what is a gradient and what is a jacobian let's talk about what happens when we take derivatives with respect to tensors of a tensor with respect to another tensor in this case we talk about the jacobian but it's called the generalized jacobian so if we have the function that is at input takes a tensor of dx dimensions where the first shape this is kind of the shape of the tensor so the first element of the shape is n1 the second element of the shape of the input vector is n2 etc etc until n dx and it produces an output tensor that has this shape so m1 m2 mdy in this case the formula for the chain rule doesn't change and if x changes by a little amount so by delta x which is a tensor y will also be affected by how much by dy on dx multiplied by delta x and this is a tensor product it will be a jacobian this is called generalized jacobian with the following shape so all the dimensions of the output multiplied by all the dimensions of the input all right this is very abstract for now we will see actually a concrete case of this one because we will be deriving the gradient of the output of a matrix multiplication the gradient of the loss when computing backward pass with respect to each of the input of the matrix multiplication operation and we will do it also for the softmax and we will do it also for the attention so i don't want to jump to too many topics i just wanted us to get into the right mindset so we know that derivatives when we have scalar functions gradients when the output is a scalar input is a vector jacobian when the input and output are both vectors generalized jacobian when the input and output are tensors the chain rule always works in the same way all right let's talk about autogradient i will do the scalar case and then we will extend it to the tensor case so imagine we have a very simple computation graph why we have computation graph because we are talking about neural networks and neural networks are nothing more than computation graphs where we have some input we have some parameters and we do some operations with this input and parameters suppose that you have an input a and this input a is multiplied by a weight it's a parameter weight it's just a scalar and it produces an output y1 this y1 is then summed up with another number called b1 and it produces y2 this y2 is then raised to the power of 2 so this is e to the power of 2 it's just the power of 2 of the input and it produces y3 and this y3 becomes our loss function so it's a scalar now what we want to do to apply gradient descent is we want to compute the gradient of the loss function with respect to each of the input of this computation graph so each of the leaves of this computation graphs what are the leaves it's this node here so the parameter nodes and input nodes and to do that there are two ways one is if you have access to the expression that relates directly the input to the output so the to the loss then you can directly compute the gradient the derivative in this case because it's not a gradient it's a scalar versus color so in this case imagine you want to compute the derivative of the loss with respect to w1 imagine we have access to the exact expression that relates the w1 to to the phi which is our loss we can compute it as follows so we just derive this expression with respect to w1 which is two times because this is the power of two of a function so it is two multiplied by the function multiplied by the derivative of the content of this function with respect to the variable that we are deriving so it will become the following expression there is another way which is by using the chain rule so we can use the derivative of phi with respect to yw1 is the derivative of phi with respect to y3 which is the previous output of the previous node then the derivative of phi3 with respect to the previous the output of the previous node so and then the multiplied by the derivative of y2 with respect to the output of the previous node and then the derivative of y1 with respect to w1 if we do all this chain of multiplication we will obtain the same result and you can see that here this stuff here is exactly equal to this stuff here by doing this procedure here we will note something that is i want to zoom out a little bit okay to compute the derivative of phi with respect to w1 we are doing all this chain of multiplication but what is each item in what is each factor in this sequence of multiplications well this stuff here is nothing more than the derivative of phi with respect to y2 these multiplications here are nothing more than the derivative of phi with respect to w to respect to y1 and all of them combined are the derivative of phi with respect to w1 what pytorch will do it will do the following pytorch will do the backward pass because pytorch knows what is the computation graph that relates the output so the loss function in this case and the variable for which we want to compute the gradient right now we are talking about derivatives so it's not gradient but the mechanism is exactly the same so pytorch will say it will pytorch is like a person that knocks the door of this operation and says hey operation exponential power of two if i give you the gradient of the loss with respect to y3 which is one because the loss and y3 are actually the same can you give me the gradient of the loss with respect to y2 because pytorch actually does not implement an autograd system in the sense that it does not know the symbolic operations that led to the output it just knows what are the functions that computed the output and each function has a function each function is a class in python that implements two methods one is the forward step and one is the backward step the forward step takes the input so in this case y2 and computes the output y3 the backward step will take the gradient of the loss with respect to its output and needs to compute the gradient of the loss with respect to its input how can we do that well it's very simple because a pytorch will knock the door as let me copy it and this stuff here otherwise it's not easy to go back and forth so okay and let's paste it here pytorch will knock the door of this function here and we'll say hey if i give you the loss of the gradient of the loss function with respect to your output can you give me the gradient of the loss function with respect to your input yes the function can do it why because of the chain rule this operator here this function here can just do take the loss the gradient of the loss function with respect to its output multiply it by the jacobian or in this case the derivative of its output with respect to its input and it will be equal to the gradient of the loss with respect to its input then pytorch will take this one and knock the door at the next operator which is this one this summation and we'll say hey if i give you the gradient of the loss with respect to your output can you give me the gradient of the loss with respect to your input yes this operator can do it because this operator just needs to apply the chain rule so it will take the gradient of the loss with respect to y2 which is provided by pytorch and by multiplying it with the the jacobian in this case it's the derivative the derivative of the its output with respect to its input it can compute the the gradient of the loss with respect to its input then pytorch will take this output of this backward pass and will knock the door of the next operator which is this product and will ask again the same question hey if i give you the gradient of the loss with respect to your output can you give me the gradient of the loss with respect to your input yes this will do the same exact job it will take the gradient of the loss with respect to the output multiplied by the jacobian of the output with respect to the input and obtain the gradient of the loss with respect to the input and this is how pytorch runs the backward step it runs one operator at a time backwards in the computation graph knocking the door of each operator and asking always the same question if i give you the output the gradient of the loss with respect to your output can you give me the gradient of the loss with respect to your input and each operator will just apply the chain rule to to to get this to get this gradient to calculate this gradient that pytorch needs why pytorch cannot do it by itself because pytorch does not do symbolic mathematics it does not have access to the exact expression that each function is computing it just uses the function as a black box that computes forward and backward however with the jacobian we have a problem and let's see what is the problem all right so up to now we have been working with a computation graph that is made up of scalars but the things that we have said they work in the scalar case but also in the tensor case so let's go back see what is our computation graph we have seen that pytorch will go operator by operator asking always the same question if i give you the gradient of the loss with respect to your output can you compute me the gradient of the loss with respect to your input and each operator can just apply the chain rule to compute that imagine now that all of these operators are working not with scalars but are working with tensors which means that the derivative of the output with respect to the input of each operator is not a derivative it will be a jacobian because the output will be a tensor a generalized jacobian and input will be a tensor which means also that this quantity here so the derivative of the loss with respect to the input in this case will not be a derivative it will be a gradient because the output the loss is a number always while the input in this case y1 will be a tensor so number output input is a tensor then we talk about gradients so this will be a gradient and we will call it the downstream gradient that the operator needs to compute this will be the upstream gradient that pytorch will give to the each of these operators so the gradient of the loss with respect to the output of each operator and each operator needs to come up with this downstream gradient by using the jacobian however the jacobian has a problem let's see so imagine we are implementing a simple operation that is the matrix multiplication and the matrix multiplication is takes as input a x tensor it multiplies it by a w matrix made up of parameters and produces a y matrix as output suppose that x is let's call it n by d matrix w is let's say d by m matrix and so y will be a n by m matrix usually the input x is a sequence of tensor of let's say vectors each of each with d dimensions so you can think of it as a sequence of tokens each token is a vector made up of d dimensions usually we have many tokens so suppose that n usually is at least 1024 at least in the most recent language models we even have millions of tokens actually so and d is also actually quite big it usually it is at least 1024 also so also this one is 1024 and d and m m is also at least 1024 so we can actually become 2028 let's say so i i like the powers of two by the way so the problem of the jacobian is this if we compute want to compute this downstream gradient by multiplying the upstream gradient with the jacobian this jacobian matrix is huge because look at the dimensions here this will be a matrix that is it will be well n by m multiplied so it will be a generalized jacobian so it will be a tensor that has a shape n m and then the input is x so it is n by d so how many elements it will have well it will have 1024 multiplied by m which is 2048 multiplied by 1024 multiplied by d which is 1024 so it is at least wow it's a billions more than 1 billion elements so it is impossible actually to materialize this matrix here in the memory because in the ram of the gpu because it will be too big so but we need to compute this downstream gradient because pytorch needs it to continue calculating the gradient of the loss function with respect to each of the nodes in the computation graph so how can we proceed the first thing that we should notice is that this this jacobian is actually a sparse matrix and i want to show you why it is actually is a super super super sparse matrix because if you look at the input what is the effect of the input on the output the input is a sequence of tokens so this is the token number one it's a vector of some dimensions 1024 dimension then we have another token as input then we have another tokens as input then we have another tokens as input and we multiply by the w matrix which is made up of some columns some columns so this one is n by d right yes and w is d by m so d by m this will produce a matrix that is n by m so it will be also a sequence of tokens each made up of m dimensions so it will be a matrix like this so this will be the first output token this will be the second output token this will be the third output token and this will be the fourth output token now this output row here is the dot product of this input row with all the columns so the derivative of each of these dimensions with respect to the dimensions of all the other tokens will be zero because they do not contribute to this output so the jacobian will have zeros every time the we are calculating the derivative of this first dimension with respect to any other element of other tokens that's why we always can come up with a better formula for computing this downstream gradient that does not involve the materialization of the jacobian because the matter the jacobian itself is sparse so let's see how we can optimize this computation without materializing the jacobian in the case of matrix multiplication because we need it for flash attention all right guys so before proceeding to the backward watch the formulas of the backward path of the flash attention let's look at how to compute the gradient of the matrix multiplication operation with respect to its input so imagine we are creating okay pytorch already have actually how to compute the gradient of the inputs of the matrix multiplication with the gradient of the loss with respect to the input of the matrix multiplication operation but in flash attention we are creating a custom kernel which means that the custom kernel is fusing multiple operations into one operation so when pytorch will knock the door of our operator it will ask the our operator which is the triton attention operator that we have built what is the gradient of the loss function with respect to q k and v because that's the input of our function so if we look at the code that we have built so far you can see that our triton attention will be a node in the computation graph that takes as input q k and v and produces an output then pytorch will give us the gradient of the loss with respect to that output so it will give us a d o so the derivative of the loss with the gradient of the loss with respect to o and then we'll ask this class here so triton attention to compute the gradient of the loss with respect to q k and b because we are fusing multiple operations together so we are computing on the fly the softmax of query multiply by the transpose of the key and then multiplying doing the softmax and multiplying it by v to compute the output we need to compute this gradient internally to compute this the gradient of the inputs so because in this operation that we are doing fusing together there is a matrix multiplication we need to derive by hand the matrix multiplication the gradient of the of the loss function with respect to the input of the matrix multiplication operation so that we can provide it to pytorch that's why we need to derive this formula i will derive it in the simple in a very simple way and and then we will do it for the softmax as well because these are the two things that we need to derive by hand to derive the formula of the flash attention's backward pass so let's start imagine we have a computation graph a node in the computation graph called the matrix multiplication and this node in the computation graph is doing a matrix multiplication so it is computing the following operation y is equal to x multiplied by w now what pytorch will give us as input when computing the backward pass of this node pytorch will give us the gradient of the loss so it will give us d phi with respect to dy so the output of this node and will ask us to compute the gradient of the loss function so the gradient of the loss function with respect to dx and the gradient of the loss function with respect to dw the easiest one to work with and the one that i will be showing and the other one i will not show in the video but i will attach the pdf slide on how it is computed because they are very similar in the way they are computed so i don't want to make the video too long for unnecessary reasons let's compute the gradient of the loss function with respect to the input so with respect to x all right so how to do that by hand without materializing the jacobian because as we have seen we cannot just use the chain rule by materializing the jacobian which would be the easiest way because the jacobian is very big matrix that cannot even fit in the memory of the gpu so we need to find a smarter way we exploit the fact that the jacobian is sparse so hopefully we will get formula that does not involve the materialization of a very big sparse jacobian let's see so uh let's see um let's when dealing with these kind of derivations i always recommend to make some example tensors so suppose that that x is a tensor of size let's say n by d and where n let's say n is equal to one and d is equal to let's say three and w is a tensor also or a matrix with the shape let's say d by m where m is equal to let's say four and y will have as a consequence the shape n by m so it will have the shape well one by four what pytorch will give us and pytorch will give us the following quantity so it will give us this stuff here so the gradient of the loss function with respect to the output of this operator which is y so it will give us a vector or a tensor actually with the following dimension which is n by m and we need to compute the gradient of the loss function with respect to x which should be a tensor of shape n by d because when dealing with the gradient it always has the shape of the input variable because it's the output which is a scalar with respect to each element in the input so it has the same shape as the denominator all right so when dealing with this kind of problems i always recommend to create example matrices and then work out what happens to the output and then try to work out the the gradient matrix so let's do it so let's see that what is how is the output computed well the output will be a matrix that is a one by four computed as follows it will be the input so one by three so let's call the input x one one x one two x one three it will be multiplied by another matrix w that it has dimension three by four so it will be three rows by four columns so it will be w 1 1 w 1 2 w 1 3 w 1 4 then w 2 1 w 2 2 w 2 3 w 2 4 w 3 1 w 3 2 w 3 3 w 3 4 if we do this matrix multiplication it will be well it will produce the following matrix that is okay this is one row by three columns this is three column three rows by four columns so the output will be a matrix that is one by four so one row by four columns so it will be let me write it with a smaller because otherwise it will never fit here so let's do it like this it will be x 1 1 multiplied by w 1 1 plus x 1 2 multiplied by w 2 1 plus x 1 3 multiplied by w 3 1 and this will be the first element of the output the second element of the output will be x 1 1 with w 1 2 x 1 1 with w 1 2 plus x 1 2 with 1 2 with w 2 2 plus x 1 3 with w 3 2 this will be the second element of the output matrix the third element of the output matrix will be let me move this stuff on the left otherwise it will never fit so okay i think now it can fit this will be x i need to watch this one so x 1 1 with w 1 3 x 1 x 1 1 with w 1 3 plus x 1 2 with w 2 3 plus x 1 3 with w 3 3 and then we multiply the same row with the last column so it will be x 1 1 w 1 4 plus x 1 2 w 2 4 plus x 1 3 w 3 4 this will be the output y if we do the matrix multiplication what pytorch will give us it will give us the gradient of the loss so it will give us delta phi with respect to delta y because it's a gradient it has the same shape as the denominator so it has a shape that is 1 by 4 let's call it because we don't know what this value will be they will be provided to us by pytorch let's just give them generic name like d y 1 1 d y 1 2 d y 1 3 and d y 1 4 like this now to compute the the downstream gradient that we need to provide to pytorch we should be computing the we should be materializing the jacobian which is which is okay let's write the chain the chain rule formula so we need to provide delta phi to with respect to delta x which is equal to delta phi with respect to delta y this is provided by pytorch multiplied by the jacobian which is delta y with respect to delta x now instead of materializing this jacobian let's try to do this let's materialize it now and let's do the multiplication of these two quantities to see if something simplifies so this stuff here will be dy with respect to dx which means the derivative of every output y with respect to every input x how many output we have we have four elements as the output which is this stuff here and we have three element as input in the x matrix so it will be as follows i don't know how to let me copy it because my screen is not big enough and i remember that x is x 1 1 and x x 2 so delta y with respect to delta x will have the following entries so the y1 with respect to x 1 1 and as you can see y1 only has one x 1 1 appearing as multiplied by w 1 1 so the derivative with respect to x 1 1 will be w 1 1 then y 1 1 so this stuff with respect to x 1 2 it will be w 2 1 then x y 1 1 with respect to x 1 3 will be w 3 1 the second row of this matrix will be the derivative of the partial derivative of the second output so w y 2 with respect to all the x inputs which will be the derivative partial derivatives of this stuff here with respect to every x which is w 1 2 w 2 2 i guess and w 3 2 now let me check if it's what i'm doing is correct yes because i've already done it so i can always double check and then we have w the partial derivatives of this stuff here with respect to all the x which is w 1 3 w 2 3 and w 3 3 then the partial derivatives of the last output so y 4 with respect to all the x which will be w 1 w 1 4 w 2 4 and w 3 4 we obtain the following jacobian if um but this jacobian as you can see is just equal to w transposed so we don't need to materialize the jacobian we can just do the multiplication of whatever gradient pytorch is giving us multiply it by w transposed and we will get the downstream gradient so let me rewrite so we know have we know what we are doing so d phi on d dx is equal to d phi with respect to y multiplied by dy on dx but we have seen that dy on dx is just equal to w transposed so this is equal to d phi on dx dy multiplied by w transposed and this gives us the downstream gradient so in order to provide the downstream gradient that pytorch need we just need to take whatever gradient pytorch will give us multiplied by w transposed and it will give us the gradient of the loss function with respect to the input x of the matrix multiplication in the same way we can also write the formula for the gradient of the loss function with respect to w and it is equal to x transposed multiplied by d phi with respect to dw dy how to remember these formulas these are there is a mnemonic rule which is these are the only possible ways for this to have the shape of x and this to have the shape of w because this one's this stuff here will have the same shape of y so it will be n by m this stuff here will have shape of w transposed w is d by m so w transpose should be m by d and the resulting operation of this matrix multiplication or tensor multiplication will be n by d which is exactly the same shape as x in this case we will have that xt is the transposed of t and it is n by d so it's d by n multiplied by d phi with respect to dy which is a gradient so it has the same shape as the denominator so it has n by m and the output will have shape d by m which is exactly the the shape of w so if you if to remember them this is the only way this shape work out otherwise they don't work out so this is a mnemonic formula on how to remember how to compute the gradient of the inputs of a matrix multiplication given the gradient of the loss with respect to the output of the matrix multiplication and the inputs to the metric multiplication are the input matrix and the parameter matrix w now we need to derive the gradient of the output of the softmax with respect to the input of the softmax because that's another operation that we do in our fused attention because we are fusing many operations together which is matrix multiplication and the softmax so this is the second ingredient that we need to understand the backward pass of flash attention so let's do it i will use to make this derivation i will use the same notation as in the flash attention paper so first of all let's write the title of this stuff which is the gradient through the softmax the first operation that we do in during computation of the attention is we compute the product of the query multiplied by the transpose of the keys we do in a blockwise way it means that we do it block by block but it doesn't matter because the end result is the same so we can also we can write s equal to q multiplied by the transpose of the keys and then we apply the softmax to this operation to the result of this operation and we call this output p which is the softmax of s and after the uh we have applied the softmax we take the output of the softmax we multiply it by v to obtain the output so the output is equal to p multiplied by v now we need to understand how to because as i said before pytorch autograd works in the following way pytorch will treat our attention computation as a black box so we will have a computation graph like the following we will have a query input a key input and a value input which are sequences of tokens each one with some embedding dimension these are fed to some black box called the attention which is our implementation of the attention which is the function that we started coding before this will be fed as input to this node in the computation graph and the computation graph will output a an output tensor o what pytorch will give us pytorch will give us the gradient of the loss with respect to the output so as you remember pytorch knocks the door knocks the door at each operator and says if i give you the gradient of the loss with respect to your output can you give me the gradient of the loss with respect to your inputs and this is what we need to figure out so given the gradient of the loss with respect to the output we need to understand how to compute the gradient of the loss with respect to the wq the gradient of the loss with respect to wk the gradient of the loss with respect to wb however there is no direct connection between q and o or k and o because there are two intermediate operations so one there is a first a matrix multiplication then there is a softmax then there is an additional matrix multiplication however we have tools that allow us to understand how the gradient propagates through multiple operations when they are applied in sequence and that's called the chain rule however we have seen that applying the chain rule in its naive way by materializing the jacobian is infeasible so we need to understand how to apply the chain rule without materializing the jacobian and that's what we are going to figure out for one of the operations inside of this attention computation which is the softmax and that's why we are going to do this derivation which i promise is the last one that we will do and then we will finally go to code the backward pass of flash attention we cannot proceed directly to coding the backward pass of the flash attention because if we look at the formulas on how it is computed we will not understand how the the derivation comes out okay now we can start so let me delete this stuff delete and imagine for simplicity now we apply the softmax to a row wise to this s matrix so each row is softmaxed independently from the others so let's see what happens to one single row of this matrix and for simplicity i will call it s so s is a single row of the s matrix i could also call it s of i but if i do it like this we will have to carry over the index okay guys just just do it we will carry over the index all right so let's call si one row of the s matrix so si is equal to let's say it's the in tensor notation pytorch tensor notation it will be like this so from the matrix s from the tensor s we take the ith row and all the columns this is the definition of si i know it's very ugly notation but it helps you understand and this is a vector of size and dimensions we apply the softmax to this vector and we will obtain an output vector and we call it pi pi is equal to the softmax softmax of si so as we have seen the softmax operation does not change the shape of the input it just changed element wise each number so the output will also be a vector of size r to the power of n now what is the softmax so the softmax is defined as follows the softmax of well p i j so the jth element of the p ith vector is equal to the exponential of the jth element of the s ith vector divided by a normalization factor that is computed as follows with let's say not j let's use k in this case not even k let's use l is equal to one up to n of e to the power of s i l all right so first of all you may be wondering the softmax that we are that we apply during the forward pass of the computation of the attention is not really this softmax because in if you remember what we applied before we were applying the softmax where each of the argument of the exponential is reduced by the maximum element in the vector to which we apply the softmax so it was more or less like this so s i j minus s i max so the maximum element in the s i j s i vector and also the argument of the denominator was reduced by s i max however we also proved that this stuff here is equivalent to the standard softmax without this reduction in the argument because this reduction in the argument is only added because we want to make it numerically safe to compute but there is it's equivalent to do it without but from a mathematical point of view on the computer of course it will become numerically unstable but from a mathematical point of view it is the same thing which also means that doesn't doesn't matter how you compute the forward pass if it's equivalent to another mathematical definition you can always use the other mathematical definition to compute the backward pass it will result in the same value if you didn't understand what i said let me give you a more simple example which is imagine you have a do you remember the formula from high school this one so cosine cosine of squared of x plus sine squared of x is equal to one now imagine we compute an output y is equal to cosine squared of x and then we need to compute the derivative of y with respect to x it doesn't matter if you compute it as the derivative of cosine squared of x with respect to x or if you compute it as the derivative of one minus sine squared of x with respect to x because they will result in exactly the same result because the two definitions are equivalent and this is why we don't need to add this this factor in the exponential because the two definitions are equivalent mathematically we just use the numerically safe one because when computed on the on the computer we need something that is numerically stable that will not overflow all right now what do we want to obtain so we want to obtain the gradient of the loss with respect to the input vector of the softmax which is the s_i vector given the gradient of the loss with respect to the output of the softmax which is the p_i vector and we can obtain that with the chain rule multiply that by the jacobian p_i with respect to s_i now we the chain rule is always valid let's see what does this jacobian look like all right so this jacobian will be the p_i with respect to delta s_i well we need to do it let's look at what each element in this jacobian will look like so the jth element with respect to the let's say the kth element so we are um we are computing the the we are looking at what each element in this jacobian will look like which is what is the jacobian it's each element in the out in the numerator of the jacobian derived with respect to each element in the denominator of the jacobian in this fraction here so we are saying for each element in the output vector derived with respect to each element in the input vector this is what we are writing here so what is how is the output vector obtained well p_ij we know that it is equal to by the definition of the softmax is obtained as follows so e to the power of s_ij divided by the normalization factor let's call it l is equal to one to n e to the power of s_il all derived with respect to s_ik i k so what we are trying to do is we know that the p vector is suppose it's a vector with the three elements so this is p_1 this is well p_11 p_12 and p_13 the s vector will be a vector also with the three elements so it will be the s_11 s_12 and s_13 what we are trying to do is the calculate what the jacobian will be the derivative of this one with respect to all the input vector then then the second row of the jacobian will be the derivative of this one with respect to each of this input element then the third row of the jacobian will be this stuff here with respect to the derived with respect to each of the input element of the s vector we are trying to understand what does the generic element in this jacobian look like based on the j date element of the output vector so this j index refers to the output vector and the kth element in the input vector all right so what can happen when we do this jacobian is that we have a this one here is the derivative of a fraction of two functions and we know from high school that the derivative of the fraction of two functions is as follows so the derivative of the derivative let me write like this of f of x with respect to g of x prime is equal to with respect to x by the way is equal to f prime oops of x multiplied by g of x minus g prime of x f of x all divided by the g of x to the power of two like this now let's apply it here so this will become here we will have two cases either the variable that we are deriving with respect to so this sik has the same index as the variable being derived so either we are doing a p11 with respect to s11 or we are doing a p11 with respect to something else that has not the same index so like p11 with respect to s12 or s13 so there are two cases that we need to consider suppose that we are deriving p11 with respect to s11 or we are deriving p12 with respect to s12 or we are deriving p13 with respect to s13 so we are deriving the element of the output with respect to the same element in the input with the same index so in this case the this this derivative will look like the following so it's the derivative of f so the numerator with respect to the denominator that has the same index so we are saying that in this case j is equal to k so the numerator with respect to sij with respect to e to the power of sij with respect to sij will be e to the power of sij so because e to the power of x1 with respect to x1 will be e to the power of x1 so this is equal to i am reducing the size now e to the power of sij then we need to multiply that by the denominator of the fraction which is this summation here so the summation over all possible l of e to the power of sil minus the derivative of the denominator with respect to the variable being derived so this denominator is the sum of all the exponentials of all the input elements if we derive it with respect to one particular input element there will be at least one term that contains that input element and so the all the other terms will result in zero so the only derivative that will survive will be the e to the power of sik with respect to sik so we write minus e to the power of sik multiplied by the numerator which is e to the power of sij all this divided by the denominator to the power of two which is this summation here so l equal to one up to n e to the power of sil all to the power of two and this stuff here will be equal to well we can see that they this two term this one and this one have a one term factor in common which is e to the power of sij so we can collect that so e to the power of sij multiplied by the summation minus e to the power of sik all this divided by the denominator which is the power of two of this stuff here so let me just copy and paste it which is let me rotate it also because i don't know why i always write little little yeah all right and this stuff here is equal to well we can separate the two terms so we can separate this term here and this term here because the denominator is to the power of two so we can write it also as e to the power of sij divided by the denominator so which is summation of l equal one to n e to the power of sil multiplied by this stuff here so this stuff here divided by the same denominator so there's summation of l equal one up to n e to the power of sil minus e to the power of sik i am sik divided by the same denominator sil now this one can be written as this stuff here is nothing more than the output element pij because this one is just the softmax applied to the sij element which we know that the output of the softmax applied to the sij element is called pij because it's one element of the output vector which we call the p so this stuff here is equal to pij multiplied by this stuff here will be equal to one minus this stuff here what is this stuff here is the output of the softmax applied to the sik element so it will be pik so it is equal to one minus pik okay and this is in the case the variable with respect to which we derive has the same index as the numerator in this fraction here in this derivative here the other case is when the two variables so the output the index of the output with respect to the index of the input are not the same in this case we will have another case so we will have that j let me write it again so this stuff here i hope i can copy it all without in the other case in which s is not equal to j uh yes it's j not equal to k so j is not equal to k what happens in this case it will be well the derivative of the numerator because we need to apply again this formula here so derivative of the numerator with respect to something that is not the same variable it will be zero because it's like computing the derivative e to the power of x1 with respect to x2 it will be zero so it will be zero so all the first term here will become zero no matter what is g of x minus the derivative of the denominator of this fraction here with respect to the variable sik g prime of sik so this is all the variable in the input and we are deriving it with respect to one particular variable of the input so only one item in the summation will survive so it will be the item sik so it will be e to the power of sik multiplied by f of x which is the numerator in this fraction which is e to the power oh we forgot a minus e to the power of sij let me see if i forgot something all divided by the denominator of this fraction here to the power of two so it is equal to the summation l equal one up to n of e to the power of sil all to the power of two i believe i didn't forget anything so let's continue so here also we can see that this one here is because uh okay let's separate it minus e to the power of sik divided by the summation l equal one up to n of e to the power of sil multiplied by e to the power of sij divided by the summation l equal one up to n of e to the power of sil this stuff here is nothing more than the softmax applied to the kth element of the si vector this one here is nothing more than the softmax applied to the jth element of the si vector so we know what these are we know that we call them p minus pik pij so in the end we have two cases one is the derivative of this stuff here looks like the following each item in the jacobian looks like the following when the numerator and the denominator have the same index so j equal to k this stuff here is equal to now this notation here is wrong so i shouldn't be writing it with the equal sign but doesn't matter guys it's we are doing a little okay so pij pij multiplied by one minus pik let me check yes the other case is when the j is not equal to k then this stuff here let me write it like this will be equal to minus pik multiplied pij now that we know what the two typical cases of this jacobian look like let's actually look at what this jacobian look like in the matrix form so this jacobian will look like the following it will be a matrix that is more or less like the following it will be an n by n matrix where n is the size of the input vector and the output vector at here the first element of the jacobian as you saw as you remember the first row of the jacobian in the numerator convention is the derivative of the first output with respect to all the input so this first term here will be the derivative of p11 with respect to s11 so in this case j and k match so we know that it will be equal to p11 multiplied by 1 minus p11 the second element to the right of this one so the element one two will be the derivative of p12 with respect to sorry the p11 with respect to s12 the j and k do not match so we will be in this case here so it will be minus p11 p12 the third element you can check it by yourself it will be minus p11 p13 blah blah blah until the end which will be minus p11 p1n the second row of this jacobian will be will look like this so it will be the derivative of p12 with respect to s11 the j and k do not match so we are in this case here so it will be minus p12 p11 then the next element it will be the derivative of p12 with respect to s12 so j and k match so we are in the first case so it will be p12 multiplied by 1 minus p12 then this stuff here will be equal to then the third element will be minus p12 with respect to p13 blah blah blah and until we arrive to the last one which is minus p12 with respect to p1n not with respect to multiplied by b1 and all the elements like this until the last row the last row will be the the first element of the last row will be the derivative of the last output element with respect to the first input element so it will be the derivative of p1n with respect to s11 so the two indices do not match so we are in the second case so it will be minus p1n p11 this will be minus p1n p12 etc etc etc now let me do also the third element since we are here so minus p1n p13 etc etc etc until the last element of the last row which will be minus p1n p1n i guess oh oh no that's wrong guys because the two indices match so it should be p1n multiplied by 1 minus p1n this is what the jacobian will look like let's see if we can find a better how to generate this jacobian with some pattern recognition let's write it in a different way first of all the thing first thing that we can notice is that this jacobian is symmetric so you can see that this element is equal to this element if you expand the third row you will see that it's equal to this element this one on the top right corner is equal to the one in the top bottom left corner so this matrix is symmetric the second thing that we can notice is that only the element in the diagonal are different they have an additional term because you can look at this element here so let me write this element here can also be written as p11 minus p11 multiplied by p11 the second element here in the second row so the second diagonal element of this matrix is p12 minus p12 multiplied by p12 so this element on the diagonal actually look like just like the other elements they just have an additional term which is p11 in the first diagonal element p12 in the second diagonal element so we can also say that this matrix here is the product of all the possible combinations of p_ij with p_ik which we can obtain with an outer product or even with the product of one column with the transpose of the same column so if you do one column vector for example imagine p is a column vector and you do p multiplied by p_t you obtain all the possible combinations of products of these two vectors because this will be one i can do a simple case so p11 p1 let's call it p2 p3 multiplied by the row vector p1 p2 p3 this will generate all the possible combinations of products between p1 and the p the first vector and the second vector because this will be a three by one this is one by three so it will be generated three by three vector and it will be equal to p1 p1 p1 p2 p1 p2 p1 p3 etc etc etc moreover we can see that in the diagonal of the matrix we have this additional term this additional term p1 in the first diagonal element p1 p12 in the second diagonal element p13 in the third diagonal element i actually call it p1 it's wrong because i should call it p_i that's why i didn't want to bring the i indices so it's not really p1 it should be p_i p_i p_i p_i because we are doing it for the generic height p_i vector so let me fix the indices p_i_n p_i_3 this is one p_i and p_i okay so this is p_i p_i p_i p_i p_i p_i p_i okay we can obtain so we can write the this this jacobian here also as the diagonal matrix that in the diagonal has all the element of the p_i vector minus the p vector multiplied by the transpose of itself so with itself but transposed because we need all the elements to be kind of a combination of one element of p with itself with another element of p plus only on the diagonal we need some this additional term which are the elements of p and all the elements of the the the output of this p multiplied by p transposed are negated that's why we need this minus sign so if you look at the flash attention paper they give you this formula here they say that if y is equal to the softmax of x then the jacobian will look like the following will be diagonal of y minus y y transposed where y is the is a column vector all right guys i know this has been long so let's take a pause and we are going to now code finally first of all let's check the mathematics of the backward path of flash attention we will see it briefly i will not do any more derivation but i will explain it and then we finally switch to coding it so let's go all right guys now finally we can see the the backward path of the flash attention so we will be looking at the algorithm and if you look at the the the appendix of the flash attention paper you will see this part b.2 where they derive the backward path step by step now i don't want to do all the same all the steps of this derivation because it's going to be too long but i want to give you all the tools necessary to understand it now let's start from what kind of what say conventions they are using notations they are using in this paper so the first thing that we need to rehearse is the naming of what is what is the name of each matrix as you know in the forward attention in the forward pass we do the query multiply by the transpose of the key and the output of this we call it s then we apply the softmax to this s matrix and it becomes the p matrix the softmax is applied by rows then we talk take this p matrix and we multiply by a v matrix to obtain the output of the attention let's look at for example how the computation of the height row of the output is computed based on the p matrix and the v matrix so we can understand this kind of notation that they are using here in the paper because the way i read this formula here is the height row of the output which is a column vector because in when we write in in mathematics in linear algebra whenever we write the name of a vector it is always by convention a column vector but the origin of this particular vector is actually a row of the output matrix let's try to understand what is the output row of a matrix in a matrix multiplication now so that we can understand how to go from here to here so let's write a generic matrix multiplication for example an a matrix let's say that it is the following and we only write one row actually let me zoom again and i want to write smaller so we have enough space so we make a matrix that has a row let's call it a 1 a 2 a 3 and then we multiply this will be a matrix with many rows like the this one because we want to study the effect only of one row and we multiply it by another matrix let's call it this one is the matrix a and it has i don't know let's say n rows by three columns then we should have another matrix b with three columns and some number of three rows and some number of column let's say four columns so we call the first row let's call it let me zoom a more so it's b11 b12 b13 b14 then this one should be b21 b22 b23 b24 this should be b31 b32 b33 b34 etc i know i am not very rigorous in my notation i should have called all these elements with the capital letter a and the capital letter b so this is the notation that you use when referring to single item of a matrix but please forgive me for this so the output of this matrix multiplication will be another matrix that is n by 4 so it will be n by 4 so we will have four columns for each row of the output i want to write the output in a different way so i want to write it as follows as a vector only so the first output row as a vector and want to understand what is each dimension of this vector so because otherwise i don't have enough space to write it here so the first let's write it so let's call it o i want to write what is o of one which is the first row of the output but written as a column vector so o of one will be here we should use the small letter o of one should be a vector where the first dimension is the dot product of this stuff here so the first row of the a matrix with the first column of the b matrix so the first let's say dimension will be a1 with b11 i should also call this one a11 a12 actually and a13 so a13 because we have many rows in the a matrix so let me use the correct naming so this will be a11 with b11 a11 b11 plus a12 multiplied by b21 plus a13 with b31 and this will be the first dimension of the first row of the output matrix o the second dimension of the first row of the output matrix o will be the dot product of this row of the a matrix with the second column of the b matrix and let me write here b so it will be a11 b12 plus a12 b22 plus a13 b32 the third dimension will be a11 b13 plus a12 b23 plus a13 b33 the fourth dimension will be a11 b14 plus a12 b24 plus a13 b34 now this is the output the first output row of the o matrix and it's a vector called o1 and these are this is the first dimension of this vector this is the second this was the third and this is the fourth dimension and each of this stuff here is one scalar um so the output o1 which is the first row of the output matrix can also be written as the first element as you can see in is a sum of many vectors where the first element is a11 multiplied let me use a smaller this one but i want to use a smaller i can't change the size here okay it doesn't matter so as you can see here there is a1 multiplying a different b number every time so this is a b11 b12 b13 b14 what is b11 b12 b13 b14 it is the first row of the b matrix so it is equal to b1 and all the dimensions of the first row then plus then we have the element a12 multiplied by b21 b22 b23 etc etc and this is the second row of the b matrix so we use the tensor notation of pytorch to describe this row which is a b2 and all the dimensions of b2 so it looks this is a vector scalar product and plus a13 multiplied by b3 and all the dimensions of b3 this one can also be written as the summation over all possible i that go from 1 to 3 where 1 to 3 is how many columns there are in the a matrix of a ij well a1 let's call let's call this one j actually sorry let's call it j equal to 1 and let's call this the generic ith row of the output matrix will be a i1 a i2 and a i3 each one multiplied by the corresponding row in the b matrix so we can write it as a i j multiplied by b j where b j is the a row of b we can also write it like this to indicate that this is a vector and this is exactly what they do here so the output in the output matrix when we do the multiplication p multiplied by v the ith row of the output matrix we call it o i which is a vector but by notation it is a column vector where the elements of this column vector are actually the elements of the ith row of o this is only by notation guys is equal to the ith row of p so the ith row of the matrix that is on the left in the matrix multiplication multiplied by all the columns of the v matrix which can also be written as the summation over all the elements of the ith row of p so all the elements of the ith row of the first matrix the one on the left in the matrix multiplication multiplied by each vector in the v matrix where the jth matrix here in v is each row of the v matrix so and p i j can also be written as p i j is what is the the output of the softmax so as you know the output of the softmax is e to the power of l the element input of the softmax what is the element input of the softmax is the query multiplied by the transpose of the keys so it's a dot product between one query and one key and that's why you have this stuff here in the exponential so this is the first step in understanding this derivation another thing that we have studied so far is how to derive the backward path of the matrix multiplication and of the softmax so now let's use it in the matrix multiplication let's rehearse the formula so if given a matrix multiplication that is y equal to x multiplied by w we know that given the gradient of the loss function with respect to y so the output of this operation we know how to derive the gradient of the loss with respect to one of the input of this function which is the x or w to get the gradient with respect to x we need to take the upstream gradient so the the gradient with respect to the output multiplied by the transpose of w t and to get the gradient with respect to w we need to do the xt so the input transposed multiplied by the upstream gradient this one is the formula that we didn't derive and this one is the formula that we derived but how to derive them is exactly the same procedure in attention we are doing the last product that we are doing is o equal to p multiplied by v what pytorch will give us as input during the backward path is the gradient of the loss with respect to the output and we need to use this gradient of the loss with respect to the output of the attention to derive the gradient of the loss with respect to q with respect to k and with respect to v so that it can then be used by the operators in the backward path in the in the computation graph in the operations before okay so but in order to arrive to the gradient with respect to query key and value we need to derive the gradient with respect to each intermediate operation so the last operation that we do is o equal to p multiplied by v so the gradient with respect to o of the loss with respect to v given the gradient of the loss with respect to o it is exactly like computing the gradient of the of the loss with respect to x in a matrix multiplication and we know that it is equal to pt so just by analogy guys so this is our reference point and i am just changing the names here and you should understand what is the analogy here so the gradient of the loss with respect to v which is the matrix on the right which is like computing it with respect to w it is equal to just like this formula here so the transpose of the matrix on the left multiplied by the upstream gradient which in the paper they write it as this so dv is equal to pt multiplied by do and it's the formula that you said you can see here the other derivation is how to derive the gradient with respect to dp dp is just like deriving the gradient of the loss with respect to the matrix that is on the left side of the matrix multiplication so it is just like deriving the gradient of the loss with respect to x in the reference formulas which is equal to the upstream gradient multiplied by the transpose of the other matrix which in the notation of the paper they write it as dp is equal to do multiplied by v transposed and it's this formula here how they compute this stuff here is exactly as above so as this derivation here they call vj the jth row of the v matrix and they write it as pij multiplied by do how to arrive to this formula here well let's do it so let me write let's see okay theoretically we know that from this derivation here so from this derivation here or from this derivation here we know that the i-th row of the output in a matrix multiplication first of all let's simplify our life every time you see a transpose and you don't like work with the transpose in a matrix multiplication just give it a different name and then work with the different name and after when you have derived the formula you resubstitute the transpose operation in this case we are doing dv is equal to p transpose multiplied by do let's call p transposed let's give it a name that we are we didn't use so far so let's call it f i always use f when it's available so we call dv is equal to f do we know from above here from this derivation here or this derivation here is equivalent that the output of a matrix multiplication so the out i-th row of the let's know the j-th row let's call it the j-th row dvj is equal to a summation of each element of the j-th row of f of the first matrix so we do the let's see here we do the sum by i so let's do it by i it's the sum over all possible i of the i-th element in the j-th row of the first matrix so fji multiplied dot product not dot product this is a scalar vector multiplication multiplied by a vector that is let me check what was the formula so it was the j-th row of the other matrix so in this case it should be the i-th row of the other matrix o of i where i this is the i-th row of i this is the j-th row of the v matrix and but also we don't we know that f is not a matrix that we have it's actually the transpose of p which means that fji will be equal to pij because in a matrix transposition you invert the two indices so this is the summation over all possible i's of p not ji but ij multiplied by o i and this should be equal to the same formula that you see on the right here this allows you to compute one output row in the v matrix okay and we know that pij is just the output of the softmax the output of the softmax is the input of the softmax to the exponential of the input of the softmax divided by the normalization factor associated with that row so because we are iterating through the row of i it will be the i-th the normalization factor associated with that row of of o i so we know that the formula for the p is equal to the softmax of s now the i-th row of p will be the softmax of the i-th row of s and this is what is written here we know from our derivation that the jacobian with respect to the softmax operation so if we have an input x and the output is y of the softmax operation the jacobian of this of the y with respect to the x is equal to the diagonal y it's a diagonal matrix of the element of the factor y minus y multiplied by y transposed and we have also seen before that this matrix is symmetric however you may not understand this formula here because we have seen from our in the chain rule we always write it like this we always write that the downstream gradient so the d phi of let's say t x should be equal to the upstream gradient so d phi with respect to d y multiplied by d y and with respect to d x this only works if you make this matrix here as a in the numerator convention the numerator convention is one of the two convention in which you can create a jacobian we so far we have always written it as the numerator convention if you use the numerator convention this is a row vector and this is a row vector however if you want to treat this stuff here as a column vector then you need to take the transposed or you need to make the jacobian in the denominator convention how to get this formula here because this formula here is basically doing the jacobian multiplied by the upstream gradient not the gradient upstream gradient multiplied by the jacobian and it's only because here we treat it as a column vector and when you do the you want to transform a row vector into a column vector you take the transpose of both sides of the equation and let's do it actually so we apply the transpose to the both side of the equation okay in a matrix multiplication if you do a b transposed it become b transposed multiplied by a transposed so the transposed is applied independently to each input of the matrix multiplication but we invert the matrix multiplication and if you remember the matrix multiplication is not commutative so what we do here is that we say okay it will be the dphi of dx and here they call it here they call it dsi so it will basically just become d phi on dx if you treat this one as a column vector so this one as a column vector will be equal to dy on dx as a column vector as a jacobian in the denominator layout in this case multiplied by d phi on dy as a column vector this one is a column vector this is a column vector and this is what you see here that's why the jacobian is on the left side of the upstream gradient what else we need well i i know that there is a lot of things here in this derivation but i prefer actually going directly to the code otherwise i think it's going to be too boring um so let's go to the code and while writing the code i go back to the formulas in which we can find the association of what we are doing and the formula in the paper i think this is the best way so let's proceed further all right guys now we can finally code the backward pass before we code the backward pass let's look at the algorithm of the backward pass as written in the paper this is the paper flash attention one and i will be because we will follow the structure of the code that is present on the triton website so it's not my idea to split it like this but i simplified it in such i simplified it so it's different than the one that you can find online because mine is a simplified version and mine works with the causal and non-causal attention um so first if you look at this algorithm you need to you can see that we have an outer loop through all the k and v blocks and an inner loop through all the query blocks however as you can see to compute the dq which is the downstream gradient of the the loss with respect to the q matrix we need to have an iteration through all the k's and to compute each dk block we need to have an iteration through all the queues so if we follow the loop like it is it would involve writing to the high bandwidth memory so to the dram of the gpu at every inner iteration and that could be also that is not so efficient and also if we don't want to write it would require some sort of some sort of synchronization between blocks which is also not very efficient so we split we will split this four into two parts because we can see that each dq depends on a loop over the k's and each dk depends on a loop over all the queues so to compute dk we will fix the kth block and iterate through all the q blocks then we will do another iteration in which we fix the q block and iterate through all the kv blocks to compute the dq this is what we are going to follow and this is an idea that i took from the original implementation that is present on triton website another thing that we can notice here is um where where is it here to compute the dq and dk so a dq vector and the dk vector we need this element this information here called the di di and it's shared between the two so we can pre-compute it and then we can reuse it for the qi vector to compute the qi vector and the dk vector what is this di di is um is uh introduced here and it's the dot product of a vector that is the doi vector multiplied by o vector so the first thing that we will do is do a loop over all the vectors in o and do and do their dot products to compute this di element then we will use this di element and actually uh let me see yeah and then we will use this di element to update to to compute dq and dk and we will also have another two loops one in which we fix the q and we iterate through all the keys and one in we fix the keys and iterate to all the queues so let's start so now that we know more or less the structure of the code that we're with all right so we start by writing this backward function here uh let me check yeah okay so do you remember this is saved tensor these are all the information that we save during the forward pass uh to compute the backward pass now to to optimize the memory utilization in flash attention we don't save the query multiplied by the transpose of the key matrix because that would be a sequence by sequence matrix that is too big to save into the hbm in the dram during the forward pass and then i re get it back from the hbm into the local memory because i want to remind you that in triton uh compared to cuda in triton what we do is we load stuff from the high bandwidth memory in the shared memory so the sram we do all the operations there and then after when we call the store method we save the element from the shared memory into the high bandwidth memory so in order to not materialize this s matrix in its entirety save it to the hbm and then reget it back which could be very slow and secondly actually it is very expensive because usually right now we are computing attention on thousands and thousands of tokens so imagine saving a matrix that is 5000 by 5000 that's a big matrix to save for each batch uh for b each batch and for each head so that would be really too expensive to save so the idea in flash attention is to recompute what we can compute on the fly during the backward pass because any way if we were to load it it would be memory i/o bound so it's faster to recompute than to save it and restore it from the memory this is the idea of flash attention okay so we saved some stuff during the forward pass and now we can access it back during the backward pass and this stuff is saved in the context and this it's a it's a kind of a dictionary that is made available by by pytorch all right so we get back the query key and values and as you know pytorch during the autograd will just give us the gradient of the loss with respect to the output of our implementation of the attention of our attention so this is triton attention and then we need to compute dq dk and dv by using only the gradient of the output with respect to the the loss with respect to the output um we do for some checks so here i know i could optimize this code and make it even smaller by for example checking that here the stride that i am using i actually inside of the code i always uh pretend that the stride is the same but uh doesn't matter i just take the code from triton and uh try to simplify it my goal was to simplify it not optimize it so all right we create the um the vectors the tensors in which we will store the result of this backward pass which is the dq dk and dv and as you know from what we have seen of the definition of the gradient the size of the output of the gradient vector is the size of the vector with respect to which we calculate the gradient because in the numerator is always a scalar and we compute the gradient with respect to all the elements in the input vector so the output the gradient itself is a vector of the same size of the element by which we compute the gradient with respect to so uh we get some information on the bed size blah blah blah and later we will see what is this number of warps and the number of stages i will not explain it now it's how pytorch number of parts warps is an indication on how many threads we want to launch in our grid and number of stages is next to the number of stages that has used in software pipelining we will see later what is software pipelining when we talk about the auto tuning then we define some uh blocks uh in the original um in the original code i think they call it a block kv1 kv2 q1 and q2 i think it was confusing i call it a block macro and block micro because the thing that we will fix and the things that we will iterate from will be once it's the query so we fix the query block and we iterate through all the keys and then we will fix the keys and values block and we iterate through the queries the one that we iterate on is the micro one and the one that we fix is the macro one this is my uh the naming that i am using um then we as i said before we need to pre-compute the di elements that we saw in the paper before so that's the first kernel that we are going to launch and this kernel will have its own launch grid because later we want to optimize the the tuning of this kernel later we will talk about the tuning with respect to its own parameters so let me see what are we going to do so here so the first kernel that we are going to launch is this pre-process kernel this pre- process kernel will pre-compute all the di elements that we need to compute i remember dk and dv if i know dq and dk and this di element depends only on o and do um so let's do it and let's create another function called the backward preprocessor what is the process preprocess grid this is the launch grid of this function of this kernel and this will be launched on a independently for each batch and for each head and moreover it will be work with a block size of vectors of o what is this block what is this number of vectors of o it will be the block size macro so on 128 vectors of o so uh let me copy the signature of this function this is here so let's write it here i think it's fine yeah okay this function takes a the matrix o so it's a pointer to the matrix o it's a pointer to the d o and it's a pointer to the matrix d where we will store this di elements and we have one for each vector in the output that's why the shape of this d is a batch size number head sequence length it means it's one for each of the output element in the output of the attention this di where is it actually it's not this one it's this one yeah like m so it has the same shape as m which is as you can see it is this size here so batch size number heads and sequence length m if you remember is the matrix that we saved during the forward pass which includes the normalization factor of the softmax and also the maximum element but in log sum exp format so that when we apply it will automatically apply the maximum element for each row and also normalize at the same time which i think i proved previously so let me do it so we write it like this so we extract the the index of this program so this program has two index like identifier this is equivalent to the cuda identifier and this is along the axis 0 so let's see what we what we what did we launch on the axis 0 so on the axis 0 of this launch grid we defined what is the block of vectors of the o that this particular will program will work with and the second axis is which batch and which head inside of each batch this particular program will work with so this identifies the block index of q so which group of vectors in the o matrix this particular program will work with here is called q i believe because i copied it from the original code where they call it q but i could have eventually also call it o um so we define uh so basically this means that we are for this program we need to skip some query vectors that have been already or that will be or have been already processed by other programs in parallel so we will only block with a number of query vectors inside of o that have the following indices so imagine that the query block size is i think it's 128 the way we have defined it but suppose it's a 4 for simplicity so this one will be and the query vectors are how many are sequence length number of query vectors we have so some of imagine the query vectors are in total they are i don't know let's say uh 64 and 32 will be managed by other programs so this particular of skew will be equal to 33 34 35 and 36 this tells me which query vectors or which vectors in the output o matrix among all the vectors in the o matrix this particular program is going to work with okay so then we extract also the index of the batch which tells us which batch and which head in each batch this particular program is going to work with which is the dimension one of our launch grid and then we define the offset of the dimension because we need to load all the dimensions of each vector so these are the it's a vector that tells which dimensions we need to load from each vector and we will load all of them so we don't divide on the head dimension dimension we just divide on the sequence length dimension the the load among multiple programs um you will see in this part of the the video so when we are writing the backward pass that we will not be using the make block pointer like we did during the forward pass so this function here we will work with directly with indexing by using the strides so let's do it so let's load a single block of rows of o which i want to remind you has the same shape as q and that's why we can call it block size q so the o block that we are loading is o so uh the load function accepts a pointer to what it should load actually not a pointer it accepts a array of pointers or a multi-dimensional array of pointer in case you want to load a multi-dimensional data so actually load also allows you to load two-dimensional data in this case we are going to load two-dimensional data which is a block of rows of o which should be a block a tensor of the shape block size q in this case multiplied by the other dimension being head dimension but we don't we need to tell it where in this o matrix it needs to find this one first of all we need to skip some batches and some heads based on what the head and the batch that will be processed by other programs so based on the index that this um program will process of the batch and the head we need to skip all the other batches and heads let's write the shape of this tensor so the o tensor has a shape block size not block size batch size a number of heads then sequence length and then head dimension each block and each head will have a sequence length multiplied by dim head dim number of items so based on our index we skip how many items our index multiplied by head dimension multiplied by sequence length so what i mean is this the batch zero and the head zero will have a sequence length multiplied by head dimension items the batch zero and the head one will also have the same number of items and the batch zero and head two will also have the same number of items so how many items sequence length multiplied by head dimension do we need to skip from the starting of the o tensor it is equal to the index of the current batch and head indicator so because this index indicates both the head in the batch and the head inside of each batch because it's already the product of the head and the batch so how many we skip indicated by the this index and after we point to this starting point of the current batch and the current head we need to select a two-dimensional tensor where the offsets are indicated for the rows by off skew and that's why we have this one um the i don't know what this is called this is uh the the index uh semi-colon index that tells all the all these vectors in off skew will with an additional dimension for the columns and these columns will be the off dim so basically this will select a tensor of the following shape inside of this big tensor that includes head size and number of heads this is what we are doing so we are saying select a tensor of this size inside of one that is made up of four dimensions by skipping the elements of all the batch and heads that will be processed by other programs i always talk in terms of programs because in triton these are called programs in coda you would refer to them as kernels all right so this one is done i hope it is decently clear um all right so then we also load a single block of d o in the same way because we are going to load a group of vectors from all the sequence length also from d o and the d o has the same shape as o which has the same shape as q and that's why we can use the um the the block index we call it q because it's equivalent because they have the same shape okay and how to compute this d i element well it's written in the paper so if we go in the in the what is it man if we go here it shows you how to compute the d i of each given a block of d o and a block of o it tells you how to compute d i which is the row sum which means the sum of by rows for each row we will have one sum for each vector in the o matrix we will have some of the element wise product so this stuff here is the element wise product of d o i multiplied by o i so it's not a matrix multiplication it's element wise product which means each element of one matrix with the corresponding element of the second matrix and the output shape it will be the same as the two matrices which must have the same shape okay so we compute this d i block which will have shape block size q because we will have one sum for each vector then well we need to store it somewhere so we need to calculate where to store it inside of the d matrix well the d matrix is i remember correctly has the same shape as m so it should be batch size a number of heads and sequence length so we need to select the right batch and the right head and also the right position inside of the sequence length based on the block index q that we have okay so let me index okay all right because we already um so we skip um again just like before we know that the d is of this size each batch and each head will have sequence length number of elements so how many number of elements we need to skip from the starting of the tensor is sequence length multiplied by the combined index batch size head number and plus we need to also skip some queries based on our block index q and it's already this skipping is already done inside of off skew so we add off skew and then once we have computed the index where we should store this d i block why did i even call it d block let's store it so let me i didn't call it d block i think it was already in the original code but this is d i and this big matrix d is actually the matrix that includes all the d i for one for each token in the sequence length all right so the pre-processing has been done now we need to do prepare the two for loops as you remember i said before we will be doing two for loops one in which we fix the query and we iterate through all the keys and values and one in which we fix the key and value block and we iterate through all the queries and while coding it i will always show you the formula from the paper so don't worry let's start with the next iteration so first we create the launch grid for the next iteration as the launch grid is always the same so we first because we we need to keep one block fixed and iterate through all the other blocks the block that we keep fixed will define how many programs we have that run in parallel and the block that is fixed has a block size macro number of elements that's why we create a sequence length divided by block size macro number of blocks thread blocks or programs in this axis the axis two in this grid is i could have used also the axis one indifferently i think it was already done here in the original code it's we will indicate which batch and which head inside of each batch we are going to work with so and just like the forward pass we will also use a variable called the stage that if the attention that we are computing is causal it will be equal to three and if we are computing a non-causal attention then it will be equal to one the first iteration we will fix k and v blocks and we will iterate through all the q blocks in size of block size micro number of query vectors so let's look at the signature so we pass we launch it as a launch grid because and we have defined how many programs we have so we have how many kv blocks we will have it's a sequence length divided by the block size macro because that's the the block that we will keep fixed in this uh for loop in this function and then we go through all the query blocks in size of block size micro which i defined it as 32 and later we will talk about auto tuning and how to tune these values all right so i pass the query vector the key vector and the v vector uh sorry not vector tensors now the query tensor k tensor and v tensor and they are pointing to the beginning of the tensor which means that they are beginning to the first batch and the first head and the first token and the first dimension of the tensors then we pass the softmax scale we pass do dq dk and db m is the one that is needed to compute as you remember from what we said before we did not see the p matrix in the hbm because we want to recompute it on the fly doing the backward pass so the query multiplied by transpose of the keys it's a very big matrix to save in the hbm and restore it so we want to compute it on the fly but we don't need to recompute the normalization factor and the maximum element for each row to apply the softmax that was already computed during the forward pass and saved into this matrix m which includes the log sum exp of the maximum of each row plus the logarithm of the normalization factor with the log sum x to 3 we can just apply it and it will also normalize each value then we have the d vector tensor that we computed here with all the di values one for each vector in the o tensor then we need to pass some the number of heads the sequence length the block size that we want to use for the kv which is the macro block size and the micros block size is always the one that we iterate on i think using this name it should be easier to understand which one we are iterating and which we want to keep fixed so the fixed one is macro and the iterating one is the micro head dimension later we will see why we use a different block size to iterate from because this is related to the number of stages that triton can divide your for loop into thanks to software pipelining then we have head dimension the stage indicates if the attention that we computed in the forward pass was causal or not causal the number of warps and the number of stages which we defined as fixed but later we will talk about auto tuning so sometimes i repeat the same stuff over and over so i should change that okay let's write the signature of this function let's put it here so we already described what is the signature of this function let's go directly to the meat so the first thing that we need to do is understand the offset by which we need to move this query key and value and the offset is given by the first wall we need to enter the right batch and the right head inside of each batch we compute the index of the batch just like during the forward pass by dividing the program the program index which is a multiplication of the index of the head and of the the batch we divided by the number of heads to get which batch this program is working with and to get the head we just do the modulus just like in the for loop for one person the offset batch head indicates let me check what is it for okay it enters the right batch and the right head so what is the stride if you remember correctly the stride tells us how many items you need to skip in that dimension to arrive to the next index in the same dimension so if we want to skip index number of batch we multiply it by the stride batch which is how many elements you need to skip to arrive to the next batch plus we also need to enter the right head so we multiply the index of the head multiplied by the stride of the head to enter exactly in that head in the tensor for each of the q k and v matrices plus we also have this is will be used for if i remember for m and d because m and d only don't have the um the head dimension head dimension so they are only batch size number of heads sequence length so we just use the index batch multiplied by sequence length because for each batch and on each head we will have sequence length number of items so you can think of it at the stride to move from one batch head to the next batch head uh or to the yeah so uh let's move the pointers and this was so we move the pointer q k and v by the offset batch head because we want to enter the right um batch and the right head inside of these big tensors and we do it also for d o d q d k and d v because they have the same shape as a q k and v and d o also has the same shape as q so they have the same shape so we move by the same uh by the same offset all right so then we move m and d to move them to the right starting point on which the sequence of the current head and the current batch and the current head starts so they are pointing to the first vector of the sequence dedicated to the current batch and the current head and the same is true for q k and v and the d o d q d k and v okay then we load some other stuff because here we fix in this iteration in this method we are going to do a for loop in which we fix k v and we iterate through q so we first need to load this deeps block of k v and we do it as follows as follows so we know we need to load a 2d tensor so we need to define what are the ranges in the second dimension of each vector k and v that we need to load and it's defined by this by this vector then we want to understand which kv block this particular program is going to work with so this particular program is going to skip some kvs that will already be managed by other programs that may be running in parallel and how to understand what this program should be working with in based on the index of the program zero which is defined on the sequence length divided by the block size macro and if you remember block size macro is the thing that we fix so it's telling us this program id zero will tell us how many block size macro kv are already being managed by other programs so we shouldn't care about them so we skip them so let's go back here and this is the number of vectors that we need to skip so our kv start from start kv and how many we need to load them well depends on what is the block kv this block kv is equal to block size macro so it will be 128 vectors so we define our tensors two-dimensional tensors that we will store in the sram because in triton every time you load something you load it from the hbm into the sram so we define where they should be saved in the sram and they are initially zeros and now we load them so we load them as follows we say that okay in the k in the k tensor pointer which is already pointing to the right index to the right batch and to the right head because that's something that we did here we say we should need we need to load the right sequence of keys which should start from offski because this already includes how many we should skip in the sequence length dimension and for each of these vectors we need to load all the dimensions in the in the head dimension dimension because the k if i want to remind you is batch number of heads sequence length and head dim now by using this line we are skipping to the right b and to the right head so it's like we already indexed here and here we already selected an index so right now this k is pointing to the beginning of a tensor of two dimension and we tell okay we don't want all the sequence we want some part of this sequence which part the one that is indicated by this start kv and how many of in the sequence length we want well we want uh all right i think it's easy to write it like this so we can write it that from start kv to start kv plus block kv uh so we want this number of tensor exactly at this location and for head dimension what do we want to select we want to select all the dimensions so we say that we want from zero to head dimension which is exactly this offskdim okay uh we do it for the k block and we do it for the v block here i think i didn't change the comment this should be block kv and this should be block kv before it was called the block kv1 right like in the original code i simplified a little bit the naming i think this one is better easier to follow because in the original code they also do for two for loops but in the second for loop they will do it backward just to not change the structure of the loops but i think mine is more verbose but easier to understand and probably less efficient mine is much less efficient um then we have offsq because we need to understand for each block of queries how many vectors we need to load and it's indicated by this offsq and how many are them it's a block q block q in the color of this method was block size micro so it is 32 vectors okay um now we need to access q vectors and o vectors trans uh no q vectors but already transposed and the o vectors also we need to access them because we are going to iterate through queries and o vectors actually also why because let's look at here let's look at the formulas in the paper to compute vj so to compute the dvj that's what we are trying to compute here we need to iterate through all the do vectors and to compute dk we need to iterate through all the qi vectors because the qi is a block of vectors so that's why we need um and why do we need to access a q as a transposed because we need to compute let me show you here pij transposed to compute pij transposed we need to we need the q transposed because the pij would be the softmax of the query multiplied by the transpose of the keys after we apply the softmax it becomes p but if you want the transposed of p then you need to do query transposed k multiplied by query transposed so that's why we access the query transposed instead of queries and the way we access the query transposed is just by playing with the stride so let's do it like this and i have also written the comment on why we can do it so this is equivalent to accessing the query uh how many first okay what is this um what is this operation uh what is this operation here this is saying go to the query starting point starting um pointer to the query which is already pointing to the right batch and to the right head for which this particular program should work with and select a two-dimensional vector where you repeat the query starting point along the in this case along the columns but we should be repeating it along the rows because we want to select rows of queries however if we want to select the query transposed we just invert the two dimensions so this is a let me actually show you without doing the query transposed so let's do it simplified like this so to access the query um the query pointers without transposition we can just do like this go to the query tensor and create a 2d tensor where in the rows you put the starting point of each query that you want to get and and replicate each of these points also on the column that's the meaning of adding this dimension none this is equivalent to when you do in pytorch the unsqueeze like you are calling off q multiplied not unsqueeze i think one so this is equivalent to adding the column dimension to this tensor and repeating all the values that are on the row on all the um on the columns how many columns will be there it will be broadcasted when we sum it with this tensor here this is a combination of unsqueezing and broadcasting so we are taking the query vectors indicated by off skew and then we are for each query vector we are selecting all the head dimensions indicated by dim if you invert this broadcasting it will create the transposed of the the the query vector that you are trying to access so this stuff here is equivalent to the these two lines so accessing query and then transposing and uh it's something that you can do uh i could write down what is happening at the pointer level so basically you need to think of off skew as being a vector of pointers we multiplied by the sequence stride which tells us how many element we need to skip to go from one query vector to the next because each stride q will be the stride will will be equal to in the case the head dimension is 128 the stride of the sequence dimension will be 128 it means that to go from one query vector to the next you need to um you need to uh go forward by 128 elements because i want to remind you that in the memory the tensors are always stored like flattened like each dimension is flattened with the next dimension so imagine you have three rows and four columns but the first you will have the first three rows then the sorry the first row so the first four columns then the next four columns then the next four columns row after row it's difficult to visualize until you write it down so how to write it down take um create a vector of off skew so what is off skew at the beginning it's is a range that is from here from 0 to 100 no 0 to 32 0 1 2 3 4 5 6 7 etc etc we are multiplying each one with the stride of the sequence so this will not skip any element this will skip exactly 128 elements this will skip exactly implying that the head dimension is 128 this will skip two times 128 elements this will skip three times 128 elements and then we are adding also another dimension to this vector so this will be a vector then you broadcast it on head dimension number of columns and to each of them you add one number so it will become a vector like for okay let me just do it guys otherwise i think it's too confusing okay so we have a vector that is as follows so zero then we have 128 then we have two times 128 then we have three times 128 etc etc we are adding how many columns indicated by off dim so off dim has how many columns so it has a head dim number of columns please for simplicity let's pretend it's not 128 dimensions let's pretend it's four dimensions so this will be four this will be two times four this will be three times four we are adding another dimension that is the dim dimension each one multiplied by the stride of dim which will be one because it's the last dimension stride dim so we are adding how many columns four so we are adding um one zero one two three i guess zero one two three right also to this one we are adding oh my god zero one two three and also to this one we are adding zero one two three okay and then also to this one we are adding zero one two three so what this will select this will select from the starting point of the pointer q it will select the element zero then the element one then the element two and then the element three which is exactly the head dimension of the first vector that we should be selecting then it will select the element four from the starting point of the vector the element uh sorry this one let me write the result of this operation so this one will be zero one two three then it will select the element four five six seven then it will select the element um eight i guess nine ten eleven and then it will select the element 12 13 14 15 so from the starting point of where this q is pointing it will select the first element right after this q the second element right after this q the third element right after this q etc etc and this will be the you can see that this will be the first query vector this will be the second query vector this will be the third query vector this is the fourth query vector because in the memory they are stored one after another they are flattened so in the memory they are stored like this they are stored like the following they are stored like this one after another so it will select all of them and it also create a virtual tensor with the right shape that we want to visualize it into so as you saw as we saw before when you work with a tensor layout in memory you can always view it as whatever shape you like based on the shape that you want and the reshaping is always free doesn't involve changing the arrangement of the elements in the memory i hope now it's more clear so now we can proceed further oh my god it was quite complicated so whenever i get stuck i just draw things and i think you should do it too because that's the only way to learn if you try to imagine everything in your head it's always difficult and we do the same job for the o vectors so in the o vectors we don't access it as access it as a transpose because we don't need it in transpose only the q we need it in transposed okay it traced through the sequence dimension of the query so we start from the query number zero in the current um well in the query we need to go through the all the sequence length dimension because only the key we select the right key that we want to work with so i want to remind you here we fix the key and we go through all the queries but the query we need to start from zero until sequence length so the number of steps of this for loop will be sequence length divided by block q so if we have a 1000 elements in the sequence and block q is 32 it will be 1000 divided by 32 a bad choice of 1000 should be 1024 otherwise it's not divisible so then we go through each block in this for loop and we load a block of q the first one indicated by our pointer and at the end of the iteration we will move it to the next to the next block of q okay we'll add also the log sum exp values that are stored in the m matrix because we want to compute on the fly pt pt is the transposed of the softmax of query multiplied by the keys but we want to not take a query multiply by the transpose of the key and then do the transpose we just already access q as a transposed so we can already compute the pt instead of computing p and then transposing it um so we load the offsets of the elements that we need from this log sum exp matrix which is the m matrix that we computed during the forward pass and we access a block of q at a time the one we are currently working with in the iteration then we access a query key transposed already so we do the if you want to get the pt p should be um this is actually not p because we didn't do the softmax it's actually s t but okay if you want to get the pt you need to get the softmax of st the softmax of st is what it's a it's transposed of s what is s is a query multiplied by transposed of the key so to get st you need to do um key transposed no key multiplied by query transposed so as you remember in the matrix multiplication if you transpose the matrix multiplication you need to also invert the two element in the matrix multiplication so that's why we are doing a key multiplied by query transposed this will give us s transposed we are also scaling it with the softmax scale before we apply the to apply the softmax we just need to do the exponential of each element minus its maximum divide by the normalization value but with the log sum extract we just need to each element subtracted by the m value which already includes the normalization factor i think i already did the derivation of this so we don't need to go through that again okay so now we have the pt block actually so in this formula i should have written st actually okay then when doing the causal attention we also need to mask out some values so as you can see here so in this case the causal mask is applied after the softmax has been computed because during this one is you are used to compute the apply the soft the causal mask before computing the softmax attention but this is actually during the forward pass because you don't want the normalization factor to be affected by the element that should be zero but we already computed the normalization factor so it cannot be affected anymore so we can compute we can mask out after applying the software because the normalization factor has already been calculated based on the fact that we applied the mask and that's why we we can apply it after applying the softmax so the mask is always the same so if the query is more than the index of the query so the mask is true in this case for all the values that do not need to be masked so all the values that do not need to be masked are these ones here and all the other value will be replaced with the zeros all right so after we have the pt block already masked we can calculate dv dv i will write i will point to the right formula in the paper so we load a block of do why do we not load a block of do let's look at the paper so how to compute the dv block so the dv block is computed as the old dv plus so a repeated sum as you can see as you can see it's here plus equal the old dv plus pt so here pt dropped indicates the pij after applying the dropout in this implementation we don't support the dropout and also very few models actually use the dropout in the attention so pt multiplied by doi so a block of doi and doi is the same block that should be also doi and ki qi are referring to always the same block of rows in the respective tensors that's why because this inner iteration i indicates a block of q and a block of o but we are always referring to the same positions in the tensors because do has the same shape as dq so we go through the blocks of query and the do simultaneously because one is needed for dv so for dv we need do and for dk we need q and that's why we compute the dv as follows just like from the paper so pt block multiplied by do as you can see it's a p transpose multiplied by the o block so we have computed computed the do block then we need to load the di element that we computed pre-computed initially with the first call to the function called the attention backward pre-process because we will need it for dk so let's see and how many of them we are loading exactly the same number of query that we load because they are we load always the same number of block size micro number of vectors okay i will copy some stuff and explain it step by step so the next operation that we need to do is to compute this dk to compute the dk we need the dst to compute the st we need to to get a dpt so let's go one by one let's go from the back from the end to the beginning of this formulas so we don't understand where everything is used to where everything is created so let's start from dk if you look at the paper dk is equal to the old dk plus ds transposed multiplied by a block of q and this is what is written here so it is plus equal means basically just the old plus the new some it's an incremental addition so increment the old k with some new stuff which is this stuff here so the softmax scale multiplied because also there is a softmax scale this tau here multiplied by the matrix multiplication between dst block and the transposed of um and and q and q you can see here this q but we don't have a q we have a q transpose so we take the transpose of q transpose and it becomes back q now let's look at this dst block dst is calculated as follows so in the formula of the paper we have ds ds is here it is equal yeah it is here and it is equal to a block pij multiplied element wise with dpi minus di now um we don't need ds we need ds transposed so to compute ds transposed this is an element wise multiplication not a matrix multiplication which means that when you take the transport of this operation you don't need to invert anything you just need to take the transpose of the two operands so to compute the st we take the transposed of p which is the pt and we already have that and then the transpose of everything that is inside of the parentheses so this dpt minus di where we inverted the rows with the columns so this dpt is what well in the paper we know the formula for dp dp is here and it is equal to d wait dp here and it is equal to do multiplied by b transposed so but we don't need the dp we need the dpt and in this case it's not an element wise multiplication it is a matrix multiplication so um in order to get not a dp dp but dpt we need to take the transpose of these two operands of this matrix multiplication and in the matrix multiplication when you take the transpose you need to also invert the order of the two operands so we need to take the vt transposed which becomes v so the v block matrix multiplied by the other operand so doi transposed and that's why we are doing the transpose of do right now i'm not going through all the single pointers because i already told you how to check what a pointer is pointing to and what an offset is referring to i hope that now you have a better understanding on how these pointers work in triton which is also the same way in the in which they work in cuda because in the gpu we only get a pointer to the starting point to the starting address of the tensor and then we need to work out all these indices we have computed the dk block so we now go to the next query to the next block of queries and so the next block of queries because we are fixing k and v blocks and we are iterating through all the queries so we need to move the query transpose the pointers by stride sequence which means that how can we go from one query to the next and we multiply it with the current block q which is a vector which indicates the pointers to the current element in q that we are accessing and we do it also for do and we use the block q as element and the stride q because do and q all have the same shape okay after we have run the for loop of all the queries we can store this dk and dv block so we write it back as follows and this is the end of our function guys so we save the dv block exactly in the position inside of the current okay dv is already i believe pointing to the right batch and to the right head because we incremented it here and also in the case of dk then we need to tell it in the sequence dimension where they should save this block of k and v and this is indicated by this one we say and we create the the pointers just like before guys don't make me do it again it's a really easy if you write it down like you write this vector of key and values pointers which is not pointers actually they are a range of the of key and value that you need to take from the sequence dimension you add another dimension that is the column so you repeat each value in the columns and then you add the dimension here for the head dimension anyway after we have calculated the pointers where we should store the dk and the dv we store them in the the pointers of we store them in the dv i mean we store them in the dv tensor and the dk tensor what do we save we save the dv block and the dk block which is the one that we were incrementally changing in the for loop that we have written okay now that we finished this one we can go to the next function that will do the other for loop so let's do it okay so now we do the second part of the iteration which is this one so let me just copy it and then we we describe it uh let's write it here okay we use the same launch grid as before of course we need to declare this function and again we um we because the grid is defined for the block size macro for what is the thing that we keep fixed and then we in the side of the for iteration we do um steps of block size micro in this case we are fixing q and we are iterating through k and v because we need to compute dq right now we have computed dk and dv okay the i believe the arguments are the same as before so and actually this is also the reason why in the original implementation on the triton website the author decided to um to use the same for loop but with different arguments and i believe it was a little confusing so that's why i just separated them i just repeat the code twice it's the goal of this video is to be as easy to understand as possible not to be as efficient as possible so uh let's go uh here so let me copy the signature again let me define this function here okay so uh again we need to first move the query key and value uh to the right pointer so which will point to the exact batch and the exact head that we are working with in this program so um let's do it let me check where is the code here and the first part is exactly the same as the other for loop that we have written so let's go here and really is i just copied so it's exactly the same so we check what is the index batch head we move the query key value pointers to the right place the d o d q d k d v point to the right place the m and d to the right place exactly like before so i don't think i need to explain that again and then we load a block of q the one that we will keep fixed so dq let me load a lot of stuff here actually okay we define the offset that we will need to load the blocks of k and p in the head dimension because we are going to iterate in the k and v we will access them as transposed blocks so instead of accessing them directly as a k and v we access access them as a kt and pt and you know that that's possible just by changing the strides in this case because we are treating them as a 2d vectors we treat the offs kv when you want to access k as just not transposed but k you treat this offs kv as a row vector sorry a column vector so you repeat on the rows each k offset that you want to access in this case we are repeating it as a we are treating it as a row vector so it will be repeated on the rows um sorry it will be broadcasted on the column dimension and that's how you can access the transposed version of k and how you can access the transposed version of v another thing that we are doing is we are loading the q vector which vector well based on offs q which is the q vector which vector well based on offs q which is based on the start q which is based on the exact starting point in which this particular program should be working with because this particular program works as two dimensions the first dimension indicate which batch and which head this program should be working with and the second dimension which is the program index number zero indicates which among all the sequence length which query this particular program is going to work with this is indicated by the index block this should be actually q in this case i forgot to change the name so actually let me change it so it's index q because we start we skip some q how many q we skip based on the index of the current program multiplied by how many blocks have already been processed by the previous programs this will tell us inside of the sequence length what are the queries that this one needs to select so that's why we use the start query plus the range that is block q so imagine the starting query for this program among all the sequence length is 100 then this will load the query row 100 101 102 until 100 plus block q minus 1 this is the range that we of the query vectors that we will load in this program we load the block of q by using a q plus the offset repeated on the columns so we treat it as a column vector but we repeat broadcast it on the rows vector where each column will be one head dimension multiplied by the stride in this case we actually can also not multiply by the stride because the stride in the dimension dimension so the last dimension of the batch is one because to go from one um actually the stride um how it is defined the stride of the last dimension is always one because to go one element to the next element you should move to move it to by one element um so we load the dq which is the stuff that we are going to compute in this um iteration and then we have a do that we need to load and the do we use the same offset as q because the do and dq have the same shape and they work in the same way so we load a block of q and we load the corresponding block of o of do in this case and the do has the same shape as o which has the same shape as q plus we need to load also the m normalization factors which are in the m matrix which one the chorus the one corresponding to this particular group of queries that we are going to work with in this particular program we start with the offsets are the as you can see the offsets are the first block of kv starting from the zero position so because we will iterate through all the kvs and we start from the zero kv so the key key vector zero and the v vector zero and then we move by block kv number of vectors forward at each iteration i hope i didn't go too fast because most of the things that are written here are very similar to what we have already done in the the other for loop so i don't want to be you know repeat myself too much um what did matter is actually the formulas that we will use which is exactly the one in the paper so uh we go through these blocks of kv we load the first block of k transposed and v transposed which is loaded like this as usual you tell it what pointers the elements you want to load and what are the pointers of another element that you want to know that it will load the the block that you are asking triton to load inside of the sram so this stuff all reside in the sram and also q resides in the sram and also do reside in the sram um then we compute the query multiplied by the transpose of the keys because we need to compute the p block so the query the qk block is just the query in the current query block with the k transposed in the current query block and the current key block um why but we access the query the keys already as a transposed so we don't need to transpose it and anyway even if we did if we need to transpose it it's just um it's not it doesn't require any computation to transpose matrix we just access it in a different way because in the memory layout it's always stored kind of as a flattened array then we compute the p block which is the output of the softmax so each of the query key we subtract the log sum exp value for the this block of queries that's why for loading the m block we use the offsets of the queries that we are loading and as you remember the m block already includes also the normalization factor because each m is actually the maximum value for each row plus the logarithm of the normalization factor that when you apply with the properties of the exponential it goes into the denominator okay and then we apply again the autoregressive masking oops what did i do let me go back to the code here so we have the stage this one so when we launch the backward pass stage three indicates that it's a in the forward pass we computed the causal attention and the one indicates that we computed the non-causal attention so if we computed the causal attention in the forward pass we also need to mask out these elements in the backward pass so we check um we create the mask which tells us which index this mask is true for only for the elements for which the query index is more than the key index and if this is true then we uh we don't mask otherwise we mask um let's compute the next operation which is to compute dp and ds actually i let's compute directly dk and then we explain it like before so we start from the end and we go to where this stuff what is needed to compute it so if you look at the formula uh let me check this one we don't need i think okay let's go here to the ipad okay what we are trying to compute here is dq so dq as you can see in the paper is a dq is equal to the old dq plus tau which is the softmax scale which is this stuff here multiplied by the matrix multiplication between the ds and the k block so the ds block is here and the k block is the transpose of the kt block because we are accessing k already as a transpose block we could also access a k directly as not transposed block by inverting if you don't want to access it as a transpose block just do like this like here none this will treat it as a row vector and broadcast along the columns otherwise and also this one you need to change so this one you shouldn't need to change because this one you need to treat it as a column vector the dimensions but if you want to access it as a k transpose then you just invert these two operations i hope i didn't mess up anything so let's move forward um so okay we know that the formula for the dq is exactly the same as the as the paper one but what is this ds block let's look at the paper this ds block is coming from this stuff here so this i believe this stuff here ds which is a pi the p block element wise multiplication with the dpi minus di which is dpi minus di now what is the this p block the p block is exactly the output of the softmax which we already have what is the dp block well the dp block is exactly do multiplied by v transposed which is a do which we already loaded and it's here multiplied by the transpose of the v which we already load as transposed and this is how we compute the dq let's include then of course we need to move to the next block of keys so we increment the pointers just like before so we move to the next block of keys and values and also remove the pointers um just like before and then we need to store the result of dq and this way we only need to do one write to the hbm by dividing the for loop like the following so if you look at the original algorithm uh i i don't know if the original algorithm actually corresponds in to to the implementation that they did in cuda but i don't think so because it would not be so optimized but in the original algorithm in the paper they say that you need to go through all the keys and then while going through the keys you need to go to all the queues and for each queue that you visit then you need to write back the queue while you are updating it which is not optimized that's why we needed to do two for loops one in which we fix the query and we update the keys because each key is updated depends only on a particular block of queue on all the blocks of queue sorry and then we fix the queries and we iterate through all the keys because one block of queue depends on all the blocks of case and this is why we split and this is the second loop that we have written now we have written everything that we needed to for flash attention um the forward pass and the backward pass so uh we should be ready to uh launch the uh the kernel i hope i didn't make any mistake in copying the code so i don't think i will try to launch it and if there is any error i will just use my reference code which i have already written that i used as a copy the only difference up to now between my reference code and the one that we have written is the auto tuning which i didn't explain so let's talk about the auto tuning so the auto tuning is also something that was already present in the original paper and i kept it as is uh i removed the auto tuning for the backward pass but in the forward pass you if you check there is this code here that indicates the auto tuning configuration for triton so triton basically cannot know beforehand what is the best block size or what is the best block size for the query or what is the best block size for the key and values or what is the best block size for another dimension that we have we need to try based on the hardware that we are running on based on the availability on the sram based on the thread coarsening that triton can apply so i didn't talk also about thread coarsening basically in cuda you can choose if each thread does one atomic operation for example in a matrix addition each thread is doing one addition of one particular element of the output matrix or it's managing multiple elements this is called thread coarsening and i think i didn't check the documentation but i believe triton does it for you based on the block size that you give it and the number of warps that you want the number of warps is what is a block of threads of 32 threads that work cooperatively running the same instruction always at the same time the number of stages is more interesting it's an optimization that triton does basically it is not loop unrolling so actually let's talk about uh let's talk about software pipelining because this is the last part that we need to understand from this code which is the auto tuning so i believe that the most interesting part here is not choosing the block size q and the block size k because that is just a kind of you try whatever whatever configuration works best based on the timing through cuda triton will actually run all these configurations for you every time the sequence length or the head dimension changes and for every pair of head dimension and sequence length it will choose the best configuration that runs in the least amount of time that gives you the best throughput actually so let's look at this numstages what is it and how it works so let's do it okay so software pipelining is it's used when you have a kind of a for loop so you have a sequential operation in which each iteration does not depend on the previous iteration so the operations that you're doing in one iteration are independent from what you have done in the previous iteration which is more or less what we have done before in our for loops actually there i believe there are how to say conditions in which this doesn't have to be true so like the operations can depend on each other and you still can do software pipelining so for example imagine you have the following for loop for loop that rose from one to n and first you load some data then you load some other data then you do a matrix multiplication and then you store some data so here you are reading data here you are reading data here you are computing some stuff and here you are writing data if we look at what happens at each iteration we will see the following picture imagine our gpu is made up of a compute unit and a unit that is dedicated to loading stuff so reading from the memory or writing to the memory what we will see in the time scale is that at the first iteration first we are reading some data and the compute unit is idle because we need this data then we are reading some more data and the compute unit is idle because we need this data then finally we have enough data and then we can compute this operation and the reading unit is idle and then we are writing some data back to the memory and the compute unit is again idle and then it will be idle for another two time steps until it has enough data to run the computation so as you can see this is not very efficient because at any time a point in time there is only one unit working and the other is sitting idle so one way to optimize this for loop is to do software pipelining and you can tell triton to do it for your for loops by telling it how many stages you want so let's see how it works so to pipeline a for loop means that first of all you need to convert all these operations into async operations and in cuda at least in the gpu of nvidia there are the async loading from the memory and the async load writing to the memory which basically means that i spawn a load operation and after and when i only i check if it has completed when i actually need it so i will spawn this operation and this instruction will return immediately and move to the next instruction here i will spawn a load iteration and this will return immediately and move to the next instruction and then i can compute but before computing i just check if these two operations have completed so i can spawn immediately two reads and then i just check if this they have completed so with the software pipelining what we are doing is we are pipelining operations of different iterations into a single iterations so first basically what we will do is we will do the read the first matrix that we need for computing this matrix multiplication then at this next iteration we read the the we we read the first matrix of the second iteration and also read the second matrix of the first iteration so i call it read a and read b which indicates read the first matrix of the that we need and the b means the read the second matrix that we need all these operations are asynchronous then i launch another asynchronous operation at the third iteration that says read the the first matrix of the third iteration and then read the second matrix of the of the second iteration and then compute the matrix multiplication because at the third iteration this one and this one should have completed but while computing the matrix multiplication i don't keep the loading unit idle because they are still computing this this and this load this can only work if you can spawn async operations so at the third iteration i can compute this matrix multiplication by using this one and this one because they should have finished but while i'm computing the matrix multiplication i already spawned some async operations to load the data necessary for the second iteration and the third iteration so at the fourth iteration i will spawn the loading of the data for the fourth iteration loading the data for the third iteration while computing the matrix multiplication of the second iteration because they should have already completed by now actually it's not like we expect them to have been completed there are primitives in the language in the CUDA language to check if the operation has completed so actually before doing the multiplication we will actually check if the async operation has finished so it's not like we just expect it we have finished it with respect to time this is like in javascript you have these things called promise i remember and you can wait for the promise to be finished before you actually need them but you can spawn as many promise as you want in C# i think they are called tasks so you spawn as many tasks as you want and then when you need it then you just wait for them only the one that you needed while the other are still running in the background asynchronously this is the whole idea of software pipelining software pipelining as you can see only works when you have async operations and also it increases the memory requirement for your program because when matrix multiplication one is going to run we may have enough data for the first two iterations plus half data for the third iteration so we increase the memory requirement for the SRAM okay and the Triton will do this software pipelining for you it will convert all the load all the stores and maybe also the matrix multiplication into async operations and do this pipelining for you if you are confused by how it works there is another easy solution to explain you how it works because it's already something that we do in model training it is called pipeline parallelism so in pipeline parallelism it works as follows we have a very big neural network that does not fit in a single gpu so imagine this neural network is made up of three layers layer one layer two and layer three but this is so big it does not fit entirely in one single gpu so one way would be to put this each layer into one gpu so we put for example layer one into gpu one layer two into gpu two layer three into gpu number three so imagine we have an input for this neural network so we put it to the first gpu the gpu one will process the layer one and generate some output which will be transferred to the gpu two which will calculate its own output and transfer it to the gpu three which will compute its own output and finally we will have the output of the neural network the problem is when you send the output of the gpu one to the gpu two for the gpu two to do its own thing the gpu one now is free so it is a waste of resources we could always should keep the gpus busy so what one thing that we can do is instead of sending all the the mega batch to the gpu one we send many smaller batches how does it work imagine that we send the batch number zero so batch zero uh to the gpu one the gpu one will compute its output and send it to the gpu two so now the gpu two is computing the batch number zero so now the batch zero is not here anymore but now the gpu one is free so we send another micro batch called the batch one then the gpu two will finish processing the batch zero and we'll send it to the batch to the gpu number three so now the gpu three has the batch number zero and the gpu two now is free so we transferred and hopefully also gpu one has finished so we transfer the batch number one from gpu one to gpu two the batches and then the gpu one will be free so so we transfer here becomes one and now this one is free so because it's gpu one is free we can introduce another batch so batch number two etc etc etc so we always introduce when while moving one batch from one gpu to the other we introduce a new batch at the beginning of the pipeline and they shift by one position at every iteration this will keep the gpus always busy there is only one problem of the pipeline parallelism which is the this bubbling effect because to create this pipeline you at the beginning of this um okay actually in the pipeline parallelism you also have the problem of the backward step so the backward step has to run exactly in reverse in the order in which you receive the micro batches while in triton when doing software pipelining you have the problem of the prologue and the epilogue because you need to create this pipeline and and to start the pipelining and at the end of the pipeline you need to use all the stuff that is currently in the pipeline so only in the beginning step and in the last step of this for loop your um all the units of this gpu may not be working simultaneously which what does it mean it means that in order to use pipelining you want the number of iterations of your for loop to be much more bigger than the number of stages in which your iteration is divided into in this case we have four stages these are called stages so you want the number of iterations to be much more to be much larger than the number of stages all right guys finally i have completed the video um i hope that you learned a lot from this video i believe that we can run the triton code so let's run it actually uh let's see i copied everything i believe we also put the code to test it but we didn't call uh put the uh main method which we can copy right now i hope there is no error so i really hope there is no error i really hope so um let me check if i am in the right machine i am so let's just run program pray if there is an error i will just copy my own reference implementation but i hope it works because otherwise i forgot something so i'm running my code on an h100 because my company has h100 if you have a smaller gpu what you can do is you can reduce the sequence length you can reduce the batch size i think it's already one when we call it uh oh no the batch size you can reduce the batch size the number of heads the sequence length you can even put head dimension equal to 8 and sequence length equal to 16 let's check run backward triton backward returned an incorrect number of gradient expected 5 got 1 we probably forgot some return statement i believe yes so i forgot the return statement here so in the backward pass after running the last for loop we need to return the stuff that we have computed cross finger again okay passed so the backward pass that is computed by torch it is equivalent to our backward patch up to 10 to the power of minus 2 error absolute error so when you as you can see this backward that we run here is different than the backward that we run here because when you apply triton attention it will introduce a new computation graph in the computation graph of our tensors that will include this triton attention operator and when pytorch want to compute the backward pass it will just call the backward function of this triton attention to compute it and it will populate the grad value of all the tensors that are the input to this triton attention and this is how pytorch autograd works guys thank you for watching my video guys it has been super super super demanding i spent many months first of all to learn myself about the triton about cuda about flash attention etc also i have a full-time job so it is really hard to make videos like this like i need to dedicate you know the nights the mornings the weekends i spent three days just to record this video because sometimes i don't like how i explain something sometimes i make mistake or sometimes i need to restart because what i'm doing is wrong etc etc i believe there should be no big errors in what i have done so far but for sure my notation is completely bad like because all the mathematics i know has been self-taught by i i learned it by myself so because i didn't learn it in academia i have bad habits and i'm trying to get rid of them so i use the very bad notation sometimes i calls with the capital letter sometimes with this lowercase sometimes i just forget the index etc so i'm trying to solve these problems um i believe i have explained everything so i should be you should have all the knowledge to derive all the formulas that you see in the paper of the flash attention and you should also have an internal image on how the back the the attention calculation is working block by blocks i know that i could have spent 20 hours explaining things better but i also have a life and i also have a wife so i i i cannot make a 100 hours videos also there were some interruptions making these videos i i removed some wisdom teeth so it took me at least one more than one week to to to recover because it was so painful so thank you guys for watching my video i hope you learned a lot also this time i as you can see triton is something new there is not much documentation so something that i have said about triton may not be totally correct because really there is very little documentation so all the triton that i have learned is by looking at the code written by others and try to understand it um and i think that's it guys so i wish you a wonderful day and see you next time on my channel