back to index

LongNet: Scaling Transformers to 1,000,000,000 tokens: Python Code + Explanation


Chapters

0:0 Introduction
5:30 How it works
14:0 Computational Complexity
19:17 Attention visualization

Whisper Transcript | Transcript Only Page

00:00:00.000 | Hello guys, welcome to my channel. Today we will be talking about a new model that came recently
00:00:05.440 | it's called LongNet and it's a model based on the transformer that came from Microsoft Research Asia
00:00:12.720 | just two weeks ago I guess and the basic idea is that they want to scale the transformer
00:00:20.800 | but the wonderful idea is that they managed to scale it to 1 billion tokens. Now if you're
00:00:27.120 | familiar with language models you know that the sequence length makes a huge impact on the
00:00:33.120 | performance of the model because the sequence length tells you how many tokens can relate to
00:00:40.080 | each other when performing the attention mechanism which allows the model for example to have a
00:00:45.920 | longer context or a shorter context for example for a model like GPT you want to be able to have
00:00:51.760 | a long context so that the model can watch words that were written a long time ago to calculate
00:01:00.080 | the next token for example and the LongNet is able to scale this to 1 billion token so to show you
00:01:07.760 | how amazing this is I want to show you this graph in which we show that for example the sequence
00:01:12.960 | length of GPT was just 512 and then we have this pulse transformer 12,064,262,1 million but however
00:01:24.160 | they go 1,000 times more with LongNet 1 billion tokens and just to imagine the scale it's really
00:01:33.760 | amazing because you can basically feed all the wikipedia of the text of wikipedia to the model
00:01:42.160 | and the model will be able to calculate the attention using all these tokens but let's see
00:01:48.160 | how this all works first of all LongNet as claimed by the authors has significant advantages the
00:01:54.640 | first is that it is linear the computation the computational complexity is linear with the
00:02:00.320 | sequence length and we will see why and how there is a logarithmic dependency between tokens so it
00:02:07.600 | means that basically the more distance the tokens the dependency is less powerful is less they are
00:02:16.240 | how to say the the attention mechanism is less powerful between two tokens that are very far
00:02:22.160 | from each other and more strong between two tokens that are close to each other and it can be trained
00:02:29.120 | on a distributed network it means that we can calculate this attention mechanism on a distributed
00:02:35.040 | system so multiple gpus or multiple computers and it is a drop-in replacement for standard attention
00:02:42.240 | which means that if we already have a model that uses the attention mechanism we can basically just
00:02:47.440 | replace the attention mechanism without changing all the rest of the model and it will work as
00:02:53.280 | before but with this new improved attention that can use longer sequence lengths and if you're not
00:03:00.720 | familiar with the transformer models i kindly ask you to watch my previous video in which i
00:03:05.360 | explained the attention model the attention mechanism and the transformer model and i will
00:03:11.600 | review it basically here to show you what was the problem with the attention mechanism before
00:03:16.240 | so here i have the slides from my previous video and we can see here the self-attention mechanism
00:03:22.080 | with the self-attention we had before we had the matrix called the q k and v and the q was
00:03:27.440 | basically the sentence which is a matrix of sequence length by d model d model is the size
00:03:33.440 | of the vector of the representing the embedding of each word and when we do the multiplication
00:03:39.840 | of query multiplied by the k or the transpose of the k to produce this matrix requires a number of
00:03:46.880 | operations that is n to the power of 2 multiplied by the d model why because so it's n to the power
00:03:55.920 | of 2 multiplied by the d model why is this the case well to produce for example this item here
00:04:02.000 | in this the output of the softmax we need to do the dot product of this word so the word your with
00:04:07.760 | the word your so the word with itself and the vector the dot product of two vectors that are d
00:04:16.400 | d model long is d model and we need to do this for all the items in this matrix and the items
00:04:24.400 | in this matrix are n to the power of 2 so the sequence to the power of 2 and this is the reason
00:04:30.320 | why the complexity of the self-attention before but also of the attention of the cross attention
00:04:35.600 | before was in the order of n to the power of 2 by d and this table is also this comparison is also
00:04:42.400 | present in the paper here so the vanilla attention had a complexity of n to the power of 2 multiplied
00:04:48.000 | by d but with this new model the long net we have an attention model an attention mechanism that is
00:04:55.360 | in the order of n multiplied by so it grows linearly with the sequence length and we will see
00:05:01.920 | how so here the introduction the authors claim that actually the sequence length is one of the
00:05:08.000 | main problems with the language models and how scaling it is a priority and here they showed
00:05:15.600 | how this work is better than the other basically because we reduce the number of computations to
00:05:20.720 | calculate the attention and how they scale it to 1 billion we will see all of this in detail i will
00:05:26.640 | also also show some visualizations of how this works the basic principle is attention allocation
00:05:33.200 | decreases exponentially as the distance between the tokens grow now let's have a look at the
00:05:39.040 | picture to see how this works before we had let's go to my previous slide before we had a matrix
00:05:47.520 | like this so we calculated the attention between all tokens with all other tokens but with the log
00:05:53.840 | net we don't do this imagine we have a sequence length of 16 and of course in the upper part of
00:06:00.240 | this matrix we don't calculate the attention because we want the model the attention mechanism
00:06:04.240 | to be causal so we don't want the token number one to be related to the token number eight but
00:06:09.040 | of course we want the token number eight to be able to watch the token number one for example
00:06:13.680 | and so the oldest part is empty and the second thing is instead of calculating all the attention
00:06:25.040 | all the dot products of all the tokens with all other tokens what we do is we split the sequence
00:06:32.560 | into small windows of different sizes so first we start with the size of four in this case the n is
00:06:39.680 | the number of tokens of this sentence we split into four segments and here are called segments
00:06:46.800 | each one of size four and we calculate the attention between all the words in this small
00:06:53.280 | box with all the other words and we do it for all these small segments here here we also see another
00:07:00.880 | parameter called deleted rate because we are not skipping any token so we are calculating all token
00:07:06.400 | with all other token we do it again this time however by increasing the size of the window so
00:07:12.320 | we don't use a window of size four we do use a window of size eight and we calculate the attention
00:07:18.240 | between each word and every other word in this window so basically and we do it all for all the
00:07:27.200 | windows until they cover all the sequence length again then we increase the we double the segment
00:07:35.200 | length so the size of the window that we watch but we also double the deletion rate so how many
00:07:41.360 | tokens we skip so we relate the token number zero with the token number four and then we skip three
00:07:49.040 | and then we do again the dot product and we then we skip three and we do again the dot product so
00:07:53.600 | we skip every three here we skip everyone here we skip zero why do we do this because we want
00:08:01.280 | smaller window we want the attention mechanism to be more precise because if you for example when
00:08:06.720 | you read a book you know when you read a paragraph the words in the paragraph should be very related
00:08:12.080 | to each other because they're talking about where something very specific this one can be thought of
00:08:18.000 | as a chapter so the in the chapter we don't want to relate all the words of all the chapters to
00:08:24.160 | each other but maybe some parts of the chapters because basically in the same chapter the
00:08:30.960 | paragraph more or less will talk about the same topics but it's not like we need the the dot
00:08:37.440 | product between all the words in the chapter with all the other words in the same chapter
00:08:42.160 | and then if we go to the book level we don't want the the dot product between every word of the book
00:08:49.120 | with all the other words but we want some general idea so basically we want some words to so the
00:08:53.600 | general theme of the book should be present but not every word with other words so this is the
00:08:58.960 | idea that we use also for this attention mechanism here for small windows so words that are very
00:09:06.000 | close to each other we do the dot product so words that are more far from each other we don't do all
00:09:11.760 | the dot products and for very big windows we do even less another thing is that we the number of
00:09:19.760 | dot products in each window no matter the size of the window or the relation rate is always the same
00:09:25.280 | so here for example we have four plus three plus two plus one dot products in this window and it's
00:09:31.760 | the same number of dot products that we have here and it's the same number of dot products that we
00:09:36.160 | have this that we have here now you may be wondering well this is not relating the token number one to
00:09:46.160 | the token number 16 for example right yeah but what if we overlapped all of them together we
00:09:53.120 | obtain something like this and you can see here that still the token number one is not related
00:09:59.280 | to the token number 16 but we can always find a connection a way of going from token number one
00:10:08.240 | to token number 16 by using intermediate tokens and we will see later how this is possible i also
00:10:13.920 | made a tool to visualize this and let's watch some details from the paper so first we start by
00:10:21.920 | introducing the vanilla transformer so this is the basically the the attention mechanism as in the
00:10:29.280 | paper attention is all you need and it's the same one that we saw here then basically here they
00:10:35.760 | describe what is the deleted attention so in the deleted attention we choose a w and the r so a
00:10:42.800 | window size and the deletion rate and we divide our sequence into n divided by w boxes like this
00:10:52.240 | here n is 16 if the segment length is 4 we will have 4 boxes if the segment length is 8 we will
00:10:59.520 | have 2 boxes etc and we also skip every r tokens actually r minus 1 tokens and all of this actually
00:11:10.720 | as you can see they are independent because here the attention mechanism to be calculated in this
00:11:16.080 | box for example you only need the to have available the embedding of the tokens that are in this box
00:11:23.520 | because there is no interconnection between these two boxes so this one and this one can be calculated
00:11:30.240 | in parallel okay and the next thing is that they calculate the softmax for each boxes so they
00:11:37.600 | calculate the attention in each of these boxes here and then they combine them together basically
00:11:43.760 | with the concatenation and the another important interesting thing the one we saw before is that
00:11:50.400 | they don't use just one r or one w they use a sequence of r and w's and we will see here that
00:11:58.640 | this sequence of r and w are geometric sequences with an alpha constant here so in this case the
00:12:05.280 | alpha constant is 2 basically what they do is they start with a small window so w1 for example equal
00:12:11.520 | to 4 then each time you multiply the previous window by 2 and also the dilation rate by 2
00:12:17.040 | so from 4 we go to 8 from 8 we go to 16 until we reach the sequence length the same happens with
00:12:22.880 | the dilation rate at the beginning we don't skip any word then we start skipping 1 then we multiply
00:12:28.800 | it by 2 and we skip every 3 and they combine all of this together using these two equations
00:12:36.720 | basically they calculate the denominator of the softmax for each of this attention so all of this
00:12:44.000 | this and this and then they use it as a weight for a weighted average we can see it here
00:12:50.800 | and how to transform this into a multi-head attention well basically
00:12:56.560 | before we were start we for each of this combination of segment length and the
00:13:04.800 | dilation rate suppose we have four heads the segment length is 8 and the dilation rate is 2
00:13:10.720 | as you know with the dilation rate of 2 we need to skip every second token so we can calculate it
00:13:19.280 | like this for the head number one we start from zero and we skip every other token so we calculate
00:13:24.880 | from zero and then we skip the one and then we skip we arrive to the two and then three we skip
00:13:31.120 | and four etc otherwise we can skip the zero and we start from the one so we pass we skip one and
00:13:38.880 | then we do the other and we keep the same dilation the dilated rate for the head number three and the
00:13:44.960 | head number four we do the same in this case the head number one and the head number three are the
00:13:49.920 | same because actually the the stride is smaller than the number of heads if we had a stride that
00:13:56.720 | was bigger than the number of heads or equal to the number of heads we would see four different
00:14:01.280 | patterns here but the basic idea of the multi-head is this one and let's look at the computational
00:14:07.920 | complexity of this model and how it is calculated well the computational complexity of this attention
00:14:14.000 | mechanism is basically given by the dot product that we do to calculate the attention so the soft
00:14:20.240 | max of the query multiplied by the key and with the vanilla transformer we had n to the power of
00:14:26.560 | 2 multiplied by d but here we have w divided by r to the power of 2 multiplied by d so w is our
00:14:34.800 | segment size so let's go here r is the dilation rate and what we can see here is that the number
00:14:46.240 | of dot products that we are doing is this one so this the size of the window is w divided by r
00:14:54.800 | so w divided by r is 4 so this is the size of the window right and w divided by r is also the number
00:15:03.440 | of tokens in this window for which we will calculate the dot product because you can see
00:15:09.600 | that this matrix here even if the size is 8 the number of actual dot product that we will do is
00:15:16.000 | actually not 8 by 8 but 4 by 4 because we are skipping every other token and even if the size
00:15:24.480 | of this window is 16 by 16 we will not be calculating 16 by 16 dot products we will be
00:15:30.880 | calculating 4 by 4 dot products because we are skipping three tokens and this is the idea behind
00:15:37.440 | the calculation of the complexity the fact that we are not calculating the dot product between
00:15:43.280 | all the window all the tokens in a window but only w divided by r multiplied by 2 is the number of
00:15:54.000 | of the size of the of the dot products that we will do and each dot product is involves a vector
00:16:01.600 | of dimension d so we also multiplied by d and divided by w is the number of boxes so for example
00:16:08.640 | if we are here if when w is 4 the number of boxes is also 4 because n divided by w 4 and when the
00:16:16.960 | number the w is 8 the number of boxes is 2 because 16 divided by 2 and when the sequence length is
00:16:23.520 | 16 so w is equal to 16 the number of boxes that we get is 16 by 16 so only one so we can see that
00:16:32.880 | the the number of floating point operations that we are doing is proportional to n divided by 2 so
00:16:39.200 | the number of boxes and in each box we will do w divided by r to the power of 2 multiplied by d
00:16:44.720 | operations because of the dot product and you may be wondering that this this the window size is
00:16:52.720 | still very big right so if you do it in numpy or in pytorch actually the the number of operations
00:16:59.520 | you will do for example for this window of size 16 by 16 is still 16 by 16 but there are better
00:17:05.520 | ways to represent what are called sparse matrices so this actually is a matrix that is sparse so if
00:17:11.600 | you create a matrix multiplication algorithm that knows this and that can take into consideration
00:17:18.800 | that this matrix is sparse then you can do many less operations first you can store less information
00:17:26.800 | because you know that most of the matrix is zero and the second thing is that you you can perform
00:17:32.400 | less operations so if you can just skip calculating the dot product for all the positions of this
00:17:38.080 | matrix that are you know are zero then you do less operations and i think the on the authors
00:17:44.800 | of the paper they created some custom kernel for cuda to do this the another thing the author shows
00:17:51.200 | okay here this is the number of floating point operations for one window size so for one w but
00:17:58.560 | we don't have one w we have many w's and we also know that these w's are according to a geometric
00:18:05.280 | sequence as written in the paper written here we set w and r to geometric sequences geometric
00:18:12.080 | sequence means that we take the previous w and to get the next w we multiply it by one alpha and
00:18:17.600 | this alpha is fixed starting from w zero or here for example w zero is equal to four and the
00:18:25.360 | dilation rate r zero is equal to one and every time they multiply by two and let's go back here
00:18:33.600 | okay so they need to combine the floating point operations for all of this w and r's and they do
00:18:40.400 | it here but considering that this w and r are actually the result of a geometric sequence
00:18:46.240 | this becomes depending on alpha and w zero so the initial w that you choose and if we watch this
00:18:53.440 | expression here we can see that it's the number of floating point operations that you need to do
00:18:59.520 | calculate this combined attention here so the combination of all these w and r's here is
00:19:07.120 | proportional to n and d it linear it grows linearly with n and d so just like it's written here
00:19:15.760 | another interesting fact is that even if two words are not connected to each other directly
00:19:22.240 | by a dot product we can calculate the information distance between them that is how many jumps you
00:19:29.920 | need to make to go from one token to the next let me explain this better for example let's
00:19:36.640 | watch my notebook that i made here this is a notebook that i made specifically for learning
00:19:42.160 | this model and i wanted to actually test how it works so we imagine we have a sequence length of
00:19:48.480 | 16 and so here in the my representation is from 0 to 15 not from 1 to 16 but the idea is the same
00:19:55.840 | and we know that we will be calculating for example the first attention will be calculated
00:20:01.920 | for this box this box this box this box then another one that will be this one this one and
00:20:09.040 | the last one that is this one and this is exactly the same is the combined attention that we see
00:20:16.960 | here so the overlapping of this attention this attention this attention is exactly the same
00:20:22.480 | just the colors are different now let's how let's look how the words for example the token number
00:20:29.440 | 0 and the token number 15 so the last token are related the idea is that we cannot go from token
00:20:35.920 | number 0 to token number 15 directly because there is no dot product between 0 and 15 but we can find
00:20:42.320 | a path to go there so from 0 to 15 we can go from 0 to 12 and from 12 to 15 let's see so from 0 to
00:20:50.640 | 12 there is a dot product right then there is a dot product between 12 and itself because in the
00:20:56.400 | attention mechanism we are always making the dot product between every node every token and itself
00:21:01.440 | and the token number 12 it's related to the token number 15 so there is a dot product between the
00:21:08.640 | token number 12 and 15 so actually the token number 0 it is related to the token number 15
00:21:14.000 | through the token number 12 and we can find this path for all the tokens and i can prove it i i
00:21:20.960 | show it in this notebook that for example all the nodes are reachable from the node number zero so
00:21:27.680 | from the token number zero we can reach all the tokens by using different uh tokens as intermediate
00:21:34.160 | and in this paper let's go to the paper uh here they show that the maximum number of jumps that
00:21:42.080 | you need to make to go from one token to any other token is uh less than is gross with the logarithm
00:21:51.520 | of the sequence length that is if the sequence length is uh let's say 10 times bigger you don't
00:21:59.520 | need 10 times you don't need to make 10 times bigger jumps to to go from one token to the next
00:22:05.680 | this also so why do we are we talking about jumps because it means also how strong is the
00:22:11.840 | relationship between two tokens because if we calculate the dot product between two tokens
00:22:16.880 | then that means that the model will find that immediately that that dot product so the model
00:22:22.720 | will learn to relate immediately that two tokens but if we have intermediate tokens the model will
00:22:28.960 | make will take more iterations to find this connection between tokens so it will the
00:22:33.520 | connection between those two tokens will be more weak and this is what the authors claimed
00:22:38.320 | they claim that the attention mechanism is spread in such a way that the strength of the attention
00:22:45.360 | mechanism becomes weaker exponentially with the by increasing the sequence length and or in other
00:22:54.560 | words we can say that the number of jumps that you need to make grows with the logarithm of n
00:22:59.440 | and we can do the same for example with other length of tokens for example here i use the token
00:23:06.160 | sequence of sequence length of 16 but we can use 32 for example and visualize it
00:23:12.080 | and let's see if it's visualizable yeah so basically our um our log net will do this
00:23:20.480 | he will start with um small boxes of size four then he will also calculate the attention for
00:23:27.440 | the box size eight then also for the box size 16 and also for the box size 32 here we can see the
00:23:36.720 | overlap attention maxi so all the different sizes but also all the single groups so for example the
00:23:43.520 | the all the tokens that are directly connected to each other with different color so the token
00:23:49.600 | number zero is directly connected to the known to the token number three and and also the token
00:23:56.320 | number four is directly connected to the token number five because they are part of the same box
00:24:00.320 | when they are calculated but the other tokens they have to be inferred for example with a sequence
00:24:05.600 | length of 32 we can see that still the token number zero is reachable from every other token
00:24:12.880 | but by different number of steps for example to go from token number zero to token number 17
00:24:19.200 | we need to pass from 16 let's see from 0 to 17 we cannot go directly because there is no dot product
00:24:26.880 | here but we can go to 16 there is a dot product here and 16 is related to itself also and 16 is
00:24:33.760 | also related to 17 so we actually can go from 0 to 17 by passing from 16 and this can be done for
00:24:41.440 | all the nodes and i also made a graph here to visualize this so from 0 we cannot go directly
00:24:46.400 | to 17 but we can go for to 16 and from 16 we can go to 17 and this is the idea of the long net we
00:24:56.880 | let's go back okay we don't calculate all the dot products to each other with each other so all the
00:25:04.240 | tokens with all the other tokens but we spread this attention mechanism in such a way that words
00:25:11.440 | that are very close to each other are directly connected and words that are far from each other
00:25:16.800 | are connected through other tokens and let's watch also in the paper they also show how the
00:25:28.240 | model can be trained in a distributed way well we already saw it because all of these boxes
00:25:36.080 | are actually independent from each other so to calculate for example the attention in this box
00:25:41.440 | here you need only the embedding of the token number 0 1 2 and 3 and that's it to calculate
00:25:48.400 | the attention of mechanism of this box here you need to have the embedding of the token number
00:25:54.720 | 0 2 and 4 and 6 but not of the other and to calculate this one the same etc and another
00:26:05.200 | interesting thing is that the number of dot products in each box is always constant so if
00:26:10.800 | we have we can choose the model in such a way that we each computer can hold at most that number of
00:26:20.240 | dot products and so this this mechanism is quite parallelizable and it's really important okay
00:26:29.680 | it's really important because it allows us to scale greatly with the because parallelization
00:26:35.920 | is very important for us because we can compute the model on the cloud or on different gpus and
00:26:40.880 | we take can take advantage of this parallelization another interesting thing is that the runtime
00:26:48.160 | we can see here that with the sequence length increasing we grow linearly with the with the
00:26:54.080 | runtime but not like exponentially with the vanilla transformer you can see here and then
00:26:59.280 | in the rest of the paper they show how the how the model performs to other previous models
00:27:05.760 | now my my point is also not to show actually the the results which you can look by yourself
00:27:11.840 | my my my goal was to actually show the attention mechanism of this new long net and i hope it was
00:27:18.640 | clear i hope also you will use my python notebook to experiment by yourself i show you how basically
00:27:26.240 | it works here you define the sequence length that you want to visualize and the notebook will only
00:27:31.360 | visualize short sequence length i think i set this to 32 so if it's bigger than 32 it will not be
00:27:37.040 | visualized because it's not easy to visualize it also and basically to calculate the the the distance
00:27:45.680 | between one token and the other token i just basically do a bfs breadth-first search it's a
00:27:51.760 | very ugly one unoptimized one doesn't matter because i built this notebook in half hour
00:27:56.800 | just for showing how the model works so you can you are invited to make it better if you want
00:28:03.520 | last thing that i didn't show and that is very interesting is that we can see that the maximum
00:28:09.920 | node distance from the node from zero to any other node is three and it's changing with the logarithm
00:28:16.640 | of n as you can see so if we are for example if our n is equal to let's say 16 let's do it again
00:28:24.000 | if the the sequence length is 16 we can see here that this is the path to go from the node number
00:28:32.080 | zero to any or any other node and the maximum distance to go from node number zero to any
00:28:36.960 | other node is two and it's just like the logarithm of n which is 16 you can also change which node
00:28:43.680 | you want to go from so if you want to calculate other paths for example here we say that i want
00:28:48.480 | to go from node number five to every other node and here we prove that the node all the nodes
00:28:55.120 | are reachable from the node number five and here we display the paths okay and this is the maximum
00:29:01.360 | distance from node number five to any other node and this is the graph i hope you like my video
00:29:08.160 | guys and i hope it was more or less clear how this this mechanism works i didn't explain all
00:29:15.120 | the equations i here i have a lot of i have a lot of comments reason on the sides because i like to
00:29:21.840 | write take notes when i read the paper mostly because i want to understand also the maths
00:29:26.960 | behind it so if you're interested in some parts just write in the comments and i will try to
00:29:30.960 | explain it better but i think most people just want to understand the mechanism and they are
00:29:34.960 | waiting for the official code to be released to actually watch how it works and i hope i didn't
00:29:40.080 | make any mistakes because basically there is no information online about the long net so everything
00:29:45.040 | i told you is all because of my research and i hope you enjoyed the video so please come back
00:29:52.320 | to my channel for more videos about deep learning and machine learning and have a great day