back to index

Flash Attention derived and coded from first principles with Triton (Python)


Chapters

0:0 Introduction
3:10 Multi-Head Attention
9:6 Why Flash Attention
12:50 Safe Softmax
27:3 Online Softmax
39:44 Online Softmax (Proof)
47:26 Block Matrix Multiplication
88:38 Flash Attention forward (by hand)
104:1 Flash Attention forward (paper)
110:53 Intro to CUDA with examples
146:28 Tensor Layouts
160:48 Intro to Triton with examples
174:26 Flash Attention forward (coding)
262:11 LogSumExp trick in Flash Attention 2
272:53 Derivatives, gradients, Jacobians
285:54 Autograd
300:0 Jacobian of the MatMul operation
316:14 Jacobian through the Softmax
347:33 Flash Attention backwards (paper)
373:11 Flash Attention backwards (coding)
441:10 Triton Autotuning
443:29 Triton tricks: software pipelining
453:38 Running the code

Whisper Transcript | Transcript Only Page

00:00:00.400 | Hello guys, welcome back to my channel. Today we are going to explore FlashAttention.
00:00:04.640 | Now, we are going to explore FlashAttention from first principle which means that not only
00:00:09.600 | we will code FlashAttention, we will actually derive it. So we pretend that the paper, FlashAttention
00:00:14.640 | paper, never existed and we look at the attention computation and we look at the problem it has and
00:00:20.240 | we try to solve it step by step pretending that FlashAttention never existed. This will give us
00:00:25.120 | a deep understanding of how it works and also we will combine theory with practice because we will
00:00:30.320 | code it. Now, in order to code FlashAttention we will need to write a kernel for our GPU and in
00:00:37.120 | our specific case I will be using an NVIDIA GPU so a CUDA kernel but instead of writing C++ code
00:00:42.560 | we will use a Triton which is a way of converting Python directly into CUDA kernels that can run
00:00:50.000 | directly on the GPU and Triton you can think of it as a compiler that takes in Python and
00:00:54.880 | converts it into something that can run on the GPU. So let's look at the topics for today. First
00:01:00.880 | of all I will give an introduction to multi-head attention because we need to look at what is
00:01:05.040 | attention and how it's computed and what are the problems in computing this attention.
00:01:08.960 | Then we will look at actually the most critical part of the attention computation is this Softmax
00:01:14.080 | and how it impacts the computation and complexity. We will look at what is online Softmax. Then we
00:01:20.240 | will explore what is the GPU because we are going to write a kernel that will run on the GPU so we
00:01:24.480 | need to understand what is the difference for example the CPU and the GPU and what is a kernel
00:01:28.480 | and how it differs from a normal program that you write for the CPU. We will look at how tensors
00:01:34.560 | are laid out in memory so row major layout, column major layout or etc strides. We are going to look
00:01:41.760 | at block matrix multiplication, Triton, software pipeline and all the optimization that Triton does
00:01:46.640 | to our code. Finally we will be able to code the FlashAttention forward pass but of course we are
00:01:52.320 | not satisfied only by coding the forward pass. We also want to code the backward pass but in order
00:01:57.040 | to code the backward pass we also need to understand how autograd works and the gradient descent works
00:02:02.400 | in the case of custom operations so we need to understand what are derivatives, what are
00:02:06.880 | gradients, what are Jacobians and then we calculate the gradient of the common operations that we use
00:02:11.680 | in FlashAttention and finally we will have enough knowledge to code the backward pass. For this
00:02:16.800 | reason this video is going to be super long but I hope you don't mind because we are going to learn
00:02:21.360 | a lot. Of course you may be wondering all of this requires a lot of knowledge that you may not have
00:02:27.280 | but that's not a problem because that's my problem because in this video I will make sure that if you
00:02:31.680 | only have high school calculus so you know what are derivatives you have basics of linear algebra
00:02:36.320 | like you know what is matrix multiplication or what is the transpose of a matrix and you have
00:02:40.880 | a basic knowledge of attention mechanism so like for example you have watched my previous video on
00:02:45.040 | the attention is all you need paper and you have a lot of patience that should be enough to understand
00:02:50.880 | all of this video because all the topics that I will introduce I will always introduce them in
00:02:55.360 | such a way that I pretend that you don't know anything about the topic so we try to derive
00:03:00.640 | everything from first principle everything from scratch. Okay now that we have seen the introduction
00:03:06.000 | let's go see the first part of the video which is the multi-head attention. All right let's talk
00:03:12.160 | about multi-head attention. Now I am using the slides from my previous video attention is all
00:03:17.040 | you need so we can look at very fast at what multi-head attention is and how it works. I hope
00:03:23.120 | you remember the formula softmax of the query multiplied by the transpose of the key divided by
00:03:27.440 | dk all multiplied by b because we will be using that a lot throughout the video. Now multi-head
00:03:33.760 | attention starts from an input sequence or two input sequence in case we are talking about cross
00:03:38.400 | attention. In the simple case of self-attention we have one input sequence which is a sequence of in
00:03:44.240 | the case of language model a sequence of tokens where we have sec number of tokens and each token
00:03:51.440 | is represented by an embedding so a vector with d model dimensions. The first thing that we do
00:03:58.480 | is we convert this input sequence into query key and values through three linear projections one
00:04:06.240 | called wq one called wk one called wv which in pytorch are represented through linear layers
00:04:13.520 | and these linear layers are of d model by d model so they do not change the shape of the input
00:04:19.520 | tensor and then after we do this job of projecting them they become three different sequences one
00:04:29.200 | called query one called key and one called value so here i'm calling them q prime k prime and v
00:04:34.960 | prime then we divide them into smaller embeddings so each of this token which is made up of d model
00:04:43.920 | dimensions we divide it into smaller tokens each one suppose we have four heads each one will have
00:04:51.680 | d model divided by four dimensions so this one is a sequence of tokens where each token is not
00:04:58.400 | the entire token but a part of the embedding of each token and this one is a another part of the
00:05:03.920 | embedding of the tokens and this one is another part of the embedding of the token etc and we
00:05:08.720 | do this job for the query key and value sequence then we compute the attention as follows so the
00:05:15.520 | softmax of the query multiplied by the transpose of the key divided by the the square root of dk
00:05:20.960 | where dk is the dimension of each head so how many dimensions each head is working with and then we
00:05:29.360 | do the multiplication with v and this will give us the output of the attention mechanism for each
00:05:34.480 | head and this job is done independently for each head this should be clear to you if it's not please
00:05:41.120 | watch my previous video on the attention mechanism because we will be working with this scenario a
00:05:47.120 | lot now then we take this o the output of each head and then we concatenate it back in order to
00:05:57.200 | get the representation of each token as a full embedding so before we split this embedding into
00:06:05.360 | smaller embeddings this one here is called the q1 q2 q3 q4 then after we compute the attention we
00:06:12.640 | get back the output of each head and we concatenate it together to get back the full embedding
00:06:17.760 | dimension which is this edge here we run it through another linear projection called wo which
00:06:23.120 | will be the output of the multi head attention now flash attention is not concerned with all of
00:06:29.440 | these operations actually flash attention is only concerned with the operation that require
00:06:34.240 | optimization and the operations that require optimizations are this one so the softmax of
00:06:40.400 | the query multiplied by the transpose of the key divided by the square root of dk multiplied by v
00:06:44.800 | which means that the projection of the input sequence through wq wk and wv is not something
00:06:51.440 | that flash attention is concerned about because that's a matrix multiplication so when you use a
00:06:56.960 | linear layer it's just a matrix multiplication of the input with the weight matrix of the linear
00:07:01.760 | layer and this kind of operation so the matrix multiplication is one of the most um optimized
00:07:08.560 | operation that we have in the gpu because the manufacturer of the gpu usually also releases
00:07:14.480 | um the necessary library for computing the the matrix multiplication so actually these are quite
00:07:20.400 | fast and they do not require any optimization so flash attention will pretend that the query
00:07:25.520 | is has already been passed through by wq and the key has already passed through wk and the v has
00:07:31.600 | already passed from wb moreover flash attention will not be concerned with the projection with
00:07:37.840 | wo because that's also a matrix multiplication because the wo is always represented in pytorch
00:07:42.960 | as a linear layer so it's a matrix multiplication and matrix multiplication as we have seen are very
00:07:49.680 | optimized so there is nothing to optimize there but what we need to optimize in terms of speed
00:07:55.840 | is this operation here softmax of the query multiplied by the transpose of the keys divided
00:07:59.440 | by the square root of vk multiplied by v all right guys so now we have rehearsed what is
00:08:05.440 | multi-head attention i also want to give you a lot of visualization which is basically
00:08:10.240 | here in the paper of the multi-head attention we can see that we have the input that is v
00:08:18.160 | k and q so q k and v each of them runs through a linear layer which is the w q w k and w v
00:08:26.080 | then we do the scaled dot product attention which is done independently for each head so each head
00:08:33.760 | will do query multiplied by the transpose of the key divided by the square root of dk where each
00:08:39.040 | query and each key is not the full embedding of each token but a part of the embedding of the
00:08:43.680 | token because we split them into smaller embeddings and eventually we take all the output of each of
00:08:49.440 | this head which are computed in parallel so that's why you see this dimension edge in the depth we
00:08:55.200 | concatenate them and then we run them through w o what are we concerned with we are concerned
00:09:00.560 | with optimizing this particular block here the scaled dot product attention so let's start our
00:09:05.840 | journey one thing that is very important to understand is why do we even need a better
00:09:12.640 | implementation of the attention mechanism and if you look at the flash attention paper you will
00:09:17.360 | notice the following part this is the paper flash attention one and in the flash attention one paper
00:09:23.680 | they describe the attention implement implementation as it's done naively when using pytorch so first
00:09:29.760 | we do the multiplication of the query multiplied by the transpose of the keys then we apply the
00:09:35.680 | softmax to the output of this operation and finally we multiply the output of the softmax
00:09:40.240 | with the v matrix to obtain the output of the attention the way this implementation is done
00:09:46.160 | by pytorch without any optimization is as follows so we load the first of all these tensors are
00:09:55.120 | residing in the gpu the gpu is made up of two main memories one is called the hbm which is the dram
00:10:03.200 | which is the the ram of the gpu which is the 40 gigabyte of the a100 for example so it's the
00:10:10.880 | biggest memory that we have in the gpu and then there are there is the shared memory so the problem
00:10:19.360 | of the gpu is that accessing this hbm so the global it's also called the global memory it's very very
00:10:25.680 | slow compared to the shared memory however the shared memory it's much much smaller compared to
00:10:30.560 | the hbm and what they claim in the flash attention paper is that the operation of the attention is
00:10:37.200 | i/o bound meaning that if we keep accessing the global memory the overall operation of computing
00:10:47.120 | the attention is not because computing all these operations it's slow but because we keep accessing
00:10:52.880 | the global memory which is slow so we call this kind of operations i/o bound so the only way to
00:11:00.080 | improve this situation is to compute the attention inside the shared memory of the gpu which is much
00:11:06.480 | smaller which is much closer to the cores that actually do the computation so we will need to
00:11:13.120 | kind of also split the attention computation into smaller blocks that can reside in the shared
00:11:19.600 | memory and we will see later in how this is possible through block matrix multiplication
00:11:25.760 | and this is in the paper here they call it the tiling and it's a very how to say use the technique
00:11:32.640 | when doing when writing kernels for the gpu which are usually involve some kind of matrix multiplication
00:11:40.640 | so now we know what problem the flash attention is trying to solve it's trying to make sure that
00:11:47.120 | we do not need to access the hbm so the high bandwidth memory when computing the attention
00:11:53.680 | but copying only a part of each matrix inside the local memory so the shared memory of the gpu that
00:12:01.440 | is closer to the cores and computing a part of the output matrix there then copying that part
00:12:09.120 | to the output in that is residing in the hbm and keep doing it for all the blocks in which we can
00:12:15.120 | divide this query key and value matrices and later we will see how this blocked computation is done
00:12:21.920 | but also we will see that the biggest problem in computing this block computation is the softmax
00:12:27.280 | because the softmax needs to access all the row of the s matrix to apply the softmax because
00:12:34.720 | the the softmax needs to have a normalization factor which is the sum of all the exponentials
00:12:42.400 | of all the values to which it is applied row wise and we will see later how we will solve this
00:12:47.680 | problem so let's move on all right guys um okay when i say guys i mean guys and girls because i
00:12:55.840 | don't know in my usually i just say guys too you know but please girls don't feel excluded so we
00:13:03.280 | saw that first of all flash attention is only concerned in optimizing this softmax of the
00:13:08.960 | transpose of softmax of the query multiplied by three divided by the square root of dk multiplied
00:13:14.480 | by b and we need to introduce a little bit of notation so that we don't get lost in the future
00:13:20.480 | slides first of all this is the formulas i took from the flash attention paper but for now we
00:13:25.920 | let's pretend flash attention never existed so we are trying to solve the problem step by step
00:13:30.080 | now um we should treat this q as something that has as the sequence that is the output
00:13:37.760 | of the input sequence that has already passed through wq the k as something that has already
00:13:43.280 | passed through wk and v as something that has already passed through wv because we don't want
00:13:49.040 | to optimize the matrix multiplication because it's already fast enough another thing is let's
00:13:55.040 | talk about what are the dimensions of these matrices so we can then understand what is the
00:13:59.520 | the dimensions of the output of this operation so we will see treat q as a sequence of tokens
00:14:07.280 | with n tokens so n tokens where each token is d has d dimensions so lowercase d dimensions
00:14:16.160 | why because usually we take the queries and then we split them into multiple heads so we have we
00:14:22.480 | pretend we have already done this splitting so we pretend we are ready to cover input sequence we
00:14:27.040 | already run it through wq and then we have already split it into multiple heads and each of this head
00:14:33.280 | will do the following operation so the the one we already saw and so the usual formula query
00:14:40.400 | multiply the transpose of the keys and each of this head will work with these dimensions
00:14:45.920 | for the query for the key and for the value sequence so now let's look at the the dimensions
00:14:52.560 | of the output so the first operation that we will do is the query multiplied by the transpose of the
00:14:56.800 | keys where the transpose of the keys is a matrix that originally is n by d but become but with the
00:15:03.040 | transpose will be d by n so d by n and the result will be a matrix that is n by n because in the
00:15:10.480 | matrix multiplication the outer dimensions become the dimension of the output matrix
00:15:14.480 | what do what is the next operation that we do we take the output of this operation so the query
00:15:20.400 | multiply by transpose of the keys and we run it through a softmax operation and we will see what
00:15:24.320 | is the softmax operation which preserves the shape of the input so it doesn't change the shape of the
00:15:30.720 | input matrix it just changes the values of it and then we take the output of the softmax and
00:15:36.480 | we multiply it by v which will change the which will change the of course the shape because the
00:15:43.840 | p matrix is n by n so this one is n by n and v is n by d so this one the output will be n by d the
00:15:54.800 | outer dimensions of this matrix multiplication now let's look at the details of each of these
00:16:00.080 | operations so when we do query multiply by transpose of the keys we will get a matrix that
00:16:04.560 | is n by n where each value in this matrix is a dot product of a row of q and a column of k
00:16:13.840 | in particular the first element of this matrix will be the dot product of the first query with
00:16:19.840 | the first key vector the second element will be the dot product of the first query with the second
00:16:26.080 | key vector and the third element will be the first query with the third key etc etc and the
00:16:31.920 | let's say the the last row of this matrix will be the dot product of the last query with the first
00:16:39.680 | key then the last query with the second key the last query with the third key etc etc until the
00:16:44.960 | last query with the last key you may also notice that here i have written query transpose the key
00:16:53.440 | because when we what is q1 first of all q1 is the first row of the query matrix
00:17:00.640 | so a little bit of background on matrix multiplication so we know that when we do
00:17:06.800 | matrix multiplication each output element is one row of the first matrix with one column of the
00:17:12.560 | second matrix but we are doing the product of the first matrix with the transpose of the second so
00:17:18.480 | it will be the dot product of the one row of the query matrix with one row of the key matrix
00:17:24.960 | because we are doing the multiplication with key k transposed when you take a vector from a matrix
00:17:34.240 | the usual notation so the in in as in how to say in in in mathematics in a linear algebra
00:17:42.240 | we always pretend that a vector is a column vector so we cannot just write q multiplied
00:17:48.240 | by k because that would be mean that would mean we are doing the dot product of we are doing the
00:17:55.920 | kind of the matrix multiplication of one column matrix with one column matrix that is not possible
00:18:02.480 | because the shapes do not match so as a notation we write that we do the dot product of the first
00:18:08.160 | matrix the transpose which is a column vector but we transpose it so it becomes a row vector
00:18:13.520 | with the second vector this is just because of notation guys so you just need to pretend that
00:18:20.320 | this is the first query with the first key then the first query with the second key the first
00:18:24.640 | query with the third key etc etc etc so we are doing dot products of vectors then we apply this
00:18:31.840 | softmax operation the softmax operation what it will do it will transform each of these dot products
00:18:38.960 | which are scalars so the output of a dot product is a scalar and it will transform each of these
00:18:45.280 | numbers in such a way that they become kind of a probability distribution row wise which means that
00:18:52.000 | each of these numbers is between 0 and 1 and when we sum up these numbers together they are sum up
00:18:58.880 | to 1 and this condition this property will be valid for each row so this row also will sum up to 1
00:19:05.760 | this row will sum up to 1 and this row will sum up to 1 etc etc etc let's see what is the softmax
00:19:12.720 | operation now given a vector so let's call it x which is made up of n dimensions the softmax is
00:19:21.760 | defined as follows so it is the the softmax basically transforms this transforms this vector
00:19:29.680 | into another vector with the same dimension where each item of the output vector is calculated as
00:19:35.280 | follows so the height element of the output vector is the exponential of the element input element
00:19:42.400 | divided by the summation of all the exponentials of all the dimensions of the vector
00:19:48.240 | basically this is called the normalization factor to make it all these numbers between 0 and 1 we
00:19:55.600 | usually normalize that's why it's called the normalization factor and we use the softmax
00:20:01.520 | because we want each of these numbers to be positive we don't want the stuff the output
00:20:06.480 | of this operation to be negative so that's why we use the exponential but there is a problem
00:20:11.840 | the problem is imagine our input vector is made up of many numbers that are maybe large so for
00:20:18.000 | example let's say x1 is equal to 100 x2 is equal to 200 x3 is equal to 300 which is can happen
00:20:26.160 | if we do the exponential of these numbers so the exponential of 100 that is going to be a
00:20:31.600 | huge number it's going to very close to infinity at least compared to what we can store in a
00:20:37.120 | computer so the output of exponential of 100 may not fit into a floating point 32 or a floating
00:20:44.240 | point 16 number or even an integer of 32 bit so we cannot compute it because it will overflow our
00:20:53.440 | our variable our integer that is storing this value this output so we talk in this case about
00:21:00.080 | numerical instability so every time you hear the term numerical instability in computer science
00:21:05.840 | it means that the number cannot be represented within a fixed representation with the bits we
00:21:11.520 | have available which are usually 32 bit or 16 bit we have also 64 bit but that will be too expensive
00:21:18.720 | to use so let's try to find a solution to make this stuff here computable and numerically stable
00:21:25.440 | in order to make this softmax operation numerically stable which means that we want
00:21:31.600 | these numbers to not explode or to become too small that they are not representable
00:21:36.880 | we need to find a solution and luckily it's quite easy so the softmax as we have seen before it is
00:21:42.960 | the following formula so each number is exponentiated and then we divide it by this
00:21:47.040 | normalization factor which is just the sum of the exponential of each input dimension of the
00:21:51.600 | input vector if we multiply the numerator and the denominator of a fraction with a constant
00:21:58.000 | with a number then the fraction will not change so that's what we are going to do we are multiplying
00:22:01.760 | the numerator and the denominator with this factor c as long as c is not equal to zero of course
00:22:08.080 | then we can take this c and by using the distributive property of the product with
00:22:16.320 | respect to the sum we can bring this c inside of the summation as you can see here
00:22:20.320 | then we can also write every number as the exponential of the log of itself because the
00:22:27.520 | exponential and the log will cancel out and then we can by using the properties of the exponentials
00:22:35.520 | we know that the product of two exponential is equal to the sum of the is equal to the
00:22:40.400 | exponential of the sum of the arguments of each exponential and we do it on the numerator and in
00:22:46.000 | the denominator then we just call this quantity minus log c equal to k or k is equal to minus k
00:22:54.960 | is equal to log c so we can replace this quantity with k we can do that because this is a constant
00:23:01.600 | that we have chosen and we just are assigning it to another constant so basically by doing
00:23:08.480 | this derivation we can see that we can sneak in a value inside of this exponential that if chosen
00:23:15.440 | carefully can reduce the argument of this exponential and we will choose this k equal
00:23:21.360 | to the maximum element inside of the input vector that we are applying the softmax to
00:23:26.480 | so that each of this argument will be either zero in case xi is equal to the maximum element
00:23:33.840 | that we are processing of the vector or it will be less than zero and we know that the exponential
00:23:40.320 | when it's equal to zero will be equal to the output of the exponential will be one so the
00:23:45.200 | argument when it's zero it will be equal to one and when it's smaller than zero so it's in the
00:23:50.160 | negative range it will be between zero and one so which is easily representable with floating point
00:23:56.080 | 32 for example so this exponential will not explode anymore so basically to apply the softmax
00:24:04.400 | to a vector in a numerically safe way we need to find a k constant which is the maximum value of
00:24:12.800 | this vector and when we apply it we need to subtract each element minus this constant that
00:24:18.400 | we have chosen so let's look at the algorithm to compute the softmax so first of all given a vector
00:24:25.760 | or given an n by n matrix because we want to apply the softmax to this matrix here which is n by n
00:24:32.400 | we need to go through each row of this matrix and for each row we need to find the maximum value
00:24:38.800 | among the elements which takes time complexity linear with respect to the size of the vector
00:24:44.720 | to the size of the row to which we are applying the softmax then we need to compute the normalization
00:24:49.760 | factor which is this stuff here and we we cannot compute it before the step number one because we
00:24:56.720 | need to have the maximum element to compute this summation here and after we have calculated the
00:25:02.960 | normalization factor we can then divide each element's exponential by the normalization factor
00:25:08.640 | and we cannot do the step number three before calculating the normalization factor because
00:25:13.200 | we need to divide each number by the normalization factor so if you like pseudocode for algorithms
00:25:20.080 | this is an algorithm for computing the softmax that we have seen right now so first we find the
00:25:25.120 | maximum of the row to which we are applying the softmax then we compute the normalization factor
00:25:31.200 | and then we apply the softmax to each element which means that we calculate compute the
00:25:35.360 | exponential of each element minus the maximum value of the vector divided by the normalization
00:25:41.840 | factor now this pseudocode is an algorithm that is quite slow because look at a practical example
00:25:50.240 | imagine we have this vector here first we need to do step one find the maximum value in this
00:25:55.520 | vector which is number five and this takes linear time computation then we need to calculate the
00:26:00.960 | normalization constant which is the sum of the exponential of each element minus the maximum
00:26:06.960 | value so e to the power of 3 minus 5 plus e to the power of 2 minus 5 etc etc this we will call it
00:26:14.000 | l and then each we need to go again through this vector again and take the exponential of each
00:26:20.160 | element minus the maximum divided by the normalization factor so to apply the softmax to
00:26:27.040 | an n by n matrix we need to go through each element of this matrix three times and these
00:26:33.760 | operations must be done sequentially so we cannot start operation two until we have done operation
00:26:38.880 | one and we cannot start operation three until we have done one and two so this is quite slow
00:26:45.440 | only to apply an operation that doesn't even change the shape of the matrix it's just
00:26:51.280 | uh normal uh normalizing the values so there must be a better way that that does not involve
00:26:57.760 | three sequential operations in which we need to go through this matrix three times let's see
00:27:02.480 | all right guys let's rehearse what is the problem that we are trying to solve the problem statement
00:27:08.640 | is the following can we find a better way to compute the softmax that does not involve going
00:27:15.120 | through the vector three times because let's look at the pseudocode of the algorithm for computing
00:27:20.480 | the local the softmax that we have found so far imagine we have a vector made up of four elements
00:27:26.080 | the first thing that we need to do is to compute the maximum element in this vector which means
00:27:30.960 | going through this for loop here that allow us to compute the maximum element in this vector which
00:27:36.800 | means that we start from the left side of the vector and iteratively go to the right side so
00:27:41.520 | we start from the first element arrive to the end and we compare the previously found maximum with
00:27:47.680 | the current element to find the global maximum basically this means that uh i i know that this
00:27:53.360 | is very simple uh i'm probably sure that you don't need to this example but making this example will
00:27:58.720 | help us understand what we will do next so please bear with me even if it's super simple what i'm
00:28:03.520 | doing okay we at the beginning m0 is equal to minus infinity m1 is basically the for loop at
00:28:14.000 | the iteration number one which means that we are m1 will be equal to the maximum of the previous
00:28:19.920 | estimate of the m which is minus infinity with the current element which is three so it will become
00:28:26.960 | equal to three then m2 will be equal to the maximum of the previously computed maximum so m1
00:28:33.840 | so three with the current element which is two so it will be equal to three m3 will be equal to
00:28:40.400 | the maximum of the previously computed maximum so three with the current three with the current
00:28:46.880 | element which is five so it will be equal to five and m4 will be equal to the maximum of the
00:28:53.200 | previously computed maximum and the current element so it will be equal to five so this allow
00:28:57.920 | us to compute the maximum element so at the fourth iteration we will have the maximum the global
00:29:02.320 | maximum independently of what is the input array um delete okay after we have computed the maximum
00:29:10.080 | which we know is five we can compute the normalization factor so let's start with the
00:29:15.520 | l0 l0 is equal to zero l1 will be equal to the exponential of l0 so actually sorry it will be
00:29:23.120 | l0 plus the exponential of the current element so three minus the maximum element we have found in
00:29:29.680 | the previous for loop so five then l2 will be equal to l1 plus the exponential of the
00:29:35.920 | the current element so it's two minus the maximum then l3 will be equal to l2 plus the exponential
00:29:45.360 | of the current element five minus five then l4 will be equal to the l3 plus exponential of one
00:29:54.080 | minus five if you expand this l this will be basically equal to e to the power of three minus
00:30:02.160 | five plus e to the power of two minus five plus e to the power of five minus five plus e to the
00:30:07.760 | power of one minus one minus five after we have computed this normalization factor we can use it
00:30:16.320 | to normalize the each element in the input vector which means that the x new x1 so x1 prime let's
00:30:23.040 | see will be equal to e to the power of what's the first element three minus five divided by l that
00:30:34.240 | we computed in the previous for loop so the l at the fourth iteration the new x2 so x2 prime will
00:30:43.360 | be equal to the e to the power of two minus five divided by l4 and x3 prime will be equal to the e
00:30:50.640 | to the power of five minus five divided by l4 etc etc for all the elements i know this is super
00:30:58.080 | simple but it will help us later so in this for loop we have that we need to go through the vector
00:31:05.600 | three times because first we need to compute this for loop then we need to compute this for loop
00:31:10.560 | and then we need to compute another for loop we cannot do them not in this sequence because in
00:31:15.680 | order to compute this for loop we need to have the maximum element because we need it here
00:31:19.680 | and we cannot compute this for loop until we have computed the previous one because we need to have
00:31:24.240 | the normalization factor however we are stubborn and let's try to fuse these two operations into
00:31:30.800 | one for loop which means that we go through the array and simultaneously compute mi and in the
00:31:37.440 | same iteration we also try to compute lj of course we will not be able to compute lj because we don't
00:31:43.440 | have the global maximum because we didn't go through the old array yet however let's try to
00:31:50.080 | use the locally and whatever estimate we have of the maximum so far so let's try to use instead of
00:31:56.240 | mn let's try to use mi so the local maximum that we have computed so far so if we apply the softmax
00:32:02.720 | in this way in this fused way to this vector we will have the following iterations so this is our
00:32:10.160 | array or vector and the first step is mi so m1 will be equal to the previous maximum which is
00:32:17.440 | minus infinity with the current element so the maximum minus infinity and the current element
00:32:22.480 | is equal to 3 and l1 will be equal to the previous l so l0 which is starts from 0 plus e to the power
00:32:31.600 | of the current element minus we should be using the global maximum but we don't have the global
00:32:36.240 | maximum so let's use the whatever maximum we have so far so we can use 3 now at the second iteration
00:32:42.560 | we are at this element of the vector and we compute the maximum so far so the maximum so far
00:32:48.480 | is the previous maximum and the current element so the maximum of the previous maximum and the
00:32:52.640 | current element which is the maximum between 3 and 2 which is 3 and the normalization factor
00:32:59.120 | is the previous normalization factor plus exponential of 2 minus 3 which is the current
00:33:04.720 | element minus whatever maximum we have so far now if our array were made only of these two elements
00:33:12.640 | so 3 and 2 then whatever we have computed is actually correct because the maximum that we
00:33:19.120 | have found is a 3 and it's actually the global maximum and the normalization factor that we
00:33:24.800 | have computed is actually correct because each of the exponential has been computed with the global
00:33:29.680 | maximum because the first element was computed using 3 as the with the argument minus 3 and
00:33:37.200 | also the second element was computed with the argument with the argument having minus 3 in the
00:33:42.720 | in the argument which is the global maximum of the vector however when we arrive at the third
00:33:48.560 | iteration so let me delete this vector so let me arrive here at the third iteration the maximum
00:33:54.400 | will change which will also cause our normalization factor to get to to be wrong because
00:34:00.640 | we arrive at the element number 3 so the number 5 here and we compute the maximum
00:34:07.760 | so the maximum is the comparison of the previous maximum and the current element so the new maximum
00:34:13.520 | becomes 5 and the normalization factor is the previous normalization factor so l2 plus the
00:34:19.760 | exponential of the current element minus the current estimate of the maximum which is 5
00:34:26.240 | however if you look at this l3 this is wrong why because l3 is equal to if you expand this
00:34:34.240 | summation it will be equal to e to the power of 3 minus 3 plus e to the power of 2 minus 3
00:34:42.720 | plus e to the power of 5 minus 5 this exponential here is using 5 as the global maximum this
00:34:50.320 | exponential here is using 3 as the global maximum and this one is using 3 as the global maximum
00:34:55.440 | so the first two elements have been computed thinking that the global maximum is 3 but
00:35:00.560 | actually we later we found a better global maximum which is 5 so which makes this normalization
00:35:05.920 | factor wrong however can we fix at the third iteration whatever normalization we have computed
00:35:13.280 | so far up to the second iteration actually we can because if we expand this so as we have
00:35:20.800 | here we have expanded it what we need here is here to have a minus 5 because that's actually
00:35:27.360 | the global maximum that we have found so far not the minus 3 that we had at the previous iteration
00:35:32.880 | so and here we also need to fix this replace this minus 3 with minus 5 how can we do that well if
00:35:39.040 | we multiply this one here and this one here with a correction factor that will sneak in a new maximum
00:35:47.520 | inside of this exponential then we solve the problem and actually this correction factor
00:35:52.160 | is very easy to calculate because at the third iteration if we multiply l2 so the previously
00:35:57.840 | computed normalization factor with this factor here which is the exponential of the previous
00:36:03.040 | estimate of the maximum minus the current estimate of the maximum so 5 we will see that
00:36:08.960 | e by the properties of the exponentials this one here will become e to the power of 3 minus 3 plus
00:36:16.800 | 3 minus 5 so this minus 3 will cancel out with this 3 and also the second factor will have this
00:36:23.200 | 3 will cancel out with this minus 3 will cancel out with this 3 and they will become e to the
00:36:28.400 | power of 3 minus 5 and 2 to the power of e to the power of 2 minus 5 which is actually correct
00:36:34.640 | because at the third iteration we should be actually happy we should be using minus 5 as the
00:36:40.400 | maximum of the array so far so basically what we have found is a way to fix whatever normalization
00:36:48.960 | factor we have computed so far while iterating through the array when we found we when we find
00:36:55.680 | a better maximum compared to what we have so far and when we don't need to fix anything then the
00:37:02.160 | formula still stands because what we did here as a multiplication as a correction factor so this is
00:37:08.080 | the correction factor this correction factor is nothing more than the previous maximum so the
00:37:15.840 | previous estimate of the maximum minus the current estimates of the maximum at the current iteration
00:37:21.440 | so the current max so this is basically m of i minus 1 and this is m of i so the current maximum
00:37:30.400 | at the current iteration and let me delete it otherwise it remains forever in my slides
00:37:35.440 | so basically when we arrive to the last element we will see that the maximum doesn't change because
00:37:42.480 | we compare the previous maximum with the current element which is less than the previous maximum so
00:37:47.760 | the maximum doesn't change and we don't need to fix anything because the the the previous l3 so
00:37:55.600 | the previously computed normalization factor is correct because they have all been using the minus
00:38:01.200 | 5 so when we don't need to fix anything we just multiply by e to the power of the previous maximum
00:38:08.080 | minus the current maximum which is e to the power of zero in this case so it's not fixing anything
00:38:13.440 | so we have found a way to fix the previously computed normalization factor while going
00:38:20.240 | through the array even if at the current iteration we don't have the global maximum yet so that every
00:38:27.200 | time the maximum changes we can fix and every time it doesn't change we just multiply with e to the
00:38:32.400 | power of zero which is like multiplying with one so the new algorithm that we have found for the
00:38:38.320 | softmax is the following so we start with m0 equal to minus infinity we start with l0 equal to zero
00:38:44.560 | we go through the array we compute the locally the local maximum so up so the maximum so far
00:38:52.640 | from the zeroth element to the ith element so to the element at which we are doing the iteration
00:38:59.680 | and the previously computed li can be fixed by using this correction factor which is e to the
00:39:06.160 | power of the previous maximum minus the current maximum plus the exponential of the current
00:39:12.560 | element minus the current estimate of the maximum in this way we go through the array only once
00:39:19.120 | and we obtain two values the global maximum at at the end at the same time the
00:39:27.120 | normalization factor and then we can use it to compute the softmax so we made three transformed
00:39:33.520 | three passes through the array into two passes through the array and this is very important
00:39:40.080 | and we will see how we actually use it to derive flash attention the example that i have given you
00:39:46.880 | so far is not really a proof that our algorithm will work in every case because we made a very
00:39:52.800 | simple example by using a vector made up of four elements but does our new algorithm work in every
00:40:00.400 | single case with whatever the numbers are we need to prove that so we will prove that by induction
00:40:06.960 | so what first of all what are we trying to prove we have fused the first two for loops into one
00:40:13.280 | for loop as you can see here what we expect is that at the end of this for loop this mn so the
00:40:21.200 | m at the last iteration will be actually the global maximum in the vector and this ln so the
00:40:27.840 | l at the last iteration will be equal to the sum of all the exponential of all the elements
00:40:34.240 | minus the maximum element of the vector so the global maximum of the vector and we need to prove
00:40:41.440 | that because what i did before was an example and that was not really a rigorous proof and the way
00:40:47.520 | we will prove it is by induction which is a typical way of proving this kind of theorems
00:40:53.120 | now proof by induction basically works in the following way we need to prove that
00:40:59.200 | our algorithm works for a base case for example with n equal to one and then we pretend we assume
00:41:08.800 | that the algorithm works on n and we need to prove that it also works for n plus one if this holds
00:41:17.120 | then we have proven our algorithm for every possible n because it will work for the base
00:41:22.720 | case so for example n equal to one and then by using the induction step we say so this if it
00:41:29.200 | works for n and then it also works for n plus one then it means that it will also work for two but
00:41:34.080 | then if it works for two then it should also work for three because of the induction step that we
00:41:38.240 | will prove and if it works for three then it will also work for four etc etc up to infinity so let's
00:41:44.560 | prove it for the base case which is n equal to one it's very simple so at n equal to one this
00:41:51.680 | for loop will only have one iteration so m m1 and l1 m1 will be the maximum of the previous m which
00:41:59.120 | is minus infinity because we initialize m0 equal to minus infinity so it will be equal to the
00:42:07.600 | maximum of the previous m and the current element which is x1 so it will be equal to x1 whatever x1
00:42:13.200 | we is uh x1 usually will never be equal it cannot be equal to minus infinity um because it's a
00:42:20.960 | number in a fixed representation so it cannot be minus infinity um so the the x the m1 at the end
00:42:29.040 | so it will because we have only one element n equal to one this is m1 is also the last um m of
00:42:37.280 | this it of this for loop it will be equal to the global maximum of the vector made up of only one
00:42:42.560 | element and l1 will be equal to the previous l which we start from zero so l0 multiplied by a
00:42:50.480 | correction factor which will be in this case e to the power of minus infinity because the correction
00:42:54.960 | factor is the previous estimate of the max of the max minus the current estimate of the max but the
00:43:01.360 | previous estimate of the max is minus infinity minus x1 it is equal to minus infinity so this
00:43:07.120 | one will be this will be cancelled out and then plus e to the power of x1 minus the current maximum
00:43:14.640 | which is x1 so m1 and if this one will be equal to the sum of all the elements of the vector which
00:43:24.000 | is made up of only one element minus the maximum element in the array which is x1 so we have proven
00:43:31.280 | that it works for n equal to one now we assume that it works for n does it also work for an array
00:43:39.200 | of vector or with a vector of size n plus one so let's see what happens at the n plus one iteration
00:43:49.200 | at the n plus one iteration we will be doing the maximum of the previous estimate of m which is
00:43:54.880 | the m at the nth iteration and the current element so xn of plus one this by the properties of the
00:44:03.120 | max function it will be actually equal to the maximum of the global vector up to n plus one
00:44:10.400 | because the maximum will choose whatever is the maximum between the previous estimate and the
00:44:16.320 | current estimate and ln plus one which is the normalization factor at the n plus one iteration
00:44:23.520 | will be equal to the ln so the previous estimate not previous estimate but the previous
00:44:27.920 | normalization factor at the nth iteration multiplied by the correction factor which is the
00:44:35.040 | previous maximum minus the current maximum plus the exponential of x the current element
00:44:43.600 | minus the current estimate of the maximum but ln we have we assume that this property so this
00:44:54.160 | algorithm works up to n so ln is for sure equal to the sum of all the exponentials of the previous
00:45:03.600 | of the vector up to n minus the local maximum of the vector up to the nth element
00:45:12.800 | which is mn we multiply by the correction factor if there is something to correct which will be the
00:45:21.360 | previous maximum minus the current maximum plus the exponential of the current element minus the
00:45:26.880 | current estimate of the maximum now by the properties of the exponentials so we can bring
00:45:35.440 | this one inside of the summation and we will see that this mn and this mn will cancel out because
00:45:42.240 | it will be exponential of xj minus mn plus mn minus mn plus one so this mn and this mn will
00:45:49.520 | cancel out and we obtain this one plus this factor here that remains unchanged however you can see
00:45:55.920 | that this stuff here is exactly the argument of this summation for the at the iteration n plus one
00:46:04.800 | so it is this one is e to the power of xj where j is going from one to n minus mn plus one plus
00:46:13.280 | e to the power of xn plus one minus mn plus one so the j only appears here and it's equal maximum
00:46:21.360 | to n and this is similar to being a j with n plus one so we can increase the index of this summation
00:46:29.200 | by one and it will be the same and it will result in the same summation so we have proven
00:46:36.240 | that also at the n plus one iteration we will have that the l will be equal to the sum of all
00:46:43.920 | the elements of the array the exponential of all the elements of the array up to the n plus one
00:46:49.760 | element minus the maximum up to the n plus one element so we have proven that if it works and
00:46:57.840 | then it also works for n plus one this is enough to prove that it works for all size of arrays
00:47:04.880 | don't worry if you didn't get the proof by induction it is if it's the first time you are
00:47:11.680 | seeing this kind of proof it may take a little bit to to get it if you want to learn a little
00:47:17.520 | bit more about proof by induction i recommend watching some other proof it's very simple it's
00:47:22.000 | just you need to get into the right mindset anyway let's move forward all right let's talk about
00:47:29.680 | block matrix multiplication i know that you want to jump to the code immediately and we will go
00:47:34.720 | there we just need a little more theory actually so imagine we are doing a matrix multiplication
00:47:40.800 | so we have a matrix a we want to multiply it with a matrix b and it will produce an output matrix
00:47:47.200 | c imagine the dimensions of the first matrix are m by k the second matrix is a k by n it will
00:47:54.800 | produce an output matrix that is m by n now imagine we want to parallelize the computation of this
00:48:01.600 | output matrix i know that i didn't talk about gpus yet so we will not talk about gpus we will
00:48:08.800 | talk about parallelization in the case of a multi-core cpu with which you are very probably
00:48:14.720 | familiar with because right now in nowadays when you buy a computer you have a cpu and usually you
00:48:21.120 | can buy a single core cpu or multi-core like a two core four core eight core etc etc each of the
00:48:27.760 | these cores are actually kind of small cpus inside your cpu that can execute operations in parallel
00:48:33.520 | how to parallelize the matrix multiplication imagine you have this matrix multiplication
00:48:39.520 | to parallelize each of the output element in this c matrix is a dot product of a row of the
00:48:46.320 | a matrix with a column of the b matrix for example this element on the top left is the dot product
00:48:53.280 | of the first row of a and the first column of b this element on the top right of c is the dot
00:49:00.240 | product of the first row of a and the last column of b this element on the bottom left is the dot
00:49:07.040 | product of the last row of a and the first column of b etc etc for all the other elements now to
00:49:13.600 | parallelize this computation we need as many cores as is as there are elements in c if we want to
00:49:20.080 | parallelize it so if m and n are very small then maybe we have enough cores but imagine m and n are
00:49:28.080 | quite big we imagine like 100 by 100 we don't have 10 000 cores right now in the cpus so how can we
00:49:36.720 | parallelize a matrix operation by using less cores than there are elements in the matrix itself
00:49:44.640 | that's when we talk about block matrix multiplication basically block matrix
00:49:50.160 | multiplication means that you can divide the original matrix into smaller blocks of elements
00:49:57.360 | and then the operations of matrix multiplication can be computed between these blocks for example
00:50:04.800 | imagine we have a matrix that is 8 by 4 it means that it has 8 rows and 4 columns which means that
00:50:13.680 | it has 32 elements and then we are multiplying it with another matrix that is 4 by 8 so it has 4
00:50:22.560 | rows and 8 columns so it also has 32 elements the output matrix will should have 64 elements
00:50:31.680 | we don't have 64 cores so how can we parallelize it imagine we only have 8 cores now with 8 cores
00:50:39.600 | we can divide this original matrix a into 4 blocks where the first block is this top left block of
00:50:47.520 | 2 by no 4 by 2 elements so um let's say um 8 elements on the top left and then 8 elements
00:50:58.960 | on the top right of this matrix then 8 elements on the bottom left and 8 elements in the bottom
00:51:04.240 | right of this matrix these are 4 blocks then we divide also the b matrix into um 8 blocks
00:51:11.200 | where each block is made up of 4 elements so this b11 is the top left 4 elements in the original
00:51:19.680 | matrix this b4 is the top right 4 elements in the original matrix this b21 is the um
00:51:28.080 | bottom left 4 elements in the original matrix etc etc etc how do we do this block matrix
00:51:33.600 | multiplication we can watch these matrices as made only by their blocks so we can view this
00:51:40.720 | matrix here as made up only by its blocks we can view this matrix here as made up only by its blocks
00:51:48.320 | and the output of this multiplication will be a matrices that is computed in the same way as the
00:51:55.600 | original matrix but where the output of each dot product will not be a single element of the output
00:52:02.880 | matrix but it will be a block of elements of the output matrix for example the top left block here
00:52:10.720 | is the dot product of the first row of this matrix with the first column of this matrix
00:52:18.480 | and it will be computed as follows so it will be a11 multiplied by b11 plus a12 multiplied by b21
00:52:25.920 | and this output will not be a single scalar but it will be uh well let me count it should be
00:52:33.600 | eight elements so it should be four um made up it should be a block of four elements
00:52:40.320 | or eight elements let me let me count actually so because we have eight blocks and it should be
00:52:47.600 | made up of eight elements let's we can see that here um how to find the dimensions of this output
00:52:55.520 | block well we can check what is a11 a11 is four by two so it's eight elements in a smaller matrix
00:53:05.920 | made up of eight elements where the elements are distributed in four rows and two columns
00:53:10.960 | we are multiplying it by b11 which is a smaller matrix compared to the original made up of
00:53:16.560 | two by two elements so four elements so when we multiply four by two multiplied by two by two it
00:53:23.280 | will produce a four by two output block matrix so block so if we do this computation here block by
00:53:32.240 | block it will produce a block of output elements of the original matrix so not not a single scalar
00:53:38.400 | but a block of outputs which makes it very easy to parallelize because if we have only eight cores
00:53:44.320 | we can assign each output block to one core and each core will not produce one output element of
00:53:50.480 | the original matrix but it will produce eight elements of the original matrix as a four by two
00:53:56.400 | matrix so basically block matrix allow us to to do the matrix multiplication either by element
00:54:05.360 | by element so like in the original matrix so each row with each column or blocks by blocks in the
00:54:10.320 | same way like we do normal matrix multiplication because the the matrix multiplication that we are
00:54:15.280 | doing between blocks is the same way as we do matrix multiplication with the original matrix
00:54:21.120 | and it will produce not a scalar but a block and now let's see why this is very important for us
00:54:26.320 | so why should we care about block matrix multiplication because we are trying to
00:54:33.920 | compute the following operation so the query multiplied by the transpose of the keys
00:54:38.640 | and then we will should apply the softmax of this operation and then we should
00:54:42.240 | multiply the output of the softmax with v for now let's ignore the softmax let's pretend that
00:54:48.960 | we are not going to apply any softmax so we take the output of the query multiplied by the transpose
00:54:54.400 | of the keys and we just multiply it by v to obtain the output of that edge which is wrong of course
00:54:58.560 | but it simplifies our tractation of what we are going to do next so for for this moment let's
00:55:04.800 | pretend that we are not going to apply any softmax so we just do the query multiplied by
00:55:08.400 | transpose of the keys and directly we multiply the result of this operation with v this will
00:55:12.720 | result in a matrix that is n by d so n tokens each made up of an embedding of d dimensions
00:55:19.680 | so lowercase d dimensions and we know that query key and values are themselves matrices of n by d
00:55:28.320 | dimensions so the n tokens which made up of an embedding of d dimensions so imagine we have a
00:55:36.400 | query matrix and the key and the value matrix that are 8 by 128 so we have 8 tokens each token is
00:55:44.000 | made up of 128 dimensions we can divide as we have seen each when we compute a matrix multiplication
00:55:51.600 | we can divide our matrix into blocks how we choose the blocks is up to us as long as the
00:55:59.760 | operating the shapes of the blocks match when doing the matrix multiplication so for example
00:56:05.760 | in the previous case we divided our matrix a into blocks such that the the shape of the block matrix
00:56:13.360 | so the matrix that is made up only of the blocks is compatible with the block matrix b
00:56:20.080 | so that this operation is possible so this is the only requirement that we need to be
00:56:25.040 | aware when doing the block matrix multiplication the shapes of the blocked matrix so the matrix
00:56:30.000 | that is made only of the blocks should match in the matrix multiplication for the rest it doesn't
00:56:35.280 | matter how we divide it so imagine that we choose to divide this query matrix into blocks of rows
00:56:42.960 | and we can do that we don't have to necessarily divide also the columns we can just divide the
00:56:47.200 | rows so that each q is not a single row but it's a group of two rows so q1 is a group of the first
00:56:54.640 | two rows of the q matrix of the q sequence q2 is the group of the second two rows of the q sequence
00:57:01.120 | etc etc and we do the same also for v for k we don't do it because we are actually going to
00:57:07.760 | multiply with k transposed so we do the subdivision directly on k transposed so we so we have the q
00:57:15.680 | which has been divided into groups of rows and then we have a k transposed which is a matrix
00:57:21.920 | that is 108 by 8 because it's the transpose of the keys which is 8 by 108 and we decide to divide
00:57:31.120 | each of the column group of columns of k into a single block so the k1 is the first two columns
00:57:38.960 | of k transposed k2 is the second group of two columns in k transposed etc etc until k4 which
00:57:46.400 | is the last two columns in k transposed the first operation that we do is the multiplication query
00:57:51.840 | multiplied by the transpose of the keys which basically means that we need to multiply each
00:57:56.640 | query with all the keys then the second query with all the keys etc etc now each query is not a single
00:58:04.960 | row of the q sequence it's a group of two rows of this q sequence and each k is not a single column
00:58:12.080 | of k transposed it's a group of two columns of k transposed but doesn't matter because we have
00:58:17.200 | seen that the matrix multiplication if we write the matrices as made up of blocks we just compute
00:58:22.880 | it in the same way when we do a normal matrix multiplication so we are multiplying this matrix
00:58:28.720 | by this matrix and for what we know this matrix here is made up of four rows with some dimensions
00:58:36.640 | which is 128 dimensions and this one here is made up of how many rows 128 rows and four columns
00:58:46.960 | i didn't draw the columns because it's too many to draw here but you need to pretend it's a lot of
00:58:53.760 | dimensions one for each 128 for each vector and here you need to pretend that this is 128
00:59:00.080 | rows when we do the matrix multiplication we apply the normal matrix multiplication procedure
00:59:05.760 | which is each output element so this first of all the output shape of this matrix of this matrix
00:59:10.400 | multiplication will be four by four because it's the outer dimensions of the two metrics that you
00:59:14.560 | are multiplying the first element of the output will be the dot product of this vector here
00:59:21.280 | with this vector here the second element so this one here will be the dot product of this
00:59:27.120 | vector here with this vector here however this is not vector and this is not a vector so it's
00:59:33.840 | actually a matrix multiplication in this case this element here is not a scalar it is a group of
00:59:40.320 | elements of the output matrix because we are doing block matrix multiplication and how many elements
00:59:45.760 | it will be well we know that the original q1 is a 2 by 128 the k1 is 108 by 2 so it will be a group
00:59:54.640 | of 2 by 2 elements of the output matrix so we are doing the matrix multiplication of the q1
01:00:01.840 | with k1 then q1 with k2 then q1 with k3 q1 with k4 etc etc for the first row and then the second
01:00:10.240 | row will be q2 with all the k's and the q3 with all the k's and q4 with all the k's so as you can
01:00:16.000 | see when we do matrix multiplication we don't even care if what is underlying is a block or a vector
01:00:22.880 | or a scalar we just apply the same procedure first row of the black block matrix multiplication with
01:00:30.320 | the first column of the matrix of the second matrix and then the first row with the second
01:00:36.400 | column the first row with the third column etc etc let's then multiply because the formula says
01:00:43.920 | that we need to multiply query with the transpose of the keys and then multiply by b all of these are
01:00:49.280 | block matrices now as you can see from my using of colors every time i refer to the original matrix i
01:00:57.680 | use the blue color and every time i refer to the block matrix i use the pink color so we need to
01:01:04.160 | multiply the output of the query multiply by the transpose of the key then by v because we are
01:01:08.720 | skipping for now the softmax and later we will see why so if we want to do this multiplication we
01:01:14.560 | need to do the following so it will be uh this matrix is made up of blocks and block matrix
01:01:22.080 | multiplication just ignores this fact and just does the matrix multiplication like it is a normal
01:01:26.560 | matrix multiplication so we do the first row with the first column then the first row with the second
01:01:33.200 | column then the third row the first row with the third column etc etc so the first block of row how
01:01:40.320 | is going to be calculated this output in the output matrix of this matrix multiplication
01:01:46.080 | well it will be the first row so the dot product of the first row the dot product because it's not
01:01:55.280 | really a dot product it's the actually the matrix multiplication of the first row but in a dot
01:02:00.320 | product way let's say with the first column which is made up of v1 v2 v3 and v4 so it will be this
01:02:09.360 | element with v1 plus this element with v2 plus this element with v3 plus this element with v4
01:02:16.640 | and this will be the first output element the second output block will be this row
01:02:26.080 | with this column which will be this element with v1 this element plus this element with
01:02:33.360 | v2 plus this element with v3 plus this element with v4 and this will produce the second output
01:02:39.280 | block etc etc also for the third and the fourth block output let's look at what is each block
01:02:47.360 | made up of so each block is made up of the um the first element so query one multiplied by key one
01:02:54.240 | because um it's the result of the query multiplied by the keys with the v1 of the second matrix
01:03:01.600 | plus the this element with this one plus this element with this one plus this element with
01:03:07.760 | this one so the pseudocode for generating this output of this attention mechanism which is not
01:03:14.560 | really attention mechanism because we skip the softmax but i just want you to get into the
01:03:19.120 | habit of thinking in terms of blocks is the following so we take each query block
01:03:25.920 | we go through each query and as you can see let's look at actually what this output is made up of
01:03:34.080 | it is made up of the query one multiplied by key one and the result multiplied by v1 then the query
01:03:40.400 | one with k2 then the result multiplied by b2 then the query one with k3 and the result multiplied
01:03:47.280 | by v3 plus the query one with the k4 and result multiplied by v4 this is basically what we are
01:03:53.280 | doing is the dot product of this row with this column made up of blocks so the the pseudocode
01:04:02.160 | for generating this first row is the query is then query number one and then we iterate through the
01:04:09.280 | keys and the values from one to four and we sum iteratively so for each block basically to generate
01:04:16.720 | this output matrix and if you for each row we will see that it's a different query with all the keys
01:04:22.800 | and values and then this will be the the query number three with all the keys and values and
01:04:27.600 | this will be the query four with all the keys and values so to generate this output matrix we need
01:04:33.040 | to do we iterate through the queries and this will be one row of this output matrix and then we need
01:04:39.120 | to do this iterative sum of the query i that we are iterating through multiplied by the jth k
01:04:46.640 | and v and we keep summing them iteratively and that would that will produce the output matrix
01:04:53.040 | or you can see here i know that what i have done so far is not useless not useful for flash
01:04:59.280 | attention but it's useful for us to get into the mindset of computing this product by blocks
01:05:04.800 | because later we will use it also with the softmax all right guys i i know that we have
01:05:11.680 | computed what we have computed so far is not really the softmax operation it's not sorry
01:05:16.240 | they're really the attention mechanism because we have skipped the softmax so somehow we need to
01:05:21.680 | restore it and the the following few i think 10 20 minutes we are going to be really really
01:05:28.320 | challenging because i am going to do a lot of operations that will involve a lot of different
01:05:34.240 | blocks and a lot of different matrix multiplication and the variants of the softmax so it may be
01:05:40.880 | difficult to follow however don't give up you can watch this part twice three times and every time
01:05:47.600 | you it will have a better understanding i also recommend watch it until we reach the flash
01:05:53.440 | attention algorithm before we start restarting from to to go back to re-watch it because you
01:06:00.480 | watch it we reach the flash attention algorithm and it will give you a better understanding of
01:06:05.920 | what has happened so far and then you can re-watch it to deepen your understanding another thing that
01:06:11.760 | i recommend is take pen and paper and write exactly the operations that you are seeing
01:06:17.440 | and write the shapes of each of these blocks of these elements that are made in the that are part
01:06:24.400 | in this matrix multiplications so that you better understand what is happening and you better
01:06:32.480 | remember what when i refer to a particular element or a particular block okay after giving this small
01:06:39.760 | motivational speech let's start so what we have done so far was query multiplied by the transpose
01:06:46.880 | of the keys however each query is not a single row of the query sequence but it's a block of
01:06:54.000 | queries it's a block of rows in our particular case this q1 is not one row of the query sequence
01:07:02.000 | it's two rows of the query sequence because we have chosen as a block size a group of two rows
01:07:07.280 | and this k transposed one is not one column of the k transposed matrix is two columns of the
01:07:16.240 | k transposed matrix because we have chosen it like this and if you don't remember let's go back to
01:07:21.040 | see it here we have chosen k1 is two columns and q1 is two rows of the query original matrix
01:07:30.000 | and every time i use the blue color i am referring to the original shape and every time i'm using the
01:07:36.560 | pink or violet whatever it is i am referring to the block matrix so it's a block of elements of
01:07:43.920 | the original matrix okay now the first thing that we have done was a query multiplied by the
01:07:50.880 | transpose of the keys and this produces a block matrix as output that we will call s where each
01:07:58.560 | element sij so the s11 element of this matrix will be the query one with the k transposed one
01:08:06.720 | this s12 will be query one with k transpose the two s13 will be query one with k transpose the
01:08:14.720 | three etc etc for all the rows and for all the columns then we should be applying the softmax
01:08:20.400 | because if you remember the formula is softmax of the query multiplied by the transpose of the keys
01:08:24.800 | however i want to restore the softmax operation but with a twist which means that we will apply
01:08:31.520 | the simplified version of the softmax and we will call it softmax star which is just the softmax
01:08:38.160 | without the normalization so let me write it for you what it means let's do it with the same color
01:08:45.520 | that i chose for the softmax which is orange so the softmax if you remember correctly if we remember
01:08:52.800 | it's the softmax of a vector we apply it element wise so each element is modified according to the
01:09:03.760 | following formula so the ith element of the output vector to which we are applying the softmax
01:09:08.640 | is equal to the exponential of the ith element of the input vector minus the maximum element
01:09:20.080 | in the input vector divided by a normalization factor that is calculated according to this
01:09:26.320 | summation that is going from j equal to 1 up to n of the exponential of xi minus x max so basically
01:09:37.920 | we are doing the exponential of each element minus this x max and why are if you remember correctly
01:09:44.320 | why are we subtracting this x max to make this exponential numerically stable computable because
01:09:50.480 | otherwise it will explode and because we are applying it to the numerator we also need to
01:09:54.880 | apply to the denominator okay the softmax star operation is exactly like the softmax but without
01:10:01.520 | the normalization part which means that it's just the numerator of the softmax so we will modify
01:10:06.640 | each element of the vector to which we apply the softmax star according to this formula
01:10:11.680 | let me move it more aligned like this so we just do element element wise operation that is the
01:10:19.040 | exponential of each element minus the maximum of the vector to which we are applying softmax star
01:10:24.480 | okay now why did i introduce this softmax star operation because we will be applying it to the
01:10:32.240 | matrix that we have computed so far which is this s matrix so we apply the softmax star to each
01:10:39.200 | element of this s matrix but each element of this s matrix is itself a matrix because it's a block
01:10:47.040 | matrix and each element of this s matrix so for example the element s11 is a two by two matrix
01:10:53.600 | because it is coming from the product of two matrices which are a group of rows and a group
01:10:58.880 | of columns from the q and the k so for example this s11 is what is let's draw it actually this
01:11:07.040 | s11 will be for example made up of four elements let's call it i don't know a of s11
01:11:15.120 | uh let's let's choose better naming let's call it i don't know a
01:11:21.120 | b c and d just the generic elements when we apply the softmax star to this s11
01:11:33.520 | it will result so let's apply the softmax star softmax star it will result in a matrix
01:11:43.360 | that is each element the exponential of each element minus the maximum for each row now we
01:11:52.480 | don't know which is the maximum so let's choose one suppose that the maximum for this row is a
01:11:56.880 | and the maximum for this row is d the first element of the output of this softmax star applied
01:12:03.760 | to this block s11 will be the exponential of a minus a because that's what we chose as the maximum
01:12:13.440 | for this row the second element will be the exponential of b minus a because it's the maximum
01:12:19.760 | for that row then in the bottom row it will be the exponential of c minus d because that's the
01:12:28.160 | maximum for the bottom row and this will be the exponential of d minus t and that's the exponential
01:12:34.400 | that's how the softmax star will modify each block in this block matrix let me delete this stuff
01:12:41.920 | otherwise it will remain in my slides forever and later i want to share the slides with you guys so
01:12:46.720 | you can use my same slides so delete delete okay after we have applied the softmax to each of the
01:12:56.080 | elements in this s matrix we will call it the p matrix and each element p11 will again be a block
01:13:03.920 | of two by two elements so p11 will be the softmax so p11 will be the softmax star applied to s11
01:13:15.200 | where s11 is what is a query one k transposed one and the p12 will be the softmax star applied to
01:13:21.760 | s12 where s12 is what is a query one multiplied by k transposed two etc etc etc for all the elements
01:13:29.760 | of s okay now that we have applied this softmax star operation the next operation that we should
01:13:36.480 | be doing according to the formula of the attention is the softmax of the query multiplied by the
01:13:41.600 | transpose of the keys then the result of the softmax multiplied by v i know that we didn't
01:13:46.880 | apply the real softmax we apply the softmax star which is softmax without the normalization
01:13:52.560 | later we will see how to compensate this lack of normalization because we will do it at the end
01:13:58.560 | and it's something that we can do okay so we take this p matrix which is the result of the softmax
01:14:06.480 | star applied to this s matrix and we multiply it by v what how do we do it well it's a block or
01:14:14.640 | it's a matrix made up of blocks of matrices so p11 is actually not a scalar but it's a matrix
01:14:21.920 | of two by two elements and we need to multiply it by v but we don't multiply with the original
01:14:27.440 | sequence v but with the blocked sequence v just like before where each v is not one row of v but
01:14:34.320 | it's a group of rows of v and how many rows is it is it is two rows of v for now please ignore
01:14:42.800 | completely whatever i have written here because we will use it later so we need to do this product
01:14:47.440 | of this matrix here which is made up of blocks remember with this matrix here which is made up
01:14:52.880 | of blocks it is made up of four rows where each row is not really a row it is a block of rows
01:15:00.800 | and this one it is made up of four by four elements where each element is not really a
01:15:04.720 | scalar but it's a matrix so as you remember in the block matrix multiplication when the algorithm for
01:15:11.760 | computing the matrix multiplication is the same as the normal matrix multiplication except that
01:15:16.240 | we use blocks so what i am doing is guys the following operation so let's write it somewhere
01:15:22.960 | let's say o is equal to p multiplied by v okay so the first output row a row because it's not
01:15:33.520 | really a row but it's a block row will be computed as follows the first row of this block matrix
01:15:41.200 | with the first with the first column of this v matrix and we are treating it like a block matrix
01:15:50.800 | so it will be p11 multiplied by v1 plus p12 multiplied by v2 plus p13 multiplied by v3 plus
01:16:02.320 | p14 multiplied by v4 this will produce the first output row of o but it's not really a row because
01:16:11.920 | it's a made up of two rows so this stuff here is not one row it is two row and we can prove that
01:16:19.360 | because what is p11 p11 is let's write it somewhere so p11 is a 2x2 matrix yeah 2x2 and
01:16:30.160 | we are multiplying it with v1 which is a block of two rows of v so it is a two rows by 128
01:16:38.720 | dimensions so it is equal to 2 by 128 so this stuff here is 2 by 128 so this block here the
01:16:50.080 | output block that we're computing is a block of two rows of the output matrix that we are computing
01:16:55.120 | i know this is really difficult to follow because we are involving blocks so we need
01:17:03.120 | to visualize at the same time matrix as blocks and as the original matrix that's why i highly
01:17:09.360 | recommend you to pause the video think it through write down whatever you need to write down because
01:17:16.240 | it's not easy to follow it just by memorizing the shape so you you actually need to write down
01:17:21.600 | things anyway we are computing the first output block of the output o matrix now if we if you
01:17:31.120 | remember the output the output this output here should be the output of the output of the softmax
01:17:42.560 | multiplied by v now this softmax has not been applied to the entire row of this matrix here
01:17:52.560 | as matrix here basically to compute this softmax star what we did was to compute the softmax star
01:18:00.880 | at each block independently from the other blocks which means that the maximum that we are using to
01:18:07.360 | compute each softmax star is not the global maximum for the row of this s matrix but the
01:18:15.680 | local maximum of each block and this is wrong actually because when we compute the softmax
01:18:22.640 | we apply the softmax we should be using the global row i want to give you an example
01:18:30.080 | without using blocks because otherwise i think it's not easy to follow so when we do the normal
01:18:36.320 | attention so we have a query multiplied by the transpose of the keys this produces a matrix
01:18:43.120 | that is n by n so sequence by sequence where each element of this matrix so let's say three four
01:18:50.640 | five i don't know how many is one two three four five six yeah six two three four and five six
01:18:59.600 | should be one two three four five six okay this one here should be the dot product of the first
01:19:05.520 | query with the first um let me use because query one transpose the key one uh this is because
01:19:16.880 | as i said before when we do the product of two vectors we always treat them as column vectors
01:19:23.120 | so when you want to write the dot product you cannot multiply two column vectors you need to
01:19:27.840 | multiply one row vector with one column vector that's why we transpose this one if it confuses
01:19:32.400 | you you can also write q1 k1 that's totally fine it's just uh wrong from a notation point of view
01:19:38.880 | anyway the first one will be the dot product of the query one with the k1 the second element will
01:19:45.760 | be the dot product of the query one with the k2 the third will be the query one with the k3 etc etc
01:19:52.480 | etc um so this is a q1 with k1 q1 with the k2 k2 and q1 with the k3 q1 with the k4 um anyway
01:20:11.120 | when we do the softmax we actually calculate the maximum on this entire row however what we are
01:20:18.400 | doing is we are actually doing a block matrix multiplication and as you remember um when we do
01:20:26.800 | by blocks we are grouping together rows of queries and rows of keys and in this particular case we
01:20:35.920 | are grouping two queries together to create one uh one group of queries and two keys together to
01:20:43.920 | create one block of keys so we need another row of this one so it's the let me choose a query one
01:20:51.200 | k or query 2k1 this should be query 2k1 query 2k2 query 2k3 query 2k4
01:21:04.000 | query 2k5 and query 2k6 um when we each of this each of this block here is computing
01:21:16.160 | this block here is computing two by two elements of the original matrix if we had never applied
01:21:23.600 | the blocks so it is computing these two four elements here and if we apply the softmax
01:21:32.960 | star to each of these blocks we are not using the maximum element in this row we are only using the
01:21:39.680 | maximum element in each block which means that when we will use it in the downstream product with
01:21:46.320 | vmatrix we will be summing values that are wrong because each of these values here will be based
01:21:54.240 | on a maximum that is not the global maximum for this row it is the local maximum of this block
01:22:01.120 | here and um and this block here will have the global the low it will use the local maximum of
01:22:08.320 | this block here and this block here will use the local maximum of this block here etc etc etc so
01:22:15.520 | what i'm trying to say is that when you sum p11 with v1 p11 may have some maximum local maximum
01:22:24.480 | that is different than from the local maximum of p12 and p13 may have a different maximum
01:22:30.800 | local maximum that of p1 p11 and p12 so we need to find a way to fix the maximum that was used
01:22:40.720 | to compute the exponential here with the maximum found here in case the maximum here is higher than
01:22:49.600 | the one local to p11 so if we have found for example here a maximum that is higher than the
01:22:55.760 | maximum used here here then we need to fix this one and this one because that maximum in the
01:23:00.960 | softmax should be the maximum for all the row not the one belonging to the each block and this leads
01:23:07.440 | to our next step how to fix this first of all let me introduce a little pseudo code for computing
01:23:15.040 | this output matrix here which is an output block matrix and later we will use this pseudo code to
01:23:23.120 | adjust the error that we have made in some blocks in case the future blocks so the p13 has a better
01:23:30.640 | maximum than p11 or p12 so to compute this output matrix o we go through so for example to compute
01:23:39.280 | the first row we choose well p11 is what is is let's go back p11 is let me delete also this one
01:23:46.880 | it's not needed anymore p11 is the softmax star of q1 k1 p12 is the softmax star of q1 k2 p13 is
01:23:58.960 | the softmax star of q1 k3 p14 is the softmax star of q1 k4 which means that to compute this block
01:24:10.960 | here here we first need to compute the p11 what is p11 well p11 is the softmax star of a block of q
01:24:19.520 | and another block of k which in the case of the first row of the output matrix means that
01:24:26.320 | it is the query 1 with the softmax star of the query 1 with q1 the softmax star of the query 1
01:24:34.160 | with k2 the softmax star of the query 1 with k3 the softmax star of the query 1 with k4 which
01:24:39.920 | means that we need to go we need to make a for loop through all the keys while keeping the query
01:24:46.400 | fixed so to compute the first output row we need to do the softmax star to produce p11 we need to
01:24:53.680 | do the softmax star of query 1 k1 and we sum it initially to zeros because we don't we need to
01:25:02.640 | initialize our output somehow and we initialize it with zeros then we sum the next p12 which is
01:25:10.320 | the query 1 with the k2 and then we sum the next p13 which is a query 1 with the k3 etc etc that's
01:25:16.960 | why we have this inner loop here all right so however this output that we are computing is
01:25:23.280 | wrong because i told you we have computed the softmax star using statistics the maximum value
01:25:30.000 | that is belonging to each block and not the one that is the overall row of the original matrix
01:25:35.360 | how to fix that we have a tool actually we have computed before an algorithm called the
01:25:41.520 | online softmax i don't know if i referred to it before as the online softmax but it's called the
01:25:45.680 | online softmax that allows to fix previous iterations when we are computing the current
01:25:52.400 | iteration based how well let's review the online softmax we start imagine we are working with one
01:26:00.560 | single vector so we are a vector made up of n elements what we do is we do a for loop where
01:26:07.600 | we compute iteratively the maximum up to the height element and we fix the normalization factor
01:26:15.920 | computed in previous iteration in case we found a better maximum at the current element
01:26:22.480 | if this is not clear guys go back and watch the online software because this is very important
01:26:27.360 | because this is what we are going to use to fix this p11 p12 blocks in case we found better
01:26:33.280 | maximum in p13 or p14 etc so let's see how to apply this online softmax to this case here
01:26:42.240 | so that we can compute so you may be wondering why are we going through all these troubles i mean
01:26:49.360 | why the real reason is when first of all why did we introduce block matrix multiplication
01:26:55.920 | because we want to compute matrix multiplication in parallel so you can think that each of this p11
01:27:02.000 | because they are independent from each other and because each of them are using the maximum
01:27:06.720 | belonging to each block they can be computed independently from each other then however we
01:27:11.760 | need to somehow aggregate their value and to aggregate the value we need to fix the values
01:27:17.920 | that have been calculated independently because we didn't when computing values independently we
01:27:22.480 | don't have a global view we have a local view so we compute local blocks p11 p12 p13 etc etc
01:27:29.440 | and then when we aggregate these values we need to fix them so that's why we are trying to come
01:27:36.240 | up with this system of fixing values that have been calculated independently so how to fix this
01:27:44.000 | let's look at the following algorithm first of all this o block here as i said before it is a block
01:27:52.400 | of two rows where each row is made up of 128 dimensions and we have seen that before by
01:27:59.600 | checking the dimensions of p11 and v1 the result of p11 v1 which means that for each output block
01:28:06.880 | we need to take care of two maximums and two normalization factors so up to now i didn't use
01:28:13.840 | the normalization factor we said that we are applying the softmax star which is the softmax
01:28:17.600 | without the normalization but eventually we will need to compute this normalization so we want to
01:28:23.760 | create an algorithm that fixes the maximum used to compute each of this p11 and also computes
01:28:30.640 | simultaneously the normalization factor and at the end we will apply this normalization factor
01:28:36.240 | and the way we will do it is as follows we start with initializing the maximum to minus infinity
01:28:43.040 | one for each row that we are computing so our output block is made up of two rows so we need
01:28:47.840 | one maximum for the top row and one maximum for the bottom row and also the normalization factor
01:28:53.600 | which we initialize with zero because we didn't sum anything for now and the output we initialize
01:28:58.880 | it with all zeros because we didn't sum anything to this output for now we compute the we uh to
01:29:07.040 | compute the output row so this output block here so this output block here we need to go through
01:29:14.320 | all the keys uh to produce this p11 p12 p13 p14 while the query is the query number one the query
01:29:21.840 | block number one so the first step that we do is we compute the maximum of the first block p11
01:29:30.880 | which is the row max so the maximum for each row of the block q1 k1 this is not p11 it's s1 sorry
01:29:42.400 | guys this is s11 so we compute the maximum of this one and we call it actually s1 as you can see here
01:29:50.800 | and then we can calculate p11 which is the softmax star which is the exponential of the
01:30:00.080 | query multiple query one k1 so s1 minus the maximum in the local group s1 and we add it to our output
01:30:11.920 | for now the output is initialized with zero so for now ignore this part here i will explain it later
01:30:19.440 | so for now all one should be equal only to p11 v1 now at the step number two we may find in the
01:30:29.760 | local group s12 so this one is s12 we may find a better maximum for the top row and the bottom row
01:30:39.040 | and this maximum is the m2 which may be better than the previous maximum for each of these two
01:30:46.320 | row but may also not be so we need to find a way to fix in case it's better and to not fix anything
01:30:52.720 | in case it's not better and the way we do it is this so we compute the new maximum of the current
01:30:59.200 | local row query two we calculate the p12 which is the softmax star of s2 which is s2 minus m2 which
01:31:11.200 | is the local maximum and then we need to add it to the output however in this case we may have found
01:31:18.800 | a better maximum so how to fix the o1 which only used the maximum that was local to s1 well we know
01:31:28.400 | that we can fix that by using exponentials because each of this element of o1 is just an exponential
01:31:35.920 | without the normalization because we are applying softmax star so how to fix an exponential with
01:31:41.920 | another exponential so basically we are saying that we multiply o1 which is a matrix so let me
01:31:49.920 | show you what is this matrix so o1 is a matrix made up of two rows so as you can see here i
01:31:58.160 | have the shape of o1 it's a 2x128 matrix so this is the top row so o11 o12 blah blah until o1 128
01:32:11.680 | then o21 o22 blah blah and o2128 we need to fix this value how we basically just using the
01:32:25.120 | exponential that we have used in the online softmax that we have seen before so if we multiply this
01:32:31.120 | matrix here by a diagonal matrix that is made as follows it's a diagonal matrix made up of two
01:32:39.280 | elements because the exponential of m1 minus m2 will be a vector of two elements and exponential
01:32:46.160 | of a element wise exponential is another vector of two elements and this basically means that
01:32:53.280 | diagonal matrix where in the diagonal we have the elements of the vector to which we are applying
01:32:58.800 | this diag operation which means that this value here will be the exponential of the first element
01:33:06.000 | of m1 so let me show you how to write it exponential of m1 minus m2 minus m2 so the
01:33:21.840 | first element so let's call it one here here is a zero here will be zero and let's delete this one
01:33:28.720 | and we write another one here exponential m1 minus m2 but the second element of this vector
01:33:38.320 | so basically the diag this notation here diag means basically just take the vector and distribute
01:33:44.560 | it over a n by n matrix where n is the size of the vector to which is applied and all the other
01:33:50.720 | elements of this matrix should be zeros this is what this diag means if we do this operation here
01:33:57.040 | we will see that the output of this multiplication will fix each element of the top row using this
01:34:06.400 | exponential and the bottom row with this exponential which will basically cancel out
01:34:12.880 | this m1 that was computed in the previous iteration and introduce the m2 that we have computed in the
01:34:18.720 | current iteration in each of these elements in this o block matrix okay so this output will be
01:34:32.240 | this element will multiply by this one so it will fix o11 with this factor here and o21 will
01:34:41.520 | not be fixed by will be multiplied by zero so it will not contribute to this first output element
01:34:47.920 | so this element here will only depend on o11 fixed by the exponential of m1 minus m2 but the
01:34:56.400 | first element of this vector and then o12 will also be fixed by um o12 will be fixed by this
01:35:04.240 | exponential here but not by this one and all the dimensions of the first row will be fixed by this
01:35:08.800 | exponential and all the dimensions of the second row here will be fixed by this exponential here
01:35:15.920 | this this scalar here which is the second element of the vector exp of m1 minus m2
01:35:22.240 | okay it was really challenging this one so so what we are doing is we compute p12 and we fix
01:35:32.400 | all the elements in p1 by multiplying by this matrix here by multiplying by this factor here
01:35:39.520 | matrix factor here and when we will compute step 3 we will fix step 2 etc etc etc now let's talk
01:35:48.240 | about the normalization factor because for now we have been ignoring it the normalization factor
01:35:54.160 | is something that we can compute while computing these maximums because it is provided in the
01:36:00.560 | pseudocode of the online algorithm that we have seen before for the softmax so while computing
01:36:05.680 | the maximum we can actually compute the normalization factor by fixing the normalization
01:36:11.040 | factor of the previous iteration and this is exactly what we are doing here so at the first
01:36:16.000 | iteration we compute the normalization factor using the local maximum and at the second iteration so
01:36:22.000 | you can for now ignore uh this one because we are not fixing l0 with anything because l0 will be 0
01:36:27.840 | so we are just basically um we are just computing this summation here so l0 will be 0 so this
01:36:35.520 | factor here will be 0 um and when computing l2 so the normalization step at the second iteration
01:36:43.120 | we will fix l1 with an exponential which guess what it's exactly the same exponential that fixes
01:36:50.800 | the maximum uh the p11 so it is the previous estimation of the maximum minus the current
01:36:57.920 | estimation of the maximum plus the new uh normalization factor using the local maximum
01:37:04.320 | and we keep doing this job at the end we will obtain a correct output for this uh matrix for
01:37:13.760 | for this block here but without the normalization how to apply the normalization well the
01:37:20.400 | normalization is something that is we need to divide each element of this o by the normalization
01:37:27.200 | factor but because we are keeping while iterating through these four loops we also calculate the
01:37:34.400 | normalization factor we keep accumulating it until we reach the end of the iteration and then
01:37:40.240 | we apply the normalization factor so we take the last output and we just divide it by l4 which is
01:37:46.240 | the normalization factor calculated as the fourth iteration and that will fix the softmax all right
01:37:53.760 | guys so now that we have derived the algorithm of how to compute this output of the attention
01:38:00.000 | blockwise while also fixing the softmax which is done independently in each single block
01:38:05.600 | we know that the normalization is done at the end i want to also prove it so what we done when
01:38:12.560 | we introduced this algorithm that computes the softmax in an online way we proved by induction
01:38:19.120 | that this algorithm is correct so at the end of this algorithm this l of the last iteration will
01:38:25.760 | actually be the normalization factor that we can apply to get the softmax so we don't apply the
01:38:33.360 | normalization while computing this output in an online way iteratively way by multiplying the
01:38:39.600 | query with all the blocks of keys we apply it at the end of this four iteration and at the end of
01:38:47.840 | this four iteration we will have the last output and we also know that the last l will contain
01:38:55.280 | exact normalization factor that we need to apply to each row because this o of four is a block
01:39:02.160 | of output rows which is if you remember from the attention mechanism each output the output
01:39:09.920 | of the attention has the same shape as the input query vector which is a sequence of tokens so this
01:39:17.280 | o is a sequence of tokens that we need to apply the normalization to and we know that the correct
01:39:23.600 | factor is l4 so let's prove this simple formula l4 is a vector one for that contains as many
01:39:31.200 | elements as there are rows in o4 so in this o block of rows suppose that it contains two rows
01:39:42.080 | like in the algorithm that i have described so far in which we pretend that we are grouping
01:39:46.480 | two rows of queries with two columns of keys together so the output o the block o will contain
01:39:54.960 | two rows of the output so we will have two normalization factor in this l4 vector here
01:40:01.600 | what we are doing with this formula is we are taking this l4 vector and we are creating a
01:40:08.480 | diagonal matrix with it and then we are computing the inverse of this diagonal matrix so l4 is a
01:40:16.240 | vector that contains two normalization factors so it's l i don't know let's call it l l4 element 1
01:40:24.720 | and l4 element 2 this is our l4 vector then we have o4 o4 is a matrix as you can see from the
01:40:38.160 | shape is a 2 by 128 matrix so o is let's copy it actually oh no it's not copied o4 is a matrix that
01:40:48.720 | is two rows with 128 elements so the first row with 128 dimensions and the second row with 128
01:40:58.400 | dimensions the first thing that we are doing with this l4 is we are converting it into a diagonal
01:41:04.080 | matrix which will be a diagonal matrix 2 by 2 because it contains two elements so it will become
01:41:10.240 | something like this so it will be l4 the first element of l4 0 and then 0 l4 the second element
01:41:19.520 | of this vector then we are computing the inverse of this matrix the inverse of a diagonal matrix
01:41:26.640 | is just the diagonal matrix with each element on the diagonal that becomes its reciprocal
01:41:33.760 | this is from linear algebra it's not i'm making it i'm making this up so
01:41:38.320 | the inverse of this matrix here is equal to the same diagonal matrix but where each element is
01:41:48.720 | 1 over l4 the first element of l4 0 0 and 1 over l4 the second element of l4 and then we are
01:42:00.000 | multiplying this stuff here so let me delete some stuff so this stuff here is getting multiplied
01:42:06.880 | by o which is a matrix that is a 2 by 128 so we are doing this multiplication now
01:42:15.440 | multiply now the output of this so this is a 2 let me write it 2 by 2 multiplied by 2 by 128
01:42:27.920 | will be a matrix that is 2 by 128 where the first dimension of the first row of the output of this
01:42:40.640 | operation will be the dot product of this call this row here with the first column so basically
01:42:49.600 | we are dividing this element here by l4 the first element of l4 the second output element here will
01:42:56.240 | be the dot product of this row with the second column so we are only multiplying we are dividing
01:43:02.160 | the the second element here of this input vector here by l4 the first element of l4 because the
01:43:09.680 | all the elements of the second row will be multiplied by 0 so they will not contribute
01:43:13.440 | to this output row while the second output row will be the dot this element here will be the
01:43:18.480 | dot product of this row with the first column the first element here is multiplied by 0 so it will
01:43:25.280 | not contribute to this output so it's only the second element the first row of the second this
01:43:30.800 | first element of the second row of the input matrix here will be divided by l4 2 so basically
01:43:37.040 | this will be applied will divide all the elements in the second row and this will divide all the
01:43:42.720 | element in the first row in producing this one here which is exactly what we need to do when
01:43:47.040 | we want to normalize we need to apply this normalization factor and this should help you
01:43:52.080 | better visualize why this operation is normalizing the vectors of the output at the end and still
01:43:57.920 | obtaining the same result now let's proceed further all right guys finally we are ready to
01:44:04.400 | see the flash attention forward pass by also comparing it with what we have derived so far
01:44:11.120 | so if you look at the flash attention paper first of all this is the flash attention 2 forward pass
01:44:16.720 | and later i will explain what are the differences between the flash attention 1 and the flash
01:44:20.240 | attention 2 i didn't want to jump directly to this forward pass because i believe that even if the
01:44:27.440 | derivation like the derivation was a little uh difficult to follow i believe that it gave you
01:44:33.280 | some intuition into what is happening so even if you understand 50 percent of it that's enough
01:44:37.760 | because later we will also code it and you should reach like a 90 percent of understanding so every
01:44:43.200 | time we introduce some new information it should improve your your understanding so basically in
01:44:49.600 | flash attention what we are flash attention 2 especially we take our as input we have our
01:44:56.400 | query key and values which are a sequence of tokens each token is made up of a vector
01:45:01.760 | of d dimensions and the d lowercase d dimensions and we divide this query guess what into blocks
01:45:09.920 | in how many blocks well depending on this parameter br which is the size of the
01:45:16.800 | query block that we want to choose so how many rows of query we want to group together into one
01:45:22.400 | block and we also do it with the k and v and we divided that into blocks of depending on this
01:45:32.080 | parameter bc then we also initialize the output which is the output that we want to produce so
01:45:38.480 | what is the flash attention computing well the flash attention is computing the following so
01:45:43.760 | it's computing the softmax softmax of the query multiplied by the transpose of the keys divided
01:45:51.280 | by the some normalization factor multiply that by b and so that's what it's going to compute
01:45:59.920 | and it's going to compute it this way first of all there is an outer loop through the queries
01:46:05.520 | which corresponds to the same pseudo code that we have seen before because we want to compute
01:46:11.520 | each block of the output matrix in parallel with the with respect to the others so basically we
01:46:19.120 | want to compute this output block and this block output block independently this output block here
01:46:26.480 | depends on the query one and all the keys this output block here depends on the k query two and
01:46:33.040 | all the keys this output block here depends on the query three and all the keys where query one
01:46:38.560 | is not the first query but it's the first group of queries or first block of queries query two is
01:46:44.320 | not the first query two is not the second row of the very metric but it's the second block of the
01:46:49.360 | query matrix etc etc and so that's why we have this outer iteration among all the blocks because
01:46:59.760 | we want to compute all those blocks of the output matrix in parallel but to compute each of this
01:47:05.360 | output block we need to go to an iteration among all the keys that's why we have an inner loop on
01:47:11.520 | the keys and we do exactly the same operation that we have done so far by hand so first we compute
01:47:18.560 | the s matrix which is what the each block of query with the corresponding block of the keys
01:47:24.080 | then we compute the local maximum to the current s block this is the local maximum and we
01:47:32.160 | compare it with the maximum of the previous iteration because that's what we do in the
01:47:36.400 | online softmax then we compute the p block which is the softmax star of the s block
01:47:45.200 | minus the local maximum of the s block then we compute the normalization factor what is the
01:47:53.440 | normalization factor it is the summation of all the exponential of the softmax star but
01:48:01.200 | by fixing the normalization factor of the previous step and we know how to fix the
01:48:07.760 | normalization factor because we just multiply by an exponential which is the previous maximum
01:48:13.680 | minus the current maximum that's what this factor is and then we compute the output exactly using
01:48:19.520 | the same correction factor that we have seen before which is the diagonal matrix made up of
01:48:24.720 | the diagonal where on the diagonal you have the elements of this vector here which is the
01:48:30.800 | exponential of the previous maximum minus the current maximum multiplied by the output of the
01:48:36.320 | previous step because we want to fix the previous step because it was based on the previous p which
01:48:40.640 | was using the maximum of the local previous p plus the current p v which is based on the current
01:48:47.920 | local maximum and it will be fixed by the next iteration okay and at the end after we have
01:48:55.520 | gone through all the case so we have computed all the output block but we didn't apply the
01:49:02.000 | normalization factor and it's applied at the end because while going through each key we are
01:49:07.200 | calculating the l normalization factor for the softmax because inside of this for loop we are
01:49:13.360 | just computing the softmax star so we are not normalizing each value so at the end someone has
01:49:18.720 | to normalize it and it will be this instruction here which is use the normalization factor that
01:49:28.640 | we have computed over all the iterations and apply it to each element of O because the difference
01:49:34.240 | between the softmax star and the actual softmax is just the division by the normalization factor
01:49:42.720 | and this instruction here is actually dividing each O with the corresponding normalization
01:49:48.080 | factor one for each row of the block each row in the output block that we are computing
01:49:54.800 | later we will see also what do we do what is what does it what is this SRAM what is the HBM
01:50:03.040 | for now i just want you to concentrate on the operations that we are doing and they are exactly
01:50:08.320 | the same operations that we have done so far later we will see also why do we need to save this stuff
01:50:14.160 | here and etc etc but for now you should have enough knowledge to be able to follow what is
01:50:20.880 | written in the flash attention paper for with respect to the forward pass algorithm and what
01:50:27.360 | we are doing basically is just block matrix multiplication and while computing this block
01:50:32.160 | we fix the previous block by using tricks of the exponential all right now that we have seen
01:50:39.920 | forward pass of the flash attention before we can implement it we still lack a little bit of
01:50:44.560 | knowledge because we don't know anything about the GPUs and we don't know anything about CUDA
01:50:49.360 | and we don't know anything about Triton so that's what we are going to see next all right guys it's
01:50:55.280 | time for us to explore finally the GPU and the CUDA programming model well let's start by comparing
01:51:02.240 | the CPU and the GPU and this will let us understand how CUDA works then so first of all what is the
01:51:07.440 | CUDA and what is the GPU the GPU is the hardware unit that we are that we buy and CUDA is a software
01:51:13.600 | stack made by made by NVIDIA to write software for this GPU that they sell AMD has its own software
01:51:21.760 | stack and other manufacturer have their own in this particular video we will be seeing example
01:51:27.280 | of CUDA kernels but the knowledge that you will get can apply also to other GPUs now the first
01:51:33.680 | difference between a CPU and the GPU is its purpose the your computer is right now running
01:51:40.240 | on a CPU and your operating system is interfacing with the CPU in using the the so-called scheduler
01:51:48.080 | so right now probably you are running a browser you are also running some other software on your
01:51:52.240 | computers on your computer and the scheduler is tasked with switching between them very fast on
01:51:58.720 | your CPU in such a way that it looks like to you that the processes are running concurrently
01:52:03.200 | this actually is a fake kind of parallelism unless your CPU also has multiple cores
01:52:10.560 | which nowadays CPUs do have so a CPU usually has one or multiple cores but not so many of them so
01:52:17.120 | usually have a dual core or quad core or eight core CPU and each of these cores can execute
01:52:23.680 | instructions in parallel the CPU is tasked the the main purpose of the CPU is to execute
01:52:32.320 | many different tasks and switching between them very fast so maybe you have a browser that is
01:52:38.240 | running a small game and then you have another movie player but then you have a word processor
01:52:43.440 | and then you maybe have some utility to manage your to download files etc so most of these
01:52:50.640 | programs actually are not compute intensive are actually I/O bound meaning that most of the time
01:52:55.120 | they are either waiting for the network or they are waiting for the disk and they are very different
01:53:00.640 | from each other in the purpose so the browser is completely different from a movie player
01:53:04.640 | and it's completely different from a word processor so the job of the CPU is to actually
01:53:10.720 | reduce the latencies of processing all these operations and it's highly optimized to process
01:53:17.040 | to optimize each of these execution units called the cores which means that each core has a part
01:53:24.000 | that is tasked to understand first of all what is the next instruction to run or to predict the
01:53:31.200 | branch of how the what the next operation may be based on the conditions that you are running for
01:53:37.920 | example if you have a if condition the branch predictor can understand what is the more most
01:53:42.560 | likely next instruction and can do some optimizations also the CPU is has a lot of
01:53:48.560 | caches to reduce the latencies in the loading data from all the devices it can interface with it can
01:53:54.000 | interface with the the RAM for sure but it can also interface with the disk it can also interface
01:53:59.680 | with some peripherals like the printer like the mouse like the keyboard etc etc on the other hand
01:54:05.680 | the GPU is not tasked to do many different things at the same time but it's tasked to do
01:54:10.960 | one thing or a few things but on a massive amount of data so the operations that we do on the GPU
01:54:18.160 | are requires a lot of computation and for that for this reason most of the area so the physical area
01:54:26.000 | of the GPU is dedicated to compute units so this green stuff that you can see here and these are
01:54:32.080 | called cores and you can see that the part that is dedicated to the control area so the part that is
01:54:38.720 | tasked with understanding what is the next instruction to run or to do some optimization
01:54:44.000 | in this the program is very little you may be thinking well does it make it does it make the
01:54:50.800 | GPU less fast compared to the GPU to the CPU well not really because we have many more cores that
01:54:56.720 | can compensate for these higher latencies okay i can give you a lot of knowledge about the GPU
01:55:04.800 | from a theoretical point of view i think the best way to understand the CUDA programming model is
01:55:09.040 | just to jump into the code so we don't get bored okay imagine we have a very simple task and we
01:55:15.280 | have a vector we have two vectors a and b and we want to calculate the sum of these two vectors
01:55:22.640 | into and save the result into another vector c where each item is the element wise sum of the
01:55:29.120 | corresponding item of a and b how would you proceed with this task on the CPU well you would
01:55:36.560 | do a for loop for example so for example you would make a for loop that starts from the first index
01:55:43.120 | so the index is zero and c of zero is equal to a of zero plus b of zero then c of one is equal to
01:55:50.560 | a of one plus b of one etc and you do a for loop on all the elements of this vector in the GPU we
01:55:59.600 | want to do the same operation but in parallel because we have a lot of compute units called
01:56:04.480 | cores and we want all of them to work in parallel so the first thing that we need to understand is
01:56:09.200 | how to divide the work that we are going to do into sub units of work and dedicate each core
01:56:16.640 | to one subunit one simple subdivision would be okay the first core should do this summation
01:56:22.640 | the second core should do this summation the third core should do the summation etc etc so
01:56:28.720 | imagine we have a eight element vector we need eight cores to do this element wise summation
01:56:36.640 | we will call the course threads because it should also remind you of the multi-threading
01:56:42.000 | that we already use in operating system so multiple threads work concurrently on the same
01:56:46.800 | on the same or similar job in the GPU let's look at the code now the code that i am going to show
01:56:53.520 | you is a CUDA kernel and it's written in c but you don't have to understand c and you don't have to
01:56:59.040 | understand this code what i want you to understand is the intuition behind it because later we will
01:57:03.760 | need this knowledge and convert it into triton which is python and you should already be familiar
01:57:08.000 | with python so let's go to the code and i have a very simple vector addition we can see it here
01:57:16.240 | okay first of all how to do a vector summation usually the gpu is interfaced with a cpu
01:57:26.000 | and the cpu has to first of all tell the gpu what is the data it is going to work with so the cpu
01:57:33.600 | needs to have these vectors it needs to transfer them to the gpu then the gpu needs to do this
01:57:39.920 | vector summation then the cpu has to copy back the information from the output from the gpu to
01:57:45.920 | the cpu and then make it available to the program this is what we are going to do here so we are
01:57:52.080 | going to allocate a three vectors of size n one called a one called b and one is the output vector
01:58:00.560 | we initialize their items randomly so the a of i is a random number between 0 and 100 excluded
01:58:09.440 | then we allocate memory on the gpu to hold these vectors and then we copy them to the gpu so we
01:58:16.800 | copy the a vector to the gpu and the b vector to the gpu of course we don't copy the result because
01:58:21.760 | that's what we want the gpu to populate with the output so we just allocate it on the gpu what we
01:58:28.000 | don't copy our output vector on the gpu because it's it's made of random values then what we do
01:58:36.400 | is we launch the kernel the launching the kernel means that we launch a program that the gpu should
01:58:41.680 | execute in parallel on multiple threads or multiple cores each of these threads should do a unit of
01:58:48.000 | operation a unit of work that is independent from the others actually they can be dependent on the
01:58:52.800 | others and but we will not be talking about synchronization so we launch this kernel and
01:58:58.720 | what we are saying in this line is launch one block of threads and later we will see what are
01:59:04.960 | blocks but you can think of you can ignore this one for now what we are saying here is launch n
01:59:10.320 | threads so n parallel operations on with the following arguments so the output where we want
01:59:18.160 | to save data the input array a and the b input b and the number of elements let's see what happens
01:59:25.760 | inside of this method this method is following a particular syntax that is um how to say CUDA
01:59:33.360 | specific so this global is actually added it's a like a superset of the c language where we have
01:59:39.120 | some additional keywords that belong to CUDA so it's not really c it's CUDA c so it's a very
01:59:48.080 | simple method as you can see and the first thing that we need to do is CUDA cannot know what each
01:59:55.760 | thread should do it's we should tell each thread what to do so the mapping between the data and
02:00:02.400 | the what each thread should do it's up to us as software engineer CUDA what we'll do is when we
02:00:08.880 | ask it to launch n threads in parallel it will allocate n threads and assign a unique identifier
02:00:15.600 | to each of these threads in our simple case we can see it like this so it will assign the first
02:00:23.120 | thread the index zero so we are asking for example imagine we have a vector of eight elements it will
02:00:29.760 | assign the first thread index zero here i call it one but it's it's wrong but we can write another
02:00:36.480 | number here so this will be actually thread zero this will be thread one this will be thread two
02:00:41.120 | thread three thread four thread five thread six and thread seven so let me delete this one
02:00:48.080 | so we don't get confused
02:00:51.680 | and what we are saying is that the item that each thread should process is equal to its thread
02:01:08.160 | index so this is the thread zero so it should process the item with index zero this is the
02:01:14.240 | thread one and it should process the item with index one this is the thread number two and it
02:01:19.040 | should process the item with index two and this is what we are doing in this line of code we are
02:01:24.480 | saying which item each thread should process which is exactly the thread identifier so the thread id
02:01:33.120 | later we will see why why do we have this dot x but that's for later next thing that you should
02:01:40.080 | see is okay we are doing the output of the height position is equal to the a vector as the height
02:01:48.240 | position plus the b vector as the height position so it's a very simple summation element wise
02:01:52.480 | you may have noticed this if statement why do we need an if statement if we already know that we
02:01:58.640 | are going to launch eight threads and of course i will be between um we already know that we are
02:02:07.040 | going to launch n threads so i should of course be less than n because each thread id will be
02:02:12.160 | between zero and n minus one so why do we need this if condition this is needed because when you
02:02:18.640 | CUDA when it launches a number of threads this number of threads is always a multiple of a
02:02:26.720 | unit which is a 32 in the case of the CUDA so if we have like 34 elements in a vector
02:02:34.080 | and we ask CUDA to launch 34 threads CUDA will not launch 34 exactly it will launch 64 threads
02:02:40.960 | so multiple of 32 which is the warp size by the way um and uh what we need to do is we need to
02:02:49.840 | ask these threads to only work for we only need to ask the threads that have a corresponding element
02:02:58.240 | to work and all the others that don't have a corresponding element because the the vector
02:03:03.120 | is not large enough for all of them to not do anything so do not enter this uh if statement
02:03:08.400 | there is another thing that we should learn which is actually the threads
02:03:16.320 | actually when we have a group of threads in in a CUDA programming model but i believe
02:03:21.360 | also in other GPUs a group of threads of 32 threads is called a warp and this 32 threads
02:03:30.480 | will share the same control unit so let's go back to the slide so as you so as you can see here we
02:03:39.600 | have this yellow unit here in the GPU and a group of threads will share the same control unit which
02:03:45.280 | means that what is this control unit it's a part of the hardware of the GPU that is tasked with
02:03:50.960 | understanding what is the next instruction to run now if the group of threads is sharing the same
02:03:57.120 | unit it means that this group of thread will always execute the same statement at any time
02:04:03.520 | they will always work in synchrony will always work on the same instruction they it's it cannot
02:04:10.000 | be like this thread is working on one instruction and this one is working on another instruction
02:04:14.480 | what does this mean on a programming level it means that if when we launch a group of threads
02:04:21.200 | of course CUDA will spawn more threads than we need if the if the number of elements of our
02:04:26.720 | vector is not a multiple of 32 this means that when we did this thread they will first execute
02:04:33.520 | this operation and each of them will have its own value of this thread id so they will execute the
02:04:40.640 | same instruction but the data at each instruction may be different because each of them have their
02:04:45.840 | own registers which means that they will always they will for example reach this statement here
02:04:52.560 | and the first thread will have i equal to zero the second thread will have i equal to one etc
02:04:57.200 | etc even if they are executing the same instruction this programming model is called the single
02:05:01.760 | instruction multiple data CUDA likes to call it a single instruction multiple thread doesn't matter
02:05:07.120 | for us it just means that they will always execute the same instruction but the value of the
02:05:12.480 | variables may be different then after executing this statement they will reach this statement
02:05:19.040 | here the if statement and of course some of them will evaluate this statement to true and some of
02:05:24.640 | them will execute the statement to false which also means that some of them should enter this
02:05:30.480 | if statement and some of them should not enter this if statement however because the control
02:05:35.920 | unit is the same for all of them they will be forced to enter this if statement even if they
02:05:41.040 | should not so how CUDA manages this control divergence it will basically make work like
02:05:47.200 | this all the threads for which this if statement is equal to true will enter this if and will
02:05:52.960 | execute the instructions inside of this if and all the threads that have this statement equal to false
02:06:00.800 | so the condition of this if equal to false they will enter the for loop because they cannot not
02:06:06.080 | enter it because they should be always executing the same instruction at any time but they will
02:06:10.720 | just not do any operations inside of this for loop they will just sit idle this is um called the
02:06:18.000 | control divergence and it can reduce the um the the throughput of your program so you want to
02:06:23.920 | minimize it but you may be wondering why doesn't the gpu dedicate a control unit to each core so
02:06:31.360 | that they can work independently from each other because the control unit is expensive to add in
02:06:35.600 | the chip area of the gpu it's much more efficient to add more workers instead of adding a control
02:06:42.000 | area control units for each worker so this is a design choice of the gpu and it works fine
02:06:48.560 | okay now that we have seen how a kernel works let's move forward to another example
02:06:54.880 | all right the next example that we are going to see is the following is the same as the as before
02:07:02.000 | so we are going to do a vector addition but imagine that we have a very large vector so imagine that
02:07:07.360 | we have a vector with 1 million elements of course we could do like before so we launch a kernel with
02:07:13.760 | 1 million threads the problem is CUDA will reject it because it's a i don't have 1 million threads
02:07:20.240 | to run in parallel so how can we proceed in this case because usually we are working with very
02:07:25.040 | big matrices or very big vectors so we need to process a massive amount of data so how to manage
02:07:31.680 | a parallel um let's say parallel computation when we do not have enough uh computation cores
02:07:40.400 | one way is to divide the input vector into blocks of elements for example we may decide
02:07:48.000 | for example imagine our gpu only has 32 cores in total we may divide our input vector into
02:07:56.480 | blocks of size 32 such that the first 32 element are the first block the next 32 element are the
02:08:02.800 | second block the third 32 element the third block and the last 32 element are the last block
02:08:08.880 | in this way we can ask the gpu to work on one block at a time so we can say okay work on the
02:08:16.560 | first block and after it has processed the first block it can work on the second block and then
02:08:22.640 | the third block and the fourth block this also allows the gpu itself to manage a subunit of
02:08:29.200 | work because imagine now we have blocks of 32 elements but we have a gpu of 64 cores
02:08:36.320 | the gpu we can also schedule two blocks at the same time because it has enough cores
02:08:42.480 | so we need to give some granularity uh we need to reduce the ground increase the granularity of our
02:08:49.040 | data to let the gpu decide how many blocks to schedule this is the reason we introduce blocks
02:08:56.080 | inside of CUDA so let me make a concrete example but with a very simple assumption imagine our
02:09:02.400 | gpu only has two cores or let's say four cores actually so we have n is equal to eight elements
02:09:10.480 | eight and we have four cores in total so what we can do for example is to is divide this
02:09:20.320 | vector into groups of either four cores or even less let's say two two elements at a time so this
02:09:27.120 | is the block number one this is the block number two this is the block number three and this is
02:09:34.000 | the block number four we can ask CUDA to launch a kernel that is made up of four blocks and where
02:09:43.120 | each block is made up of two threads so when we launch the CUDA kernel we can show the code now
02:09:51.040 | we ask the CUDA where is the instruction this first instruction tells CUDA how many blocks
02:10:00.240 | we have and the second part of this in this symbols tells how many threads we have for each
02:10:09.920 | block in our case we want n divided by the block size number of blocks where the block size in my
02:10:18.560 | picture is two so how many blocks we will have we will have a number of blocks so the number of blocks
02:10:30.640 | is n divided by two where two is the block size so this is the block size and this will be equal
02:10:40.800 | to four blocks each of size equal to two and this is what we are doing here so we are saying that
02:10:48.880 | the number of blocks is okay the ceiling because it may not be a multiple of the block size n
02:10:55.200 | of n divided by the block size and this tells how many blocks we have and this is will be the
02:11:01.280 | this will define our grid it means the grid is basically telling how many blocks we have
02:11:05.920 | and then each block is made up of block size number of threads then the problem is how do
02:11:12.240 | we assign the work to do to each of these threads when we launch a kernel like this
02:11:18.800 | with this configuration so the number of blocks and the number of threads per block CUDA will do
02:11:24.720 | the following job it will assign this block each block a index called the block id where the block
02:11:34.880 | id of the first block is zero so let me write here so this will have the first block will have a
02:11:40.560 | block id equal to zero and in each block it will assign a thread id and the thread id of the first
02:11:49.920 | thread of each block will be the thread zero and the second thread will be the thread number one
02:11:55.440 | the second block will have a block id block id equal to one and the first thread of this block
02:12:03.600 | will be the thread number zero and the second thread of this block will be the thread number one
02:12:07.840 | the third block will have a block id block id equal to two and the first thread will be the
02:12:16.160 | thread number zero and the second thread will be thread number one etc until the last block which
02:12:21.760 | will be equal to three this will be thread number zero and thread number one the problem is now
02:12:27.280 | based only on the index of the block and the index of the thread how can we map it to what element of
02:12:34.480 | the vector each thread should work with one simple assignment would be to just do well you can see
02:12:41.920 | that in this case we need the this vector this thread here to work with element zero this one
02:12:49.200 | should work with element one this one should work with element number two this one to the element
02:12:54.560 | number three this one four this one five six and seven this five is so ugly so let me write it
02:13:05.280 | again how can we find the mapping given only the block id and the thread id how can we find which
02:13:13.200 | element it should correspond to well it's very simple formula so you can see that the element
02:13:18.080 | let's call it the element id which in the code i call it i is equal to the block id
02:13:28.640 | multiplied by the size of each block which is a block size let's call it
02:13:33.200 | block size yeah i have it block size plus the thread id
02:13:40.400 | because in the case of the first thread this will be equal to zero multiplied by two plus zero which
02:13:49.280 | is zero in this case it will be equal to zero multiplied by two which is zero plus one and
02:13:54.640 | it will be equal to one in this case it will be equal to one because block id is equal to one one
02:13:59.600 | multiplied by two is equal to two plus zero is equal to two etc etc and you can see that this
02:14:04.400 | formula works for all the threads so the mapping when we launch a CUDA kernel we are telling the
02:14:10.800 | gpu how many blocks we want and how many threads there are in each block but CUDA has no notion of
02:14:17.760 | how to map each CUDA has no way of knowing how to map each thread into the element it should work
02:14:28.720 | with that's up to us and that's what we are doing here when we are creating this kernel here so we
02:14:36.560 | are telling that each element each thread should work with the ith element of the vector where i
02:14:43.360 | is calculated as follows the block id to which this thread belongs multiplied by the block size
02:14:48.960 | so how many threads there are in each block plus the thread id and this will tell the ith element
02:14:56.400 | this particular thread should work with by giving in let's go back to the slides by choosing the
02:15:06.240 | block size equal to two and having four cores the gpu can choose to run one block or two block
02:15:14.000 | concurrently if it has enough free cores so that's why we want to work with by block by block because
02:15:20.080 | it allows the gpu to choose how it want to parallelize the operations if it has enough
02:15:25.120 | cores and we don't need to have n cores for n element vector we can divide it into smaller
02:15:32.240 | blocks and let the gpu manage the scheduling let's see one last example and then we move on
02:15:37.360 | to triton imagine now we want to do a matrix addition instead of doing a vector addition
02:15:43.360 | now in a matrix addition we have data that we can see on two axes one is the rows and one is
02:15:51.200 | the columns it's usually we represent the vertical axis as the y-axis and the horizontal axis as the
02:16:00.080 | x-axis by using the same blocked intuition that we used before so dividing the data input data
02:16:09.760 | into blocks this is how we can divide the labor of our matrix addition into blocks for example
02:16:16.880 | we can divide our rows into blocks and call this one the block zero and this one in the block one
02:16:23.120 | and this one is the block two the same we can do on the x-axis so we can choose this one as the
02:16:30.080 | block zero this one as the block one and this one as the block two on the x-axis with x is the column
02:16:36.320 | axis and the y is the row axis we don't even have to choose the same block size for the rows and the
02:16:43.440 | columns we can even choose the to group together three columns and two rows instead of doing two
02:16:50.400 | and two in this case we need to find because as we said before when we launch a CUDA kernel CUDA
02:16:56.800 | will just assign ids to the blocks and the threads in each block then it's up to us understanding
02:17:03.760 | what to how to map the id of the block and its corresponding thread id into the data element
02:17:09.360 | that this particular thread should work it should work with so in the case of matrix addition we
02:17:15.120 | could say that each thread should work with one output element of the output matrix c so it will
02:17:21.600 | become the sum of the a element plus the b element and it should map it to the c matrix output matrix
02:17:30.640 | so how to do it imagine we have six rows and we have six columns one easy way would be to divide
02:17:38.800 | these rows into three blocks each made up of two rows and each column into three blocks
02:17:45.920 | each block made up of two columns CUDA will launch as many blocks as there are the combinations of
02:17:55.200 | the rows and column blocks so in this case we have three blocks for the columns and three
02:18:00.720 | blocks for the rows so it will launch nine blocks so this is the block number 00 because it's a
02:18:10.000 | CUDA will identify the dimensions of the block based on the axis in which we have divided it
02:18:17.920 | so we will call this the x dimension the columns and the rows we will call it the y dimension
02:18:22.960 | so it will launch as many blocks as there are combinations of x and y in this case we have nine
02:18:28.240 | so this will be the block 00 this will be the block 01 this will be the block
02:18:33.040 | 02 this one will be the block 10 11 and 12 etc etc inside of each block we will also divide the
02:18:42.480 | threads into x threads and y threads along the two dimensions so this will be the thread 0 and
02:18:48.720 | the thread 1 along the x axis in the x block and this will be the thread 0 and the thread 1 in the
02:18:56.080 | y in the in the block 0 of the y axis and each block will have two threads and they will be
02:19:03.840 | identified as thread 0 and the thread 1 so let's look at how the launch grid works in this case
02:19:10.560 | so imagine we have a matrix with number num rows number of rows and num columns num calls number
02:19:19.360 | of columns and we want to divide each row the rows into block size number of rows and the calls
02:19:27.600 | block size number of columns we define basically the number of blocks that we need is this one so
02:19:35.440 | this is just a fancy way of writing the ceiling of the num rows divided by the rows block size
02:19:41.440 | and this is just a fancy way of writing the ceiling of the number of columns divided by the
02:19:46.800 | calls block size this tells us how many blocks we will have on the rows and how many we will
02:19:50.800 | have on the columns the grid you can see here which tells us how many blocks we have is a tuple
02:19:57.040 | that accepts three values which tells how many blocks we want on the x dimension how many we
02:20:02.960 | want on the y dimension and how many we want on the z dimension we are not going to use the z
02:20:08.480 | dimension because we only have a matrix then inside of each block how many threads we want
02:20:14.240 | for the x dimension and for the y dimension as the x dimension we have chosen the columns so we
02:20:20.720 | are saying how many blocks we want the columns and how many blocks we want for the rows and then
02:20:25.440 | inside of each block how many threads we want for the column block and how many threads we want for
02:20:30.720 | the row block this will define our launch grid and what CUDA will do it will just launch this
02:20:37.280 | following configuration so it will launch as many blocks as there are combinations of x and y's and
02:20:42.400 | inside of each x and y it will assign a thread id in such a way that the thread zero on the x-axis
02:20:50.720 | so there will be two threads on the x-axis and the two threads on the y-axis of each block
02:20:55.840 | now let's try to understand how to map just based on the block id on the x-axis just based on the
02:21:02.800 | block id on the y-axis and the thread id on the x and y-axis how to map it to the one element of
02:21:09.840 | the output matrix let's look at the code so first we can use the following formula to identify which
02:21:17.680 | row this element should work with which the which because each element of a matrix is identified by
02:21:26.160 | two indices one is the row identifier and one is the column identifier the row identifier we can
02:21:31.680 | look at it like the block id multiplied by the block size plus the thread id let's see why it
02:21:38.400 | makes sense so in this case for example this thread will work with the row zero because the
02:21:44.640 | block id is on the y-axis is zero and the thread id zero so it's a block id multiplied by the block
02:21:51.600 | size so zero plus zero it will be zero so this element will be working with the row number zero
02:21:58.160 | and which column it will be working with well it will be working with the block id zero multiplied
02:22:04.000 | by the block size on the column which is again zero i mean this block size is two but multiplied
02:22:09.760 | by zero it will be zero plus the thread zero so it will be zero this element here on the here
02:22:16.560 | it will be the block id of the y-axis multiplied by the block size plus the thread so it will be
02:22:24.720 | the element zero on the row and for the columns it will be the element one let's see another one
02:22:31.200 | for example here uh for example this element here so this um how this thread will uh which element
02:22:40.000 | it will work with well it will be the block size on the y-axis multiplied by the the block id on
02:22:45.760 | the y-axis multiplied by the block size so it will be one multiplied by two so that will be our row
02:22:51.840 | so the row number two uh which makes sense because it's the um this is the row zero this is the row
02:22:58.160 | one and this is the row two and the column will be the block id on the x-axis which in this case
02:23:07.040 | is equal to one multiplied by the block size which is equal to two so two plus one is equal to three
02:23:13.120 | so this thread here will work with the element number two three and this formula now makes sense
02:23:20.400 | so this is how we use the block id and the thread id inside of each block to map it to which element
02:23:26.720 | this particular thread should work with so as i said before cuda has no notion of knowing which
02:23:33.200 | element this particular thread should work with this is up to us just based on the block id and
02:23:38.640 | the thread id that cuda assigns then we make sure that the row index is less than the number of row
02:23:45.440 | and the column index is less the number of columns why because as i said before when we launch um
02:23:50.480 | blocks and threads cuda will round up that number to a multiple of 32 in the case of the threads
02:23:56.960 | so which means that some of this thread should not work with any data so we make sure that all
02:24:01.520 | the threads that should not have the corresponding element to work with they should be just sit idle
02:24:06.960 | inside of this if statement but the one that have it they should go enter and do some job so we
02:24:13.680 | calculate the index of the element of the matrix that this particular thread should work with as
02:24:20.960 | follows which is the row index multiplied by the number of columns plus the column index
02:24:26.080 | this is just another way of writing a or for example this is just another way of writing
02:24:36.160 | a of row index call index but the way we allocate arrays in c or c++ is a flattened array where all
02:24:49.680 | the rows are one after another so we need to identify the element inside of the array based
02:24:55.360 | on its row index and column index and this is the formula that we use to identify it
02:24:59.760 | if you have never worked with um arrays in c++ or c then it doesn't matter because later we will see
02:25:08.800 | tensor layouts and this will be much more clear but if you have already worked with then you
02:25:13.280 | already know how to index an element inside of a multi-dimensional array in c++ and then we compute
02:25:20.400 | the output as as usual so i know that this has been a lot of information so what should we
02:25:26.560 | should we remember from this the first thing that we should remember is that we decide how to divide
02:25:31.040 | the work on whatever matrix we are working with or whatever thread we are working whatever vector
02:25:36.720 | we are working with we tell cuda how many blocks we want and we tell cuda how many threads we want
02:25:42.000 | in each block based on the identifier of the block id and the thread id we should come up with a
02:25:48.080 | strategy on how to map it to a subunit of work so which part of the matrix or which part of the
02:25:53.760 | vector that particular thread should work with um now the next step for us is to understand the
02:26:00.880 | tensor layouts because we are going to work with the tensors and we need to understand how the
02:26:06.240 | tensors are layout in the memory of the gpu or in the cpu as well actually so we need to understand
02:26:13.120 | what is the row column row major layout and the column major layout what is the stride etc
02:26:18.800 | and convert all the knowledge that we have about cuda into triton so that we can then code with
02:26:24.400 | triton our kernel so let's go all right guys finally it's time for us to explore tensor layouts
02:26:32.640 | now why do we need to explore tensor layouts because before we we have seen some examples
02:26:38.480 | of cuda kernels and when you give a matrix to cuda or to a cuda kernel or a vector to cuda
02:26:46.160 | kernel cuda will not give you will not give you the entire matrix like like in python where you
02:26:51.440 | can access each element by its index cuda will just give you a pointer a pointer to the starting
02:26:57.600 | element of that particular matrix or the starting element of that particular vector then it's up to
02:27:03.440 | you to calculate the memory address of all the remaining elements so suppose that we have a
02:27:09.120 | simple vector in pytorch this simple vector could be the following which is a vector of shape 7
02:27:16.480 | because it's a tensor with only one dimension with shape 7 which is the number of elements in the
02:27:21.200 | first dimension for now ignore this property called the stride and later i will explain it
02:27:26.880 | what is it how this tensor will be saved in the memory of the cpu or in the gpu it will be saved
02:27:34.080 | as follows suppose that the starting address of the first element is the address 100 and suppose
02:27:40.560 | that each element is made up of a floating point of 16 bit so it means that each element will occupy
02:27:46.640 | two bytes so the start address of the second element will be the address 102 and the third
02:27:52.480 | element will be 104 and the fourth element will be 106 etc etc etc so this is exactly what you get
02:28:00.720 | when you in c you get you allocate a vector or a matrix with malloc so when you allocate in c a
02:28:08.240 | vector or a memory with malloc c or the memory allocator will just allocate enough memory to
02:28:15.520 | store all the elements and it will give you a pointer to the start address of this memory
02:28:20.400 | then it's up to you to understand where each of these elements is stored in that block of memory
02:28:25.200 | and this is to to do this we introduce a property called the stride the stride tells us how many
02:28:32.640 | elements we need to skip to arrive to the next element in the particular dimension in this case
02:28:38.640 | for example in the case of a vector we only have one dimension which is the x dimension
02:28:43.760 | or the columns dimension you can think of it so this is the first column this is the second the
02:28:49.120 | third the fourth fifth etc etc um so in order to arrive from one element to the next we just need
02:28:55.040 | to skip one element so to go from here we need to just increase our pointer by one element and
02:29:00.320 | then to go here we need to increase again pointer by one element etc this allow us to do a for loop
02:29:06.320 | on this tensor let's look at a more complicated case like the matrix so the matrix is a two
02:29:13.040 | dimensional and suppose we have the following matrix which is made up of six elements with
02:29:18.800 | two rows and three columns so the shape of this tensor will be two by three because if we have
02:29:24.400 | two rows and three columns how this matrix will be saved in the memory in the memory it will be
02:29:31.840 | just a flattened matrix it means and this is called the row major layout but there is also
02:29:38.400 | another one called column major layout that we will not be discussing so how it will be stored
02:29:44.720 | in the memory is as follows it will be the first elements of the first row so the elements of the
02:29:50.000 | first row followed immediately by the elements of the second row so that the memory address
02:29:56.400 | imagine with this the memory address of the first element is 62 to go to the next element
02:30:01.200 | we need to increase the memory address by the number of bytes that each element occupies which
02:30:06.080 | is two bytes so the the address of the second element will be 64 the third element will be 66
02:30:12.960 | and the next row will start immediately after the end of the first row let's introduce this
02:30:19.200 | property stride so the stride is what the stride tells us how many elements you need to skip in
02:30:24.560 | each dimension to arrive to the next element of that dimension for example imagine we want to
02:30:32.320 | address we want to get the element so all the elements of the first row
02:30:36.640 | so let's call this tensor here let's call it t so t of zero and this basically this indexing here
02:30:49.040 | says give me all the elements of the first row so in the first row select the all only the first
02:30:54.400 | row and give me all the elements of that row how to how does this indexing work well by starting
02:31:01.040 | from the pointer to the first element it will select only the first row and then it will move
02:31:08.560 | the index here one element after another so it will select the first one the second one the third
02:31:15.760 | one how does it know that it needs to move one element by one element because in this dimension
02:31:20.560 | the stride is one so the stride tells us how many elements you need to skip to arrive to the next
02:31:25.920 | element in that dimension imagine now that we want to get the t of let's say zero and one well
02:31:36.240 | in this case let's say t of one actually and all the elements of the first row it will first of all
02:31:44.560 | it needs to skip some elements from the first dimension it needs to skip the element zero
02:31:49.680 | because we don't we are not selecting it we only want to select the element one of the first
02:31:54.080 | dimension which basically means the row with index one so because it will start from the first
02:32:00.640 | pointer to the first element it will it needs to know how many elements to skip and how many
02:32:06.160 | element to skip is given by the stride so the stride tells us how many elements you need to
02:32:10.480 | skip to arrive to the next element of the first dimension so in this case it will take the pointer
02:32:15.520 | to the first element skip three elements and it will be starting with the second row and then
02:32:20.640 | inside this row it will go through the second in the the index of the second dimension in which
02:32:26.160 | the stride is one so it will just go one after another and it will return only this part of the
02:32:31.040 | memory so to rehearse the stride is just a a number that tells us how many elements you need to skip
02:32:40.880 | in each dimension to arrive to the next index in that dimension so it means that to go from one
02:32:46.880 | row to the other we need to skip three elements to go from one column to the other we need to skip
02:32:51.200 | one element why is the stride useful well the stride is useful because it allows us to reshape
02:32:59.360 | tensors very easily and without doing any computation let's see okay imagine we want to
02:33:06.080 | reshape a matrix imagine initially the shape of this matrix is a two by three so we have a two
02:33:11.680 | row by three columns and we have a stride calculated as follows means that to go from one row to the
02:33:16.800 | other you need to skip three elements and to go from one column one row to the other you need to
02:33:20.960 | skip three elements and to go from one column to the next you need to skip one element so you need
02:33:27.040 | to jump by one element if we want to reshape it into this shape so three by two basically we want
02:33:33.600 | to have three rows and two columns we can reshape it without actually changing its memory layout
02:33:45.760 | just by changing the stride because look at this physical configuration of the tensor and we can
02:33:53.280 | access this same tensor as this shape or as this shape exactly by using the same physical view
02:34:00.160 | because to go from one row to the next here the stride is a three so we need to skip three
02:34:06.400 | elements it means that the starting address the starting element of the second row is given by
02:34:12.720 | the start pointer plus three elements so exactly here the second row will start and each element
02:34:19.760 | of the second row is one after another because the stride of the second dimension is one so you can
02:34:26.000 | see that to get the second row we can just start from here and then go one after another and get
02:34:32.080 | all these elements which is exactly the second row suppose we want to obtain the second row of
02:34:37.120 | this view here of this shape of this reshaped matrix how to do that let's look at the stride
02:34:43.360 | the stride now is a two in the row it means that to go from one row to the next we need to skip
02:34:48.640 | two elements so if we want to select this row here we go from the starting point of the memory
02:34:55.360 | so this start pointer we skip the first two elements because the stride says that to go
02:35:01.920 | from one row to the next you need to skip two elements so we arrive here and then we select
02:35:06.480 | exactly two elements which are one after another because the stride in the second dimension is one
02:35:11.360 | so the stride allow us to reshape the tensor without changing the physical layout on how it
02:35:21.440 | is stored in the memory moreover the stride also allow us to get the transpose of a matrix without
02:35:29.280 | changing the shape of how it is stored in the memory so without changing the arrangement of
02:35:33.600 | the elements in the memory and this is very cool because we can view the same matrix as without
02:35:39.760 | the transpose and also the transpose version of the matrix without changing anything in the memory
02:35:44.400 | so it comes for free just by working with the index and the stride so to transpose the matrix
02:35:50.320 | along two dimensions we just need to swap the stride along these two dimensions that we want
02:35:54.480 | to transpose so in this case for example imagine we want to get the transpose of this matrix
02:35:59.440 | we just need to swap the strides so if we want to get the second row of the transpose matrix
02:36:04.880 | how to get that well you we always have the pointer to the first element where the tensor
02:36:11.280 | is stored so at the beginning of where the tensor is stored in the memory and it says that in order
02:36:18.240 | to go to from one row to the next we need to skip one element which is correct because as you can
02:36:25.680 | see the second element is exactly the second element also in the memory so we just skip by one
02:36:31.280 | and we get the starting point of the second row and then to go from one element to the next in
02:36:38.080 | within the same row we need to skip three elements so the second element of the second row will be
02:36:44.480 | after three elements after the first element of the second row so after two we need to skip three
02:36:50.880 | elements so we skip this one we skip this one and we arrive to this one eight which is exactly the
02:36:55.600 | second column of the second of the second row so basically the the stride as you can see allow us
02:37:02.160 | to do two things one is it allow us to reshape the tensor without having to reallocate it in
02:37:08.880 | another configuration in the memory secondly it allow us to transpose a matrix without having to
02:37:14.400 | rearrange the elements in the memory which is great because moving memory around is expensive
02:37:19.040 | and rearranging the memory is expensive so that it's great that this this stuff comes for free
02:37:25.360 | basically another thing okay for example if you try to you know that in pytorch there are two
02:37:33.760 | methods to reshape a tensor one is called the reshape method and one is called the view method
02:37:39.280 | the after transposing a matrix by swiping the by swiping the stride of the two dimensions that you
02:37:47.760 | want to transpose you cannot reshape for free the tensor anymore because um the tensor basically what
02:37:56.240 | is the stride the stride how it is computed the stride is just the uh let me show you with a
02:38:02.880 | concrete example the stride is just the product of all the shape uh after um in the future dimension
02:38:11.840 | so the stride of the zeroth dimension is just the product of the elements in the shape of
02:38:18.400 | the future dimension so the stride of zero is just the product of all the shape starting from
02:38:23.520 | the index number one uh it's not easy to see with the 2d matrix because we don't have enough elements
02:38:28.880 | so let's do it with a 3d matrix so this is a tensor with the three dimensions so it is a shape
02:38:35.360 | of two four three which means that we have two matrices each matrix is made up of four rows and
02:38:40.960 | each made and three columns the stride is calculated as follows so the zeroth dimension stride is just
02:38:49.280 | the product of four by three and this three here comes the with the product of just a three with
02:38:55.360 | its with one because we don't have any future dimension of the three so when we transpose
02:39:00.960 | this stride property is lost and we cannot um after transposing this matrix by swapping the strides we
02:39:09.840 | cannot do further reshaping operations so basically the the tensor is not log contiguous so
02:39:16.560 | this is a very advanced okay property if you it doesn't matter if you know it or not but if you
02:39:22.080 | are curious basically in pytorch you cannot um view a tensor after it has been transposed
02:39:29.600 | because the pytorch to transpose a tensor will just swap the two strides but it loses the stride
02:39:35.440 | property which is basically the stride will not be anymore the product of the future shapes so this
02:39:42.560 | is not anymore two this should be two for example and this should be one but after transposing this
02:39:50.160 | property is lost so you need to actually reallocate the tensor if you want to reshape it after it has
02:39:55.360 | been transposed it doesn't matter if you remember this it's just a curiosity anyway so what is the
02:40:01.200 | transposed what is the stride used for is the stride for the stride is used for two things
02:40:05.680 | first of all it is used to understand how to index this tensor so just by having a pointer to the
02:40:12.400 | first to the starting address of this tensor we can index this tensor however we like so we can
02:40:19.440 | access any row any column moreover it allow us to reshape this tensor for free so without
02:40:26.880 | rearranging the elements inside the memory and third it allow us to transpose the tensor however
02:40:33.280 | we like just by swapping the strides of two uh the two dimensions that we want to transpose
02:40:38.400 | now that we have seen also how the tensor is stored in the memory we can finally go to see
02:40:44.880 | triton um and see some examples all right guys now that we have seen how uh tensors work
02:40:53.440 | tensor layout works how CUDA works now we can see some examples of triton kernels to see how triton
02:40:59.040 | differs from CUDA now if you go on the triton website you will find some tutorials like in
02:41:07.600 | this section here and let's do let's work one tutorial together to understand how triton is
02:41:13.280 | different from CUDA so if you go to the tutorial there are many examples so first of all the code
02:41:19.200 | that i will be coding for flash attention is based on this tutorial here fused attention
02:41:23.040 | that you can see here but with some modifications because i simplified the code a lot i removed for
02:41:28.160 | example the fp8 implementation i also for example um this code here on the fused attention only
02:41:34.800 | works in the backward pass only for the causal attention while my code will work for the causal
02:41:39.360 | and non-causal attention uh the second another modification i did is instead of using the
02:41:44.720 | exponential tool that they use here to make things faster drawing because the exponential tool is
02:41:49.360 | implemented with a faster unit i i use the the original implementation of flash attention which
02:41:56.800 | use the exponential with the base e etc so i simplified my code as much as possible to make
02:42:02.800 | it simple to follow instead of making it optimized so for sure my code will be slower than the the
02:42:08.240 | fused attention that you see here but mine should be more comprehensible more easy to follow anyway
02:42:14.880 | let's go to the vector addition tutorial and if you go to the vector addition tutorial there are
02:42:20.080 | some examples on how to do a vector addition with triton this should allow you to get into the
02:42:25.280 | mindset of how to write kernels with triton instead of writing first the kernel and then calling it
02:42:31.920 | let's do the opposite so let's see how to call this kernel and let's explore how it works so i
02:42:37.360 | have already copied the tutorial vector addition from the website so let's look at first of all
02:42:42.960 | what we want to achieve we have an input vector x and an input vector y and we want to compute
02:42:49.920 | the vector addition which means that with the torch we want to do the following operation
02:42:54.160 | and also we want to do the same operation also with the triton by calling this method add and
02:42:59.200 | then we want to compare the two vectors output and they should be equal or at least the difference
02:43:04.480 | should be very very small because of course there is always some rounding error in case you are
02:43:08.160 | working with floating point numbers the size of this vector is 98 000 elements and we want to
02:43:15.680 | work in a blocked way so as you remember before with the cuda you can do vector addition by
02:43:21.920 | spawning a lot of number of threads each doing one operation but when the number of threads
02:43:27.360 | that you have is not enough then you need to divide the input vector into blocks and this
02:43:32.000 | is what we are going to do here so let's look at this add method so this add method basically
02:43:37.600 | will first of all allocate the necessary memory for the output vector then it will compute the
02:43:45.360 | launch grid the launch grid tells triton just like in cuda how many kernels we want to how
02:43:52.240 | many blocks we want to launch how many blocks of threads we want to launch if you remember in the
02:43:58.960 | cuda kernel we specify how many blocks we want and then how many threads we want for each block
02:44:05.360 | in the case of triton we tell how many blocks we want and then we don't force how many threads to
02:44:15.360 | launch it will be triton that will choose how many threads to launch we just tell what each
02:44:22.320 | group of threads should do so in this case for example we divide our number of elements so n
02:44:29.600 | so which is 98 000 into blocks of size block size which is initialized as 1024 this is basically
02:44:39.280 | saying take them to calculate the grid size you do the ceiling division so basically this means
02:44:45.840 | ceiling of seal of n elements divided by block size this is the meaning of this one so how many
02:44:56.160 | blocks we want now what each block should do is inside of the kernel so let's go to the kernel
02:45:03.280 | and when we launch the the kernel we we can specify the launch grid in this square parentheses and
02:45:09.120 | then in the round parentheses we specify the arguments of this kernel so let's go to the kernel
02:45:14.960 | we see that python triton will not give us access to the tensor x it will give us a pointer to the
02:45:25.040 | first element of this tensor and this takes us back to the tensor layouts so the reason we studied
02:45:30.560 | the tensor layouts and the strides and all the stuff is because triton this code this add kernel
02:45:37.840 | will run on the gpu and the gpu cannot um does not index tensors like pytorch by using all the
02:45:46.400 | dimension and with the broadcasting and all this fancy stuff the gpu will just give you the pointer
02:45:52.800 | to the first element of this tensor in the memory and then it's up to you to compute all the indexes
02:45:58.960 | of all the elements that you want to access so this x ptr is the pointer to the first element
02:46:05.280 | of the x vector this y pointer is the first the pointer to the first element of the y
02:46:10.960 | vector then we have the pointer to the output vector where we want to store the result of this
02:46:17.120 | matrix addition we specify how many elements our vectors have and what is the block size so
02:46:23.120 | how many items each block should process which may not correspond to how many threads each
02:46:30.480 | each kernel will have you may be confused because okay in triton in coda we specified how many
02:46:39.360 | threads each block should have so the granularity that we manage is the thread level
02:46:46.480 | here we are saying it's a group of thread that should work with this quantity of data then it's
02:46:52.480 | up to triton to optimize the number of threads that it will actually use actually there are
02:46:57.600 | tricks there are ways to say how many threads we actually want by specifying the number of words
02:47:01.840 | but we will see that later for now just remember that this thread this kernel here will process a
02:47:08.720 | number of elements in the input vectors how many number how many elements block size number of
02:47:14.800 | elements first of all we need to identify which block we are we are in coda we use the the variable
02:47:24.320 | called the block id.x to identify the identifier of the block which tells us which group of elements
02:47:30.320 | we should be working with in triton you do the same by using program id and in coda the block
02:47:38.640 | id can be along the x y and z axis in triton these are called the dimension 0 1 and 2 here we have
02:47:47.200 | one dimensional data so we only use one axis to specify the block index so we get the block index
02:47:54.160 | which is the p id in this day in triton this is called the program id it's more intuitive to think
02:48:00.800 | of it as the program like this is a kind of a program that is running in parallel with other
02:48:05.360 | programs that will have different program id and based on the program id we can understand
02:48:11.120 | what is the starting element this program should work with so this blue block of threads should
02:48:16.240 | work with them and together that is just the p id multiplied by the block size so the p id 0 should
02:48:22.080 | be working with the element that starts from the element 0 the p id 1 should start with the element
02:48:27.680 | 1024 and the p id 2 should start from the element 2048 so it should skip the first 2048 elements
02:48:34.960 | and start with the element with index 2048 next we define how to load these elements
02:48:42.400 | based on the pointer in which of the x and the y vector to do that we specify a list of offsets
02:48:53.280 | with respect to the starting address that we want to load so because each program in triton works
02:48:59.680 | with a group of data so not one single element but a block of elements we mean we need to understand
02:49:08.480 | which elements to load so the offset of these elements in the case of the program id 0 it will
02:49:13.920 | load the block start so 0 plus the elements from index 0 to 1024 excluded with the program element
02:49:27.920 | 1 this basically will result in a vector that is well the program start with p id equal to 1 will
02:49:35.600 | be 1024 then 1025 1026 1027 etc etc until 2047 with the program number let's say 2 this this
02:49:50.960 | offset will be the elements 2048 2049 blah blah blah until 3000 and something now we also as you
02:50:03.520 | remember when we create when we launch a grid the number of threads is not always based on the number
02:50:13.520 | of elements in the block or the number of elements in your vector it is always a multiple of a base
02:50:19.120 | number which is usually 32 which means that the grid this program may have more threads that it
02:50:26.320 | needs so some threads should not be doing anything so should not be loading any data and should not
02:50:31.920 | be computing any summation so what we this is what we why we need this mask this means that
02:50:37.600 | if all these offsets that we are loading it should be at most up to n elements because imagine you
02:50:45.440 | have not 1000 2000 imagine you have a vector of 2060 elements which means that this offset for
02:50:56.240 | the the third program of this kernel will load the offset that go from 2048 2049 blah blah 2060 and
02:51:07.360 | then also 2061 2062 etc etc but we said that we only have a 2060 elements so all the
02:51:15.280 | elements of 2061 62 etc until 3000 and something they don't exist so we need to tell somehow that
02:51:23.360 | all the threads that are working with these elements should not load anything that's why
02:51:27.920 | we need this mask this mask tells load among all the offsets that this block should work with
02:51:34.720 | only those elements that actually exist for which this mask is true then we load the elements of
02:51:42.800 | this current program which is a group of elements defined by these offsets and only the one that
02:51:50.960 | for which this mask is true so only the one that actually exists all the others should be ignored
02:51:55.680 | and we can also specify what it should load in case this the mask is false with another parameter
02:52:04.160 | but we will not see that here we also load the group of elements of the y vector and then we
02:52:09.680 | compute the output x plus y so if you remember previously in CUDA we we did something like this
02:52:16.800 | like the output of i is equal to the x of i plus the y of i so we did it one element at a time
02:52:26.080 | because each thread was working with one index here we are working with a group of elements so
02:52:31.040 | this x is a group of elements is a block of elements at most of size block size actually
02:52:38.720 | of size block size and it's this y is a group of elements from the y vector and we are computing
02:52:46.720 | the output group by group so this this is summing a group of elements of x with the corresponding
02:52:54.800 | group in y and writing it in output then we need to restore this output we need to store it in the
02:53:02.160 | output tensor output ptr that you can see here which is a pointer to the first element of the
02:53:07.440 | output vector and we say that where should we store this output vector which is of size shape
02:53:14.240 | of this vector here is block size where should we save it well in the same offset to where which we
02:53:21.840 | loaded x so if this program worked with the index 2048 2049 etc etc then all this output should be
02:53:30.720 | written in the same offset 2048 2049 etc up to 3000 and something using the mask as well because
02:53:39.360 | we don't want to write all the values of this block size because maybe we don't have enough
02:53:43.600 | elements so only write the one that are actually present in the vector so the reason we need the
02:53:49.520 | mask is because CUDA will launch a number of thread that is always a multiple of a base unit
02:53:53.920 | that may not be a multiple of the vector size that we are working with so we need to find a way to
02:54:00.720 | tell some threads to not do anything for those that the data is not available so let's rehearse
02:54:07.360 | what you have seen so far in CUDA the program that we write is at the thread level so each thread
02:54:12.400 | what it should do in triton it's this block of data we work with a block of threads what data
02:54:21.360 | this block of thread should work with all right guys the final finally the moment has come so
02:54:29.920 | we are going to code the flash attention for our pass right now in triton but let's rehearse the
02:54:36.480 | algorithm so the goal of the attention mechanism in specifically in triton in flash attention is
02:54:43.200 | to compute the attention output which is we want to compute the output of the following formula so
02:54:48.240 | the query multiplied by the transpose of the key divided by the square root of the head dimension
02:54:52.000 | all multiply we apply the softmax and then all multiply by b now we in this video we will be
02:55:01.680 | coding the forward pass and also the backward pass but before coding the backward pass we need
02:55:07.200 | to understand how the autograd works we need to understand what is the gradient what is the
02:55:12.080 | jacobian how to derive the gradient of the softmax operation how to derive the gradient
02:55:16.960 | of the matrix multiplication operation etc etc so that is going to be another part of the video
02:55:21.840 | for now let's concentrate on the forward pass right now we have some tools so we know that
02:55:26.880 | we have this thing called the gpu that can parallelize operation among multiple cores
02:55:31.200 | we know that in cuda we can parallelize operations by telling by writing a program that is the
02:55:36.960 | definition of what each thread should do or we can follow the triton programming mode which is
02:55:43.440 | telling in python what each group of threads should do the mapping between the what each thread
02:55:51.600 | should do and the which element that should try to work with is up to us to the programmers and
02:55:58.080 | the same happens in triton we tell we how many blocks of threads we want how much data each
02:56:05.840 | thread should block of thread should process so that's the block size that we saw in the vector
02:56:11.440 | addition but then the mapping between the elements of the vector and the the identity of each group
02:56:20.160 | of threads so the program id that we saw is up to us and the same will happen when we record
02:56:26.080 | flash attention let's see what can we parallelize in this flash attention so first of all this code
02:56:33.360 | that you see in the forward pass of the flash attention is takes as input query key and value
02:56:40.160 | that is a vector that is a matrices of n by d however usually in a transformer network we don't
02:56:48.640 | have only one sequence made up of d dimensions we have many sequences made up of d dimensions
02:56:55.520 | and this d is the lowercase d which is the the number of dimensions dedicated for each head but
02:57:02.720 | we don't have only one head we have multiple head so the algorithm that you see here is what each
02:57:09.360 | head should work so each head of each batch should do moreover we have seen before when talking about
02:57:19.680 | block matrix multiplication that we can parallelize the computation of the output because this output
02:57:25.600 | block here depends on the query one and all the keys this one here depends on the query group
02:57:31.040 | block of query two with all the keys and this one here is the query tree with all the keys etc
02:57:37.440 | so because this one only depends on query the group the block query one and this one only
02:57:42.960 | depends on the block query two they can work independently from each other by sharing of
02:57:47.920 | course work the keys another thing that we need to understand about triton is the shared memory so
02:57:55.920 | um the in the gpu we have the high bandwidth memory and which is the kind of the ram so the
02:58:04.720 | when you buy an a100 they tell you that it has a 40 gigabyte that's the amount of memory in the
02:58:10.640 | high bandwidth memory so the dram so let's look at actually the structure of the gpu
02:58:15.440 | which is here we have this dram which is the big memory that we that the gpu has and then each
02:58:24.640 | streaming multiprocessor so it's a let's call it a block of threads actually also have a shared
02:58:32.160 | memory so inside of the gpu actually we have we have these streaming multiprocessors and
02:58:37.280 | these streaming multiprocessors have a part of memory called the shared memory which is much
02:58:42.160 | smaller than the dram like much much much smaller what changes between these two memories the access
02:58:49.360 | to the dram is very slow and the access to the shared memory is very very very fast so one thing
02:58:56.080 | that is different between cuda and triton is that whenever you load some information in cuda you are
02:59:02.080 | loading that information directly from the global memory because when we launch a cuda kernel first
02:59:07.680 | of all as you remember in my c++ code we first copy the tensors from or the vectors from the
02:59:14.480 | cpu to the gpu and they reside in the global memory of the gpu then we load these elements
02:59:21.600 | directly from the global memory but the access to the global memory usually it's much much much
02:59:27.120 | slower so what happens with the flash attention is that the flash attention computation in its
02:59:32.160 | the attention computation in its naive version the one that we can do with the torch is very
02:59:37.280 | slow because the access to the global memory is very slow so we want to use as much as possible
02:59:43.120 | the shared memory so we want to reuse the elements loaded from the global memory into the shared
02:59:48.480 | memory so that we don't need to access the global memory every time to load elements from the
02:59:52.640 | vectors or the matrices and this is what happens also in triton so in triton whenever you load
02:59:58.800 | some data you are copying the information from the global memory to the shared memory
03:00:03.120 | then whatever operations that you are doing is done on the shared memory and then when
03:00:08.240 | you store the information you are copying the data from the shared memory to the global memory
03:00:12.400 | this makes it much faster so we always work with the elements that have been loaded in the shared
03:00:18.640 | memory and this shared memory basically it's shared for all the threads that belong to the same
03:00:25.440 | block in triton we have an abstraction level that doesn't make us work directly with the threads
03:00:32.160 | so we always work with a group of threads that belong to the same block that share this shared
03:00:36.560 | memory so in triton we are copying information from the global memory to the shared memory we
03:00:41.280 | do some operation with it and then we store back to the global memory and this is what we are going
03:00:45.360 | to do with flash attention now let's review the algorithm of flash attention so in flash attention
03:00:50.880 | we have to go an outer for loop that is among all the between all the keys and then an inner loop
03:00:57.760 | that is sorry between all the query blocks and then an inner loop that is through all the key block
03:01:04.080 | in the original flash attention algorithm the flash attention one the outer block was on the
03:01:10.880 | keys and inner block was on the queries this made it less parallelizable why because the outer loop
03:01:17.840 | is on the queries and we have seen before that the the output of this attention can be computed
03:01:24.720 | independently for each block of queries so it's much easier to parallelize so this outer for loop
03:01:30.240 | actually we don't have to run a for loop we just spawn many kernels each working with one iteration
03:01:35.760 | of this outer for loop so each working with a different query block of this outer for loop
03:01:41.120 | and the inner for loop is something that we have to iterate through so each triton kernel will
03:01:47.920 | work with one query block and then iterate through all the key blocks
03:01:52.720 | and inside of this key block we have already seen the operations that we are going to do which
03:01:59.200 | the we explored before and at the end of this for loop we need to store back the output
03:02:05.840 | in the high bandwidth memory and this is how it's gonna we are going to work another thing that we
03:02:13.760 | should notice is that this query key value are n by d so as i said before but usually in in a
03:02:21.600 | transformer model we don't have only one sequence we have many sequences so we can also parallelize
03:02:28.800 | on the number of sequences that we have in the batch because each batch can work independently
03:02:33.120 | from each other and inside each and each head each sequence has multiple heads so each head
03:02:40.720 | also can work independently from each other because that we know from the attention is all
03:02:44.320 | unit paper that's what's the meaning of head that's what's the meaning of multi-head attention
03:02:48.640 | so that each head can compute the attention independently from each other so we will also
03:02:52.560 | parallelize along the head dimension and moreover if you look at this definition of the query block
03:02:59.520 | we can also split the query into blocks and each query block can work independently from the other
03:03:04.320 | query blocks by in producing one output block this is how we are going to parallelize so we are going
03:03:10.240 | to parallelize each sequence in the batch but inside of each sequence we are going to parallelize
03:03:15.600 | each head and inside of each head we are going to parallelize each query block so how many programs
03:03:21.920 | we we will have working in parallel at most it will be the sequence the number of batches so
03:03:28.080 | the batch the number of sequences in the batch so the batch size it will be the batch size
03:03:34.080 | multiplied by the number of heads
03:03:37.600 | multiplied by the number of blocks that we will divide the query sequence into
03:03:47.200 | so let's call it the i don't know block size q
03:03:50.480 | the block size q all right now that we have seen this one let's go actually code it so
03:04:02.640 | i have already introduced a little bit the differences between my implementation of the
03:04:08.160 | flash attention and the one that you can find on the triton documentation which is first of all i
03:04:12.480 | don't work with fp8 because i believe this is unnecessary for our explanation it's of course
03:04:18.240 | much faster because the recent gpus also support fp8 second difference is that in the um in the
03:04:25.920 | flash attention on the triton website the backward pass is only implemented for the
03:04:31.840 | causal attention but in my case i implement it for the causal and the non-causal attention even if
03:04:36.400 | it's slower and later i actually i want to give you an exercise on how to improve it
03:04:41.680 | and the third difference main difference is that i made make explicit use of the
03:04:48.080 | softmax scale so i actually use the scale when needed another difference is that in the online
03:04:56.480 | triton computation of the flash attention is this x is not really e to the power of x but it's 2 to
03:05:03.120 | the power of x and then they compensate it with by by using the logarithm however because probably
03:05:10.160 | the implementation of 2 to the power of x is faster than the e to the power of x but in my
03:05:15.600 | case i retain the original exponential because i want to follow the original algorithm to make it
03:05:20.480 | simpler to visualize the code along with the algorithm as in the flash attention paper
03:05:25.120 | so i know i have created a lot of hype so let's do it let's start by creating a new
03:05:34.720 | file let's call it a program.py just like before when i introduced triton i will start by coding
03:05:41.120 | first the code that will use our kernel and then we code the kernel and we will only be coding the
03:05:45.920 | forward pass of the kernel so let's start by importing what we need to import which is just
03:05:53.280 | the torch and the triton and secondly let's start by let me check okay the copilot is already off
03:06:00.400 | so i don't have to worry about that let's start to implement the code that will test our
03:06:05.840 | implementation of the triton and compare it with the naive implementation of the attention mechanism
03:06:10.240 | so we create our query key and value sequence for testing
03:06:17.760 | which is if you remember it's a query is the batch size and it has the dimension batch size
03:06:24.160 | because we have multiple sequences each sequence has a number of heads and it's made up of
03:06:30.080 | sql and tokens and each token is identified by a head dim number of dimensions if you
03:06:37.440 | and then this is because we have already split each token into smaller tokens each
03:06:43.440 | each with its own head dimension if you remove the num heads dimension then you put back you
03:06:50.000 | concatenate all the dimensions of this head dim we initialize the query key and the value sequence
03:06:57.040 | by using a normal distribution this code i already took from the tutorial of triton so it's nothing
03:07:02.960 | different and we require the gradient because we want to compute the gradient with respect
03:07:08.080 | to query key and value and we will see later why because because we want to implement the back we
03:07:12.960 | want to test also the backward pass even though we will not be coding it now so the first thing
03:07:18.160 | that we do is we define our softmax scale which is as you remember the formula is a query multiplied
03:07:25.840 | by the transpose of the keys and then divided by the square root of head dimension
03:07:33.040 | so dk or dd head sometimes it's called and then we need to so we need to compute this
03:07:42.240 | one we can already compute it it's this this is the one over the square root of the head dimension
03:07:46.880 | and then we also define do and later we will see what is this but this is basically we will be
03:07:55.200 | needed needed for the backward pass um don't worry if you don't understand what is do later we will
03:08:04.400 | see it let's do the naive implementation of the attention which is very simple which is first we
03:08:10.960 | define the mask and we use this mask only if the attention we are computing is causal so as you can
03:08:16.720 | see we pass this parameter called the causal that tells if we want to compute the causal attention
03:08:22.480 | or the not causal attention and the d type which is a float 16 because we want to work directly
03:08:28.160 | with 16 bit floating point numbers we will not be working with fp8 just uh because we don't we
03:08:34.880 | don't want to implement my implementation is actually not as fast as the one in the tutorial
03:08:39.840 | of the triton website but i believe it's much more easier to comprehend so we define the mask
03:08:48.320 | we compute the the product the query multiplied by the transpose of the key divided by the square
03:08:53.680 | root of the head dimension so that's why we are multiplying by softmax scale if the attention
03:08:58.160 | we are computing is causal then we use this mask that we have computed so we replace all the points
03:09:04.160 | all the dot products where this mask is equal to zero with minus infinities and then the softmax
03:09:09.920 | will replace this minus infinities with zeros because then we are applying the softmax and
03:09:14.800 | the softmax is applied by rows just like the normal attention we compute okay the second thing
03:09:21.120 | that we do is we want to um so the output is the product of the output of the softmax with the v
03:09:28.240 | so this is the reference output on the naive implementation of um flash of the attention
03:09:34.640 | mechanism then we want to compute we want to also derive the gradients of the output with respect to
03:09:42.320 | the um inputs and in this case it's the the the v the k and the q later we will see what are we
03:09:51.600 | doing here then we want also to we want to compare this reference implementation with our triton
03:09:58.480 | implementation so let's do it so our triton implementation will be implemented as a class
03:10:03.680 | called triton attention that we will call using this method called apply and later we will see
03:10:09.200 | what is this method in which we pass the query key and value if we want to compute the causal
03:10:13.520 | attention the softmax scale that it should be using and it should produce some output which
03:10:18.560 | is the output of the output of the softmax multiplied by v then we can run also the
03:10:23.680 | backward and this backward will be the the same backward that we will compute with the
03:10:28.560 | triton attention and then we compare okay and then we can compare uh the result of our
03:10:38.000 | implementation so this triton attention dot apply with the reference implementation which is this
03:10:43.440 | one here and this should be uh we use the the function all close which basically compares
03:10:49.280 | the elements of two tensors and make sure that their absolute difference is no more than this
03:10:54.560 | one we are not using the relative distance we are just using the absolute distance between the two
03:10:59.440 | elements which corresponding elements of two vectors this uh implementation that you have
03:11:04.720 | that we will build will work with the causal attention and also with not causal attention
03:11:08.480 | while the uh the one that we saw in the website of triton it only works with the uh the forward
03:11:15.120 | pass actually works with the causal and non-causal while the backward pass only works in the case of
03:11:19.200 | the causal attention um okay but it's highly optimized the one online so if you want to
03:11:25.120 | learn a little more tricks on how to optimize triton kernels there is a lot of knowledge there
03:11:30.320 | anyway guys now let's try to uh implement this triton attention at least the forward pass so
03:11:36.080 | let's go to implement this triton attention class
03:11:39.600 | okay here every time you want to introduce a new operation into torch you need to derive the um
03:11:50.560 | you need to implement your operation by deriving from this autograd dot function class so every
03:11:56.400 | operation in torch actually if it's the softmax or it's the um i don't know the the relu or the
03:12:03.360 | zwiglu or whatever there is it is always implemented as a function is a class that
03:12:08.240 | derives from this function and it should provide two methods one called the forward pass and one
03:12:13.200 | called the backward pass the forward should produce the output of this operation and the
03:12:17.280 | backward should compute the gradient um the gradient with the of the loss with respect to
03:12:22.960 | that the the input of that function and later we will see how that works for now let's concentrate
03:12:28.480 | on the forward pass to implement the forward pass we need to create a static method that
03:12:32.880 | is called forward which takes as input one thing called the context so as you know in autograd in
03:12:41.600 | when training neural networks we have the forward pass and the backward when computing the
03:12:46.240 | backward pass we need to reuse the activations of each of the computation nodes during the forward
03:12:51.600 | pass and this context basically allow us to save the information to uh for the necessary activations
03:12:57.520 | that we will need during the backward pass and later we will see in the triton um in the flash
03:13:02.400 | attention algorithm what information we need to save in order to compute the backward pass for
03:13:07.680 | example what we will need to save during the backward pass we will need to recompute on the fly
03:13:13.040 | the soft the query multiplied by the transport of the keys for each block but we don't want to
03:13:18.800 | recompute the normalization factor or the maximum value for each row so we will save those two
03:13:23.520 | values and actually we will not save two values we will save one value we do a trick called the log
03:13:28.720 | sum exploit log sum exploit that we will see later anyway this context is just a kind of a storage
03:13:35.920 | area where we can save some stuff that will be necessary for us to recompute the backward and
03:13:41.040 | you can see whatever you like then we have the input of this operation which is the query key
03:13:47.360 | and value which is a three tensors with the causal if we are going to compute the causal attention
03:13:52.720 | and the softmax scale that we should apply based on the one over the square root of the
03:13:56.800 | head dimension which we could also compute it on the fly actually by the way by by checking the
03:14:04.240 | shape of this but okay it doesn't matter anyway so um the first thing that we are going to do
03:14:09.520 | is to extract the shapes of these objects and make sure all the shapes are what we expect them to be
03:14:14.560 | so the shape of the query key and value is a batch size by number of heads by sequence length
03:14:19.840 | by head dimension we make sure that the head dimension matches for the query key and value
03:14:25.920 | they should match because each vector should should be of the same size
03:14:30.560 | and then we declare what we pre-allocate the output vector so where we should save our output
03:14:38.640 | so as you remember the output in the attention mechanism has the same same shape as the query
03:14:44.720 | key and value sequence where the query key and value sequence i want to remind you is not the
03:14:49.440 | query key and value of the input of the attention which is a sequence of tokens but it's the output
03:14:54.880 | already of the wqwk and wv because flash attention is not concerned with optimizing those metrics
03:15:01.040 | multiplication but only the output of the wqwk and wv so we pre-allocate the output tensor where
03:15:08.640 | we will store this output which has the same shape as the query key and sequence uh matrix
03:15:16.240 | actually actually no not true actually it has the same shape as the query but it may not be the same
03:15:23.600 | as the key and value why because there is this thing called cross attention where the query key
03:15:29.920 | and value are transposition are different projection through wqwk wv not of the same
03:15:36.560 | input sequence but of two sequences so cross attention happens when we have a query that
03:15:41.440 | comes from one uh sequence and the key and value come from another sequence and they pass through
03:15:47.440 | their own wk wv and they may not have the same sequence length so the shapes of the output of
03:15:53.280 | the attention only depends on the shape of the query sequence not of the key and value sequence
03:15:58.880 | this is happens during cross attention but usually in language models we always work
03:16:03.360 | with the self-attention so that should not happen at least in the causal language models
03:16:08.480 | then we have the stage and later we will see what is this stage basically the stage it's just a
03:16:16.400 | number that tells if the operation that we are going to do later is for the causal attention
03:16:22.160 | or for the not causal attention and then we need to define our launch grid the launch grid tells
03:16:28.400 | us how many parallel process we need to be launched by triton actually they will be launched by cuda
03:16:34.800 | but by we always work with the triton as an interface to cuda so by triton so in triton
03:16:41.840 | as i said before we want to parallelize along the batch dimension so each batch each sequence
03:16:48.560 | in the batch should work independently from each other not only each inside of each sequence in
03:16:54.160 | the batch each head should work independently from each other so at least we have a batch size
03:16:58.720 | multiplied by number of heads programs and for each of this program we have another
03:17:05.840 | dimension called the we divide the query into blocks of queries so as you remember when talking
03:17:15.200 | about a block matrix multiplication we don't work with the query as the original matrix query matrix
03:17:20.560 | so where each query is one vector or one token we work with group of queries so each block of
03:17:27.840 | queries is a group of tokens in the query sequence so we are saying that we want to launch at a
03:17:34.560 | number of kernels or blocks of threads or a group of threads along two dimensions so just like the
03:17:44.480 | cuda kernel can be launched along two dimension x and y here we are launching programs along two
03:17:50.000 | dimensions one dimension that tells us which batch which head of which batch we are going
03:17:56.240 | to work with so which head of which batch element are we going to work with and inside this we are
03:18:07.040 | going to say okay this is a sequence which group of queries are we going to work with
03:18:16.400 | are we going to going to work with so overall and the group of queries is what is the sequence
03:18:25.840 | length divided by the number of queries that we want to group together so the block size cube
03:18:30.880 | tells us how many queries are there in each block of queries so this cdiv is just the ceiling
03:18:37.200 | division so it is equal to let me write it here this is equal to ceiling of sequence length
03:18:44.960 | divided by the block size q this tells us how many blocks of q we have so let's rehearse we
03:18:57.840 | have a tensor that is q that is batch size by number of heads and each flash attention algorithm
03:19:05.680 | will work with the following the sequence length head dimension moreover we have seen that the
03:19:10.720 | flash attention has two loops one is the outer loop among all the query blocks one is the inner
03:19:17.200 | loop along all the key block we have seen that the query block can work independently from each
03:19:22.960 | other so we can spawn as many programs in parallel as there are number of blocks of q because they
03:19:28.880 | can work in parallel so this grid tells us how many programs there are that can work in parallel
03:19:34.080 | then it will be the gpu that based on its resources will decide how many program actually to work in
03:19:39.360 | parallel if it has enough resources to make them all work in parallel wonderful if it doesn't have
03:19:45.200 | enough resources to make them work in parallel it will launch them sequentially one after another
03:19:48.880 | and the last dimension is this is like the z dimension in the cuda in the cuda launch grid
03:19:57.760 | and we don't want to use it because we don't want an additional level of parallelism all right this
03:20:05.040 | is our launch grid so we will launch a number of programs that is this one a number of programs
03:20:12.240 | of parallel programs or number of parallel kernels and each kernel in triton work is a group of
03:20:18.800 | threads which is a batch size multiplied by number of heads multiplied by a number of blocks of q
03:20:30.000 | so how many blocks we have we divided the q sequence into okay let's continue so then we
03:20:38.320 | will see what is this one so this m is another matrix that we will need and it's the log sum
03:20:44.400 | expo for the backward pass and we will see at the end of this video what is not at the end of this
03:20:49.280 | video but at the end of the forward pass what it's needed for but basically this is you can think of
03:20:54.960 | it as the maximum for each row um you we to to recompute the query multiplied by the key in the
03:21:02.560 | backward pass we should also have if we don't want to recompute the maximum for each row and
03:21:07.120 | the normalization factor of the softmax we should save two things one is the maximum for each row
03:21:12.000 | and one is the the normalization factor however by using the log sum exp trick we can only save
03:21:19.680 | one value which is the as you can see in the algorithm of flash attention it's this stuff here
03:21:26.800 | which is let's see here it's this stuff here so this li which is the maximum for each row
03:21:36.480 | plus the logarithm of the of the normalization factor and basically in when computing the
03:21:46.320 | backward pass we need to recompute on the fly this block here so this query multiplied by the
03:21:51.360 | transpose but to apply the softmax as you remember we need to have the maximum for each row and the
03:21:55.120 | normalization factor so we don't we don't recompute them during the backward because we have already
03:22:01.440 | computed them during the forward so we save this information but we don't need to save these two
03:22:05.600 | information separately we can aggregate it into one single value called li and later we will see
03:22:10.720 | how we can use it all right so we have defined also this one and we can proceed further so now
03:22:20.800 | we launch our grid our kernel don't be scared it's going to be a little long so here so we are
03:22:29.920 | launching the the kernel for the forward pass by defining what is the launch grid so how many of
03:22:36.720 | this program should run in parallel at most we are passing the query we are passing the key we
03:22:42.160 | are passing the values we are passing the softmax scale the m which is the information that we need
03:22:47.120 | to save for the backward pass it's actually the l in the code of the pseudo code of the flash
03:22:53.760 | attention algorithm here i call it m i think because also in the original code it was called m
03:23:00.640 | the o where the our kernel should save its output and then as you remember
03:23:07.280 | we don't get all the nice access by indexing tensor like we are used to in torch we only get
03:23:18.080 | a pointer to the starting element of q a pointer to the starting element of k and to the starting
03:23:23.680 | element of v and then we have to figure out all the index in the memory of the other elements
03:23:30.080 | how to calculate the index we need the stride because the stride tells us how many elements
03:23:35.440 | to skip to go from one dimension to the other and that's why we are passing the stride for
03:23:40.400 | each dimension of each tensor actually in our case we are only working with q k and v that are
03:23:48.400 | actually of the same d type and of the same shape so we should not need actually to pass all all the
03:23:55.360 | strides for each of these tensors because they should have the same strides however in the
03:24:01.520 | original code i believe they were passing it so i kept it so the stride allow will allow us to
03:24:07.440 | index this pointer so to understand to access the elements of of this tensor just by using its
03:24:16.160 | starting the pointer to the starting element and then the strides we will be able to index any
03:24:21.360 | element we want in the tensor then we pass the information of these shapes so the batch size
03:24:28.320 | the number of heads the sequence length and the head dimension and which is the same for all of
03:24:34.720 | them and then the stage the stage indicates if we are going to compute the causal attention or not
03:24:39.680 | causal attention so let's not implement it and let's continue writing this method so then then
03:24:46.560 | we need to save some information that we will be needed for the backward pass which is this
03:24:50.400 | context variable that i told you before so we save some information for the backward pass which is
03:24:55.840 | the query key and value which are the tensor for which we want to compute the gradient during the
03:25:01.200 | backward pass and then we need to store also this m tensor and this o tensor and then we can all we
03:25:13.280 | need to also store the causal variable so because if we computed the causal attention during the
03:25:18.640 | forward forward pass then during the backward pass we need to have this information because
03:25:25.440 | we need to mask out the things that we don't want to contribute to the gradient but we will
03:25:31.040 | see that later when computing the backward pass for now let's concentrate on this attention forward
03:25:35.600 | so we need to implement this forward kernel that you can see so underscore attention underscore
03:25:41.600 | forward method now a triton kernel is just a python method with a particular decorator called
03:25:48.400 | triton.git so we copy and paste the signature so this is what makes a method become a triton kernel
03:25:56.800 | and as you can see here we pass the query key and value matrix along with other information
03:26:03.360 | the emmetrix so please don't confuse the emmetrix with the mask that we will apply
03:26:07.680 | on the fly we will generate it on the fly because we are only concerned in this case with a causal
03:26:13.760 | attention or not causal attention we do not accept custom masks here and then we pass the strides of
03:26:21.520 | all these tensors the batch size the number number of heads the sequence length the head dimension
03:26:26.160 | which is the shape of each of these tensors and the block size q and the block size kv the block
03:26:33.600 | size q indicates how many queries we want to group together to make one block of the q matrix
03:26:39.920 | and how the kv indicates how many keys and values we want to put together to make one block of the
03:26:46.160 | k and v matrix which is what we do when we do block matrix multiplication this stage is a
03:26:52.720 | number that indicates if it's a causal or not causal attention we are doing so it will be three
03:26:58.640 | in case it's a causal and one in case it's not causal okay the first thing that we do is to
03:27:04.000 | verify some information so we verify that the um the block size of the kv is less than or equal to
03:27:12.560 | the head dimension to be honest i don't think we need it with my code because i removed most of the
03:27:18.640 | constraints so this uh this check was also present in the original code so i kept it but it all
03:27:24.160 | depends on how we are later we will see what is the auto tuning process and later we will see
03:27:29.680 | what variables we are going to auto tune for and how many stages we will choose how many warps we
03:27:34.560 | will choose etc etc so let's leave it for later you can comment it or you can keep it it shouldn't
03:27:39.600 | matter um the first thing that we do as i said before we launch a grid so a grid is a series
03:27:46.880 | of programs where we will have some identifiers like in the cuda we had an identifier for the
03:27:52.240 | blocks on the x-axis and on the y-axis in triton we get this identifier for the programs we launched
03:27:59.600 | um um uh sequence length divided by block size q number of programs along the zeroth axis and
03:28:09.600 | the batch size multiplied by number of heads on along the first axis of the grid of the launch
03:28:16.880 | grid which will help us identify which um part of the query we are going to work with in this program
03:28:25.200 | in this kernel and also in which batch and on which head this program should work with so that's
03:28:33.040 | what we are going to do now we are trying to understand what part of the input we should
03:28:37.680 | work with based on the ids of the program which corresponds to the block id in cuda
03:28:43.680 | so let me copy so the program id zero indicates it's this stuff here tells us which part of the
03:28:53.760 | queries so which block of the queries we are going to work with why do we have a block on the query
03:28:59.520 | because as we saw before the output can be computed independently for each block of the
03:29:03.760 | queries while each block of the query has to iterate through all the key and values
03:29:08.960 | so this is what will tell us what is the index of the block of the queries we are going to work
03:29:15.040 | with in this particular program then we can understand also which index which batch and
03:29:20.800 | which head with this program is associated with the program id number one is the product of the
03:29:26.880 | batch size and the number of heads it means that we will have as many programs on the axis number
03:29:32.960 | one as there are indicated by this product so this product lets us understand this product will tell
03:29:41.360 | us which batch and which head this particular program is associated with so to get the id of
03:29:48.400 | the batch we just divide this number by the number of heads and it will give us the head index and
03:29:54.160 | to get the head index inside this batch we just do the this number here modulus the number of heads
03:30:00.960 | okay the next thing that we need to do we need to okay first of all when we pass a tensor because
03:30:12.720 | as you can see here the q parameter to this attention forward method is a tensor because
03:30:18.480 | it's the input of this function forward function and this forward function is called here when we
03:30:24.320 | do attention dot apply and it's this q stuff here and this q stuff here has been created as a tensor
03:30:31.040 | so when we pass a tensor to a triton kernel it's not really a tensor it is a pointer to the first
03:30:36.720 | element of that tensor in the memory now we need to understand because now we know which batch we
03:30:42.800 | are going to work with and which head we are going to work with we need to index this tensor to
03:30:47.680 | select the right batch and the right head inside of the right batch which means that basically
03:30:53.840 | we have this q tensor so we need to do some some sort of like some stuff like this like q
03:31:01.440 | of the index batch and the number of the number of heads indicates the which head we are going
03:31:11.200 | to work with so it should be index of head and we need to select everything that is inside
03:31:17.040 | these indices so we are we need to enter the tensor at the right location where the particular
03:31:23.840 | sequence length and head dimension for this batch and for this head starts for that we need to
03:31:29.520 | generate an offset in which we need to move this tensor from because this tensor this pointer sorry
03:31:36.160 | this not answer this pointer from because this pointer is pointing at the beginning of the entire
03:31:41.040 | tensor so we need to move in the batch size dimension and in the number of heads dimension
03:31:45.840 | to do that we generate the following offset which will tell us where this uh where this
03:31:52.640 | particular batch and where this particular head starts in this tensor and to do that we need to
03:31:58.640 | do the strides we need to use the strides so what we are going to do is we are creating we're going
03:32:03.520 | to create the qkv offset this should be the sequence length which will be the index batch
03:32:11.520 | multiplied by the stride for the batch dimension which will tell us how many elements we need to
03:32:17.680 | skip to get to the next batch and it's based and we multiply it by the index of the batch that we
03:32:23.600 | want so for the zeroth batch we don't skip anything because we are already pointing to the first
03:32:27.760 | element of that batch but if we are at the batch one we will skip that many number of elements
03:32:32.720 | plus we also need to skip the some heads how many head we need to skip based on which head
03:32:38.320 | we are going to work with and what tells us how how to go from one head to the next the stride
03:32:44.080 | of the head dimension so we multiply the index head so the head that we should be working with
03:32:49.280 | with the stride q head all right then we select now triton helps us with a new function that i
03:33:02.560 | think it was quite recent that helps us index element inside of a tensor without having to
03:33:08.880 | deal with all the complex indexing maths that can be confusing for beginners so i will be using a
03:33:16.960 | few methods to help us with this with this indexing and this function is called make block
03:33:25.760 | pointer and it's this following so basically this make block pointer takes as input a vector
03:33:33.600 | and sorry a pointer not a vector it takes as input a pointer in this case we are saying
03:33:41.920 | create a block that has the following shape that is sequence length by head dimension so let me
03:33:48.720 | do it one by one actually i don't want to confuse you guys with all this stuff altogether okay so
03:33:54.400 | take a start there is a pointer that is right now pointing at q plus q kv offset so right now it is
03:34:02.320 | not pointing at the first batch but it's pointing exactly to our batch so the the batch that this
03:34:08.160 | particular program should be working with and inside this batch to the particular head that
03:34:13.200 | this program should be working with which is basically saying that we have we are pointing
03:34:19.760 | to a tensor that is as follows so we are pointing to the following tensors which is the right head
03:34:26.480 | the right sorry the right batch and the right head and then we are selecting everything inside
03:34:34.160 | so it's pointing to the first element of this particular tensor this tensor particular tensor
03:34:41.040 | because we have already selected the batch and the head it is a two-dimensional tensor with this the
03:34:46.560 | following shape because the following dimensions are sequence length and head dim so we are saying
03:34:52.480 | take this pointer which contains a tensor of the following shape sequence length and head dimension
03:35:00.640 | and i'm also giving you the strides of these dimensions that are in this pointer so the the
03:35:07.280 | two dimensions that are that we need are the sequence dimension and the head dim dimension
03:35:12.320 | which is this one for the q tensor and um and in this um in this query tensor we want to select
03:35:27.040 | a block of queries based on the query on the block of queries that this program should be
03:35:33.040 | working with so i think i need to maybe probably use the ipad otherwise it can be very confusing
03:35:39.840 | to visualize so uh let's do it actually let me see if i can create another here and let's use the ipad
03:35:53.840 | all right okay so we have a q vector q tensor because this construct we will be using it for
03:36:02.960 | all the other tests so if you understand it for one tensor you understand it for all the others
03:36:07.040 | we have a q tensor that is a batch by number of heads
03:36:14.800 | number of heads then the sequence length and then the head dimension
03:36:19.840 | with the following line so the this line here so when we create a q plus qkv offset
03:36:32.400 | we are already selecting the right batch dimension and already the right head dimension which means
03:36:40.000 | that we have already forwarded our q to not point to the first batch and the first head but to point
03:36:48.800 | to the exact batch that this program is working with and the exact head that this program is
03:36:53.040 | working with which basically means that right now it is pointing at a tensor that is made up of these
03:36:58.960 | two dimensions now inside of this tensor we also need to select the right block of query that this
03:37:07.840 | program should work with and this dimension here so the sequence dimension is all the queries so
03:37:14.880 | we need to select the right queries so we need to skip some queries how to skip some queries well
03:37:21.200 | we say that we need to skip block index multiplied by block size q number of queries because they
03:37:28.560 | will be processed by another by another program that will have this number here the program id
03:37:36.480 | will be different so we are selecting with this line not only inside of the q the right index
03:37:44.880 | and the head but also the right position in this dimension in the sequence length dimension that
03:37:50.000 | will point to the exact to the starting point of the exact query block that this particular program
03:37:56.160 | should be working with this is what is happening and we are also creating this block basically
03:38:02.080 | later we will see how it can be used to to create a block of the shape we are telling what is the
03:38:09.840 | the size of this tensor so this tensor has two dimensions because we are pointing to the
03:38:16.960 | beginning of the right query sequence so it has only two dimensions the sequence dimension and
03:38:22.400 | the head dim dimension so it's the last dimension and we are already pointing to the right beginning
03:38:29.280 | of the sequence dimension because we have already skipped some queries why we are skipping some
03:38:35.440 | queries because these queries will be handled by another program that will have a block index q to
03:38:40.000 | some other values and this order actually i don't know what is this order you can try to put 0 1
03:38:48.400 | and 1 2 i think it's some optimization that triton does i have read the online documentation and i
03:38:52.880 | couldn't find anything about it so this is something that i will investigate but actually
03:38:58.160 | even if you put 0 1 it doesn't matter so i think it's something that you tell triton if this you
03:39:05.600 | want the transposed of this block or you want the not transposed version of this block and later we
03:39:10.080 | will see actually how we can transpose the key block without doing any transpose operation
03:39:15.280 | actually we will just change the strides like we have seen before so um now this make block pointer
03:39:23.280 | is not something that is necessary but it makes our life easier when we will index this particular
03:39:30.000 | pointer so we can treat this pointer nearly as nearly in the same way when we work with the
03:39:36.320 | tensor in pytorch we will be able to skip one increase one index in one dimension without
03:39:43.600 | having to do the computation of the strides later when doing the backward pass i will not use this
03:39:49.360 | one and do all the pointer indexing by hand so you can check the differences of indexing a tensor
03:39:54.560 | by using make block pointer and not by using it anyway to rehearse what are we creating we are
03:40:00.800 | creating a pointer to the right index in the batch to the right index in the head dimension and we
03:40:07.920 | are already skipping some queries based on the block index queue so this pointer is already
03:40:15.040 | pointing to the right block of queries that this particular program should be working with
03:40:18.800 | let's look instead at the v and the k block now so let's copy the v block now which is
03:40:28.960 | similar to the query but we are not going inside we are only indexing by the index batch and the
03:40:35.360 | index head so what this one actually let me write it here is already skipping so this amount of
03:40:43.200 | queries this is what we are indexing with this make block pointer so we are in the right batch
03:40:50.400 | in the right head and we are skipping some queries here we are just indexing by batch and by head so
03:40:58.880 | we are doing v of index batch index head and we are not selecting we are not skipping anything
03:41:07.680 | because you see this offset is equal to zero in the first dimension in the second dimension
03:41:12.000 | so we are not skipping anything on the sequence length and we are not skipping anything in the
03:41:16.160 | head dimension dimension head dimension dimension um all right so let's look at the k block pointer
03:41:25.520 | and this is different because as you know when computing the flash attention algorithm we need
03:41:30.880 | to have access to the block of queries and all the block of the key transposed so when accessing the
03:41:38.720 | key we shouldn't access it like we are accessing q we should invert the two dimensions that we want
03:41:46.000 | to transpose for and that's very simple with make block ptr and you can see it here we say that we
03:41:53.040 | want to point to the right index and to the right head and the tensor inside of it so let's let me
03:42:00.640 | write it here so later i can explain in line by line so what we are doing here is go to the k
03:42:05.440 | tensor select the right batch select the right head select everything that is inside so it's a
03:42:13.120 | tensor of two dimensions with the sequence length and the head dim because we you can see here
03:42:18.400 | here sequence length and head dims etc but we don't want first sequence length and then head
03:42:27.680 | dim we want first head dim and then sequence length so we want to transpose it how to transpose it
03:42:32.400 | we just say that you need to read this tensor with the two strides transposed so we are saying
03:42:37.840 | first use the stride of the dimension dimension and then use the stride of the sequence dimension
03:42:44.400 | and the shape of this tensor is not sequence head dim it's a head dim sequence and it's a block
03:42:55.360 | of kvs why we are not putting directly the sequence dimension here because we want to
03:43:04.000 | skip block by block later so we are not selecting all the sequence length in the sequence dimension
03:43:09.200 | we are just selecting a block of kvs and later we will use another method to go to the next block
03:43:14.960 | so i hope that by showing you the indexing like this it's a little easier to follow the indexing
03:43:22.000 | so for each tensor we are going in the right batch in the right head dimension and for the
03:43:27.280 | query we are skipping some query blocks because each each program will work with a small different
03:43:33.120 | query block but for the key and value each program needs to iterate through all the key and value
03:43:40.000 | so we just point it to the first key and value block and then we will advance by one block by
03:43:45.600 | we will advance one block by one during the for loop that we will do later
03:43:51.120 | then in the output also we need we can make a tensor block tensor this basically creates a
03:44:00.480 | pointer just like in the query key and value case in which we select the right index batch
03:44:06.080 | so what we are doing is we are indexing by batch we are indexing by head and we are selecting
03:44:12.560 | everything is that inside i know we are not selecting everything inside we are skipping
03:44:17.040 | also in this case some blocks of queries because as i said before the output has the same shape
03:44:24.240 | as the query so um this particular block this particular program that we that will have this
03:44:32.080 | particular block index queue will only work with one block of the queries which will produce only
03:44:37.840 | one block of the output matrix and we need to select exactly that one so we we can point this
03:44:44.320 | pointer exactly to the point where we should start writing so let's skip also in this case
03:44:49.680 | block index q multiplied by block size q rows so we select exactly the block that our
03:45:01.120 | our program this particular program will produce when i speak about this particular program i mean
03:45:08.960 | the program that is identified by this program id in the x0 axis and this program id in the
03:45:15.200 | first axis because each of this program will run in parallel hopefully and each of them will have
03:45:20.480 | a different value for the block index q and index batch head okay now we have pointed our pointers
03:45:28.560 | to the right position where they should either read some information or they should either write
03:45:32.800 | some information by using make block pointer these pointers can also be treated directly as tensors
03:45:39.040 | so that's why we specify the shapes of this tensor because python triton right now provides
03:45:44.160 | some methods to work directly with blocks of to work directly with the pointers like they are
03:45:52.880 | we are accessing um tensors so we can index them like tensors all right so basically just try it
03:46:01.440 | on doing some calculation for you based on the strides so you don't have to do it by hand but
03:46:05.760 | later when we do the backward pass we will avoid using big block pointer and we will see the
03:46:10.400 | indexing done by hand all right um as you know we are processing a single block of queries so let
03:46:20.800 | let's go back to the algorithm otherwise we we lose the sight of what we are doing
03:46:26.640 | so let's go here and let's show my ipad all right so as you know each program we will parallelize
03:46:36.720 | along the query block dimension so each program will work with a different query block
03:46:40.560 | and then we need to do a for loop on all the key and value blocks right now we just
03:46:49.200 | moved our pointers to the right position to select the right query block that we should work with
03:46:54.320 | and to the beginning of the keys and values block that we should work with based on which
03:46:59.840 | index and which head this particular program should be working with all right now that we
03:47:06.240 | have pointed our pointers to the right position in which our program should be working it inside
03:47:11.040 | of the big pointers that are inside of the big tensors that are the that have the batch dimension
03:47:19.360 | the number of heads dimension the sequence length dimension and the heading dimension we have because
03:47:25.920 | we are pointing to the right batch and we are pointing to the right head these tensors have
03:47:29.840 | become two-dimensional tensors so they only work on the they are only tensors on the sequence length
03:47:35.760 | and on the head dimension now we need some more information that we will use later
03:47:42.320 | the first information that we need is the offsets of each query inside of the current block of
03:47:50.640 | queries that this particular program should be working with and that is given by the following
03:47:56.320 | line so let me copy and paste which is this one so the offsets of the queries are the first of all
03:48:05.840 | they are how many of them block size q because each block of queries is made up of a block size
03:48:12.320 | q number of queries what is each query it's a token and it's on the head dimension is the dim
03:48:19.920 | dimension is not the all the embedding of the token but a part of the embedding of each token
03:48:26.160 | which part the part corresponding to the head that this particular program is going to work with
03:48:32.000 | so we are generating the offsets that will load this particular number of
03:48:36.560 | this particular queries from the big tensor that contains all queries and we know that
03:48:41.840 | our queries start at the block index q multiplied by block size q position so if this is the program
03:48:48.880 | number zero they will the imagine block size is equal to four they will be the query with index
03:48:55.120 | zero one two and three but imagine we are the program number three which means that we need
03:49:00.880 | to skip three multiplied by four so 12 so it will point to the query number 13 14 15 and 16 etc etc
03:49:10.400 | etc all right and we do the same for the key and values initially the key and values is a range of
03:49:20.160 | keys and values that we need at each iteration and at the beginning because our pointer for the k
03:49:26.800 | and v is pointing to the beginning of the sequence of key and value for this particular batch and for
03:49:32.480 | this particular head we are pointing to the first block of key and value so we are not skipping
03:49:37.760 | anything in the query case we are skipping because our program will only work with one single block
03:49:42.560 | of queries in this case we don't skip anything because we need to iterate through this key and
03:49:47.280 | values so we are pointing to the first block of key values so imagine block size kv is equal to
03:49:53.600 | four so this stuff here will be equal to zero one two and three all right now we need as you remember
03:50:02.400 | inside of the flash attention algorithm we need to compute a block of query multiplied by the
03:50:07.760 | transpose of the keys and to each of this block we need to apply the softmax star if you remember
03:50:13.200 | what is the softmax star it is the softmax of without the normalization so while computing
03:50:20.000 | the softmax star we also actually compute the normalization factor without applying it and we
03:50:25.040 | apply the normalization factor at the end so for each block of query multiplied by transpose of
03:50:30.480 | the keys we need to have the maximum for each row in this particular block and the normalization
03:50:36.720 | factor for each row so that's why why we need these two following statistics which is this one
03:50:44.080 | and this is basically a block it's a block of numbers how many based on how many queries we
03:50:52.000 | have in our block of queries each one initialized with minus infinity just like in my algorithm
03:50:57.600 | that i showed before so let me go back to the slides in case we forgot or actually you can
03:51:03.680 | also check the flash attention algorithm we initialize it with minus infinities so so far
03:51:07.520 | we are creating this stuff here so we are initializing the mi we are we will be initializing
03:51:11.680 | the li and we will initializing the o and then we will show the inner loop here and this is exactly
03:51:19.440 | the algorithm that we have seen before so we initialize m with minus infinities now we initialize
03:51:25.200 | also the l's so let me go back to the code all right so the l's are initialized with this number
03:51:37.040 | here so here in the o blocks as we can see from the flash attention algorithm they are in the o
03:51:44.240 | block is initialized with zeros so that's why we initialize a block this is the output block that
03:51:49.840 | this particular program will compute which is based on the position in the batch and the position in
03:51:55.520 | the index so it is one block of the size of block size q so how many queries there are in this block
03:52:01.200 | by head dimension which if you want to visualize it let's go back to the slides it is equal to
03:52:09.440 | one block of this matrix here so it's one block of the output matrix so one row of blocks one
03:52:19.600 | block of rows okay so let's go back to the code now all right so now we have initialized a little
03:52:29.520 | stuff here so the output the mi and li where mi is the maximum for each row in this particular
03:52:36.480 | query block and li is the normalization factor for each of the items in the query for each of the
03:52:44.720 | rows in our query block now we need to do the for loop the inner loop in the flash attention
03:52:53.200 | algorithm we will create a separate method that will run the inner loop so let's let me copy it
03:53:01.040 | here and i am following the same structure of the code that you see in the tutorial of
03:53:06.560 | the triton website so basically if we are running the causal attention or even if we are not running
03:53:14.480 | the causal attention we make this for loop and then we will make another for loop and i will
03:53:20.160 | show you why so let me first write it and then we will see so this function here will be the
03:53:27.280 | inner loop this inner loop needs to go through all key and value blocks one by one and for each
03:53:35.440 | query and value block it needs to fix the previous calculated block of the the previous softmax star
03:53:43.440 | block so basically what we are doing here we will need to create a function as the following where
03:53:50.720 | we are going to iterate on all the key value block we will need to compute the query multiplied by
03:53:56.000 | the transpose of the keys using the query block that is fixed for this program and the key is
03:54:01.360 | block is the one that we are iterating it through and for each of these queries we need to calculate
03:54:06.400 | what is the maximum for each row we need to compute the softmax star so the softmax without
03:54:13.040 | the normalization factor we need to keep this statistics l which is the normalization factor
03:54:18.400 | that we will apply at the end of the iteration of the for loop and at the same time we need to
03:54:26.080 | update the output so as you remember the output is p11 multiplied by v1 plus p12 multiplied by v2
03:54:34.080 | but we need to fix the previous p11 so to fix that we need to every time we sum to
03:54:39.920 | o to the output we need to fix the output of the previous iteration
03:54:44.960 | and then we increase introduce the p and v block of the current iteration so here the author of the
03:54:57.040 | code for the the one that you see on the triton website decided to split this for loop into two
03:55:03.520 | steps why because in the causal attention we need to when we have a causal attention
03:55:10.560 | we have a group of we we don't we don't want the query to attend the keys that come after it
03:55:17.360 | while in the non-causal attention we let all the queries attend to all the keys which also means
03:55:23.920 | that we will need to have some kind of if statement inside of this if in the side of
03:55:29.120 | this for loop through all the key and values in which we need to check if this particular query
03:55:35.120 | that we are working with is comes before or after the key and value in case we are doing the causal
03:55:41.360 | attention so instead of iterating through all the key and values also in the case of the causal
03:55:48.160 | attention by splitting it into two steps we are saying first let's iterate through all the key
03:55:56.000 | and values for which the index is smaller than the current queries block and for this we need
03:56:02.960 | to compute the attention in the case of the causal and non-causal case then for all the elements on
03:56:09.280 | the right of this block so for which the key index is more than the q index in the case of
03:56:15.760 | causal attention we don't need to compute anything because it will be masked out because in the soft
03:56:20.800 | max it will become zeros so it will not contribute to the output so we don't even have to compute it
03:56:26.320 | this is why we split this this for loop into two steps so first we iterate to all the parts that
03:56:34.400 | are left to the diagonal of the query multiplied by the key matrix so for all the values for which
03:56:41.200 | the query index is less than the key index then we and then we skip all the parts to the right
03:56:49.120 | of this diagonal in case we are working with the causal mask but in case of the non-causal mask we
03:56:53.760 | compute the left part and the right part of this diagonal all right don't worry when we record
03:56:59.600 | this for loop it will be more clear so i just wanted to give a little introduction so let's go
03:57:04.800 | uh code this inner loop what will this inner loop do it will work with this particular query block
03:57:10.400 | that we have found so this q block it will uh right i don't see the q block because i didn't
03:57:17.520 | load it well yeah let's load it so we need to load the query block actually we forgot to load it so
03:57:25.120 | as you remember in triton we load data from the high bandwidth memory to the sram so to the shared
03:57:31.680 | memory by using the load statement and we are telling load the query block that we should be
03:57:37.280 | working with because this pointer q block ptr is already pointing to the right block that we should
03:57:43.360 | be working with so it's already skipping all the blocks that other programs should be working with
03:57:48.560 | and it will load a a tensor of size of block size q head dim so the right block of queries
03:57:57.840 | and we pass it to this inner loop to which we pass the output so where it should write this output
03:58:05.200 | the li and mi which are the statistics for the rows and for the maximum for each row of each
03:58:12.960 | query and the li which is the normalization factor for each query and the query block this program
03:58:19.360 | should be working with the beginning of the key and value block pointer because we need to iterate
03:58:25.200 | through them so we just point it to the beginning and then inside the for inner for loop we will
03:58:30.000 | iterate through them then the softmax scale that we should use when computing query multiplied by
03:58:34.880 | the transpose of the keys the block size so how many queries we have in each block of q
03:58:40.560 | and how many key and value we have in each block of kv this is a stage that tells us what if we are
03:58:48.480 | on the left side of the diagonal or on the right side of the diagonal so it will tell us if we need
03:58:52.800 | to apply the causal mask or not based on where we are and if we are need to apply the causal mask
03:59:00.160 | the offset q and the offset kv are just the offsets of the query and key inside of each q
03:59:06.480 | and kv block which is a list of indices that tells us how many queries we have
03:59:14.160 | and then the sequence length the entire sequence length because in the for loop we need to iterate
03:59:20.400 | to all the sequence length block by block so block of kv block of kv block of kv all right
03:59:26.800 | let's write this let's write this method and later we actually need to continue this method again
03:59:32.080 | so let's go and let me go here
03:59:36.240 | all right
03:59:40.960 | so this method we have already seen the signature so it's just another kernel so it can be called
03:59:50.640 | by the first kernel and this is something you can also do in cuda you can actually call
03:59:55.920 | call one cuda kernel from another cuda kernel and then we based on the stage of this inner
04:00:03.680 | loop we decide what we need to do so when we are using a causal attention so we only want to
04:00:12.560 | apply the attention to the queries for which the index is less than or equal to the key so we all
04:00:20.000 | want the query to know or attend to key and value that come after it then we pass the value three
04:00:28.240 | for the stage parameter now when we in the causal case this will become four minus three it is equal
04:00:35.760 | to one so what will happen is that we will only work with the range of keys and values that are
04:00:44.240 | from zero up to the current block of q so all the keys that whose index is less than or less than
04:00:51.600 | the the index of the queries we are working with so to the left part of the causal mask let me draw
04:00:58.720 | it otherwise i think it's going to be very difficult to follow so let's do it actually
04:01:03.120 | so let's open a new one and let's go here all right so we have been using this one before so
04:01:11.040 | we can do it again clear page all right in this now i i want you to think of the following
04:01:19.840 | matrix as a block matrix so let's draw it in pink because i have been drawing it all in pink
04:01:24.800 | we know that in the rows of this query multiplied by the transpose of the keys we have a uh the
04:01:31.520 | queries blocks of queries so we are not watching one single block we are watching all the blocks
04:01:37.120 | right now so this is the query block one this is the query block two this is the query block
04:01:41.840 | three this is the query block four each of this query block is made up of multiple tokens of
04:01:46.960 | queries and then we have the key the key blocks let's do it like this very ugly but okay key one
04:01:57.520 | key block two key block three key block four when apply calculating the attention when you calculate
04:02:04.000 | the causal attention so like with the causal mask you want only the query to attend to keys that
04:02:12.480 | come before it so when we apply the causal mask this stuff here will be made up of zeros this
04:02:18.240 | stuff here will be made up of zeros this stuff here will be made up of zeros and this stuff here
04:02:22.480 | and this stuff here and this stuff here all made up of zeros we never have to mask out anything
04:02:28.720 | when we are in this case because well when we are in this particular scenario actually in this
04:02:34.160 | particular scenario we don't need to mask out anything for sure why because all the key
04:02:39.520 | keys in this block so in this block of keys will have an index that is smaller than the index of
04:02:46.800 | the corresponding queries in case the the key the block size of the query and the key matches
04:02:52.880 | so imagine each query is made up of three queries so each block of query is made up of
04:02:57.920 | three queries so this is the query number 0 1 and 2 this is the query number 3 4 5 3 4 5 yeah
04:03:07.600 | this will be the number 6 7 and 8 and this will be the query number 9 10 and 11 in total we have
04:03:15.200 | 12 queries we will have the same indices also for the keys in case we choose the same size for the
04:03:21.920 | blocks so this key block here will be the key number 0 1 and 2 this will be the key number 3
04:03:32.400 | 4 5 this will be the 6 6 7 and 8 etc etc now what happens is that in this case as you can see the
04:03:42.720 | key indices of the keys are always smaller than the indices of the queries so we don't need to
04:03:48.960 | mask out anything even in the case of the causal mask because we are sure that in this case all
04:03:54.240 | of these dot products will never be masked out also in this case all these dot products will
04:03:59.360 | never be masked out and also in this case will never be masked out will never be masked out
04:04:02.960 | and will never be masked out and in this case however along the diagonal some of the queries
04:04:09.920 | will be more have will have an index that is bigger than that of the keys and some of them
04:04:16.000 | will not be will not have an index that is bigger than that of the keys because these are blocks of
04:04:22.160 | queries and blocks of keys some of them need to be masked out and some of them don't need to be
04:04:26.720 | masked out so we are dividing our for loop into multiple steps the first step that we are doing
04:04:32.720 | is all to the left of this diagonal in which we don't need to mask out anything then we will see
04:04:38.240 | another step here in which we we need to mask out and then everything to the right of this will be
04:04:46.800 | we will not even compute in the case of causal attention because we already know it's made up
04:04:50.400 | of zero so it will not compute so the product query multiplied by the transpose of the keys
04:04:55.280 | after the softmax will be made up of zeros so if you look at the flash attention algorithm so
04:04:59.840 | this stuff here the contribution will be zero because we are multiplying zero with v it will
04:05:05.920 | be zero so we don't need to change the output so why even compute this part of the matrix if we
04:05:11.200 | already know it's not going to contribute to the output so we just skip all those iterations
04:05:15.280 | and this is why we are splitting the for loop i hope now it's much more clear all right so let's
04:05:22.560 | go back um okay so uh we are now to the left part of the diagonal in case of the stage number one
04:05:29.520 | in the case of the stage number two it's the part in exactly on the diagonal so in which we need to
04:05:35.680 | do some dot products and some other dot products we don't need to do and then for the non-causal
04:05:40.560 | attention we just go through from zero to the sequence length without doing this multi-step
04:05:45.120 | because we don't need to mask out anything so this is why we have this stage this tells us
04:05:51.920 | what is the lower and higher index of the key block that this particular stage should be working with
04:05:59.440 | all right um now this function here multiple of is just telling triton that this number here is a
04:06:07.920 | multiple of this number so triton can make some optimizations so the stage one happens when when
04:06:14.400 | we are doing a causal attention so stage number three in this function and four minus three will
04:06:21.120 | become one so imagine we are in the causal attention we will go through the key and value
04:06:26.400 | block that are to the left of the diagonal with respect to the query block that we are working
04:06:32.320 | with um in the case we are doing not causal attention in this first call to the inner
04:06:40.320 | function this the stage will be one so the four minus stage will be equal to three so we will
04:06:46.960 | execute this part of the if statement so we will go to all the key and values in case for the
04:06:55.120 | causal attention only as you can see here we will do another iteration here that will only be done
04:07:00.960 | along the diagonal in which we need to mask out something and we don't need to mask out something
04:07:05.120 | because inside of each blocks there will be some keys that have the index below the index of the
04:07:11.280 | query and some that have above the index of the query so only in the causal attention we will
04:07:16.240 | call this function twice the first time with the stage equal to one and the second time with the
04:07:22.160 | stage equal to two and the second time we will only iterate through the group of key v blocks
04:07:29.360 | that are exactly on the diagonal of the matrix query multiply by transpose of the keys the big
04:07:36.160 | matrix that is made up of all the blocks all right now that this should be clear let's proceed
04:07:42.320 | further so let's um because we need to do the for loop the inner for loop of the flash attention
04:07:49.120 | let's go and load the first blocks of key and values which is exactly the one that the key
04:07:55.920 | and v blocks are currently pointing at which is the 0 0 block so uh we we define the the pointers
04:08:05.200 | basically um we we we point the key and value blocks to the first uh key and value block that
04:08:12.720 | this um for loop should be working with which will be based on the stage so if it's the first call to
04:08:19.040 | this function they will be pointing to the first block in the case of the causal and not causal
04:08:25.600 | if it's the second call to this function which only happens in the case of the causal attention
04:08:30.800 | they will be pointing exactly to the key and value block to the diagonal
04:08:34.400 | all right then we need to make the for loop so let's loop over all the for loop so let's do it
04:08:44.640 | so loop over the key and value and what we do is um okay we we let the compiler know that
04:08:52.960 | this number here the start kv will always be a multiple of the block size kv because we will be
04:08:58.400 | moving from one kv block to the next kv block block by block so we let the compiler know that
04:09:04.320 | this number here start kv is a multiple of block size kv it doesn't change anything from a logic
04:09:08.880 | point of view we are just telling giving some hint to the compiler so it can do some other
04:09:13.280 | optimization that triton does now the first thing that we see in the flash attention algorithm is
04:09:20.880 | we need to compute the product of the query so this is the particular block of the query that
04:09:25.600 | we are working with with the current kv block in this iteration so let's do it so we compute k and
04:09:33.200 | b so we load the the query have already been loaded by the caller of this function we have loaded it
04:09:40.480 | here here we have already loaded the query but we need to load the current block of k
04:09:46.880 | so we load the current block of k indicated by the k pointer and we multi we do the matrix
04:09:54.720 | multiplication of the current block of query the the block of query with the current block of k
04:10:00.480 | which is already transposed because when we loaded this k k when we defined the k block pointer we
04:10:08.640 | defined it already with the stride changed so we are reading the tensor already transposed so we
04:10:14.800 | are doing the query multiplied by the transpose of the keys basically okay now let's do here
04:10:22.560 | this part here basically saying okay if the stage is two when the stage is two is when we are
04:10:30.560 | exactly on the diagonal we know that some of the queries will have an index that is bigger than
04:10:35.760 | that of the keys and some of them we have an index that is smaller than that of the keys
04:10:40.080 | so we need to apply the causal mask only in this case so basically what we do is we define
04:10:46.960 | the mask that we should be applying so the mask will mask out all the values for which this
04:10:53.600 | mask is not true so when this mask is true when the index of the query is more than the index of
04:11:00.960 | the k and v's and we okay we apply the softmax scale so as you remember we here we only computed
04:11:09.440 | query multiplied by transpose of the keys but we also need to divide by the square root of
04:11:13.600 | head dimension and we do it here and then we because we already computed the
04:11:22.960 | the product we can calculate the maximum for each row and then we we we subtract because
04:11:31.440 | when later in the flash attention algorithm we have another operation which is the
04:11:36.080 | which i call the softmax star and as you remember the softmax star needs to do each row my each
04:11:46.560 | element of the s matrix so the query multiplied by the transpose of the keys minus the maximum
04:11:51.600 | for each row so we can already compute the maximum for each row and we can also before
04:11:58.640 | computing the maximum for each row we need to mask out all the elements that will be masked out
04:12:04.000 | in the stage number two which is along the diagonal and how to mask out we just replace
04:12:10.960 | with minus infinity before applying the softmax all the values for which the mask is false
04:12:15.840 | so right now we are we have computed what we have computed the query multiplied by
04:12:20.960 | transpose of the keys we have masked out in case we need to mask and when we need to mask only when
04:12:25.840 | we are along the diagonal in all the other cases we don't need to mask out anything we just multiply
04:12:30.480 | by the softmax scale and then we we subtract the mij the mij is the maximum value for each row
04:12:39.120 | because we need to compute the softmax star operation which is the softmax without the
04:12:43.280 | normalization which in the flash attention algorithm is exactly this operation which
04:12:47.520 | will produce the pij okay so let's go here so now we can compute the pij block which is this stuff
04:12:56.080 | here which is the exponential of the query kv block variable here which have already subtracted
04:13:03.920 | the m so we have already subtracted this mi at the previous instruction so now we can just
04:13:11.680 | apply the exponential and this is what we are doing here okay then we need to compute the
04:13:20.000 | sum of the the rows for the before the normalization factor so for the current block we will
04:13:29.760 | have a list of we have we have the pij block for the current kv block to compute the normalization
04:13:38.960 | factor for the softmax we need to keep summing up these exponentials and later we will fix the
04:13:45.200 | exponentials the the normalization factor that we computed at the previous step but we will do that
04:13:52.080 | later so now we just computed the normalization factor for the current block which is just the
04:13:56.560 | sum of all the values on a single row which is the same as what we did before here as you can see
04:14:04.400 | here when i show you the algorithm so for each block we do the row sum as you can see here
04:14:13.440 | of the p matrix what is the p matrix is the exponential of the s minus m and for now we
04:14:21.280 | didn't apply the the correction to the previous block that's it so we computed the lij for the
04:14:28.240 | current k and v block and then we compute the correction factor for the previous block so the
04:14:34.400 | correction factor for the previous block if you remember the formula from the paper is this one
04:14:39.200 | is the exponential of the previous estimate of the maximum minus the current estimate of the maximum
04:14:44.160 | which is exactly this one so the previous estimate of the maximum minus the current
04:14:47.760 | estimates of the maximum we will see later why mi is the previous estimate of the maximum and
04:14:54.000 | mij is the current estimate of the maximum because it is coming from the current block that we are
04:14:58.640 | computing mi is the let's say the the one that it is the the one of the previous iteration because
04:15:06.480 | later we will override mi with mij but i'm just following the flash attention algorithm so far
04:15:13.840 | so i am computing the correction factor of the previous li which in the flash attention algorithm
04:15:19.600 | is let me show you this stuff here so it is this stuff here this one here
04:15:27.440 | okay and then we apply it so apply the correction factor so we apply it so we apply the previous li
04:15:39.440 | with the correction factor plus the current li which is the one coming from the current p block
04:15:44.480 | the one that will be computed with the current k and v with the current iteration and right now we
04:15:50.320 | are doing this operation so li is equal to the previous li multiplied by the correction factor
04:15:55.520 | all right and then what we need to do okay we need to as you remember the formula is
04:16:02.240 | we calculate the p block and then we need to multiply by the v block
04:16:07.520 | so we need to load the v block so let's load it
04:16:13.680 | we load the v block based on the pointer of the v block to which this um to to which the
04:16:20.240 | pointer v is is pointing to at the beginning of this iteration in case we are in stage number
04:16:26.400 | three so in case we are doing for example not causal attention this will be pointing to the
04:16:30.400 | first k v block v block and then okay here there is just a type conversion so we make sure this is
04:16:39.760 | in floating point 16 and then we compute the output block so we are computing the following
04:16:48.160 | so we just take v p multiplied by v and we add it to the output and this is what we are doing here
04:16:54.960 | we take p we multiply it by v and add it to the o block let's go actually to this line one by one
04:17:02.800 | so first of all we need to fix the previous output block with the correction factor correction factor
04:17:07.920 | that we have here so we can fix the previous block with this alpha term here which is the
04:17:13.600 | correction factor for the previous block and so we just fix the previous block for now but we
04:17:19.760 | didn't add the new pv so to add the new pv we do the dot product of p and v and this third argument
04:17:26.160 | tells the dot this not dot product it's actually the matrix multiplication tell this matrix
04:17:31.840 | multiplication to use this element here as the accumulator so this is exactly the same as doing
04:17:38.080 | p block multiplied by the v block o block plus equal to p block multiplied by the v block
04:17:47.920 | this is just optimized because anyway this dot function here needs some place where to store
04:17:55.040 | the intermediate results so why not just store it where it should actually go and because it
04:18:01.600 | the dot the the matrix multiplication is just a dot product and the dot product is just a repeated
04:18:07.120 | sum this accumulator will be will this dot will keep summing the result to this block here which
04:18:16.000 | will exactly result in this instruction like we have done the matrix multiplication separately
04:18:22.960 | and we added it to the o block so this is uh that's why this argument is called the accumulator
04:18:30.080 | okay all right so we have also computed the output and then we save the new estimation of
04:18:37.440 | the maximum for the current iteration and it becomes mi so at the next iteration we can use
04:18:43.440 | it to calculate the correction factor and then we have finished for the current block and then we
04:18:49.760 | can move on to the next block so we advance our k and v pointers by one block of k and v
04:18:57.840 | we advance it differently because we know that the v block is a pointer to a tensor of shape
04:19:05.280 | let me write it here this is a tensor of shape sequence length head dim so we need to increase
04:19:15.040 | the sequence length by one kv the block size kv while the k block is actually the k transpose
04:19:23.200 | block so we need to and it is a transpose because we have exchanged the strides and the shape
04:19:28.720 | so it is head dimension
04:19:31.120 | head dimension sequence length so we don't change the head dimension we just advance the sequence
04:19:39.360 | length by sequence block size kv so basically we are just going to point to the next block of k and
04:19:45.920 | to the next block of v i hope you were able to follow the algorithm of flash attention i try to
04:19:52.720 | use the same names i try to use the more or less the same logic and always writing the formula that
04:19:57.280 | i'm referring to so hopefully you didn't get lost i think the only difference that there is between
04:20:02.400 | the flash attention algorithm as written on the paper and this code is probably this alpha which
04:20:06.800 | is the correction factor but i hope it's easily understandable anyway then we just return the o
04:20:13.440 | block so o block li which is the the normalization factor for each row in the current output block
04:20:24.480 | which is also a q block because we are working with one q block independently from the other
04:20:30.000 | programs and mi is the maximum value for each row which will be needed for the backward pass
04:20:37.280 | because when in the backward pass we will compute the qquery multiplied by transpose of the key
04:20:42.480 | block on the fly we need to also apply the softmax but instead of re-computing the stuff which we
04:20:47.680 | already computed during the forward pass we just save them and reuse them during the backward pass
04:20:52.480 | which will save us some computation now i know it's time to talk about the log sum x trick because
04:20:59.200 | we are going to use it so let's go back to the old method so let's go here all right so we have
04:21:06.640 | computed two calls of this function in case we are working with causal attention in case of the
04:21:12.880 | we are computing causal attention we call this function once to work with all the query blocks
04:21:17.520 | that are to the left side of the diagonal of the query key matrix then we do another call of this
04:21:23.280 | function to work only with those blocks of keys that exactly lie on the diagonal of the query key
04:21:30.800 | matrix because in this case some of the values need to be masked out and some of them do not
04:21:37.520 | need to be masked out moreover by doing this we can avoid computing the dot products for all those
04:21:44.400 | values in the causal math in the causal case for which the key is index of the key is higher than
04:21:51.760 | the index of the query saving some computation because anyway they will be resulting after the
04:21:56.560 | softmax in zeros and they will not contribute to the output so it should be faster okay now let's
04:22:03.520 | go back to the this method here so calling method and there is one last thing that we need to do
04:22:08.480 | which is we need to compute the log sum exp and now i will show you what is it so in order for
04:22:16.320 | the backward pass to recompute the softmax without having to recalculate the normalization factor and
04:22:21.760 | the maximum value for each row we should be actually saving two different stuff one is the
04:22:26.640 | maximum for each row in the query block and one is the normalization factor for each query in the
04:22:32.400 | query block however there is a trick and the trick is okay it's not really called log sum exp trick
04:22:38.960 | because the log sum exp trick is used for another purpose but let's call it log sum exp trick number
04:22:44.720 | two so the log sum exp trick number two is something like this so let me open the slides
04:22:52.640 | so when we do query multiply that transpose of the keys we get a matrix that is made up of dot
04:22:59.680 | products so something like this like this is one dot product so let's call it query one transpose
04:23:05.920 | the key one query one transpose the key two this is a query two transpose the key one and this is
04:23:13.200 | a query two transpose the key two then we need to apply the softmax right so the softmax is what
04:23:20.480 | is the let's write the formula of the softmaxes for each of these vectors so this is a vector
04:23:25.040 | and this is a vector because we applied it by rows for each of these vectors this will modify
04:23:30.080 | element wise each element as follows so the softmax of x i is equal to the exponential
04:23:37.040 | of x i minus oh my god i didn't leave enough space so let's move this stuff here back
04:23:44.000 | and this stuff here please left all right it will be the softmax of the exponential of each element
04:23:52.960 | minus the maximum for the current vector to which we are applying the softmax
04:23:57.600 | divided by the normalization factor which is the summation over all possible j's
04:24:05.360 | where n in this case is equal to 2 because we have each vector is made up of two elements
04:24:09.280 | of the exponential of x i minus x max now imagine we already have x max and we already have this
04:24:21.200 | summation in the flash attention algorithm in the forward pass this stuff here is called l i
04:24:26.240 | and this stuff here is called m i what we are going to save in the code you can see here
04:24:32.880 | we are saving actually not m i and l i separately we will be saving m i plus the logarithm of l i
04:24:41.680 | so we are going to save m i plus the log of l i so what will happen is that when we will compute the
04:24:54.240 | compute the backward pass we need to recreate this matrix here on the fly which means that
04:24:59.680 | we need to recompute the query multiply by the transpose of the keys and we to and then we should
04:25:06.080 | apply the softmax to apply the softmax we should need this stuff and this stuff here but we have
04:25:11.200 | only this stuff here so this is the m i plus the logarithm of l i so when we're computing the
04:25:17.040 | softmax we will compute the following so we will compute the softmax as follows we will define
04:25:23.600 | let's call it a new softmax so let me use another color here we will apply the softmax as follows so
04:25:33.680 | softmax of x i let's call it the softmax 2 because it's a i don't want to confuse softmax
04:25:46.880 | is equal to the exponential of each element minus we will subtract this value here
04:25:54.400 | the one corresponding to the current row to which we are applying the softmax
04:25:58.400 | so it will be the exponential of x i minus m i minus the log of l i if we expand this expression
04:26:09.040 | this will become the exponential of because the exponential the sum of two exponential of the sum
04:26:16.800 | is equal to the product of the two exponentials we can also write it like this so it will be
04:26:21.040 | the exponential of x i minus m i divided by the exponential of the log of l i which guess what
04:26:36.800 | it is equal to the exponential of x i minus m i divided by l i which is exactly the normalization
04:26:45.440 | factor and we also have m i so instead of saving two values we save only one value and when we
04:26:51.040 | apply it the exponential's properties will take care of actually also normalizing each value to
04:26:56.080 | which we apply it if you don't remember the properties of the exponential it is very simple
04:27:01.040 | so the exponential of a plus b is equal to the exponential of a multiplied by the exponential
04:27:10.320 | of b and the exponential of a not exponential it's the exponential
04:27:18.080 | a minus b is equal to the exponential of a divided by the exponential of b
04:27:28.080 | and this is the the trick that we're using so that's why we don't need to save two different
04:27:32.000 | values we just need to save one value and then when we apply it it will automatically be taken
04:27:36.800 | care will take care of normalizing because of the properties of the exponential all right let's move
04:27:42.560 | forward so we have also created this value that we will use during the backward pass now as you
04:27:51.280 | remember in the flash attrition algorithm we don't normalize each block while computing it we normalize
04:27:57.200 | the output at the end and this is exactly what we are going to do here so we normalize the block at
04:28:04.000 | the end after we have computed all the normalization factors that we need for all the rows that belong
04:28:08.560 | to the current output block we save this m i so we save it this m i is what is the normalization
04:28:17.920 | factor and the maximum for each row that we will need for the backward pass so we need to save it
04:28:24.640 | in a tensor that we will use during the backward pass so we need to understand which tensor is this
04:28:29.920 | and it's the tensor that we called m which is a tensor of a batch size num heads and sequence
04:28:35.440 | length dimensions so we need to select the right point in this tensor to select to where we should
04:28:42.160 | save this m i values so we need to select the right batch size index and the right number of head
04:28:49.120 | index so we advance this pointer by the following offset which is m plus the index batch head
04:29:01.520 | because each index okay the index batch head is what is the index of the current program that
04:29:09.520 | includes information about which head we are working with and which batch we are working with
04:29:14.640 | because each of this for each batch and for each head we have a sequence length we can skip
04:29:23.440 | a number of sequence length based on which index is okay what we are doing is basically we are
04:29:32.240 | skipping for each batch and for each head we will have a sequence length because each token in the
04:29:42.800 | sequence has a maximum value and each token in the sequence will have normalization value
04:29:47.520 | so based on the current combination of batch and head we can skip a number of sequence length that
04:29:54.240 | other programs will process so because in this tensor we have the sequence length as the last
04:30:01.440 | dimension and we have what is the combined index of the batch size and number of head size we can
04:30:08.160 | skip a number of sequence of length based on the combined index which is given by the program index
04:30:13.760 | number one which is the index batch head that we have here and this is why we skip here a sequence
04:30:20.240 | length number multiplied by the index batch head this m is pointing to the first element of the
04:30:29.600 | entire tensor so we are skipping the heads and the batch based on the combined index index batch head
04:30:37.920 | that this particular program is working with and then we have off skew off skew is because each
04:30:43.920 | of these kernels the attention forward method will work with one query block
04:30:52.240 | each query block has some indices for the exact queries it includes and this is given by off skew
04:31:01.600 | variable that you can see here which is how many blocks of queries we need to skip because they
04:31:06.560 | will be processed by other programs plus the range of queries that this particular that not this that
04:31:12.800 | a particular block of queries has so imagine this particular program is working with the queries
04:31:20.080 | that go from i don't know from 12 to 16 then this will be 12 13 14 15 so the normalization factor
04:31:29.280 | and the maximum value for each row we only have that for the disk for this indices of query queries
04:31:36.960 | so 12 13 14 and 15 and that's why we need to also skip the number of queries that this particular
04:31:42.880 | program works with which is already included in this offset of skew variable all right so now we
04:31:50.480 | can store the mi so because we have the pointer to which where it should be saved and we can also
04:31:56.400 | store the output which was computed of by our inner for loop and this guys is the forward step
04:32:03.520 | of the attention flash attention now we should go forward which is we should compute the backward
04:32:11.600 | pass we also have all the ingredients for computing the backward pass because we have
04:32:15.520 | already seen this trick which is the log sum x trick so we already know what um how to use it
04:32:22.000 | to compute the query key block during the backward pass on the fly what we miss to understand the
04:32:27.840 | backward pass well we need to understand what is the first of all what is the backward pass why do
04:32:32.160 | we even need a backward pass we need to understand what is the autograd of pytorch how does it work
04:32:37.680 | how to compute the gradient what is the gradient how to compute do we need to what is the jacobian
04:32:43.200 | when computing the gradient on the backward pass do we even need to compute that so we need to
04:32:47.520 | derive all the formulas of the backward pass by hand so if you are in for the challenge let's
04:32:52.240 | continue all right so now before looking at the flash attentions backward pass at the algorithm
04:32:59.920 | we need to understand why we even need a backward pass and to understand why we even need a backward
04:33:04.320 | pass so before looking at the autograd of pytorch we should be looking at what is what are derivatives
04:33:10.480 | what are gradients what are jacobians so that when we talk about derivatives gradients and jacobians
04:33:14.480 | we don't feel lost so i will do a very fast let's say rehearsal of what these topics are
04:33:22.240 | now what is the derivative when you have a function that takes as input a real value and
04:33:27.840 | outputs a real value we talk about derivatives which is defined as follows the derivative of
04:33:33.680 | the function with respect to its variable x is defined as the limit for a step size that
04:33:39.520 | goes to zero of the function evaluated at x plus h so x plus the step size minus f evaluated at the x
04:33:48.240 | at x divided by the step size so intuitively we are saying is the ratio of how much the output
04:33:56.160 | change for a small change for how much the input has changed in the function that this also gives
04:34:02.400 | you the intuitive intuition of why the gradient is the derivative is also the tells you the
04:34:11.600 | inclination of the tangent line of the to the function at the point in which it's evaluated
04:34:16.880 | i will use also the following notation to denote the derivative so the derivative i am used to
04:34:24.960 | write it as like this so f prime of x but it's also possible to write it as a d of f of x with
04:34:30.720 | respect to the x or d of y where y is the output of the function with respect to x
04:34:35.600 | and they are all equal to the same thing which is the definition above if we invert this formula
04:34:41.760 | here and we take h to the left side we can also write the follows so if we want to evaluate the
04:34:49.200 | function at the position x plus h we can also evaluate it as f prime of h so the derivative
04:35:00.000 | of the function in the point x multiplied by h which is the step size plus f of x this is
04:35:07.280 | actually also how we derive the Euler rule for computing the differential equations but that's
04:35:12.960 | not the topic of today so this h we can also call it delta x so f of x plus delta x is more or less
04:35:21.200 | because here we have a limit that says when this only happens when h is very very very small so
04:35:25.600 | that's why we put this more or less approximately so f of x plus delta x is more or less equal to
04:35:33.280 | f prime of x multiplied by delta x plus f of x this you can also read it as follows that if by
04:35:41.440 | inverting this formula if x changes by a little amount and this little amount is delta x how much
04:35:52.320 | y will change? y will change by this exact amount which is the derivative of y with respect to x so
04:36:00.240 | dy with respect to dx multiplied by how much x has changed so this dy dx tells us how much
04:36:08.320 | y will change with a small change of x if we multiply with the actual change of x it will
04:36:14.560 | tell us how exactly y will be affected i don't want to use stay too much on this but i would
04:36:23.520 | like to use this intuition to introduce the chain rule because imagine we have a function of a
04:36:31.600 | function so imagine we have z is equal to f of g of x we can think of x being mapped into a variable
04:36:39.600 | y through the function g and then y being mapped into a variable z through the function f if x
04:36:46.880 | changes by a little bit and by a little bit i mean delta x how much y will change? well y will change
04:36:53.200 | by delta y what is delta y? delta y is the derivative of y with respect to x multiplied
04:36:58.400 | by the step size of x but if y changes it will also affect z because there is a direct mapping
04:37:05.920 | between y and z so how much z will change for a small change in y? let's see so if y changes
04:37:13.360 | from the old y by a small step delta y then z will also change by some delta z and this delta z
04:37:21.520 | is the dz on dy multiplied by delta y if we replace this delta y with the delta y that we
04:37:29.840 | have computed in the expression above we arrive to the chain rule it will tell us how z will be
04:37:36.800 | affected so this is delta z what is the effect on z for a small change on x and it's the product
04:37:45.600 | of the two derivatives the one with the of y with respect to s and one z with respect to y
04:37:52.240 | and this is the chain rule that we study in high school so it is if you want to compute dz on dx
04:37:59.440 | it is dz on dy multiplied by dy dx which is very intuitive if you think about the following
04:38:07.360 | example so you can think of z as the price of cars and x as the price of the oil how much will
04:38:16.880 | a small change in the price of oil affect the price of a car? well this small change in the
04:38:23.200 | price of the oil will affect for example a variable y which could be the price of electricity
04:38:30.000 | so if how much the price of electricity will affect the price of a car it's through the
04:38:37.360 | derivative of the price of the electricity with respect to the the price of the car with respect
04:38:42.080 | to the electricity so to get the effect of the price of oil on the price of the car we just
04:38:49.280 | multiply the two effects and this is the intuition behind the chain rule anyway let's talk about
04:38:55.200 | gradients so when we have a function that as input takes a vector and produces a scalar we talk not
04:39:02.320 | anymore about derivatives we talk about gradients so imagine we have a function that takes as input
04:39:09.280 | a vector made up of two dimensions but n dimension in general and it produces a scalar when do we
04:39:15.680 | have to deal with this kind of function for example loss functions loss functions are something that
04:39:21.200 | are always a scalar as output and as input they take tensors so for example imagine the cross
04:39:27.760 | entropy loss it will take a sequence of tokens each tokens with its own logics and it will compute
04:39:34.800 | one single number which is the loss so how to view the effect on the output with respect to the input
04:39:44.160 | in this case well if x changes by a little amount and this little amount is not anymore a number but
04:39:50.400 | it's a vector so if change the x the old x plus delta x is a vector sum then y will also be
04:40:00.400 | affected by what y will be affected by dy on dx multiplied by delta x however this delta x is not
04:40:08.720 | a number anymore it's a vector because x1 may change by a little bit x2 will change by a little
04:40:14.880 | bit x3 will change by a little bit x4 until xn will change by a little bit so this is actually
04:40:20.880 | a dot product of this vector multiplied by this vector why a dot product because y will be affected
04:40:28.880 | by the change in x1 it will be affected by the change in x2 it will be changed affected by the
04:40:34.880 | change in x3 up to xn and each of the contribution of the contribution of x1 will be the partial
04:40:42.080 | derivative of y with respect to x1 multiplied by how much x1 has changed plus the contribution of
04:40:49.520 | x2 will be the partial derivative of y with respect to x2 multiplied by how much x2 has
04:40:55.280 | changed blah blah blah until the last contribution of xn so and the chain rule in this case also
04:41:02.320 | applies in the same way as in the scalar case so the formula does not change also for the chain
04:41:06.960 | rule here i just want you to to remember to remind you that in this case we are talking about a
04:41:13.280 | gradient and the gradient is just a vector made up of all the partial derivatives of the output
04:41:20.160 | with respect each of the input variables that are in the input vector when we talk about a function
04:41:27.760 | that have as input a vector and produces a vector then we don't talk about gradient anymore we talk
04:41:34.000 | about jacobians so if our input x the input x of this function changes by a little amount and this
04:41:41.840 | delta x is a vector then the output y will also change and this output y will change by a delta y
04:41:49.520 | that is not a number anymore it is a vector and this vector is the result of this quantity dy on
04:41:56.400 | the x multiplied by delta x delta x is a vector so this one has to be a vector it has this one
04:42:04.560 | here has to be a matrix and this matrix is called the jacobian it is a matrix that has as many rows
04:42:11.120 | later we will talk about the denotations so it has as many rows as there are output variables
04:42:16.880 | and as many columns as there are input variables the first row will be the partial derivative of
04:42:24.080 | the first output variable with respect to all the input variables the second row will be the
04:42:29.440 | partial derivative of the second output variable with respect to all the input variables and the
04:42:35.200 | last row will be the partial derivatives of the last output variable with respect to all the input
04:42:41.200 | variable in the input vector now let's talk about the notations the jacobian that i have written
04:42:48.640 | here is a is written according to the numerator layout this is called the numerator layout
04:42:56.240 | and there is another convention called the not layout sorry guys it's called the numerator
04:43:02.000 | convention and there is another convention called denominator convention or notation
04:43:06.480 | in which the rows are not the the number of rows is not the equivalent to the number of
04:43:15.440 | output variables but equal to the number of input variables so the fact that i have we we choose to
04:43:22.640 | write the jacobian as follows is based on a convention you can also write the the jacobian
04:43:29.840 | according to the denominator convention just by transposing this jacobian here and also the formula
04:43:35.200 | for the chain rule changes accordingly for now i want to keep the formula for the chain rule just
04:43:40.480 | like the one for the scalar case so that's why i am using this notation here but later we can change
04:43:46.000 | between one notation to the other just by doing a transposition okay now that we have reviewed what
04:43:52.000 | is derivative what is a gradient and what is a jacobian let's talk about what happens when we
04:43:59.360 | take derivatives with respect to tensors of a tensor with respect to another tensor in this case
04:44:04.400 | we talk about the jacobian but it's called the generalized jacobian so if we have the function
04:44:10.400 | that is at input takes a tensor of dx dimensions where the first shape this is kind of the shape
04:44:19.120 | of the tensor so the first element of the shape is n1 the second element of the shape of the input
04:44:24.320 | vector is n2 etc etc until n dx and it produces an output tensor that has this shape so m1 m2
04:44:34.000 | mdy in this case the formula for the chain rule doesn't change and if x changes by a little amount
04:44:44.000 | so by delta x which is a tensor y will also be affected by how much by dy on dx multiplied by
04:44:53.840 | delta x and this is a tensor product it will be a jacobian this is called generalized jacobian
04:45:01.520 | with the following shape so all the dimensions of the output multiplied by all the dimensions of
04:45:07.280 | the input all right this is very abstract for now we will see actually a concrete case of this one
04:45:14.320 | because we will be deriving the gradient of the output of a matrix multiplication
04:45:20.240 | the gradient of the loss when computing backward pass with respect to each of the input of the
04:45:26.800 | matrix multiplication operation and we will do it also for the softmax and we will do it also for
04:45:30.800 | the attention so i don't want to jump to too many topics i just wanted us to get into the right
04:45:35.440 | mindset so we know that derivatives when we have scalar functions gradients when the output is a
04:45:40.880 | scalar input is a vector jacobian when the input and output are both vectors generalized jacobian
04:45:46.640 | when the input and output are tensors the chain rule always works in the same way all right let's
04:45:54.560 | talk about autogradient i will do the scalar case and then we will extend it to the tensor case
04:46:01.040 | so imagine we have a very simple computation graph why we have computation graph because
04:46:05.120 | we are talking about neural networks and neural networks are nothing more than computation graphs
04:46:09.680 | where we have some input we have some parameters and we do some operations with this input and
04:46:13.440 | parameters suppose that you have an input a and this input a is multiplied by a weight
04:46:19.440 | it's a parameter weight it's just a scalar and it produces an output y1 this y1 is then summed up
04:46:26.800 | with another number called b1 and it produces y2 this y2 is then raised to the power of 2 so this
04:46:33.040 | is e to the power of 2 it's just the power of 2 of the input and it produces y3 and this y3 becomes
04:46:39.520 | our loss function so it's a scalar now what we want to do to apply gradient descent is we want
04:46:47.040 | to compute the gradient of the loss function with respect to each of the input of this computation
04:46:53.920 | graph so each of the leaves of this computation graphs what are the leaves it's this node here
04:46:58.800 | so the parameter nodes and input nodes and to do that there are two ways one is if you have access
04:47:07.440 | to the expression that relates directly the input to the output so the to the loss then you can
04:47:15.760 | directly compute the gradient the derivative in this case because it's not a gradient it's a
04:47:20.640 | scalar versus color so in this case imagine you want to compute the derivative of the loss with
04:47:25.760 | respect to w1 imagine we have access to the exact expression that relates the w1 to to the phi which
04:47:35.200 | is our loss we can compute it as follows so we just derive this expression with respect to w1
04:47:40.960 | which is two times because this is the power of two of a function so it is two multiplied by the
04:47:46.880 | function multiplied by the derivative of the content of this function with respect to the
04:47:51.840 | variable that we are deriving so it will become the following expression there is another way
04:47:57.600 | which is by using the chain rule so we can use the derivative of phi with respect to yw1 is the
04:48:04.800 | derivative of phi with respect to y3 which is the previous output of the previous node then the
04:48:11.760 | derivative of phi3 with respect to the previous the output of the previous node so and then the
04:48:17.600 | multiplied by the derivative of y2 with respect to the output of the previous node and then the
04:48:22.720 | derivative of y1 with respect to w1 if we do all this chain of multiplication we will obtain the
04:48:29.040 | same result and you can see that here this stuff here is exactly equal to this stuff here by doing
04:48:35.920 | this procedure here we will note something that is i want to zoom out a little bit okay to compute
04:48:44.240 | the derivative of phi with respect to w1 we are doing all this chain of multiplication but what
04:48:53.120 | is each item in what is each factor in this sequence of multiplications well this stuff
04:49:01.120 | here is nothing more than the derivative of phi with respect to y2 these multiplications here are
04:49:08.080 | nothing more than the derivative of phi with respect to w to respect to y1 and all of them
04:49:14.960 | combined are the derivative of phi with respect to w1 what pytorch will do it will do the following
04:49:22.560 | pytorch will do the backward pass because pytorch knows what is the computation graph that relates
04:49:30.000 | the output so the loss function in this case and the variable for which we want to compute the
04:49:36.080 | gradient right now we are talking about derivatives so it's not gradient but the mechanism is exactly
04:49:42.080 | the same so pytorch will say it will pytorch is like a person that knocks the door of this
04:49:52.400 | operation and says hey operation exponential power of two if i give you the gradient of the
04:50:03.360 | loss with respect to y3 which is one because the loss and y3 are actually the same can you give me
04:50:09.840 | the gradient of the loss with respect to y2 because pytorch actually does not implement
04:50:15.440 | an autograd system in the sense that it does not know the symbolic operations that led to the
04:50:21.280 | output it just knows what are the functions that computed the output and each function has a
04:50:27.760 | function each function is a class in python that implements two methods one is the forward step
04:50:33.760 | and one is the backward step the forward step takes the input so in this case y2 and computes
04:50:38.560 | the output y3 the backward step will take the gradient of the loss with respect to its output
04:50:45.920 | and needs to compute the gradient of the loss with respect to its input how can we do that well
04:50:51.680 | it's very simple because a pytorch will knock the door as let me copy it and this stuff here
04:50:58.560 | otherwise it's not easy to go back and forth so okay and let's paste it here pytorch will knock
04:51:06.320 | the door of this function here and we'll say hey if i give you the loss of the gradient of the loss
04:51:14.480 | function with respect to your output can you give me the gradient of the loss function with respect
04:51:19.200 | to your input yes the function can do it why because of the chain rule this operator here
04:51:25.120 | this function here can just do take the loss the gradient of the loss function with respect to its
04:51:30.480 | output multiply it by the jacobian or in this case the derivative of its output with respect to its
04:51:38.640 | input and it will be equal to the gradient of the loss with respect to its input then pytorch will
04:51:44.080 | take this one and knock the door at the next operator which is this one this summation
04:51:48.400 | and we'll say hey if i give you the gradient of the loss with respect to your output can you give
04:51:54.720 | me the gradient of the loss with respect to your input yes this operator can do it because this
04:51:59.920 | operator just needs to apply the chain rule so it will take the gradient of the loss with respect
04:52:03.920 | to y2 which is provided by pytorch and by multiplying it with the the jacobian in this
04:52:11.600 | case it's the derivative the derivative of the its output with respect to its input it can compute
04:52:16.480 | the the gradient of the loss with respect to its input then pytorch will take this output of this
04:52:23.360 | backward pass and will knock the door of the next operator which is this product and will ask again
04:52:28.800 | the same question hey if i give you the gradient of the loss with respect to your output can you
04:52:34.080 | give me the gradient of the loss with respect to your input yes this will do the same exact job it
04:52:39.200 | will take the gradient of the loss with respect to the output multiplied by the jacobian of the
04:52:44.400 | output with respect to the input and obtain the gradient of the loss with respect to the
04:52:48.640 | input and this is how pytorch runs the backward step it runs one operator at a time backwards
04:52:56.800 | in the computation graph knocking the door of each operator and asking always the same question if i
04:53:03.040 | give you the output the gradient of the loss with respect to your output can you give me the gradient
04:53:07.680 | of the loss with respect to your input and each operator will just apply the chain rule to
04:53:12.240 | to to get this to get this gradient to calculate this gradient that pytorch needs
04:53:17.520 | why pytorch cannot do it by itself because pytorch does not do symbolic mathematics it does not have
04:53:25.200 | access to the exact expression that each function is computing it just uses the function as a black
04:53:30.800 | box that computes forward and backward however with the jacobian we have a problem and let's
04:53:37.280 | see what is the problem all right so up to now we have been working with a computation graph that is
04:53:43.680 | made up of scalars but the things that we have said they work in the scalar case but also in
04:53:48.640 | the tensor case so let's go back see what is our computation graph we have seen that pytorch will
04:53:55.600 | go operator by operator asking always the same question if i give you the gradient of the loss
04:54:00.400 | with respect to your output can you compute me the gradient of the loss with respect to your input
04:54:05.120 | and each operator can just apply the chain rule to compute that imagine now that all of these
04:54:11.840 | operators are working not with scalars but are working with tensors which means that the derivative
04:54:17.600 | of the output with respect to the input of each operator is not a derivative it will be a jacobian
04:54:23.600 | because the output will be a tensor a generalized jacobian and input will be a tensor which means
04:54:30.000 | also that this quantity here so the derivative of the loss with respect to the input in this case
04:54:35.120 | will not be a derivative it will be a gradient because the output the loss is a number always
04:54:39.920 | while the input in this case y1 will be a tensor so number output input is a tensor then we talk
04:54:49.120 | about gradients so this will be a gradient and we will call it the downstream gradient that the
04:54:54.880 | operator needs to compute this will be the upstream gradient that pytorch will give to the
04:54:59.920 | each of these operators so the gradient of the loss with respect to the output of each operator
04:55:05.520 | and each operator needs to come up with this downstream gradient by using the jacobian
04:55:11.760 | however the jacobian has a problem let's see so imagine we are implementing a simple operation
04:55:18.640 | that is the matrix multiplication and the matrix multiplication is takes as input a x tensor
04:55:26.000 | it multiplies it by a w matrix made up of parameters and produces a y matrix as output
04:55:31.920 | suppose that x is let's call it n by d matrix w is let's say d by m matrix and so y will be a
04:55:45.920 | n by m matrix usually the input x is a sequence of tensor of let's say vectors each of each with
04:55:58.880 | d dimensions so you can think of it as a sequence of tokens each token is a vector made up of d
04:56:04.880 | dimensions usually we have many tokens so suppose that n usually is at least 1024 at least in the
04:56:13.440 | most recent language models we even have millions of tokens actually so and d is also actually quite
04:56:20.880 | big it usually it is at least 1024 also so also this one is 1024 and d and m m is also at least
04:56:31.600 | 1024 so we can actually become 2028 let's say so i i like the powers of two by the way so the problem
04:56:39.920 | of the jacobian is this if we compute want to compute this downstream gradient by multiplying
04:56:45.040 | the upstream gradient with the jacobian this jacobian matrix is huge because look at the
04:56:51.520 | dimensions here this will be a matrix that is it will be well n by m multiplied so it will be a
04:57:02.400 | generalized jacobian so it will be a tensor that has a shape n m and then the input is x so it is
04:57:10.800 | n by d so how many elements it will have well it will have 1024 multiplied by m which is 2048
04:57:20.320 | multiplied by 1024 multiplied by d which is 1024 so it is at least wow it's a billions more than
04:57:31.200 | 1 billion elements so it is impossible actually to materialize this matrix here in the memory
04:57:39.440 | because in the ram of the gpu because it will be too big so but we need to compute this downstream
04:57:46.640 | gradient because pytorch needs it to continue calculating the gradient of the loss function
04:57:51.760 | with respect to each of the nodes in the computation graph so how can we proceed
04:57:56.640 | the first thing that we should notice is that this this jacobian is actually a sparse matrix
04:58:04.080 | and i want to show you why it is actually is a super super super sparse matrix because if you
04:58:10.240 | look at the input what is the effect of the input on the output the input is a sequence of tokens
04:58:17.120 | so this is the token number one it's a vector of some dimensions 1024 dimension then we have
04:58:24.320 | another token as input then we have another tokens as input then we have another tokens
04:58:29.280 | as input and we multiply by the w matrix which is made up of some columns some columns so this
04:58:37.680 | one is n by d right yes and w is d by m so d by m this will produce a matrix that is n by m
04:58:50.080 | so it will be also a sequence of tokens each made up of m dimensions so it will be a matrix like
04:58:56.480 | this so this will be the first output token this will be the second output token this will be the
04:59:02.560 | third output token and this will be the fourth output token now this output row here is the dot
04:59:11.200 | product of this input row with all the columns so the derivative of each of these dimensions
04:59:19.280 | with respect to the dimensions of all the other tokens will be zero because they do not contribute
04:59:24.320 | to this output so the jacobian will have zeros every time the we are calculating the derivative
04:59:32.960 | of this first dimension with respect to any other element of other tokens that's why we always can
04:59:40.960 | come up with a better formula for computing this downstream gradient that does not involve the
04:59:46.000 | materialization of the jacobian because the matter the jacobian itself is sparse so let's see how we
04:59:51.920 | can optimize this computation without materializing the jacobian in the case of matrix multiplication
04:59:57.280 | because we need it for flash attention all right guys so before proceeding to the backward watch
05:00:04.800 | the formulas of the backward path of the flash attention let's look at how to compute the gradient
05:00:10.640 | of the matrix multiplication operation with respect to its input so imagine we are creating
05:00:15.920 | okay pytorch already have actually how to compute the gradient of the inputs of the
05:00:23.040 | matrix multiplication with the gradient of the loss with respect to the input of the matrix
05:00:26.720 | multiplication operation but in flash attention we are creating a custom kernel which means that
05:00:32.480 | the custom kernel is fusing multiple operations into one operation so when pytorch will knock
05:00:39.360 | the door of our operator it will ask the our operator which is the triton attention operator
05:00:45.280 | that we have built what is the gradient of the loss function with respect to q k and v because
05:00:50.080 | that's the input of our function so if we look at the code that we have built so far you can see
05:00:55.280 | that our triton attention will be a node in the computation graph that takes as input q k and v
05:01:02.560 | and produces an output then pytorch will give us the gradient of the loss with respect to that
05:01:09.040 | output so it will give us a d o so the derivative of the loss with the gradient of the loss with
05:01:14.160 | respect to o and then we'll ask this class here so triton attention to compute the gradient
05:01:21.440 | of the loss with respect to q k and b because we are fusing multiple operations together so we are
05:01:26.720 | computing on the fly the softmax of query multiply by the transpose of the key and then multiplying
05:01:32.000 | doing the softmax and multiplying it by v to compute the output we need to compute this
05:01:38.640 | gradient internally to compute this the gradient of the inputs so because in this operation that
05:01:45.280 | we are doing fusing together there is a matrix multiplication we need to derive by hand the
05:01:50.240 | matrix multiplication the gradient of the of the loss function with respect to the input of the
05:01:56.480 | matrix multiplication operation so that we can provide it to pytorch that's why we need to derive
05:02:03.760 | this formula i will derive it in the simple in a very simple way and and then we will do it for
05:02:12.000 | the softmax as well because these are the two things that we need to derive by hand to derive
05:02:15.840 | the formula of the flash attention's backward pass so let's start imagine we have a computation
05:02:23.680 | graph a node in the computation graph called the matrix multiplication and this node in the
05:02:29.200 | computation graph is doing a matrix multiplication so it is computing the following operation
05:02:33.360 | y is equal to x multiplied by w now what pytorch will give us as input when computing the backward
05:02:43.600 | pass of this node pytorch will give us the gradient of the loss so it will give us d phi
05:02:49.840 | with respect to dy so the output of this node and will ask us to compute the gradient of the
05:02:58.400 | loss function so the gradient of the loss function with respect to dx and the gradient of the loss
05:03:04.400 | function with respect to dw the easiest one to work with and the one that i will be showing and
05:03:10.560 | the other one i will not show in the video but i will attach the pdf slide on how it is computed
05:03:14.960 | because they are very similar in the way they are computed so i don't want to make the video
05:03:19.440 | too long for unnecessary reasons let's compute the gradient of the loss function with respect
05:03:28.240 | to the input so with respect to x all right so how to do that by hand without materializing the
05:03:37.120 | jacobian because as we have seen we cannot just use the chain rule by materializing the jacobian
05:03:42.160 | which would be the easiest way because the jacobian is very big matrix that cannot even
05:03:46.960 | fit in the memory of the gpu so we need to find a smarter way we exploit the fact that the jacobian
05:03:53.120 | is sparse so hopefully we will get formula that does not involve the materialization of a very
05:03:58.560 | big sparse jacobian let's see so uh let's see um let's when dealing with these kind of derivations
05:04:07.360 | i always recommend to make some example tensors so suppose that that x is a tensor of size let's say
05:04:15.920 | n by d and where n let's say n is equal to one and d is equal to let's say three and
05:04:25.200 | w is a tensor also or a matrix with the shape let's say d by m
05:04:36.800 | where m is equal to let's say four and y will have as a consequence the shape n by m
05:04:47.440 | so it will have the shape well one by four what pytorch will give us and pytorch will give us
05:04:58.320 | the following quantity so it will give us this stuff here so the gradient of the loss function
05:05:03.840 | with respect to the output of this operator which is y so it will give us a vector or a tensor
05:05:10.240 | actually with the following dimension which is n by m and we need to compute the gradient of the
05:05:19.280 | loss function with respect to x which should be a tensor of shape n by d because when dealing
05:05:26.240 | with the gradient it always has the shape of the input variable because it's the output which is a
05:05:33.200 | scalar with respect to each element in the input so it has the same shape as the denominator
05:05:38.160 | all right so when dealing with this kind of problems i always recommend to create example
05:05:44.160 | matrices and then work out what happens to the output and then try to work out the
05:05:49.280 | the gradient matrix so let's do it so let's see that what is how is the output computed well
05:05:57.440 | the output will be a matrix that is a one by four computed as follows it will be the input
05:06:05.920 | so one by three so let's call the input x one one x one two x one three it will be multiplied
05:06:16.160 | by another matrix w that it has dimension three by four so it will be three rows by four columns
05:06:25.280 | so it will be w 1 1 w 1 2 w 1 3 w 1 4 then w 2 1 w 2 2 w 2 3 w 2 4 w 3 1 w 3 2 w 3 3 w 3 4
05:06:48.560 | if we do this matrix multiplication it will be well it will produce the following matrix
05:06:55.040 | that is okay this is one row by three columns this is three column three rows by four columns
05:07:01.040 | so the output will be a matrix that is one by four so one row by four columns so it will be
05:07:08.560 | let me write it with a smaller because otherwise it will never fit here so
05:07:15.520 | let's do it like this it will be x 1 1 multiplied by w 1 1 plus x 1 2 multiplied by w 2 1 plus x
05:07:29.280 | 1 3 multiplied by w 3 1 and this will be the first element of the output the second element
05:07:37.120 | of the output will be x 1 1 with w 1 2 x 1 1 with w 1 2 plus x 1 2 with 1 2 with w 2 2
05:07:53.360 | plus x 1 3 with w 3 2 this will be the second element of the output matrix the third element
05:08:04.800 | of the output matrix will be let me move this stuff on the left otherwise it will never fit
05:08:08.720 | so okay i think now it can fit this will be x i need to watch this one so x 1 1 with w 1 3 x 1
05:08:20.720 | x 1 1 with w 1 3 plus x 1 2 with w 2 3 plus x 1 3 with w 3 3 and then we multiply the same
05:08:35.440 | row with the last column so it will be x 1 1 w 1 4 plus x 1 2 w 2 4 plus x 1 3 w 3 4
05:08:49.760 | this will be the output y if we do the matrix multiplication what pytorch will give us
05:08:56.880 | it will give us the gradient of the loss so it will give us delta phi with respect to delta y
05:09:04.000 | because it's a gradient it has the same shape as the denominator so it has a shape that is 1 by 4
05:09:10.640 | let's call it because we don't know what this value will be they will be provided to us by
05:09:15.600 | pytorch let's just give them generic name like d y 1 1 d y 1 2 d y 1 3 and d y 1 4 like this
05:09:28.880 | now to compute the the downstream gradient that we need to provide to pytorch we should
05:09:37.360 | be computing the we should be materializing the jacobian which is which is okay let's write the
05:09:46.560 | chain the chain rule formula so we need to provide delta phi to with respect to delta x which is
05:09:54.880 | equal to delta phi with respect to delta y this is provided by pytorch multiplied by the jacobian
05:10:02.800 | which is delta y with respect to delta x now instead of materializing this jacobian let's try
05:10:10.320 | to do this let's materialize it now and let's do the multiplication of these two quantities
05:10:16.640 | to see if something simplifies so this stuff here will be dy with respect to dx which means the
05:10:24.400 | derivative of every output y with respect to every input x how many output we have we have
05:10:32.080 | four elements as the output which is this stuff here and we have three element as input in the x
05:10:39.280 | matrix so it will be as follows i don't know how to let me copy it because my screen is
05:10:47.600 | not big enough and i remember that x is x 1 1 and x x 2 so delta y with respect to delta x
05:10:58.400 | will have the following entries so the y1 with respect to x 1 1 and as you can see y1 only has
05:11:09.760 | one x 1 1 appearing as multiplied by w 1 1 so the derivative with respect to x 1 1 will be w 1 1
05:11:16.560 | then y 1 1 so this stuff with respect to x 1 2 it will be w 2 1
05:11:27.360 | then x y 1 1 with respect to x 1 3 will be w 3 1 the second row of this matrix will be the
05:11:37.920 | derivative of the partial derivative of the second output so w y 2 with respect to all the x
05:11:45.680 | inputs which will be the derivative partial derivatives of this stuff here with respect to
05:11:51.520 | every x which is w 1 2 w 2 2 i guess and w 3 2 now let me check if it's what i'm doing is correct
05:12:05.200 | yes because i've already done it so i can always double check and then we have w the partial
05:12:13.840 | derivatives of this stuff here with respect to all the x which is w 1 3 w 2 3 and w 3 3
05:12:25.280 | then the partial derivatives of the last output so y 4 with respect to all the x which will be
05:12:33.600 | w 1 w 1 4 w 2 4 and w 3 4 we obtain the following jacobian if um but this jacobian as you can see
05:12:50.160 | is just equal to w transposed so we don't need to materialize the jacobian we can just do the
05:12:57.520 | multiplication of whatever gradient pytorch is giving us multiply it by w transposed and
05:13:05.120 | we will get the downstream gradient so let me rewrite so we know have we know what we are
05:13:10.720 | doing so d phi on d dx is equal to d phi with respect to y multiplied by dy on dx but we have
05:13:23.760 | seen that dy on dx is just equal to w transposed so this is equal to d phi on dx dy multiplied by
05:13:33.840 | w transposed and this gives us the downstream gradient so in order to provide the downstream
05:13:38.960 | gradient that pytorch need we just need to take whatever gradient pytorch will give us multiplied
05:13:43.280 | by w transposed and it will give us the gradient of the loss function with respect to the input
05:13:48.880 | x of the matrix multiplication in the same way we can also write the formula for the gradient
05:13:55.520 | of the loss function with respect to w and it is equal to x transposed multiplied by d phi
05:14:03.360 | with respect to dw dy how to remember these formulas these are there is a mnemonic rule
05:14:14.720 | which is these are the only possible ways for this to have the shape of x and this to have the
05:14:24.480 | shape of w because this one's this stuff here will have the same shape of y so it will be
05:14:30.880 | n by m this stuff here will have shape of w transposed w is d by m so w transpose should be
05:14:42.400 | m by d and the resulting operation of this matrix multiplication or tensor
05:14:48.960 | multiplication will be n by d which is exactly the same shape as x
05:14:56.320 | in this case we will have that xt is the transposed of t and it is n by d so it's d by n
05:15:09.360 | multiplied by d phi with respect to dy which is a gradient so it has the same shape as the
05:15:15.040 | denominator so it has n by m and the output will have shape d by m which is exactly the
05:15:29.520 | the shape of w so if you if to remember them this is the only way this shape work out
05:15:37.280 | otherwise they don't work out so this is a mnemonic formula on how to remember how to
05:15:40.880 | compute the gradient of the inputs of a matrix multiplication given the gradient of the loss
05:15:46.160 | with respect to the output of the matrix multiplication and the inputs to the
05:15:49.600 | metric multiplication are the input matrix and the parameter matrix w now we need to derive the
05:15:55.600 | gradient of the output of the softmax with respect to the input of the softmax because that's another
05:16:01.360 | operation that we do in our fused attention because we are fusing many operations together
05:16:05.280 | which is matrix multiplication and the softmax so this is the second ingredient that we need
05:16:09.680 | to understand the backward pass of flash attention so let's do it i will use to make
05:16:15.440 | this derivation i will use the same notation as in the flash attention paper so first of all
05:16:21.040 | let's write the title of this stuff which is the gradient through the softmax
05:16:33.840 | the first operation that we do in during computation of the attention is we compute
05:16:44.400 | the product of the query multiplied by the transpose of the keys we do in a blockwise
05:16:48.320 | way it means that we do it block by block but it doesn't matter because the end result is the same
05:16:52.800 | so we can also we can write s equal to q multiplied by the transpose of the keys
05:16:58.240 | and then we apply the softmax to this operation to the result of this operation and we call this
05:17:05.040 | output p which is the softmax of s and after the uh we have applied the softmax we take the output
05:17:15.440 | of the softmax we multiply it by v to obtain the output so the output is equal to p multiplied by v
05:17:21.520 | now we need to understand how to because as i said before pytorch autograd works in the
05:17:31.280 | following way pytorch will treat our attention computation as a black box so we will have a
05:17:37.040 | computation graph like the following we will have a query input a key input and a value input which
05:17:44.560 | are sequences of tokens each one with some embedding dimension these are fed to some
05:17:50.640 | black box called the attention which is our implementation of the attention which is the
05:17:56.800 | function that we started coding before this will be fed as input to this node in the computation
05:18:03.200 | graph and the computation graph will output a an output tensor o what pytorch will give us
05:18:10.480 | pytorch will give us the gradient of the loss with respect to the output so as you remember
05:18:17.680 | pytorch knocks the door knocks the door at each operator and says if i give you the gradient
05:18:23.920 | of the loss with respect to your output can you give me the gradient of the loss with respect to
05:18:28.560 | your inputs and this is what we need to figure out so given the gradient of the loss with respect
05:18:34.400 | to the output we need to understand how to compute the gradient of the loss with respect to the
05:18:40.400 | wq the gradient of the loss with respect to wk the gradient of the loss with respect to wb however
05:18:48.480 | there is no direct connection between q and o or k and o because there are two intermediate
05:18:54.560 | operations so one there is a first a matrix multiplication then there is a softmax then
05:18:59.120 | there is an additional matrix multiplication however we have tools that allow us to understand
05:19:04.080 | how the gradient propagates through multiple operations when they are applied in sequence
05:19:08.880 | and that's called the chain rule however we have seen that applying the chain rule in its
05:19:14.000 | naive way by materializing the jacobian is infeasible so we need to understand
05:19:19.120 | how to apply the chain rule without materializing the jacobian and that's what we are going to
05:19:24.480 | figure out for one of the operations inside of this attention computation which is the softmax
05:19:29.840 | and that's why we are going to do this derivation which i promise is the last one that we will do
05:19:34.320 | and then we will finally go to code the backward pass of flash attention we cannot proceed directly
05:19:40.000 | to coding the backward pass of the flash attention because if we look at the formulas on how it is
05:19:43.680 | computed we will not understand how the the derivation comes out okay now we can start
05:19:51.440 | so let me delete this stuff delete and imagine for simplicity now we apply the softmax to a
05:20:01.680 | row wise to this s matrix so each row is softmaxed independently from the others
05:20:08.640 | so let's see what happens to one single row of this matrix and for simplicity i will call it s
05:20:17.040 | so s is a single row of the s matrix i could also call it s of i but if i do it like this we will
05:20:26.160 | have to carry over the index okay guys just just do it we will carry over the index all right so
05:20:33.280 | let's call si one row of the s matrix so si is equal to let's say it's the in tensor notation
05:20:40.720 | pytorch tensor notation it will be like this so from the matrix s from the tensor s we take the
05:20:47.840 | ith row and all the columns this is the definition of si i know it's very ugly notation but it helps
05:20:53.440 | you understand and this is a vector of size and dimensions we apply the softmax to this
05:21:02.000 | vector and we will obtain an output vector and we call it pi pi is equal to the softmax softmax
05:21:12.800 | of si so as we have seen the softmax operation does not change the shape of the input it just
05:21:20.240 | changed element wise each number so the output will also be a vector of size r to the power of n
05:21:28.240 | now what is the softmax so the softmax is defined as follows the softmax
05:21:38.480 | of well p i j so the jth element of the p ith vector is equal to the exponential of the jth
05:21:51.760 | element of the s ith vector divided by a normalization factor that is computed as follows
05:22:01.600 | with let's say not j let's use k in this case not even k let's use l
05:22:08.160 | is equal to one up to n of e to the power of s i l all right so first of all you may be wondering
05:22:21.840 | the softmax that we are that we apply during the forward pass of the computation of the attention
05:22:28.080 | is not really this softmax because in if you remember what we applied before we were applying
05:22:33.520 | the softmax where each of the argument of the exponential is reduced by the maximum element
05:22:39.760 | in the vector to which we apply the softmax so it was more or less like this so s i j minus s i max
05:22:48.240 | so the maximum element in the s i j s i vector and also the argument of the denominator was
05:22:55.760 | reduced by s i max however we also proved that this stuff here is equivalent to the standard
05:23:05.680 | softmax without this reduction in the argument because this reduction in the argument is only
05:23:11.760 | added because we want to make it numerically safe to compute but there is it's equivalent
05:23:18.080 | to do it without but from a mathematical point of view on the computer of course it will become
05:23:23.920 | numerically unstable but from a mathematical point of view it is the same thing which also
05:23:29.600 | means that doesn't doesn't matter how you compute the forward pass if it's equivalent to another
05:23:36.320 | mathematical definition you can always use the other mathematical definition to compute the
05:23:40.160 | backward pass it will result in the same value if you didn't understand what i said let me give you
05:23:46.080 | a more simple example which is imagine you have a do you remember the formula from high school
05:23:54.320 | this one so cosine cosine of squared of x plus sine squared of x is equal to one now imagine
05:24:04.480 | we compute an output y is equal to cosine squared of x and then we need to compute the derivative
05:24:14.000 | of y with respect to x it doesn't matter if you compute it as the derivative of cosine squared
05:24:24.880 | of x with respect to x or if you compute it as the derivative of one minus sine squared of x
05:24:37.440 | with respect to x because they will result in exactly the same result because the two definitions
05:24:44.320 | are equivalent and this is why we don't need to add this this factor in the exponential
05:24:49.360 | because the two definitions are equivalent mathematically we just use the numerically
05:24:54.560 | safe one because when computed on the on the computer we need something that is numerically
05:25:00.080 | stable that will not overflow all right now what do we want to obtain
05:25:07.280 | so we want to obtain the gradient of the loss with respect to the input vector of the softmax
05:25:16.080 | which is the s_i vector given the gradient of the loss with respect to the output of the softmax
05:25:24.160 | which is the p_i vector and we can obtain that with the chain rule multiply that by the jacobian
05:25:33.200 | p_i with respect to s_i now we the chain rule is always valid let's see what does this jacobian
05:25:47.520 | look like all right so this jacobian will be the p_i with respect to delta s_i well we need to do it
05:26:00.800 | let's look at what each element in this jacobian will look like so the jth element with respect to
05:26:08.240 | the let's say the kth element so we are um we are computing the the we are looking at what each
05:26:21.840 | element in this jacobian will look like which is what is the jacobian it's each element in the out
05:26:28.320 | in the numerator of the jacobian derived with respect to each element in the denominator of the
05:26:33.840 | jacobian in this fraction here so we are saying for each element in the output vector derived
05:26:42.480 | with respect to each element in the input vector this is what we are writing here so what is how
05:26:48.880 | is the output vector obtained well p_ij we know that it is equal to by the definition of the
05:26:55.040 | softmax is obtained as follows so e to the power of s_ij divided by the normalization factor let's
05:27:07.120 | call it l is equal to one to n e to the power of s_il all derived with respect to s_ik
05:27:24.160 | i k so what we are trying to do is we know that the p vector is suppose it's a vector with the
05:27:32.400 | three elements so this is p_1 this is well p_11 p_12 and p_13 the s vector will be a vector
05:27:45.440 | also with the three elements so it will be the s_11 s_12 and s_13 what we are trying to do is
05:27:54.400 | the calculate what the jacobian will be the derivative of this one with respect to all
05:27:58.480 | the input vector then then the second row of the jacobian will be the derivative of this one with
05:28:03.600 | respect to each of this input element then the third row of the jacobian will be this stuff here
05:28:09.200 | with respect to the derived with respect to each of the input element of the s vector we are trying
05:28:14.720 | to understand what does the generic element in this jacobian look like based on the j date element
05:28:20.800 | of the output vector so this j index refers to the output vector and the kth element in the input
05:28:26.000 | vector all right so what can happen when we do this jacobian is that we have a this one here is the
05:28:37.200 | derivative of a fraction of two functions and we know from high school that the derivative of the
05:28:44.400 | fraction of two functions is as follows so the derivative of the derivative let me write like
05:28:52.160 | this of f of x with respect to g of x prime is equal to with respect to x by the way is equal to
05:29:04.240 | f prime oops of x multiplied by g of x minus g prime of x f of x
05:29:18.880 | all divided by the g of x to the power of two like this now let's apply it here so this will
05:29:28.160 | become here we will have two cases either the variable that we are deriving with respect to
05:29:33.840 | so this sik has the same index as the variable being derived so either we are doing a p11
05:29:41.840 | with respect to s11 or we are doing a p11 with respect to something else that has not the same
05:29:47.440 | index so like p11 with respect to s12 or s13 so there are two cases that we need to consider
05:29:53.280 | suppose that we are deriving p11 with respect to s11 or we are deriving p12 with respect to s12
05:29:59.360 | or we are deriving p13 with respect to s13 so we are deriving the element of the output with
05:30:05.360 | respect to the same element in the input with the same index so in this case the this this
05:30:14.880 | derivative will look like the following so it's the derivative of f so the numerator
05:30:21.920 | with respect to the denominator that has the same index so we are saying that in this case
05:30:27.360 | j is equal to k so the numerator with respect to sij with respect to e to the power of sij
05:30:40.720 | with respect to sij will be e to the power of sij so because e to the power of x1 with respect to
05:30:48.960 | x1 will be e to the power of x1 so this is equal to i am reducing the size now
05:30:54.480 | e to the power of sij then we need to multiply that by the denominator of the fraction
05:31:05.360 | which is this summation here so the summation over all possible l of e to the power of sil
05:31:14.320 | minus the derivative of the denominator with respect to the variable being derived so this
05:31:24.000 | denominator is the sum of all the exponentials of all the input elements if we derive it with
05:31:30.400 | respect to one particular input element there will be at least one term that contains that
05:31:35.040 | input element and so the all the other terms will result in zero so the only
05:31:40.400 | derivative that will survive will be the e to the power of sik with respect to sik
05:31:45.520 | so we write minus e to the power of sik
05:31:57.280 | multiplied by the numerator which is e to the power of sij
05:32:00.480 | all this divided by the denominator to the power of two which is this summation here so l equal to
05:32:12.080 | one up to n e to the power of sil all to the power of two and this stuff here will be equal to well
05:32:22.400 | we can see that they this two term this one and this one have a one term factor in common which
05:32:28.480 | is e to the power of sij so we can collect that so e to the power of sij multiplied by the summation
05:32:38.320 | minus e to the power of sik
05:32:51.040 | all this divided by the denominator which is the power of two of this stuff here so let me just
05:32:57.440 | copy and paste it which is let me rotate it also because i don't know why i always write little
05:33:02.000 | little yeah all right and this stuff here is equal to well we can separate the two terms so
05:33:14.000 | we can separate this term here and this term here because the denominator is to the power of two
05:33:20.560 | so we can write it also as e to the power of sij divided by the denominator so which is
05:33:28.960 | summation of l equal one to n e to the power of sil multiplied by this stuff here so this stuff
05:33:40.960 | here divided by the same denominator so there's summation of l equal one up to n
05:33:49.520 | e to the power of sil minus e to the power of sik i am sik divided by the same denominator
05:34:08.400 | sil now this one can be written as this stuff here is nothing more than the output element pij
05:34:18.480 | because this one is just the softmax applied to the sij element which we know that the output of
05:34:25.120 | the softmax applied to the sij element is called pij because it's one element of the output vector
05:34:30.480 | which we call the p so this stuff here is equal to pij multiplied by this stuff here will be equal
05:34:39.440 | to one minus this stuff here what is this stuff here is the output of the softmax applied to the
05:34:46.560 | sik element so it will be pik so it is equal to one minus pik okay and this is in the case the
05:34:59.360 | variable with respect to which we derive has the same index as the numerator in this fraction here
05:35:09.520 | in this derivative here the other case is when the two variables so the output the index of the
05:35:17.680 | output with respect to the index of the input are not the same in this case we will have another
05:35:23.680 | case so we will have that j let me write it again so this stuff here i hope i can copy it all without
05:35:35.840 | in the other case in which s is not equal to j
05:35:39.120 | uh yes it's j not equal to k so j is not equal to k what happens
05:35:47.840 | in this case it will be well the derivative of the numerator because we need to apply again
05:35:55.760 | this formula here so derivative of the numerator with respect to something that is not the same
05:36:00.640 | variable it will be zero because it's like computing the derivative e to the power of x1
05:36:06.400 | with respect to x2 it will be zero so it will be zero so all the first term here will become zero
05:36:14.000 | no matter what is g of x minus the derivative of the denominator of this fraction here with respect
05:36:21.680 | to the variable sik g prime of sik so this is all the variable in the input and we are deriving it
05:36:31.920 | with respect to one particular variable of the input so only one item in the summation will
05:36:37.360 | survive so it will be the item sik so it will be e to the power of sik multiplied by f of x which
05:36:50.160 | is the numerator in this fraction which is e to the power oh we forgot a minus e to the power of sij
05:36:56.240 | let me see if i forgot something all divided by the denominator
05:37:02.800 | of this fraction here to the power of two so it is equal to the summation
05:37:11.040 | l equal one up to n of e to the power of sil all to the power of two
05:37:18.160 | i believe i didn't forget anything so let's continue so here also we can see that this
05:37:26.640 | one here is because uh okay let's separate it minus e to the power of sik divided by
05:37:36.400 | the summation l equal one up to n of e to the power of sil multiplied by
05:37:43.840 | e to the power of sij divided by the summation l equal one up to n of e to the power of sil
05:37:56.160 | this stuff here is nothing more than the softmax applied to the kth element of the si vector
05:38:02.880 | this one here is nothing more than the softmax applied to the jth element of the si vector
05:38:08.480 | so we know what these are we know that we call them p minus pik pij
05:38:17.920 | so in the end we have two cases one is the derivative of this stuff here looks like the following
05:38:30.160 | each item in the jacobian looks like the following when the numerator and the denominator have the
05:38:35.280 | same index so j equal to k this stuff here is equal to now this notation here is wrong so i
05:38:43.760 | shouldn't be writing it with the equal sign but doesn't matter guys it's we are doing a little
05:38:50.560 | okay so pij pij multiplied by one minus pik let me check yes the other case is when the j
05:39:02.480 | is not equal to k then this stuff here let me write it like this will be equal to
05:39:09.920 | minus pik multiplied pij now that we know what the two typical cases of this jacobian look like
05:39:21.760 | let's actually look at what this jacobian look like in the matrix form so this jacobian will
05:39:28.880 | look like the following it will be a matrix that is more or less like the following
05:39:36.320 | it will be an n by n matrix where n is the size of the input vector and the output vector
05:39:42.080 | at here the first element of the jacobian as you saw as you remember the first row of the jacobian
05:39:50.640 | in the numerator convention is the derivative of the first output with respect to all the input
05:39:58.160 | so this first term here will be the derivative of p11 with respect to s11 so in this case j and k
05:40:08.080 | match so we know that it will be equal to p11 multiplied by 1 minus p11 the second element
05:40:16.400 | to the right of this one so the element one two will be the derivative of p12 with respect to
05:40:23.600 | sorry the p11 with respect to s12 the j and k do not match so we will be in this case here so it
05:40:30.240 | will be minus p11 p12 the third element you can check it by yourself it will be minus p11 p13
05:40:42.000 | blah blah blah until the end which will be minus p11 p1n the second row of this jacobian will be
05:40:52.000 | will look like this so it will be the derivative of p12 with respect to s11 the j and k do not
05:40:58.880 | match so we are in this case here so it will be minus p12 p11 then the next element it will be
05:41:09.440 | the derivative of p12 with respect to s12 so j and k match so we are in the first case so it will be
05:41:16.880 | p12 multiplied by 1 minus p12 then this stuff here will be equal to then the third element
05:41:25.840 | will be minus p12 with respect to p13 blah blah blah and until we arrive to the last one which
05:41:33.680 | is minus p12 with respect to p1n not with respect to multiplied by b1 and all the elements like
05:41:41.440 | this until the last row the last row will be the the first element of the last row will be the
05:41:46.160 | derivative of the last output element with respect to the first input element so it will be the
05:41:53.520 | derivative of p1n with respect to s11 so the two indices do not match so we are in the second case
05:42:02.880 | so it will be minus p1n p11 this will be minus p1n p12 etc etc etc now let me do also the
05:42:16.000 | third element since we are here so minus p1n p13 etc etc etc until the last element of the
05:42:24.560 | last row which will be minus p1n p1n i guess oh oh no that's wrong guys because the two indices
05:42:35.440 | match so it should be p1n multiplied by 1 minus p1n this is what the jacobian will look like
05:42:45.360 | let's see if we can find a better how to generate this jacobian with some pattern recognition
05:42:51.920 | let's write it in a different way first of all the thing first thing that we can notice is that
05:42:57.600 | this jacobian is symmetric so you can see that this element is equal to this element if you
05:43:02.000 | expand the third row you will see that it's equal to this element this one on the top right corner
05:43:07.040 | is equal to the one in the top bottom left corner so this matrix is symmetric
05:43:14.720 | the second thing that we can notice is that only the element in the diagonal are different
05:43:20.880 | they have an additional term because you can look at this element here so let me write this element
05:43:28.320 | here can also be written as p11 minus p11 multiplied by p11 the second element here
05:43:37.600 | in the second row so the second diagonal element of this matrix is p12 minus p12 multiplied by p12
05:43:48.320 | so this element on the diagonal actually look like just like the other elements they just have
05:43:53.920 | an additional term which is p11 in the first diagonal element p12 in the second diagonal
05:44:02.560 | element so we can also say that this matrix here is the product of all the possible combinations of
05:44:11.200 | p_ij with p_ik which we can obtain with an outer product or even with the product of one column
05:44:20.320 | with the transpose of the same column so if you do one column vector for example imagine p is a column
05:44:27.840 | vector and you do p multiplied by p_t you obtain all the possible combinations of products of these
05:44:34.640 | two vectors because this will be one i can do a simple case so p11 p1 let's call it p2 p3
05:44:43.040 | multiplied by the row vector p1 p2 p3 this will generate all the possible combinations of products
05:44:53.440 | between p1 and the p the first vector and the second vector because this will be a three by one
05:44:59.680 | this is one by three so it will be generated three by three vector and it will be equal to p1 p1
05:45:06.960 | p1 p2 p1 p2 p1 p3 etc etc etc moreover we can see that in the diagonal
05:45:18.480 | of the matrix we have this additional term this additional term p1 in the first diagonal element
05:45:26.880 | p1 p12 in the second diagonal element p13 in the third diagonal element i actually call it p1 it's
05:45:34.720 | wrong because i should call it p_i that's why i didn't want to bring the i indices so it's not
05:45:40.000 | really p1 it should be p_i p_i p_i p_i because we are doing it for the generic height p_i vector
05:45:49.040 | so let me fix the indices p_i_n p_i_3 this is one
05:46:05.440 | p_i and p_i okay so this is p_i p_i p_i p_i p_i p_i p_i okay we can obtain so we can write the this
05:46:21.440 | this jacobian here also as the diagonal matrix that in the diagonal has all the
05:46:30.240 | element of the p_i vector minus the p vector multiplied by the transpose of itself so with
05:46:38.800 | itself but transposed because we need all the elements to be kind of a combination of one
05:46:44.720 | element of p with itself with another element of p plus only on the diagonal we need some this
05:46:50.320 | additional term which are the elements of p and all the elements of the the the output of this p
05:46:56.800 | multiplied by p transposed are negated that's why we need this minus sign so if you look at the
05:47:01.840 | flash attention paper they give you this formula here they say that if y is equal to the softmax
05:47:08.160 | of x then the jacobian will look like the following will be diagonal
05:47:21.360 | of y minus y y transposed where y is the is a column vector
05:47:32.080 | all right guys i know this has been long so let's take a pause and we are going to now
05:47:38.320 | code finally first of all let's check the mathematics of the backward path of flash
05:47:44.160 | attention we will see it briefly i will not do any more derivation but i will explain it
05:47:49.920 | and then we finally switch to coding it so let's go all right guys now finally we can see the
05:47:57.120 | the backward path of the flash attention so we will be looking at the algorithm and if you look
05:48:03.760 | at the the the appendix of the flash attention paper you will see this part b.2 where they derive
05:48:09.680 | the backward path step by step now i don't want to do all the same all the steps of this derivation
05:48:17.120 | because it's going to be too long but i want to give you all the tools necessary to understand it
05:48:21.920 | now let's start from what kind of what say conventions they are using
05:48:30.240 | notations they are using in this paper so the first thing that we need to rehearse is the
05:48:36.640 | naming of what is what is the name of each matrix as you know in the forward attention in the
05:48:44.080 | forward pass we do the query multiply by the transpose of the key and the output of this we
05:48:48.400 | call it s then we apply the softmax to this s matrix and it becomes the p matrix the softmax
05:48:55.840 | is applied by rows then we talk take this p matrix and we multiply by a v matrix to obtain the output
05:49:02.560 | of the attention let's look at for example how the computation of the height row of the output
05:49:12.080 | is computed based on the p matrix and the v matrix so we can understand this kind of notation that
05:49:17.120 | they are using here in the paper because the way i read this formula here is the height
05:49:22.320 | row of the output which is a column vector because in when we write in in mathematics in
05:49:28.560 | linear algebra whenever we write the name of a vector it is always by convention a column vector
05:49:35.440 | but the origin of this particular vector is actually a row of the output matrix let's try
05:49:42.320 | to understand what is the output row of a matrix in a matrix multiplication now so that we can
05:49:50.960 | understand how to go from here to here so let's write a generic matrix multiplication for example
05:49:58.400 | an a matrix let's say that it is the following and we only write one row actually let me zoom again
05:50:07.440 | and i want to write smaller so we have enough space so we make a matrix that has a row let's
05:50:14.000 | call it a 1 a 2 a 3 and then we multiply this will be a matrix with many rows like the this one
05:50:23.760 | because we want to study the effect only of one row and we multiply it by another matrix let's
05:50:30.000 | call it this one is the matrix a and it has i don't know let's say n rows by three columns then
05:50:39.200 | we should have another matrix b with three columns and some number of three rows and some number of
05:50:47.280 | column let's say four columns so we call the first row let's call it let me zoom a more so it's b11
05:50:58.640 | b12 b13 b14 then this one should be b21 b22 b23 b24 this should be b31 b32
05:51:17.120 | b33 b34 etc i know i am not very rigorous in my notation i should have called all these elements
05:51:26.960 | with the capital letter a and the capital letter b so this is the notation that you use when
05:51:31.920 | referring to single item of a matrix but please forgive me for this so the output of this matrix
05:51:40.080 | multiplication will be another matrix that is n by 4 so it will be n by 4 so we will have four
05:51:48.640 | columns for each row of the output i want to write the output in a different way so i want to write
05:51:58.400 | it as follows as a vector only so the first output row as a vector and want to understand what is
05:52:06.560 | each dimension of this vector so because otherwise i don't have enough space to write it here
05:52:11.680 | so the first let's write it so let's call it o i want to write what is o of one which is the first
05:52:22.720 | row of the output but written as a column vector so o of one will be here we should use the small
05:52:32.400 | letter o of one should be a vector where the first dimension is the dot product of this stuff here so
05:52:40.320 | the first row of the a matrix with the first column of the b matrix so the first let's say
05:52:47.760 | dimension will be a1 with b11 i should also call this one a11 a12 actually
05:53:00.160 | and a13 so a13 because we have many rows in the a matrix so let me use the correct naming so this
05:53:10.240 | will be a11 with b11 a11 b11 plus a12 multiplied by b21 plus a13 with b31 and this will be the
05:53:28.000 | first dimension of the first row of the output matrix o the second dimension of the first row
05:53:36.000 | of the output matrix o will be the dot product of this row of the a matrix with the second column
05:53:43.040 | of the b matrix and let me write here b so it will be a11 b12 plus a12 b22 plus a13 b32
05:54:03.040 | the third dimension will be a11 b13 plus a12 b23 plus a13 b33 the fourth dimension will be a11
05:54:24.080 | b14 plus a12 b24 plus a13 b34 now this is the output the first output row of the o matrix and
05:54:40.160 | it's a vector called o1 and these are this is the first dimension of this vector this
05:54:45.120 | is the second this was the third and this is the fourth dimension and each of this stuff here is
05:54:49.280 | one scalar um so the output o1 which is the first row of the output matrix can also be written as
05:55:01.520 | the first element as you can see in is a sum of many vectors where the first element is a11
05:55:13.040 | multiplied let me use a smaller this one but i want to use a smaller i can't change the size here
05:55:19.760 | okay it doesn't matter so as you can see here there is a1 multiplying a different b number
05:55:26.880 | every time so this is a b11 b12 b13 b14 what is b11 b12 b13 b14 it is the first row of the b
05:55:36.400 | matrix so it is equal to b1 and all the dimensions of the first row then plus then we have the
05:55:46.720 | element a12 multiplied by b21 b22 b23 etc etc and this is the second row of the b matrix so we use
05:55:57.440 | the tensor notation of pytorch to describe this row which is a b2 and all the dimensions of b2
05:56:07.280 | so it looks this is a vector scalar product and plus a13 multiplied by b3
05:56:22.320 | and all the dimensions of b3 this one can also be written as the summation
05:56:29.280 | over all possible i that go from 1 to 3 where 1 to 3 is how many columns there are in the a matrix
05:56:42.400 | of a ij well a1 let's call let's call this one j actually sorry let's call it j
05:56:53.360 | equal to 1 and let's call this the generic ith row of the output matrix will be a i1 a i2 and a i3
05:57:07.840 | each one multiplied by the corresponding row in the b matrix so we can write it as a i j
05:57:16.560 | multiplied by b j where b j is the a row of b we can also write it like this to indicate that this
05:57:29.920 | is a vector and this is exactly what they do here so the output in the output matrix when we do the
05:57:37.360 | multiplication p multiplied by v the ith row of the output matrix we call it o i which is a vector
05:57:46.320 | but by notation it is a column vector where the elements of this column vector are actually the
05:57:52.160 | elements of the ith row of o this is only by notation guys is equal to the ith row of p so
05:58:02.160 | the ith row of the matrix that is on the left in the matrix multiplication multiplied by all the
05:58:07.440 | columns of the v matrix which can also be written as the summation over all the elements of the
05:58:13.680 | ith row of p so all the elements of the ith row of the first matrix the one on the left in the
05:58:19.280 | matrix multiplication multiplied by each vector in the v matrix where the jth matrix here in v
05:58:27.200 | is each row of the v matrix so and p i j can also be written as p i j is what is the the output of
05:58:38.480 | the softmax so as you know the output of the softmax is e to the power of l the element input
05:58:44.640 | of the softmax what is the element input of the softmax is the query multiplied by the transpose
05:58:49.280 | of the keys so it's a dot product between one query and one key and that's why you have this
05:58:54.080 | stuff here in the exponential so this is the first step in understanding this derivation another
05:59:00.160 | thing that we have studied so far is how to derive the backward path of the matrix multiplication
05:59:06.560 | and of the softmax so now let's use it in the matrix multiplication let's rehearse the formula
05:59:13.600 | so if given a matrix multiplication that is y equal to x multiplied by w we know that given
05:59:21.600 | the gradient of the loss function with respect to y so the output of this operation we know how
05:59:27.600 | to derive the gradient of the loss with respect to one of the input of this function which is the
05:59:33.200 | x or w to get the gradient with respect to x we need to take the upstream gradient so the
05:59:40.080 | the gradient with respect to the output multiplied by the transpose of w t and to get the gradient
05:59:47.520 | with respect to w we need to do the xt so the input transposed multiplied by the upstream gradient
05:59:55.600 | this one is the formula that we didn't derive and this one is the formula that
05:59:59.440 | we derived but how to derive them is exactly the same procedure
06:00:03.040 | in attention we are doing the last product that we are doing is o equal to p multiplied by v
06:00:12.880 | what pytorch will give us as input during the backward path is the gradient of the loss with
06:00:18.880 | respect to the output and we need to use this gradient of the loss with respect to the output
06:00:23.680 | of the attention to derive the gradient of the loss with respect to q with respect to k and with
06:00:29.600 | respect to v so that it can then be used by the operators in the backward path in the in the
06:00:35.920 | computation graph in the operations before okay so but in order to arrive to the gradient with
06:00:43.040 | respect to query key and value we need to derive the gradient with respect to each intermediate
06:00:48.480 | operation so the last operation that we do is o equal to p multiplied by v so the gradient with
06:00:55.840 | respect to o of the loss with respect to v given the gradient of the loss with respect to o
06:01:04.160 | it is exactly like computing the gradient of the of the loss with respect to x in a matrix
06:01:10.400 | multiplication and we know that it is equal to pt so just by analogy guys so this is our
06:01:17.280 | reference point and i am just changing the names here and you should understand what is the analogy
06:01:21.920 | here so the gradient of the loss with respect to v which is the matrix on the right which is like
06:01:28.960 | computing it with respect to w it is equal to just like this formula here so the transpose of
06:01:34.960 | the matrix on the left multiplied by the upstream gradient which in the paper they write it as this
06:01:40.960 | so dv is equal to pt multiplied by do and it's the formula that you said you can see here the other
06:01:48.080 | derivation is how to derive the gradient with respect to dp dp is just like deriving the
06:01:54.400 | gradient of the loss with respect to the matrix that is on the left side of the matrix multiplication
06:01:59.120 | so it is just like deriving the gradient of the loss with respect to x in the reference
06:02:04.400 | formulas which is equal to the upstream gradient multiplied by the transpose of the other matrix
06:02:11.840 | which in the notation of the paper they write it as dp is equal to do multiplied by v transposed
06:02:18.000 | and it's this formula here how they compute this stuff here is exactly as above so as this
06:02:25.520 | derivation here they call vj the jth row of the v matrix and they write it as pij multiplied by do
06:02:37.680 | how to arrive to this formula here well let's do it so let me write let's see okay
06:02:45.680 | theoretically we know that from this derivation here so from this derivation here or from this
06:02:51.440 | derivation here we know that the i-th row of the output in a matrix multiplication
06:02:57.360 | first of all let's simplify our life every time you see a transpose and you don't like work with
06:03:02.000 | the transpose in a matrix multiplication just give it a different name and then work with the
06:03:06.640 | different name and after when you have derived the formula you resubstitute the transpose
06:03:12.480 | operation in this case we are doing dv is equal to p transpose multiplied by do let's call p
06:03:19.920 | transposed let's give it a name that we are we didn't use so far so let's call it f i always
06:03:24.800 | use f when it's available so we call dv is equal to f do we know from above here from this derivation
06:03:38.720 | here or this derivation here is equivalent that the output of a matrix multiplication so the out
06:03:45.920 | i-th row of the let's know the j-th row let's call it the j-th row dvj is equal to a summation
06:03:55.680 | of each element of the j-th row of f of the first matrix so we do the let's see here we do the sum
06:04:06.720 | by i so let's do it by i it's the sum over all possible i of the i-th element in the j-th
06:04:16.880 | row of the first matrix so fji multiplied dot product not dot product this is a scalar vector
06:04:32.800 | multiplication multiplied by a vector that is let me check what was the formula so it was the j-th
06:04:40.000 | row of the other matrix so in this case it should be the i-th row of the other matrix
06:04:45.760 | o of i where i this is the i-th row of i this is the j-th row of the v matrix
06:05:00.480 | and but also we don't we know that f is not a matrix that we have it's actually the transpose
06:05:06.400 | of p which means that fji will be equal to pij because in a matrix transposition you invert the
06:05:13.440 | two indices so this is the summation over all possible i's of p not ji but ij multiplied by o i
06:05:24.320 | and this should be equal to the same formula that you see on the right here this allows you to
06:05:28.960 | compute one output row in the v matrix okay and we know that pij is just the output of the softmax
06:05:39.440 | the output of the softmax is the input of the softmax to the exponential of the input of the
06:05:45.200 | softmax divided by the normalization factor associated with that row so because we are
06:05:52.960 | iterating through the row of i it will be the i-th the normalization factor associated with that
06:05:59.600 | row of of o i so we know that the formula for the p is equal to the softmax of s now the i-th
06:06:11.840 | row of p will be the softmax of the i-th row of s and this is what is written here we know from
06:06:18.560 | our derivation that the jacobian with respect to the softmax operation so if we have an input x and
06:06:26.640 | the output is y of the softmax operation the jacobian of this of the y with respect to the x
06:06:33.440 | is equal to the diagonal y it's a diagonal matrix of the element of the factor y minus y multiplied
06:06:41.120 | by y transposed and we have also seen before that this matrix is symmetric however you may not
06:06:48.560 | understand this formula here because we have seen from our in the chain rule we always write it like
06:06:55.200 | this we always write that the downstream gradient so the d phi of let's say t x should be equal to
06:07:07.200 | the upstream gradient so d phi with respect to d y multiplied by d y and with respect to d x
06:07:17.440 | this only works if you make this matrix here as a in the numerator convention the numerator
06:07:26.160 | convention is one of the two convention in which you can create a jacobian we so far we have always
06:07:31.760 | written it as the numerator convention if you use the numerator convention this is a row vector and
06:07:38.480 | this is a row vector however if you want to treat this stuff here as a column vector then you need
06:07:45.920 | to take the transposed or you need to make the jacobian in the denominator convention
06:07:50.720 | how to get this formula here because this formula here is basically doing the
06:07:55.520 | jacobian multiplied by the upstream gradient not the gradient upstream gradient multiplied by the
06:08:02.160 | jacobian and it's only because here we treat it as a column vector and when you do the you want to
06:08:08.240 | transform a row vector into a column vector you take the transpose of both sides of the equation
06:08:12.000 | and let's do it actually so we apply the transpose to the both side of the equation
06:08:20.560 | okay in a matrix multiplication if you do a b transposed it become b transposed multiplied
06:08:30.480 | by a transposed so the transposed is applied independently to each input of the matrix
06:08:35.840 | multiplication but we invert the matrix multiplication and if you remember the
06:08:39.440 | matrix multiplication is not commutative so what we do here is that we say okay it will be the
06:08:45.840 | dphi of dx and here they call it
06:08:48.960 | here they call it dsi so it will basically just become d phi on dx if you treat this
06:09:01.200 | one as a column vector so this one as a column vector will be equal to dy on dx as a column
06:09:09.120 | vector as a jacobian in the denominator layout in this case multiplied by d phi on dy as a column
06:09:19.760 | vector this one is a column vector this is a column vector and this is what you see here that's
06:09:24.000 | why the jacobian is on the left side of the upstream gradient what else we need well i i
06:09:31.120 | know that there is a lot of things here in this derivation but i prefer actually going directly
06:09:35.840 | to the code otherwise i think it's going to be too boring um so let's go to the code and while
06:09:41.920 | writing the code i go back to the formulas in which we can find the association of what we are
06:09:47.360 | doing and the formula in the paper i think this is the best way so let's proceed further all right
06:09:54.640 | guys now we can finally code the backward pass before we code the backward pass let's look at
06:09:59.040 | the algorithm of the backward pass as written in the paper this is the paper flash attention
06:10:03.920 | one and i will be because we will follow the structure of the code that is present on the
06:10:10.560 | triton website so it's not my idea to split it like this but i simplified it in such i simplified
06:10:17.840 | it so it's different than the one that you can find online because mine is a simplified version
06:10:22.960 | and mine works with the causal and non-causal attention um so first if you look at this
06:10:29.360 | algorithm you need to you can see that we have an outer loop through all the k and v blocks
06:10:36.240 | and an inner loop through all the query blocks however as you can see to compute the dq which is
06:10:44.080 | the downstream gradient of the the loss with respect to the q matrix we need to have an
06:10:51.920 | iteration through all the k's and to compute each dk block we need to have an iteration through all
06:10:58.960 | the queues so if we follow the loop like it is it would involve writing to the high bandwidth memory
06:11:06.880 | so to the dram of the gpu at every inner iteration and that could be also that is not so efficient
06:11:13.120 | and also if we don't want to write it would require some sort of
06:11:18.080 | some sort of synchronization between blocks which is also not very efficient
06:11:24.080 | so we split we will split this four into two parts because we can see that each dq depends
06:11:30.800 | on a loop over the k's and each dk depends on a loop over all the queues so to compute dk we will
06:11:41.200 | fix the kth block and iterate through all the q blocks then we will do another iteration in which
06:11:46.640 | we fix the q block and iterate through all the kv blocks to compute the dq this is what we are
06:11:52.320 | going to follow and this is an idea that i took from the original implementation that is present
06:11:56.560 | on triton website another thing that we can notice here is um where where is it here to compute the
06:12:05.360 | dq and dk so a dq vector and the dk vector we need this element this information here called the di
06:12:15.680 | di and it's shared between the two so we can pre-compute it and then we can reuse it for the
06:12:21.840 | qi vector to compute the qi vector and the dk vector what is this di di is um is uh introduced
06:12:32.320 | here and it's the dot product of a vector that is the doi vector multiplied by o vector so the first
06:12:41.040 | thing that we will do is do a loop over all the vectors in o and do and do their dot products
06:12:48.000 | to compute this di element then we will use this di element and actually uh let me see yeah and
06:12:56.160 | then we will use this di element to update to to compute dq and dk and we will also have another
06:13:02.960 | two loops one in which we fix the q and we iterate through all the keys and one in we fix the keys
06:13:08.400 | and iterate to all the queues so let's start so now that we know more or less the structure of
06:13:14.800 | the code that we're with all right so we start by writing this backward function here
06:13:23.920 | uh let me check yeah okay so do you remember this is saved tensor these are all the information that
06:13:35.360 | we save during the forward pass uh to compute the backward pass now to to optimize the memory
06:13:43.760 | utilization in flash attention we don't save the query multiplied by the transpose of the
06:13:50.160 | key matrix because that would be a sequence by sequence matrix that is too big to save into the
06:13:55.680 | hbm in the dram during the forward pass and then i re get it back from the hbm into the local memory
06:14:02.480 | because i want to remind you that in triton uh compared to cuda in triton what we do is we load
06:14:09.040 | stuff from the high bandwidth memory in the shared memory so the sram we do all the operations there
06:14:15.920 | and then after when we call the store method we save the element from the shared memory into the
06:14:22.000 | high bandwidth memory so in order to not materialize this s matrix in its entirety save it to the hbm
06:14:30.160 | and then reget it back which could be very slow and secondly actually it is very expensive because
06:14:35.840 | usually right now we are computing attention on thousands and thousands of tokens so imagine
06:14:41.200 | saving a matrix that is 5000 by 5000 that's a big matrix to save for each batch uh for b each batch
06:14:49.280 | and for each head so that would be really too expensive to save so the idea in flash attention
06:14:56.480 | is to recompute what we can compute on the fly during the backward pass because any way if we
06:15:01.760 | were to load it it would be memory i/o bound so it's faster to recompute than to save it and restore
06:15:09.200 | it from the memory this is the idea of flash attention okay so we saved some stuff during the
06:15:16.080 | forward pass and now we can access it back during the backward pass and this stuff is saved in the
06:15:21.440 | context and this it's a it's a kind of a dictionary that is made available by by pytorch all right so
06:15:29.600 | we get back the query key and values and as you know pytorch during the autograd will just give
06:15:35.600 | us the gradient of the loss with respect to the output of our implementation of the attention
06:15:41.440 | of our attention so this is triton attention and then we need to compute dq dk and dv by using only
06:15:49.200 | the gradient of the output with respect to the the loss with respect to the output um we do for
06:15:54.800 | some checks so here i know i could optimize this code and make it even smaller by for example
06:16:02.080 | checking that here the stride that i am using i actually inside of the code i always uh pretend
06:16:07.840 | that the stride is the same but uh doesn't matter i just take the code from triton and uh try to
06:16:14.000 | simplify it my goal was to simplify it not optimize it so all right we create the um the
06:16:21.520 | vectors the tensors in which we will store the result of this backward pass which is the dq dk
06:16:28.320 | and dv and as you know from what we have seen of the definition of the gradient the size of the
06:16:34.960 | output of the gradient vector is the size of the vector with respect to which we calculate the
06:16:41.760 | gradient because in the numerator is always a scalar and we compute the gradient with respect
06:16:46.080 | to all the elements in the input vector so the output the gradient itself is a vector of the
06:16:51.200 | same size of the element by which we compute the gradient with respect to so uh we get some
06:16:58.480 | information on the bed size blah blah blah and later we will see what is this number of warps
06:17:04.400 | and the number of stages i will not explain it now it's how pytorch number of parts warps is an
06:17:11.360 | indication on how many threads we want to launch in our grid and number of stages is next to the
06:17:15.600 | number of stages that has used in software pipelining we will see later what is software
06:17:19.360 | pipelining when we talk about the auto tuning then we define some uh blocks uh in the original um
06:17:29.120 | in the original code i think they call it a block kv1 kv2 q1 and q2 i think it was confusing i call
06:17:37.920 | it a block macro and block micro because the thing that we will fix and the things that we will
06:17:42.480 | iterate from will be once it's the query so we fix the query block and we iterate through all the
06:17:48.800 | keys and then we will fix the keys and values block and we iterate through the queries the one
06:17:55.200 | that we iterate on is the micro one and the one that we fix is the macro one this is my uh the
06:18:01.360 | naming that i am using um then we as i said before we need to pre-compute the di elements that we saw
06:18:09.440 | in the paper before so that's the first kernel that we are going to launch and this kernel will
06:18:14.800 | have its own launch grid because later we want to optimize the the tuning of this kernel later
06:18:21.520 | we will talk about the tuning with respect to its own parameters so let me see what are we going to
06:18:29.280 | do so here so the first kernel that we are going to launch is this pre-process kernel this pre-
06:18:35.680 | process kernel will pre-compute all the di elements that we need to compute i remember dk and dv if i
06:18:43.120 | know dq and dk and this di element depends only on o and do um so let's do it and let's create
06:18:55.360 | another function called the backward preprocessor what is the process preprocess grid this is the
06:19:01.440 | launch grid of this function of this kernel and this will be launched on a independently for each
06:19:10.000 | batch and for each head and moreover it will be work with a block size of vectors of o what is
06:19:17.440 | this block what is this number of vectors of o it will be the block size macro so on 128 vectors of
06:19:25.760 | o so uh let me copy the signature of this function this is here so let's write it here i think it's
06:19:34.960 | fine yeah okay this function takes a the matrix o so it's a pointer to the matrix o it's a pointer
06:19:43.280 | to the d o and it's a pointer to the matrix d where we will store this di elements and we have
06:19:49.600 | one for each vector in the output that's why the shape of this d is a batch size number head
06:19:57.840 | sequence length it means it's one for each of the output element in the output of the attention
06:20:03.040 | this di where is it actually it's not this one it's this one yeah like m so it has the same shape
06:20:11.760 | as m which is as you can see it is this size here so batch size number heads and sequence length m
06:20:18.320 | if you remember is the matrix that we saved during the forward pass which includes the
06:20:22.160 | normalization factor of the softmax and also the maximum element but in log sum exp format so that
06:20:29.920 | when we apply it will automatically apply the maximum element for each row and also normalize
06:20:34.560 | at the same time which i think i proved previously so let me do it so we write it like this so we
06:20:44.560 | extract the the index of this program so this program has two index like identifier this is
06:20:55.600 | equivalent to the cuda identifier and this is along the axis 0 so let's see what we what we
06:21:01.520 | what did we launch on the axis 0 so on the axis 0 of this launch grid we defined what is the block
06:21:08.400 | of vectors of the o that this particular will program will work with and the second axis is
06:21:17.280 | which batch and which head inside of each batch this particular program will work with so this
06:21:23.760 | identifies the block index of q so which group of vectors in the o matrix this particular program
06:21:30.400 | will work with here is called q i believe because i copied it from the original code where they call
06:21:36.160 | it q but i could have eventually also call it o um so we define uh so basically this means that we
06:21:45.440 | are for this program we need to skip some query vectors that have been already or that will be or
06:21:52.160 | have been already processed by other programs in parallel so we will only block with a number of
06:21:58.560 | query vectors inside of o that have the following indices so imagine that the query block size is
06:22:06.480 | i think it's 128 the way we have defined it but suppose it's a 4 for simplicity so this one will
06:22:14.480 | be and the query vectors are how many are sequence length number of query vectors we have so some of
06:22:22.960 | imagine the query vectors are in total they are i don't know let's say uh 64 and 32 will be managed
06:22:30.960 | by other programs so this particular of skew will be equal to 33 34 35 and 36 this tells me which
06:22:40.560 | query vectors or which vectors in the output o matrix among all the vectors in the o matrix this
06:22:46.880 | particular program is going to work with okay so then we extract also the index of the batch which
06:22:54.800 | tells us which batch and which head in each batch this particular program is going to work with
06:23:01.040 | which is the dimension one of our launch grid and then we define the offset of the dimension
06:23:07.440 | because we need to load all the dimensions of each vector so these are the it's a vector that tells
06:23:14.320 | which dimensions we need to load from each vector and we will load all of them so we don't divide on
06:23:18.400 | the head dimension dimension we just divide on the sequence length dimension the the load among
06:23:26.080 | multiple programs um you will see in this part of the the video so when we are writing the backward
06:23:33.280 | pass that we will not be using the make block pointer like we did during the forward pass
06:23:39.040 | so this function here we will work with directly with indexing by using the strides so let's do it
06:23:47.200 | so let's load a single block of rows of o which i want to remind you has the same shape as q and
06:23:57.040 | that's why we can call it block size q so the o block that we are loading is o so uh the load
06:24:03.840 | function accepts a pointer to what it should load actually not a pointer it accepts a array of
06:24:11.360 | pointers or a multi-dimensional array of pointer in case you want to load a multi-dimensional data
06:24:16.720 | so actually load also allows you to load two-dimensional data in this case we are going
06:24:23.360 | to load two-dimensional data which is a block of rows of o which should be a block a tensor of the
06:24:31.520 | shape block size q in this case multiplied by the other dimension being head dimension
06:24:39.120 | but we don't we need to tell it where in this o matrix it needs to find this one first of all we
06:24:46.160 | need to skip some batches and some heads based on what the head and the batch that will be processed
06:24:52.800 | by other programs so based on the index that this um program will process of the batch and the head
06:24:59.920 | we need to skip all the other batches and heads let's write the shape of this tensor so the o
06:25:07.840 | tensor has a shape block size not block size batch size a number of heads then sequence length
06:25:17.360 | and then head dimension each block and each head will have a sequence length multiplied by dim
06:25:25.360 | head dim number of items so based on our index we skip how many items our index multiplied by head
06:25:32.960 | dimension multiplied by sequence length so what i mean is this the batch zero and the head zero
06:25:40.480 | will have a sequence length multiplied by head dimension items the batch zero and the head one
06:25:47.360 | will also have the same number of items and the batch zero and head two will also have the same
06:25:53.120 | number of items so how many items sequence length multiplied by head dimension do we need to skip
06:25:58.160 | from the starting of the o tensor it is equal to the index of the current batch and head indicator
06:26:04.320 | so because this index indicates both the head in the batch and the head inside of each batch because
06:26:11.360 | it's already the product of the head and the batch so how many we skip indicated by the this index
06:26:20.640 | and after we point to this starting point of the current batch and the current head
06:26:25.760 | we need to select a two-dimensional tensor where the offsets are indicated for the rows by off skew
06:26:33.840 | and that's why we have this one um the i don't know what this is called this is uh
06:26:41.600 | the the index uh semi-colon index that tells all the all these vectors in off skew will with an
06:26:50.880 | additional dimension for the columns and these columns will be the off dim so basically this
06:26:56.160 | will select a tensor of the following shape inside of this big tensor that includes head size and
06:27:03.920 | number of heads this is what we are doing so we are saying select a tensor of this size
06:27:11.280 | inside of one that is made up of four dimensions by skipping the elements of all the batch and
06:27:18.080 | heads that will be processed by other programs i always talk in terms of programs because in
06:27:24.400 | triton these are called programs in coda you would refer to them as kernels
06:27:28.160 | all right so this one is done i hope it is decently clear um all right so then we also load
06:27:38.560 | a single block of d o in the same way because we are going to load a group of vectors from all the
06:27:46.880 | sequence length also from d o and the d o has the same shape as o which has the same shape as q and
06:27:54.320 | that's why we can use the um the the block index we call it q because it's equivalent because they
06:28:00.080 | have the same shape okay and how to compute this d i element well it's written in the paper so if
06:28:08.160 | we go in the in the what is it man if we go here it shows you how to compute the d i of each given
06:28:17.280 | a block of d o and a block of o it tells you how to compute d i which is the row sum which means
06:28:25.200 | the sum of by rows for each row we will have one sum for each vector in the o matrix we will have
06:28:32.080 | some of the element wise product so this stuff here is the element wise product of d o i multiplied
06:28:39.360 | by o i so it's not a matrix multiplication it's element wise product which means each element of
06:28:46.480 | one matrix with the corresponding element of the second matrix and the output shape it will be the
06:28:51.280 | same as the two matrices which must have the same shape okay so we compute this d i block
06:29:00.960 | which will have shape block size q because we will have one sum for each vector
06:29:06.560 | then well we need to store it somewhere so we need to calculate where to store it
06:29:12.800 | inside of the d matrix well the d matrix is i remember correctly has the same shape as m so
06:29:20.240 | it should be batch size a number of heads and sequence length so we need to select the right
06:29:29.120 | batch and the right head and also the right position inside of the sequence length based
06:29:33.440 | on the block index q that we have okay so let me index okay
06:29:43.120 | all right because we already um so we skip um again just like before we know that the d is
06:29:53.680 | of this size each batch and each head will have sequence length number of elements so how many
06:29:59.520 | number of elements we need to skip from the starting of the tensor is sequence length
06:30:04.960 | multiplied by the combined index batch size head number and plus we need to also skip some
06:30:12.640 | queries based on our block index q and it's already this skipping is already done inside
06:30:18.480 | of off skew so we add off skew and then once we have computed the index where we should store this
06:30:24.720 | d i block why did i even call it d block let's store it so let me i didn't call it d block i
06:30:34.320 | think it was already in the original code but this is d i and this big matrix d is actually
06:30:40.240 | the matrix that includes all the d i for one for each token in the sequence length
06:30:45.600 | all right so the pre-processing has been done now we need to do prepare the two for loops as you
06:30:54.480 | remember i said before we will be doing two for loops one in which we fix the query and we iterate
06:31:00.880 | through all the keys and values and one in which we fix the key and value block and we iterate
06:31:05.120 | through all the queries and while coding it i will always show you the formula from the paper so
06:31:10.720 | don't worry let's start with the next iteration so first we create the launch grid for the next
06:31:16.720 | iteration as the launch grid is always the same so we first because we we need to keep one block
06:31:24.000 | fixed and iterate through all the other blocks the block that we keep fixed will define how many
06:31:29.600 | programs we have that run in parallel and the block that is fixed has a block size macro number
06:31:35.760 | of elements that's why we create a sequence length divided by block size macro number of blocks
06:31:41.120 | thread blocks or programs in this axis the axis two in this grid is i could have used
06:31:50.240 | also the axis one indifferently i think it was already done here in the original code
06:31:55.520 | it's we will indicate which batch and which head inside of each batch we are going to work with
06:32:01.920 | so and just like the forward pass we will also use a variable called the stage that if the
06:32:10.560 | attention that we are computing is causal it will be equal to three and if we are computing
06:32:14.880 | a non-causal attention then it will be equal to one the first iteration we will fix k and v blocks
06:32:22.960 | and we will iterate through all the q blocks in size of block size micro number of query vectors
06:32:32.000 | so let's look at the signature so we pass we launch it as a launch grid because
06:32:38.800 | and we have defined how many programs we have so we have how many kv blocks we will have
06:32:47.040 | it's a sequence length divided by the block size macro because that's the the block that we will
06:32:51.920 | keep fixed in this uh for loop in this function and then we go through all the query blocks in
06:32:59.200 | size of block size micro which i defined it as 32 and later we will talk about auto tuning and
06:33:05.600 | how to tune these values all right so i pass the query vector the key vector and the v vector uh
06:33:12.560 | sorry not vector tensors now the query tensor k tensor and v tensor and they are pointing to the
06:33:18.800 | beginning of the tensor which means that they are beginning to the first batch and the first
06:33:23.520 | head and the first token and the first dimension of the tensors then we pass the softmax scale we
06:33:30.400 | pass do dq dk and db m is the one that is needed to compute as you remember from what we said
06:33:39.040 | before we did not see the p matrix in the hbm because we want to recompute it on the fly doing
06:33:45.760 | the backward pass so the query multiplied by transpose of the keys it's a very big matrix
06:33:49.760 | to save in the hbm and restore it so we want to compute it on the fly but we don't need to
06:33:54.560 | recompute the normalization factor and the maximum element for each row to apply the softmax that was
06:34:00.480 | already computed during the forward pass and saved into this matrix m which includes the log sum exp
06:34:06.800 | of the maximum of each row plus the logarithm of the normalization factor with the log sum
06:34:14.080 | x to 3 we can just apply it and it will also normalize each value then we have the d vector
06:34:21.120 | tensor that we computed here with all the di values one for each vector in the o
06:34:27.120 | tensor then we need to pass some the number of heads the sequence length the block size that
06:34:34.320 | we want to use for the kv which is the macro block size and the micros block size is always
06:34:38.480 | the one that we iterate on i think using this name it should be easier to understand which one we are
06:34:43.440 | iterating and which we want to keep fixed so the fixed one is macro and the iterating one is the
06:34:48.160 | micro head dimension later we will see why we use a different block size to iterate from because
06:34:56.320 | this is related to the number of stages that triton can divide your for loop into thanks to
06:35:03.200 | software pipelining then we have head dimension the stage indicates if the attention that we
06:35:10.000 | computed in the forward pass was causal or not causal the number of warps and the number of
06:35:15.840 | stages which we defined as fixed but later we will talk about auto tuning so sometimes i repeat the
06:35:22.480 | same stuff over and over so i should change that okay let's write the signature of this function
06:35:33.840 | let's put it here so we already described what is the signature of this function let's go directly
06:35:41.680 | to the meat so the first thing that we need to do is understand the offset by which we need to move
06:35:47.200 | this query key and value and the offset is given by the first wall we need to enter the right batch
06:35:53.840 | and the right head inside of each batch we compute the index of the batch just like during the
06:35:58.720 | forward pass by dividing the program the program index which is a multiplication of the index of
06:36:05.040 | the head and of the the batch we divided by the number of heads to get which batch this program
06:36:11.920 | is working with and to get the head we just do the modulus just like in the for loop for one person
06:36:17.040 | the offset batch head indicates let me check what is it for okay it enters the right batch and the
06:36:24.640 | right head so what is the stride if you remember correctly the stride tells us how many items you
06:36:30.000 | need to skip in that dimension to arrive to the next index in the same dimension so if we want to
06:36:34.480 | skip index number of batch we multiply it by the stride batch which is how many elements you need
06:36:41.200 | to skip to arrive to the next batch plus we also need to enter the right head so we multiply the
06:36:48.160 | index of the head multiplied by the stride of the head to enter exactly in that head in the tensor
06:36:54.560 | for each of the q k and v matrices plus we also have this is will be used for if i remember for
06:37:02.640 | m and d because m and d only don't have the um the head dimension head dimension so they are only
06:37:10.480 | batch size number of heads sequence length so we just use the index batch multiplied by sequence
06:37:15.760 | length because for each batch and on each head we will have sequence length number of items so you
06:37:19.680 | can think of it at the stride to move from one batch head to the next batch head uh or to the
06:37:27.200 | yeah so uh let's move the pointers
06:37:33.040 | and this was so we move the pointer q k and v by the offset batch head because we want to enter
06:37:43.360 | the right um batch and the right head inside of these big tensors and we do it also for d o d q
06:37:50.960 | d k and d v because they have the same shape as a q k and v and d o also has the same shape as
06:37:56.720 | q so they have the same shape so we move by the same uh by the same offset all right so
06:38:03.680 | then we move m and d to move them to the right starting point on which the sequence of the
06:38:11.440 | current head and the current batch and the current head starts so they are pointing to the first
06:38:18.000 | vector of the sequence dedicated to the current batch and the current head
06:38:22.400 | and the same is true for q k and v and the d o d q d k and v okay then we load some other stuff
06:38:32.400 | because here we fix in this iteration in this method we are going to do a for loop in which
06:38:40.480 | we fix k v and we iterate through q so we first need to load this deeps block of k v
06:38:47.040 | and we do it as follows as follows so we know we need to load a 2d tensor so we need to define
06:38:56.160 | what are the ranges in the second dimension of each vector k and v that we need to load
06:39:04.320 | and it's defined by this by this vector then we want to understand which kv block this particular
06:39:14.960 | program is going to work with so this particular program is going to skip some kvs that will
06:39:20.720 | already be managed by other programs that may be running in parallel and how to understand
06:39:25.440 | what this program should be working with in based on the index of the program zero which is defined
06:39:32.560 | on the sequence length divided by the block size macro and if you remember block size macro is the
06:39:37.920 | thing that we fix so it's telling us this program id zero will tell us how many block size macro
06:39:45.680 | kv are already being managed by other programs so we shouldn't care about them so we skip them
06:39:52.080 | so let's go back here and this is the number of vectors that we need to skip
06:39:57.280 | so our kv start from start kv and how many we need to load them well depends on what is the
06:40:04.160 | block kv this block kv is equal to block size macro so it will be 128 vectors so we define
06:40:14.480 | our tensors two-dimensional tensors that we will store in the sram because in triton every time
06:40:22.080 | you load something you load it from the hbm into the sram so we define where they should be saved
06:40:28.080 | in the sram and they are initially zeros and now we load them so we load them as follows
06:40:34.000 | we say that okay in the k in the k tensor pointer which is already pointing to the right
06:40:46.320 | index to the right batch and to the right head because that's something that we did here
06:40:52.160 | we say we should need we need to load the right sequence of keys which should start from
06:41:00.320 | offski because this already includes how many we should skip in the sequence length dimension
06:41:06.000 | and for each of these vectors we need to load all the dimensions in the
06:41:12.640 | in the head dimension dimension because the k if i want to remind you is batch number of heads
06:41:21.120 | sequence length and head dim now by using this line we are skipping to the right b and to the
06:41:32.800 | right head so it's like we already indexed here and here we already selected an index so right
06:41:38.720 | now this k is pointing to the beginning of a tensor of two dimension and we tell okay we don't
06:41:45.120 | want all the sequence we want some part of this sequence which part the one that is indicated by
06:41:50.960 | this start kv and how many of in the sequence length we want well we want uh all right i think
06:42:00.320 | it's easy to write it like this so we can write it that from start kv to start kv plus block kv
06:42:08.480 | uh so we want this number of tensor exactly at this location and for head dimension what do
06:42:16.480 | we want to select we want to select all the dimensions so we say that we want from zero
06:42:21.280 | to head dimension which is exactly this offskdim okay uh we do it for the k block and we do it for
06:42:31.760 | the v block here i think i didn't change the comment this should be block kv and this should
06:42:40.320 | be block kv before it was called the block kv1 right like in the original code i simplified
06:42:46.960 | a little bit the naming i think this one is better easier to follow because in the original code they
06:42:51.360 | also do for two for loops but in the second for loop they will do it backward just to not change
06:42:56.400 | the structure of the loops but i think mine is more verbose but easier to understand and probably
06:43:03.120 | less efficient mine is much less efficient um then we have offsq because we need to understand for
06:43:10.240 | each block of queries how many vectors we need to load and it's indicated by this offsq and how many
06:43:18.000 | are them it's a block q block q in the color of this method was block size micro so it is 32
06:43:26.080 | vectors okay um now we need to access q vectors and o vectors trans uh no q vectors but already
06:43:38.640 | transposed and the o vectors also we need to access them because we are going to iterate
06:43:46.000 | through queries and o vectors actually also why because let's look at here let's look at the
06:43:54.720 | formulas in the paper to compute vj so to compute the dvj that's what we are trying to compute here
06:44:03.360 | we need to iterate through all the do vectors and to compute dk we need to iterate through all the
06:44:10.240 | qi vectors because the qi is a block of vectors so that's why we need um and why do we need to
06:44:21.840 | access a q as a transposed because we need to compute let me show you here pij transposed to
06:44:30.000 | compute pij transposed we need to we need the q transposed because the pij would be the softmax of
06:44:35.920 | the query multiplied by the transpose of the keys after we apply the softmax it becomes p but if you
06:44:41.840 | want the transposed of p then you need to do query transposed k multiplied by query transposed so
06:44:49.680 | that's why we access the query transposed instead of queries and the way we access the query
06:44:55.600 | transposed is just by playing with the stride so let's do it like this and i have also written
06:45:02.880 | the comment on why we can do it so this is equivalent to accessing the query uh how many
06:45:12.080 | first okay what is this um what is this operation uh what is this operation here
06:45:18.160 | this is saying go to the query starting point starting um pointer to the query which is
06:45:26.560 | already pointing to the right batch and to the right head for which this particular program
06:45:31.680 | should work with and select a two-dimensional vector where you repeat the query starting point
06:45:38.720 | along the in this case along the columns but we should be repeating it along the rows because we
06:45:44.160 | want to select rows of queries however if we want to select the query transposed we just invert the
06:45:51.280 | two dimensions so this is a let me actually show you without doing the query transposed so let's
06:45:56.720 | do it simplified like this so to access the query um the query pointers without transposition we can
06:46:06.000 | just do like this go to the query tensor and create a 2d tensor where in the rows you put
06:46:14.720 | the starting point of each query that you want to get and and replicate each of these points
06:46:22.240 | also on the column that's the meaning of adding this dimension none this is equivalent to when
06:46:28.000 | you do in pytorch the unsqueeze like you are calling off q multiplied not unsqueeze i think
06:46:40.160 | one so this is equivalent to adding the column dimension to this tensor and repeating all the
06:46:47.360 | values that are on the row on all the um on the columns how many columns will be there
06:46:53.840 | it will be broadcasted when we sum it with this tensor here this is a combination of
06:47:00.720 | unsqueezing and broadcasting so we are taking the query vectors indicated by off skew
06:47:06.800 | and then we are for each query vector we are selecting all the head dimensions indicated by
06:47:16.720 | dim if you invert this broadcasting it will create the transposed of the the the query vector that
06:47:24.480 | you are trying to access so this stuff here is equivalent to the these two lines so accessing
06:47:31.200 | query and then transposing and uh it's something that you can do uh i could write down what is
06:47:39.680 | happening at the pointer level so basically you need to think of off skew as being a vector of
06:47:46.640 | pointers we multiplied by the sequence stride which tells us how many element we need to skip
06:47:54.880 | to go from one query vector to the next because each stride q will be the stride will will be
06:48:02.240 | equal to in the case the head dimension is 128 the stride of the sequence dimension will be 128
06:48:09.040 | it means that to go from one query vector to the next you need to um you need to uh go forward by
06:48:17.040 | 128 elements because i want to remind you that in the memory the tensors are always stored like
06:48:23.440 | flattened like each dimension is flattened with the next dimension so imagine you have three rows
06:48:31.040 | and four columns but the first you will have the first three rows then the sorry the first row so
06:48:36.320 | the first four columns then the next four columns then the next four columns row after row
06:48:41.040 | it's difficult to visualize until you write it down so how to write it down take um create a
06:48:50.640 | vector of off skew so what is off skew at the beginning it's is a range that is from here from
06:48:58.720 | 0 to 100 no 0 to 32 0 1 2 3 4 5 6 7 etc etc we are multiplying each one with the stride of the
06:49:11.280 | sequence so this will not skip any element this will skip exactly 128 elements this will skip
06:49:18.000 | exactly implying that the head dimension is 128 this will skip two times 128 elements this will
06:49:25.760 | skip three times 128 elements and then we are adding also another dimension to this vector
06:49:33.920 | so this will be a vector then you broadcast it on head dimension number of columns and to each of
06:49:40.400 | them you add one number so it will become a vector like for okay let me just do it guys otherwise i
06:49:46.960 | think it's too confusing okay so we have a vector that is as follows so zero then we have 128 then
06:49:56.000 | we have two times 128 then we have three times 128 etc etc we are adding how many columns
06:50:03.120 | indicated by off dim so off dim has how many columns so it has a head dim number of columns
06:50:08.480 | please for simplicity let's pretend it's not 128 dimensions let's pretend it's four dimensions so
06:50:14.800 | this will be four this will be two times four this will be three times four we are adding another
06:50:22.320 | dimension that is the dim dimension each one multiplied by the stride of dim which will be
06:50:28.720 | one because it's the last dimension stride dim so we are adding how many columns four
06:50:38.480 | so we are adding um one zero one two three i guess zero one two three right also to this one
06:50:45.920 | we are adding oh my god zero one two three and also to this one we are adding zero one two three
06:50:56.720 | okay and then also to this one we are adding zero one two three so what this will select
06:51:05.280 | this will select from the starting point of the pointer q it will select the element zero
06:51:12.160 | then the element one then the element two and then the element three which is exactly the
06:51:18.240 | head dimension of the first vector that we should be selecting then it will select
06:51:23.280 | the element four from the starting point of the vector the element uh sorry this one let me write
06:51:29.760 | the result of this operation so this one will be zero one two three then it will select the element
06:51:34.720 | four five six seven then it will select the element um eight i guess nine ten eleven
06:51:41.600 | and then it will select the element 12 13 14 15 so from the starting point of where this q is
06:51:52.480 | pointing it will select the first element right after this q the second element right after this
06:51:57.840 | q the third element right after this q etc etc and this will be the you can see that this will be the
06:52:03.600 | first query vector this will be the second query vector this will be the third query vector this is
06:52:07.920 | the fourth query vector because in the memory they are stored one after another they are flattened
06:52:13.200 | so in the memory they are stored like this they are stored like the following they are stored
06:52:19.600 | like this one after another so it will select all of them and it also create a virtual tensor with
06:52:25.920 | the right shape that we want to visualize it into so as you saw as we saw before when you work with
06:52:32.480 | a tensor layout in memory you can always view it as whatever shape you like based on the shape that
06:52:38.400 | you want and the reshaping is always free doesn't involve changing the arrangement of the elements
06:52:44.960 | in the memory i hope now it's more clear so now we can proceed further oh my god it was quite
06:52:51.200 | complicated so whenever i get stuck i just draw things and i think you should do it too because
06:52:56.320 | that's the only way to learn if you try to imagine everything in your head it's always difficult
06:53:02.080 | and we do the same job for the o vectors so in the o vectors we don't access it as
06:53:08.320 | access it as a transpose because we don't need it in transpose only the q we need it in transposed
06:53:14.240 | okay it traced through the sequence dimension of the query so we start from the query number zero
06:53:21.280 | in the current um well in the query we need to go through the all the sequence length dimension
06:53:27.520 | because only the key we select the right key that we want to work with so i want to remind you here
06:53:32.720 | we fix the key and we go through all the queries but the query we need to start from zero until
06:53:38.080 | sequence length so the number of steps of this for loop will be sequence length divided by block q
06:53:44.800 | so if we have a 1000 elements in the sequence and block q is 32 it will be 1000 divided by 32
06:53:54.320 | a bad choice of 1000 should be 1024 otherwise it's not divisible so then we go through each
06:54:01.440 | block in this for loop and we load a block of q the first one indicated by our pointer
06:54:07.360 | and at the end of the iteration we will move it to the next to the next block of q
06:54:12.160 | okay we'll add also the log sum exp values that are stored in the m matrix
06:54:20.800 | because we want to compute on the fly pt pt is the transposed of the softmax of query multiplied
06:54:29.520 | by the keys but we want to not take a query multiply by the transpose of the key and then
06:54:34.160 | do the transpose we just already access q as a transposed so we can already compute the pt instead
06:54:40.400 | of computing p and then transposing it um so we load the offsets of the elements that we need
06:54:49.200 | from this log sum exp matrix which is the m matrix that we computed during the forward pass
06:54:57.040 | and we access a block of q at a time the one we are currently working with in the iteration
06:55:05.440 | then we access a query key transposed already so we do the if you want to get the pt
06:55:17.440 | p should be um this is actually not p because we didn't do the softmax it's actually s t but okay
06:55:24.800 | if you want to get the pt you need to get the softmax of st the softmax of st is what it's a
06:55:34.560 | it's transposed of s what is s is a query multiplied by transposed of the key so to
06:55:39.440 | get st you need to do um key transposed no key multiplied by query transposed so as you remember
06:55:45.200 | in the matrix multiplication if you transpose the matrix multiplication you need to also invert the
06:55:51.520 | two element in the matrix multiplication so that's why we are doing a key multiplied by query
06:55:56.720 | transposed this will give us s transposed we are also scaling it with the softmax scale
06:56:03.440 | before we apply the to apply the softmax we just need to do the exponential of each element minus
06:56:10.640 | its maximum divide by the normalization value but with the log sum extract we just need to
06:56:16.000 | each element subtracted by the m value which already includes the normalization factor
06:56:22.640 | i think i already did the derivation of this so we don't need to go through that again
06:56:28.000 | okay so now we have the pt block actually so in this formula i should have written st actually
06:56:35.920 | okay then when doing the causal attention we also need to mask out some values
06:56:47.760 | so as you can see here so in this case the causal mask is applied after the softmax has
06:56:56.480 | been computed because during this one is you are used to compute the apply the soft the
06:57:03.200 | causal mask before computing the softmax attention but this is actually during the forward pass
06:57:08.960 | because you don't want the normalization factor to be affected by the element that should be zero
06:57:14.240 | but we already computed the normalization factor so it cannot be affected anymore so we can compute
06:57:20.880 | we can mask out after applying the software because the normalization factor has already
06:57:25.200 | been calculated based on the fact that we applied the mask and that's why we we can apply it after
06:57:30.880 | applying the softmax so the mask is always the same so if the query is more than the
06:57:38.000 | index of the query so the mask is true in this case for all the values that do not need to be
06:57:44.800 | masked so all the values that do not need to be masked are these ones here and all the other
06:57:50.640 | value will be replaced with the zeros all right so after we have the pt block already masked
06:58:00.640 | we can calculate dv dv i will write i will point to the right formula in the paper so we load a
06:58:07.680 | block of do why do we not load a block of do let's look at the paper so how to compute the dv block
06:58:14.480 | so the dv block is computed as the old dv plus so a repeated sum as you can see as you can see
06:58:23.520 | it's here plus equal the old dv plus pt so here pt dropped indicates the pij after applying the
06:58:34.960 | dropout in this implementation we don't support the dropout and also very few models actually
06:58:40.320 | use the dropout in the attention so pt multiplied by doi so a block of doi and doi is the same
06:58:50.720 | block that should be also doi and ki qi are referring to always the same block of rows in
06:58:58.960 | the respective tensors that's why because this inner iteration i indicates a block of q and a
06:59:08.320 | block of o but we are always referring to the same positions in the tensors because do has the same
06:59:15.520 | shape as dq so we go through the blocks of query and the do simultaneously because one is needed
06:59:23.920 | for dv so for dv we need do and for dk we need q and that's why we compute the dv as follows
06:59:32.960 | just like from the paper so pt block multiplied by do as you can see it's a p transpose multiplied
06:59:38.800 | by the o block so we have computed computed the do block then we need to load the di element that
06:59:47.600 | we computed pre-computed initially with the first call to the function called the attention
06:59:56.000 | backward pre-process because we will need it for dk so let's see and how many of them we are loading
07:00:06.800 | exactly the same number of query that we load because they are we load always the same number
07:00:13.200 | of block size micro number of vectors okay i will copy some stuff and explain it step by step so
07:00:21.520 | the next operation that we need to do is to compute this dk to compute the dk we need the dst
07:00:30.240 | to compute the st we need to to get a dpt so let's go one by one let's go from the back from
07:00:39.200 | the end to the beginning of this formulas so we don't understand where everything is used to where
07:00:47.280 | everything is created so let's start from dk if you look at the paper dk is equal to the old dk
07:00:55.120 | plus ds transposed multiplied by a block of q and this is what is written here so it is
07:01:03.360 | plus equal means basically just the old plus the new some it's an incremental addition so
07:01:10.240 | increment the old k with some new stuff which is this stuff here so the softmax scale multiplied
07:01:17.200 | because also there is a softmax scale this tau here multiplied by the matrix multiplication
07:01:23.440 | between dst block and the transposed of um and and q and q you can see here this q but we don't
07:01:34.960 | have a q we have a q transpose so we take the transpose of q transpose and it becomes back q
07:01:40.560 | now let's look at this dst block dst is calculated as follows so in the formula of the paper we have
07:01:48.720 | ds ds is here it is equal yeah it is here and it is equal to a block pij multiplied element wise
07:01:59.280 | with dpi minus di now um we don't need ds we need ds transposed so to compute ds transposed this is
07:02:10.000 | an element wise multiplication not a matrix multiplication which means that when you take
07:02:14.800 | the transport of this operation you don't need to invert anything you just need to take the
07:02:18.480 | transpose of the two operands so to compute the st we take the transposed of p which is the pt and
07:02:25.920 | we already have that and then the transpose of everything that is inside of the parentheses so
07:02:31.280 | this dpt minus di where we inverted the rows with the columns so this dpt is what well in the paper
07:02:40.480 | we know the formula for dp dp is here and it is equal to d wait dp here and it is equal to do
07:02:51.280 | multiplied by b transposed so but we don't need the dp we need the dpt and in this case it's not
07:02:58.080 | an element wise multiplication it is a matrix multiplication so um in order to get not a dp
07:03:05.760 | dp but dpt we need to take the transpose of these two operands of this matrix multiplication and
07:03:10.960 | in the matrix multiplication when you take the transpose you need to also invert the order of
07:03:15.280 | the two operands so we need to take the vt transposed which becomes v so the v block
07:03:22.400 | matrix multiplied by the other operand so doi transposed and that's why we are doing the
07:03:29.120 | transpose of do right now i'm not going through all the single pointers because i already told
07:03:36.240 | you how to check what a pointer is pointing to and what an offset is referring to i hope that
07:03:42.400 | now you have a better understanding on how these pointers work in triton which is also the same way
07:03:48.480 | in the in which they work in cuda because in the gpu we only get a pointer to the starting point
07:03:55.280 | to the starting address of the tensor and then we need to work out all these indices
07:03:59.040 | we have computed the dk block so we now go to the next query to the next block of queries and
07:04:07.440 | so the next block of queries because we are fixing k and v blocks and we are iterating
07:04:16.960 | through all the queries so we need to move the query transpose the pointers
07:04:22.320 | by stride sequence which means that how can we go from one query to the next
07:04:27.200 | and we multiply it with the current block q which is a vector which indicates the pointers
07:04:34.080 | to the current element in q that we are accessing and we do it also for do and we use the block q
07:04:39.840 | as element and the stride q because do and q all have the same shape okay after we have run the
07:04:48.480 | for loop of all the queries we can store this dk and dv block so we write it back as follows
07:04:56.080 | and this is the end of our function guys so we save the dv block exactly in the position
07:05:05.520 | inside of the current okay dv is already i believe pointing to the right
07:05:10.640 | batch and to the right head because we incremented it here and also in the case of dk
07:05:16.400 | then we need to tell it in the sequence dimension where they should save this
07:05:20.320 | block of k and v and this is indicated by this one we say and we create the the pointers just
07:05:28.240 | like before guys don't make me do it again it's a really easy if you write it down like you write
07:05:35.120 | this vector of key and values pointers which is not pointers actually they are a range of the
07:05:45.920 | of key and value that you need to take from the sequence dimension you add another dimension that
07:05:52.320 | is the column so you repeat each value in the columns and then you add the dimension here
07:05:58.080 | for the head dimension anyway after we have calculated the pointers where we should store
07:06:03.280 | the dk and the dv we store them in the the pointers of we store them in the dv
07:06:12.000 | i mean we store them in the dv tensor and the dk tensor what do we save we save the dv block and the
07:06:18.480 | dk block which is the one that we were incrementally changing in the for loop that we have written
07:06:25.840 | okay now that we finished this one we can go to the next function that will do the other for loop
07:06:32.400 | so let's do it okay so now we do the second part of the iteration which is this one so let me just
07:06:40.160 | copy it and then we we describe it uh let's write it here okay we use the same launch grid as before
07:06:49.600 | of course we need to declare this function and again we um we because the grid is defined for
07:06:55.200 | the block size macro for what is the thing that we keep fixed and then we in the side of the
07:07:01.520 | for iteration we do um steps of block size micro in this case we are fixing q and we are iterating
07:07:09.840 | through k and v because we need to compute dq right now we have computed dk and dv
07:07:16.320 | okay the i believe the arguments are the same as before so and actually this is also the reason
07:07:23.680 | why in the original implementation on the triton website the author decided to um to use the same
07:07:30.880 | for loop but with different arguments and i believe it was a little confusing so that's why i just
07:07:36.720 | separated them i just repeat the code twice it's the goal of this video is to be as easy
07:07:42.240 | to understand as possible not to be as efficient as possible so uh let's go uh here so let me copy
07:07:50.400 | the signature again let me define this function here okay so uh again we need to first move the
07:08:02.720 | query key and value uh to the right pointer so which will point to the exact batch and the exact
07:08:08.640 | head that we are working with in this program so um let's do it let me check where is the code here
07:08:17.520 | and the first part is exactly the same as the other for loop that we have written
07:08:21.840 | so let's go here and really is i just copied so it's exactly the same so we check what is the
07:08:30.720 | index batch head we move the query key value pointers to the right place the d o d q d k d v
07:08:36.480 | point to the right place the m and d to the right place exactly like before so i don't think i need
07:08:41.840 | to explain that again and then we load a block of q the one that we will keep fixed
07:08:48.880 | so dq
07:08:52.880 | let me load a lot of stuff here actually
07:08:58.800 | okay we define the offset that we will need to load the blocks of k and p in the head dimension
07:09:05.840 | because we are going to iterate in the k and v we will access them as transposed blocks so instead
07:09:12.880 | of accessing them directly as a k and v we access access them as a kt and pt and you know that that's
07:09:20.240 | possible just by changing the strides in this case because we are treating them as a 2d vectors we
07:09:26.960 | treat the offs kv when you want to access k as just not transposed but k you treat this offs kv
07:09:34.880 | as a row vector sorry a column vector so you repeat on the rows each k offset that you want to
07:09:43.680 | access in this case we are repeating it as a we are treating it as a row vector so it will be
07:09:49.280 | repeated on the rows um sorry it will be broadcasted on the column dimension and that's how you can
07:09:57.520 | access the transposed version of k and how you can access the transposed version of v another thing
07:10:03.760 | that we are doing is we are loading the q vector which vector well based on offs q which is the
07:10:09.920 | q vector which vector well based on offs q which is based on the start q which is based on the
07:10:16.240 | exact starting point in which this particular program should be working with because this
07:10:20.640 | particular program works as two dimensions the first dimension indicate which batch and which
07:10:26.720 | head this program should be working with and the second dimension which is the program index number
07:10:31.440 | zero indicates which among all the sequence length which query this particular program is going to
07:10:37.120 | work with this is indicated by the index block this should be actually q in this case i forgot
07:10:43.840 | to change the name so actually let me change it so it's index q because we start we skip some q how
07:10:50.000 | many q we skip based on the index of the current program multiplied by how many blocks have already
07:10:57.600 | been processed by the previous programs this will tell us inside of the sequence length what are the
07:11:04.400 | queries that this one needs to select so that's why we use the start query plus the range that
07:11:10.320 | is block q so imagine the starting query for this program among all the sequence length is 100
07:11:16.320 | then this will load the query row 100 101 102 until 100 plus block q minus 1
07:11:24.640 | this is the range that we of the query vectors that we will load in this program
07:11:32.720 | we load the block of q by using a q plus the offset repeated on the columns so we treat it as a
07:11:41.520 | column vector but we repeat broadcast it on the rows vector where each column will be one head
07:11:51.200 | dimension multiplied by the stride in this case we actually can also not multiply by the stride
07:11:55.680 | because the stride in the dimension dimension so the last dimension of the batch is one
07:12:01.680 | because to go from one um actually the stride um how it is defined the stride of the last dimension
07:12:10.000 | is always one because to go one element to the next element you should move to move it to by one
07:12:16.720 | element um so we load the dq which is the stuff that we are going to compute in this um iteration
07:12:27.040 | and then we have a do that we need to load and the do we use the same offset as q because the do and
07:12:33.680 | dq have the same shape and they work in the same way so we load a block of q and we load the
07:12:40.080 | corresponding block of o of do in this case and the do has the same shape as o which has the same
07:12:47.440 | shape as q plus we need to load also the m normalization factors which are in the m matrix
07:12:54.400 | which one the chorus the one corresponding to this particular group of queries that we are
07:12:58.800 | going to work with in this particular program we start with the offsets are the as you can see the
07:13:06.800 | offsets are the first block of kv starting from the zero position so because we will iterate through
07:13:13.680 | all the kvs and we start from the zero kv so the key key vector zero and the v vector zero and then
07:13:21.040 | we move by block kv number of vectors forward at each iteration i hope i didn't go too fast because
07:13:30.080 | most of the things that are written here are very similar to what we have already done
07:13:34.240 | in the the other for loop so i don't want to be you know repeat myself too much um what did matter
07:13:41.840 | is actually the formulas that we will use which is exactly the one in the paper so uh we go through
07:13:48.880 | these blocks of kv we load the first block of k transposed and v transposed which is loaded like
07:13:59.200 | this as usual you tell it what pointers the elements you want to load and what are the
07:14:04.320 | pointers of another element that you want to know that it will load the the block that you are
07:14:09.360 | asking triton to load inside of the sram so this stuff all reside in the sram and also q
07:14:15.520 | resides in the sram and also do reside in the sram um then we compute the query multiplied by the
07:14:23.840 | transpose of the keys because we need to compute the p block so the query the qk block is just the
07:14:31.520 | query in the current query block with the k transposed in the current query block and the
07:14:36.880 | current key block um why but we access the query the keys already as a transposed so we don't need
07:14:45.360 | to transpose it and anyway even if we did if we need to transpose it it's just um it's not it
07:14:53.440 | doesn't require any computation to transpose matrix we just access it in a different way
07:14:58.480 | because in the memory layout it's always stored kind of as a flattened array
07:15:06.160 | then we compute the p block which is the output of the softmax so each of the query key
07:15:10.880 | we subtract the log sum exp value for the this block of queries that's why for loading the m
07:15:19.680 | block we use the offsets of the queries that we are loading and as you remember the m block already
07:15:26.400 | includes also the normalization factor because each m is actually the maximum value for each
07:15:32.160 | row plus the logarithm of the normalization factor that when you apply with the properties
07:15:36.800 | of the exponential it goes into the denominator okay and then we apply again the autoregressive
07:15:44.240 | masking oops what did i do let me go back to the code here so we have the stage this one
07:15:59.520 | so when we launch the backward pass stage three indicates that it's a in the forward pass we
07:16:05.760 | computed the causal attention and the one indicates that we computed the non-causal attention so if we
07:16:13.520 | computed the causal attention in the forward pass we also need to mask out these elements in the
07:16:19.120 | backward pass so we check um we create the mask which tells us which index this mask is true for
07:16:28.720 | only for the elements for which the query index is more than the key index and if this is true
07:16:36.400 | then we uh we don't mask otherwise we mask um let's compute the next operation which is to compute dp
07:16:45.600 | and ds actually i let's compute directly dk and then we explain it like before so we start from
07:16:51.280 | the end and we go to where this stuff what is needed to compute it so if you look at the formula
07:16:58.160 | uh let me check this one we don't need i think okay let's go here to the ipad okay what we are
07:17:13.360 | trying to compute here is dq so dq as you can see in the paper is a dq is equal to the old dq
07:17:23.920 | plus tau which is the softmax scale which is this stuff here multiplied by
07:17:31.680 | the matrix multiplication between the ds and the k block so the ds block is here and the k
07:17:41.040 | block is the transpose of the kt block because we are accessing k already as a transpose block
07:17:47.840 | we could also access a k directly as not transposed block by inverting if you don't
07:17:54.000 | want to access it as a transpose block just do like this like here none this will treat it as a
07:18:01.920 | row vector and broadcast along the columns otherwise and also this one you need to change
07:18:08.240 | so this one you shouldn't need to change because this one you need to treat it as
07:18:14.720 | a column vector the dimensions but if you want to access it as a k transpose then you just
07:18:20.080 | invert these two operations i hope i didn't mess up anything so let's move forward um so okay we
07:18:27.840 | know that the formula for the dq is exactly the same as the as the paper one but what is this ds
07:18:35.760 | block let's look at the paper this ds block is coming from this stuff here so this i believe
07:18:43.360 | this stuff here ds which is a pi the p block element wise multiplication with the dpi minus
07:18:53.040 | di which is dpi minus di now what is the this p block the p block is exactly the output of
07:19:02.160 | the softmax which we already have what is the dp block well the dp block is exactly
07:19:08.000 | do multiplied by v transposed which is a do which we already loaded and it's here
07:19:14.800 | multiplied by the transpose of the v which we already load as transposed and this is how we
07:19:20.080 | compute the dq let's include then of course we need to move to the next block of keys so we
07:19:29.280 | increment the pointers just like before so we move to the next block of keys and values
07:19:37.760 | and also remove the pointers um just like before and then we need to store the result of dq and
07:19:45.680 | this way we only need to do one write to the hbm by dividing the for loop like the following
07:19:52.400 | so if you look at the original algorithm uh i i don't know if the original algorithm actually
07:19:58.560 | corresponds in to to the implementation that they did in cuda but i don't think so because it would
07:20:04.400 | not be so optimized but in the original algorithm in the paper they say that you need to go through
07:20:10.800 | all the keys and then while going through the keys you need to go to all the queues and for
07:20:14.960 | each queue that you visit then you need to write back the queue while you are updating it which is
07:20:19.760 | not optimized that's why we needed to do two for loops one in which we fix the query and we update
07:20:26.160 | the keys because each key is updated depends only on a particular block of queue on all the blocks
07:20:32.720 | of queue sorry and then we fix the queries and we iterate through all the keys because one block of
07:20:37.680 | queue depends on all the blocks of case and this is why we split and this is the second loop that
07:20:43.680 | we have written now we have written everything that we needed to for flash attention um the
07:20:50.800 | forward pass and the backward pass so uh we should be ready to uh launch the uh the kernel i hope i
07:20:58.480 | didn't make any mistake in copying the code so i don't think i will try to launch it and if there
07:21:04.480 | is any error i will just use my reference code which i have already written that i used as a
07:21:08.720 | copy the only difference up to now between my reference code and the one that we have written
07:21:14.640 | is the auto tuning which i didn't explain so let's talk about the auto tuning so the auto tuning is
07:21:20.240 | also something that was already present in the original paper and i kept it as is uh i removed
07:21:25.120 | the auto tuning for the backward pass but in the forward pass you if you check there is this code
07:21:32.160 | here that indicates the auto tuning configuration for triton so triton basically cannot know
07:21:40.560 | beforehand what is the best block size or what is the best block size for the query or what is the
07:21:45.840 | best block size for the key and values or what is the best block size for another dimension that we
07:21:51.520 | have we need to try based on the hardware that we are running on based on the availability on the
07:21:57.840 | sram based on the thread coarsening that triton can apply so i didn't talk also about thread
07:22:03.920 | coarsening basically in cuda you can choose if each thread does one atomic operation for example
07:22:10.000 | in a matrix addition each thread is doing one addition of one particular element of the output
07:22:15.840 | matrix or it's managing multiple elements this is called thread coarsening and i think i didn't
07:22:21.120 | check the documentation but i believe triton does it for you based on the block size that
07:22:26.320 | you give it and the number of warps that you want the number of warps is what is a block of threads
07:22:33.040 | of 32 threads that work cooperatively running the same instruction always at the same time
07:22:40.560 | the number of stages is more interesting it's an optimization that triton does
07:22:48.480 | basically it is not loop unrolling so actually let's talk about uh let's talk about software
07:22:55.360 | pipelining because this is the last part that we need to understand from this code which is
07:22:58.640 | the auto tuning so i believe that the most interesting part here is not choosing the
07:23:02.720 | block size q and the block size k because that is just a kind of you try whatever whatever
07:23:07.920 | configuration works best based on the timing through cuda triton will actually run all these
07:23:14.000 | configurations for you every time the sequence length or the head dimension changes and for
07:23:18.160 | every pair of head dimension and sequence length it will choose the best configuration that runs in
07:23:23.360 | the least amount of time that gives you the best throughput actually so let's look at this
07:23:30.240 | numstages what is it and how it works so let's do it okay so software pipelining is
07:23:38.880 | it's used when you have a kind of a for loop so you have a sequential operation
07:23:44.800 | in which each iteration does not depend on the previous iteration so the operations that
07:23:49.280 | you're doing in one iteration are independent from what you have done in the previous iteration
07:23:52.480 | which is more or less what we have done before in our for loops actually there i believe there are
07:23:58.400 | how to say conditions in which this doesn't have to be true so like
07:24:03.920 | the operations can depend on each other and you still can do software pipelining
07:24:09.600 | so for example imagine you have the following for loop for loop that rose from one to n and
07:24:16.880 | first you load some data then you load some other data then you do a matrix multiplication and then
07:24:21.280 | you store some data so here you are reading data here you are reading data here you are
07:24:26.240 | computing some stuff and here you are writing data if we look at what happens at each iteration
07:24:33.280 | we will see the following picture imagine our gpu is made up of a compute unit and a unit that is
07:24:40.160 | dedicated to loading stuff so reading from the memory or writing to the memory what we will see
07:24:46.080 | in the time scale is that at the first iteration first we are reading some data and the compute
07:24:52.000 | unit is idle because we need this data then we are reading some more data and the compute unit
07:24:56.480 | is idle because we need this data then finally we have enough data and then we can compute this
07:25:00.960 | operation and the reading unit is idle and then we are writing some data back to the memory and
07:25:06.160 | the compute unit is again idle and then it will be idle for another two time steps until it has
07:25:12.000 | enough data to run the computation so as you can see this is not very efficient because at any
07:25:18.560 | time a point in time there is only one unit working and the other is sitting idle
07:25:24.480 | so one way to optimize this for loop is to do software pipelining and you can tell triton to
07:25:30.480 | do it for your for loops by telling it how many stages you want so let's see how it works so to
07:25:36.320 | pipeline a for loop means that first of all you need to convert all these operations into async
07:25:42.480 | operations and in cuda at least in the gpu of nvidia there are the async loading from the
07:25:48.560 | memory and the async load writing to the memory which basically means that i spawn a load operation
07:25:55.120 | and after and when i only i check if it has completed when i actually need it so i will
07:26:02.480 | spawn this operation and this instruction will return immediately and move to the next instruction
07:26:07.840 | here i will spawn a load iteration and this will return immediately and move to the next instruction
07:26:13.280 | and then i can compute but before computing i just check if these two operations have completed so i
07:26:19.040 | can spawn immediately two reads and then i just check if this they have completed so with the
07:26:26.560 | software pipelining what we are doing is we are pipelining operations of different iterations
07:26:32.640 | into a single iterations so first basically what we will do is we will do the read the first
07:26:38.720 | matrix that we need for computing this matrix multiplication then at this next iteration we read
07:26:45.200 | the the we we read the first matrix of the second iteration and also read the second matrix of the
07:26:55.040 | first iteration so i call it read a and read b which indicates read the first matrix of the
07:27:02.080 | that we need and the b means the read the second matrix that we need all these operations are
07:27:08.880 | asynchronous then i launch another asynchronous operation at the third iteration that says
07:27:14.160 | read the the first matrix of the third iteration and then read the second matrix of the
07:27:23.040 | of the second iteration and then compute the matrix multiplication because at the third
07:27:31.040 | iteration this one and this one should have completed but while computing the matrix
07:27:37.200 | multiplication i don't keep the loading unit idle because they are still computing this this
07:27:43.600 | and this load this can only work if you can spawn async operations so at the third iteration i can
07:27:51.600 | compute this matrix multiplication by using this one and this one because they should have finished
07:27:56.720 | but while i'm computing the matrix multiplication i already spawned some async operations to load
07:28:02.080 | the data necessary for the second iteration and the third iteration so at the fourth iteration
07:28:08.720 | i will spawn the loading of the data for the fourth iteration loading the data for the third
07:28:15.040 | iteration while computing the matrix multiplication of the second iteration because they should have
07:28:19.040 | already completed by now actually it's not like we expect them to have been completed there are
07:28:25.360 | primitives in the language in the CUDA language to check if the operation has completed so actually
07:28:31.760 | before doing the multiplication we will actually check if the async operation has finished so it's
07:28:37.520 | not like we just expect it we have finished it with respect to time this is like in javascript
07:28:43.760 | you have these things called promise i remember and you can wait for the promise to be finished
07:28:50.080 | before you actually need them but you can spawn as many promise as you want in C# i think they
07:28:55.120 | are called tasks so you spawn as many tasks as you want and then when you need it then you just wait
07:29:01.680 | for them only the one that you needed while the other are still running in the background
07:29:06.080 | asynchronously this is the whole idea of software pipelining software pipelining as you can see only
07:29:13.840 | works when you have async operations and also it increases the memory requirement for your program
07:29:19.200 | because when matrix multiplication one is going to run we may have enough data for the first
07:29:28.720 | two iterations plus half data for the third iteration so we increase the memory requirement
07:29:34.320 | for the SRAM okay and the Triton will do this software pipelining for you it will convert all
07:29:42.720 | the load all the stores and maybe also the matrix multiplication into async operations and do this
07:29:48.240 | pipelining for you if you are confused by how it works there is another easy solution to explain
07:29:53.360 | you how it works because it's already something that we do in model training it is called pipeline
07:29:58.720 | parallelism so in pipeline parallelism it works as follows we have a very big neural network that
07:30:06.240 | does not fit in a single gpu so imagine this neural network is made up of three layers layer
07:30:11.040 | one layer two and layer three but this is so big it does not fit entirely in one single gpu so one
07:30:17.360 | way would be to put this each layer into one gpu so we put for example layer one into gpu one
07:30:25.520 | layer two into gpu two layer three into gpu number three so imagine we have an input for this neural
07:30:33.680 | network so we put it to the first gpu the gpu one will process the layer one and generate some
07:30:39.840 | output which will be transferred to the gpu two which will calculate its own output and transfer
07:30:45.200 | it to the gpu three which will compute its own output and finally we will have the output of
07:30:49.280 | the neural network the problem is when you send the output of the gpu one to the gpu two for the
07:30:55.600 | gpu two to do its own thing the gpu one now is free so it is a waste of resources we could always
07:31:02.160 | should keep the gpus busy so what one thing that we can do is instead of sending all the the mega
07:31:08.880 | batch to the gpu one we send many smaller batches how does it work imagine that we send the batch
07:31:15.760 | number zero so batch zero uh to the gpu one the gpu one will compute its output and send it to
07:31:23.920 | the gpu two so now the gpu two is computing the batch number zero so now the batch zero is not
07:31:30.560 | here anymore but now the gpu one is free so we send another micro batch called the batch one
07:31:39.360 | then the gpu two will finish processing the batch zero and we'll send it to the batch to the gpu
07:31:46.160 | number three so now the gpu three has the batch number zero and the gpu two now is free so we
07:31:52.240 | transferred and hopefully also gpu one has finished so we transfer the batch number one
07:31:56.960 | from gpu one to gpu two the batches and then the gpu one will be free so so we transfer
07:32:05.760 | here becomes one and now this one is free so because it's gpu one is free we can introduce
07:32:10.720 | another batch so batch number two etc etc etc so we always introduce when while moving one
07:32:18.160 | batch from one gpu to the other we introduce a new batch at the beginning of the pipeline and
07:32:23.440 | they shift by one position at every iteration this will keep the gpus always busy there is only one
07:32:30.080 | problem of the pipeline parallelism which is the this bubbling effect because to create this
07:32:35.440 | pipeline you at the beginning of this um okay actually in the pipeline parallelism you also
07:32:40.240 | have the problem of the backward step so the backward step has to run exactly in reverse
07:32:44.880 | in the order in which you receive the micro batches while in triton when doing software
07:32:51.520 | pipelining you have the problem of the prologue and the epilogue because you need to create this
07:32:57.440 | pipeline and and to start the pipelining and at the end of the pipeline you need to
07:33:04.320 | use all the stuff that is currently in the pipeline so only in the beginning step and in
07:33:10.400 | the last step of this for loop your um all the units of this gpu may not be working simultaneously
07:33:18.000 | which what does it mean it means that in order to use pipelining you want the number of iterations
07:33:23.840 | of your for loop to be much more bigger than the number of stages in which your iteration is
07:33:28.400 | divided into in this case we have four stages these are called stages so you want the number
07:33:33.520 | of iterations to be much more to be much larger than the number of stages all right guys finally
07:33:39.760 | i have completed the video um i hope that you learned a lot from this video i believe that we
07:33:45.520 | can run the triton code so let's run it actually uh let's see i copied everything i believe we also
07:33:52.240 | put the code to test it but we didn't call uh put the uh main method which we can copy right now i
07:34:00.400 | hope there is no error so i really hope there is no error i really hope so um let me check if i am
07:34:09.440 | in the right machine i am so let's just run program pray if there is an error i will just copy my own
07:34:20.880 | reference implementation but i hope it works because otherwise i forgot something so i'm
07:34:27.280 | running my code on an h100 because my company has h100 if you have a smaller gpu what you can do is
07:34:34.720 | you can reduce the sequence length you can reduce the batch size i think it's already one when we
07:34:41.600 | call it uh oh no the batch size you can reduce the batch size the number of heads the sequence
07:34:45.600 | length you can even put head dimension equal to 8 and sequence length equal to 16 let's check
07:34:51.680 | run backward triton backward returned an incorrect number of gradient expected 5 got 1
07:34:59.600 | we probably forgot some return statement i believe yes so i forgot the return statement
07:35:08.800 | here so in the backward pass after running the last for loop we need to return the stuff that
07:35:14.160 | we have computed cross finger again okay passed so the backward pass that is computed by torch
07:35:24.080 | it is equivalent to our backward patch up to 10 to the power of minus 2 error absolute error
07:35:31.040 | so when you as you can see this backward that we run here is different than the backward that
07:35:37.360 | we run here because when you apply triton attention it will introduce a new computation
07:35:42.400 | graph in the computation graph of our tensors that will include this triton attention operator
07:35:47.920 | and when pytorch want to compute the backward pass it will just call the backward function
07:35:52.080 | of this triton attention to compute it and it will populate the grad value of all the tensors
07:35:58.400 | that are the input to this triton attention and this is how pytorch autograd works guys
07:36:04.960 | thank you for watching my video guys it has been super super super demanding
07:36:11.120 | i spent many months first of all to learn myself about the triton about cuda about flash attention
07:36:17.360 | etc also i have a full-time job so it is really hard to make videos like this like i need to
07:36:23.440 | dedicate you know the nights the mornings the weekends i spent three days just to record this
07:36:27.920 | video because sometimes i don't like how i explain something sometimes i make mistake or sometimes i
07:36:32.800 | need to restart because what i'm doing is wrong etc etc i believe there should be no big errors
07:36:39.760 | in what i have done so far but for sure my notation is completely bad like because all
07:36:45.280 | the mathematics i know has been self-taught by i i learned it by myself so because i didn't learn it
07:36:51.760 | in academia i have bad habits and i'm trying to get rid of them so i use the very bad notation
07:36:57.040 | sometimes i calls with the capital letter sometimes with this lowercase sometimes i just
07:37:01.760 | forget the index etc so i'm trying to solve these problems um i believe i have explained everything
07:37:08.240 | so i should be you should have all the knowledge to derive all the formulas that you see in the
07:37:14.320 | paper of the flash attention and you should also have an internal image on how the back the the
07:37:21.200 | attention calculation is working block by blocks i know that i could have spent 20 hours explaining
07:37:27.360 | things better but i also have a life and i also have a wife so i i i cannot make a 100 hours videos
07:37:35.920 | also there were some interruptions making these videos i i removed some wisdom teeth so it took me
07:37:41.120 | at least one more than one week to to to recover because it was so painful so thank you guys for
07:37:48.640 | watching my video i hope you learned a lot also this time i as you can see triton is something
07:37:53.600 | new there is not much documentation so something that i have said about triton may not be totally
07:37:58.480 | correct because really there is very little documentation so all the triton that i have
07:38:03.040 | learned is by looking at the code written by others and try to understand it um and i think
07:38:12.400 | that's it guys so i wish you a wonderful day and see you next time on my channel