Back to Index

LongNet: Scaling Transformers to 1,000,000,000 tokens: Python Code + Explanation


Chapters

0:0 Introduction
5:30 How it works
14:0 Computational Complexity
19:17 Attention visualization

Transcript

Hello guys, welcome to my channel. Today we will be talking about a new model that came recently it's called LongNet and it's a model based on the transformer that came from Microsoft Research Asia just two weeks ago I guess and the basic idea is that they want to scale the transformer but the wonderful idea is that they managed to scale it to 1 billion tokens.

Now if you're familiar with language models you know that the sequence length makes a huge impact on the performance of the model because the sequence length tells you how many tokens can relate to each other when performing the attention mechanism which allows the model for example to have a longer context or a shorter context for example for a model like GPT you want to be able to have a long context so that the model can watch words that were written a long time ago to calculate the next token for example and the LongNet is able to scale this to 1 billion token so to show you how amazing this is I want to show you this graph in which we show that for example the sequence length of GPT was just 512 and then we have this pulse transformer 12,064,262,1 million but however they go 1,000 times more with LongNet 1 billion tokens and just to imagine the scale it's really amazing because you can basically feed all the wikipedia of the text of wikipedia to the model and the model will be able to calculate the attention using all these tokens but let's see how this all works first of all LongNet as claimed by the authors has significant advantages the first is that it is linear the computation the computational complexity is linear with the sequence length and we will see why and how there is a logarithmic dependency between tokens so it means that basically the more distance the tokens the dependency is less powerful is less they are how to say the the attention mechanism is less powerful between two tokens that are very far from each other and more strong between two tokens that are close to each other and it can be trained on a distributed network it means that we can calculate this attention mechanism on a distributed system so multiple gpus or multiple computers and it is a drop-in replacement for standard attention which means that if we already have a model that uses the attention mechanism we can basically just replace the attention mechanism without changing all the rest of the model and it will work as before but with this new improved attention that can use longer sequence lengths and if you're not familiar with the transformer models i kindly ask you to watch my previous video in which i explained the attention model the attention mechanism and the transformer model and i will review it basically here to show you what was the problem with the attention mechanism before so here i have the slides from my previous video and we can see here the self-attention mechanism with the self-attention we had before we had the matrix called the q k and v and the q was basically the sentence which is a matrix of sequence length by d model d model is the size of the vector of the representing the embedding of each word and when we do the multiplication of query multiplied by the k or the transpose of the k to produce this matrix requires a number of operations that is n to the power of 2 multiplied by the d model why because so it's n to the power of 2 multiplied by the d model why is this the case well to produce for example this item here in this the output of the softmax we need to do the dot product of this word so the word your with the word your so the word with itself and the vector the dot product of two vectors that are d d model long is d model and we need to do this for all the items in this matrix and the items in this matrix are n to the power of 2 so the sequence to the power of 2 and this is the reason why the complexity of the self-attention before but also of the attention of the cross attention before was in the order of n to the power of 2 by d and this table is also this comparison is also present in the paper here so the vanilla attention had a complexity of n to the power of 2 multiplied by d but with this new model the long net we have an attention model an attention mechanism that is in the order of n multiplied by so it grows linearly with the sequence length and we will see how so here the introduction the authors claim that actually the sequence length is one of the main problems with the language models and how scaling it is a priority and here they showed how this work is better than the other basically because we reduce the number of computations to calculate the attention and how they scale it to 1 billion we will see all of this in detail i will also also show some visualizations of how this works the basic principle is attention allocation decreases exponentially as the distance between the tokens grow now let's have a look at the picture to see how this works before we had let's go to my previous slide before we had a matrix like this so we calculated the attention between all tokens with all other tokens but with the log net we don't do this imagine we have a sequence length of 16 and of course in the upper part of this matrix we don't calculate the attention because we want the model the attention mechanism to be causal so we don't want the token number one to be related to the token number eight but of course we want the token number eight to be able to watch the token number one for example and so the oldest part is empty and the second thing is instead of calculating all the attention all the dot products of all the tokens with all other tokens what we do is we split the sequence into small windows of different sizes so first we start with the size of four in this case the n is the number of tokens of this sentence we split into four segments and here are called segments each one of size four and we calculate the attention between all the words in this small box with all the other words and we do it for all these small segments here here we also see another parameter called deleted rate because we are not skipping any token so we are calculating all token with all other token we do it again this time however by increasing the size of the window so we don't use a window of size four we do use a window of size eight and we calculate the attention between each word and every other word in this window so basically and we do it all for all the windows until they cover all the sequence length again then we increase the we double the segment length so the size of the window that we watch but we also double the deletion rate so how many tokens we skip so we relate the token number zero with the token number four and then we skip three and then we do again the dot product and we then we skip three and we do again the dot product so we skip every three here we skip everyone here we skip zero why do we do this because we want smaller window we want the attention mechanism to be more precise because if you for example when you read a book you know when you read a paragraph the words in the paragraph should be very related to each other because they're talking about where something very specific this one can be thought of as a chapter so the in the chapter we don't want to relate all the words of all the chapters to each other but maybe some parts of the chapters because basically in the same chapter the paragraph more or less will talk about the same topics but it's not like we need the the dot product between all the words in the chapter with all the other words in the same chapter and then if we go to the book level we don't want the the dot product between every word of the book with all the other words but we want some general idea so basically we want some words to so the general theme of the book should be present but not every word with other words so this is the idea that we use also for this attention mechanism here for small windows so words that are very close to each other we do the dot product so words that are more far from each other we don't do all the dot products and for very big windows we do even less another thing is that we the number of dot products in each window no matter the size of the window or the relation rate is always the same so here for example we have four plus three plus two plus one dot products in this window and it's the same number of dot products that we have here and it's the same number of dot products that we have this that we have here now you may be wondering well this is not relating the token number one to the token number 16 for example right yeah but what if we overlapped all of them together we obtain something like this and you can see here that still the token number one is not related to the token number 16 but we can always find a connection a way of going from token number one to token number 16 by using intermediate tokens and we will see later how this is possible i also made a tool to visualize this and let's watch some details from the paper so first we start by introducing the vanilla transformer so this is the basically the the attention mechanism as in the paper attention is all you need and it's the same one that we saw here then basically here they describe what is the deleted attention so in the deleted attention we choose a w and the r so a window size and the deletion rate and we divide our sequence into n divided by w boxes like this here n is 16 if the segment length is 4 we will have 4 boxes if the segment length is 8 we will have 2 boxes etc and we also skip every r tokens actually r minus 1 tokens and all of this actually as you can see they are independent because here the attention mechanism to be calculated in this box for example you only need the to have available the embedding of the tokens that are in this box because there is no interconnection between these two boxes so this one and this one can be calculated in parallel okay and the next thing is that they calculate the softmax for each boxes so they calculate the attention in each of these boxes here and then they combine them together basically with the concatenation and the another important interesting thing the one we saw before is that they don't use just one r or one w they use a sequence of r and w's and we will see here that this sequence of r and w are geometric sequences with an alpha constant here so in this case the alpha constant is 2 basically what they do is they start with a small window so w1 for example equal to 4 then each time you multiply the previous window by 2 and also the dilation rate by 2 so from 4 we go to 8 from 8 we go to 16 until we reach the sequence length the same happens with the dilation rate at the beginning we don't skip any word then we start skipping 1 then we multiply it by 2 and we skip every 3 and they combine all of this together using these two equations basically they calculate the denominator of the softmax for each of this attention so all of this this and this and then they use it as a weight for a weighted average we can see it here and how to transform this into a multi-head attention well basically before we were start we for each of this combination of segment length and the dilation rate suppose we have four heads the segment length is 8 and the dilation rate is 2 as you know with the dilation rate of 2 we need to skip every second token so we can calculate it like this for the head number one we start from zero and we skip every other token so we calculate from zero and then we skip the one and then we skip we arrive to the two and then three we skip and four etc otherwise we can skip the zero and we start from the one so we pass we skip one and then we do the other and we keep the same dilation the dilated rate for the head number three and the head number four we do the same in this case the head number one and the head number three are the same because actually the the stride is smaller than the number of heads if we had a stride that was bigger than the number of heads or equal to the number of heads we would see four different patterns here but the basic idea of the multi-head is this one and let's look at the computational complexity of this model and how it is calculated well the computational complexity of this attention mechanism is basically given by the dot product that we do to calculate the attention so the soft max of the query multiplied by the key and with the vanilla transformer we had n to the power of 2 multiplied by d but here we have w divided by r to the power of 2 multiplied by d so w is our segment size so let's go here r is the dilation rate and what we can see here is that the number of dot products that we are doing is this one so this the size of the window is w divided by r so w divided by r is 4 so this is the size of the window right and w divided by r is also the number of tokens in this window for which we will calculate the dot product because you can see that this matrix here even if the size is 8 the number of actual dot product that we will do is actually not 8 by 8 but 4 by 4 because we are skipping every other token and even if the size of this window is 16 by 16 we will not be calculating 16 by 16 dot products we will be calculating 4 by 4 dot products because we are skipping three tokens and this is the idea behind the calculation of the complexity the fact that we are not calculating the dot product between all the window all the tokens in a window but only w divided by r multiplied by 2 is the number of of the size of the of the dot products that we will do and each dot product is involves a vector of dimension d so we also multiplied by d and divided by w is the number of boxes so for example if we are here if when w is 4 the number of boxes is also 4 because n divided by w 4 and when the number the w is 8 the number of boxes is 2 because 16 divided by 2 and when the sequence length is 16 so w is equal to 16 the number of boxes that we get is 16 by 16 so only one so we can see that the the number of floating point operations that we are doing is proportional to n divided by 2 so the number of boxes and in each box we will do w divided by r to the power of 2 multiplied by d operations because of the dot product and you may be wondering that this this the window size is still very big right so if you do it in numpy or in pytorch actually the the number of operations you will do for example for this window of size 16 by 16 is still 16 by 16 but there are better ways to represent what are called sparse matrices so this actually is a matrix that is sparse so if you create a matrix multiplication algorithm that knows this and that can take into consideration that this matrix is sparse then you can do many less operations first you can store less information because you know that most of the matrix is zero and the second thing is that you you can perform less operations so if you can just skip calculating the dot product for all the positions of this matrix that are you know are zero then you do less operations and i think the on the authors of the paper they created some custom kernel for cuda to do this the another thing the author shows okay here this is the number of floating point operations for one window size so for one w but we don't have one w we have many w's and we also know that these w's are according to a geometric sequence as written in the paper written here we set w and r to geometric sequences geometric sequence means that we take the previous w and to get the next w we multiply it by one alpha and this alpha is fixed starting from w zero or here for example w zero is equal to four and the dilation rate r zero is equal to one and every time they multiply by two and let's go back here okay so they need to combine the floating point operations for all of this w and r's and they do it here but considering that this w and r are actually the result of a geometric sequence this becomes depending on alpha and w zero so the initial w that you choose and if we watch this expression here we can see that it's the number of floating point operations that you need to do calculate this combined attention here so the combination of all these w and r's here is proportional to n and d it linear it grows linearly with n and d so just like it's written here another interesting fact is that even if two words are not connected to each other directly by a dot product we can calculate the information distance between them that is how many jumps you need to make to go from one token to the next let me explain this better for example let's watch my notebook that i made here this is a notebook that i made specifically for learning this model and i wanted to actually test how it works so we imagine we have a sequence length of 16 and so here in the my representation is from 0 to 15 not from 1 to 16 but the idea is the same and we know that we will be calculating for example the first attention will be calculated for this box this box this box this box then another one that will be this one this one and the last one that is this one and this is exactly the same is the combined attention that we see here so the overlapping of this attention this attention this attention is exactly the same just the colors are different now let's how let's look how the words for example the token number 0 and the token number 15 so the last token are related the idea is that we cannot go from token number 0 to token number 15 directly because there is no dot product between 0 and 15 but we can find a path to go there so from 0 to 15 we can go from 0 to 12 and from 12 to 15 let's see so from 0 to 12 there is a dot product right then there is a dot product between 12 and itself because in the attention mechanism we are always making the dot product between every node every token and itself and the token number 12 it's related to the token number 15 so there is a dot product between the token number 12 and 15 so actually the token number 0 it is related to the token number 15 through the token number 12 and we can find this path for all the tokens and i can prove it i i show it in this notebook that for example all the nodes are reachable from the node number zero so from the token number zero we can reach all the tokens by using different uh tokens as intermediate and in this paper let's go to the paper uh here they show that the maximum number of jumps that you need to make to go from one token to any other token is uh less than is gross with the logarithm of the sequence length that is if the sequence length is uh let's say 10 times bigger you don't need 10 times you don't need to make 10 times bigger jumps to to go from one token to the next this also so why do we are we talking about jumps because it means also how strong is the relationship between two tokens because if we calculate the dot product between two tokens then that means that the model will find that immediately that that dot product so the model will learn to relate immediately that two tokens but if we have intermediate tokens the model will make will take more iterations to find this connection between tokens so it will the connection between those two tokens will be more weak and this is what the authors claimed they claim that the attention mechanism is spread in such a way that the strength of the attention mechanism becomes weaker exponentially with the by increasing the sequence length and or in other words we can say that the number of jumps that you need to make grows with the logarithm of n and we can do the same for example with other length of tokens for example here i use the token sequence of sequence length of 16 but we can use 32 for example and visualize it and let's see if it's visualizable yeah so basically our um our log net will do this he will start with um small boxes of size four then he will also calculate the attention for the box size eight then also for the box size 16 and also for the box size 32 here we can see the overlap attention maxi so all the different sizes but also all the single groups so for example the the all the tokens that are directly connected to each other with different color so the token number zero is directly connected to the known to the token number three and and also the token number four is directly connected to the token number five because they are part of the same box when they are calculated but the other tokens they have to be inferred for example with a sequence length of 32 we can see that still the token number zero is reachable from every other token but by different number of steps for example to go from token number zero to token number 17 we need to pass from 16 let's see from 0 to 17 we cannot go directly because there is no dot product here but we can go to 16 there is a dot product here and 16 is related to itself also and 16 is also related to 17 so we actually can go from 0 to 17 by passing from 16 and this can be done for all the nodes and i also made a graph here to visualize this so from 0 we cannot go directly to 17 but we can go for to 16 and from 16 we can go to 17 and this is the idea of the long net we let's go back okay we don't calculate all the dot products to each other with each other so all the tokens with all the other tokens but we spread this attention mechanism in such a way that words that are very close to each other are directly connected and words that are far from each other are connected through other tokens and let's watch also in the paper they also show how the model can be trained in a distributed way well we already saw it because all of these boxes are actually independent from each other so to calculate for example the attention in this box here you need only the embedding of the token number 0 1 2 and 3 and that's it to calculate the attention of mechanism of this box here you need to have the embedding of the token number 0 2 and 4 and 6 but not of the other and to calculate this one the same etc and another interesting thing is that the number of dot products in each box is always constant so if we have we can choose the model in such a way that we each computer can hold at most that number of dot products and so this this mechanism is quite parallelizable and it's really important okay it's really important because it allows us to scale greatly with the because parallelization is very important for us because we can compute the model on the cloud or on different gpus and we take can take advantage of this parallelization another interesting thing is that the runtime we can see here that with the sequence length increasing we grow linearly with the with the runtime but not like exponentially with the vanilla transformer you can see here and then in the rest of the paper they show how the how the model performs to other previous models now my my point is also not to show actually the the results which you can look by yourself my my my goal was to actually show the attention mechanism of this new long net and i hope it was clear i hope also you will use my python notebook to experiment by yourself i show you how basically it works here you define the sequence length that you want to visualize and the notebook will only visualize short sequence length i think i set this to 32 so if it's bigger than 32 it will not be visualized because it's not easy to visualize it also and basically to calculate the the the distance between one token and the other token i just basically do a bfs breadth-first search it's a very ugly one unoptimized one doesn't matter because i built this notebook in half hour just for showing how the model works so you can you are invited to make it better if you want last thing that i didn't show and that is very interesting is that we can see that the maximum node distance from the node from zero to any other node is three and it's changing with the logarithm of n as you can see so if we are for example if our n is equal to let's say 16 let's do it again if the the sequence length is 16 we can see here that this is the path to go from the node number zero to any or any other node and the maximum distance to go from node number zero to any other node is two and it's just like the logarithm of n which is 16 you can also change which node you want to go from so if you want to calculate other paths for example here we say that i want to go from node number five to every other node and here we prove that the node all the nodes are reachable from the node number five and here we display the paths okay and this is the maximum distance from node number five to any other node and this is the graph i hope you like my video guys and i hope it was more or less clear how this this mechanism works i didn't explain all the equations i here i have a lot of i have a lot of comments reason on the sides because i like to write take notes when i read the paper mostly because i want to understand also the maths behind it so if you're interested in some parts just write in the comments and i will try to explain it better but i think most people just want to understand the mechanism and they are waiting for the official code to be released to actually watch how it works and i hope i didn't make any mistakes because basically there is no information online about the long net so everything i told you is all because of my research and i hope you enjoyed the video so please come back to my channel for more videos about deep learning and machine learning and have a great day