back to indexLongNet: Scaling Transformers to 1,000,000,000 tokens: Python Code + Explanation
Chapters
0:0 Introduction
5:30 How it works
14:0 Computational Complexity
19:17 Attention visualization
00:00:00.000 |
Hello guys, welcome to my channel. Today we will be talking about a new model that came recently 00:00:05.440 |
it's called LongNet and it's a model based on the transformer that came from Microsoft Research Asia 00:00:12.720 |
just two weeks ago I guess and the basic idea is that they want to scale the transformer 00:00:20.800 |
but the wonderful idea is that they managed to scale it to 1 billion tokens. Now if you're 00:00:27.120 |
familiar with language models you know that the sequence length makes a huge impact on the 00:00:33.120 |
performance of the model because the sequence length tells you how many tokens can relate to 00:00:40.080 |
each other when performing the attention mechanism which allows the model for example to have a 00:00:45.920 |
longer context or a shorter context for example for a model like GPT you want to be able to have 00:00:51.760 |
a long context so that the model can watch words that were written a long time ago to calculate 00:01:00.080 |
the next token for example and the LongNet is able to scale this to 1 billion token so to show you 00:01:07.760 |
how amazing this is I want to show you this graph in which we show that for example the sequence 00:01:12.960 |
length of GPT was just 512 and then we have this pulse transformer 12,064,262,1 million but however 00:01:24.160 |
they go 1,000 times more with LongNet 1 billion tokens and just to imagine the scale it's really 00:01:33.760 |
amazing because you can basically feed all the wikipedia of the text of wikipedia to the model 00:01:42.160 |
and the model will be able to calculate the attention using all these tokens but let's see 00:01:48.160 |
how this all works first of all LongNet as claimed by the authors has significant advantages the 00:01:54.640 |
first is that it is linear the computation the computational complexity is linear with the 00:02:00.320 |
sequence length and we will see why and how there is a logarithmic dependency between tokens so it 00:02:07.600 |
means that basically the more distance the tokens the dependency is less powerful is less they are 00:02:16.240 |
how to say the the attention mechanism is less powerful between two tokens that are very far 00:02:22.160 |
from each other and more strong between two tokens that are close to each other and it can be trained 00:02:29.120 |
on a distributed network it means that we can calculate this attention mechanism on a distributed 00:02:35.040 |
system so multiple gpus or multiple computers and it is a drop-in replacement for standard attention 00:02:42.240 |
which means that if we already have a model that uses the attention mechanism we can basically just 00:02:47.440 |
replace the attention mechanism without changing all the rest of the model and it will work as 00:02:53.280 |
before but with this new improved attention that can use longer sequence lengths and if you're not 00:03:00.720 |
familiar with the transformer models i kindly ask you to watch my previous video in which i 00:03:05.360 |
explained the attention model the attention mechanism and the transformer model and i will 00:03:11.600 |
review it basically here to show you what was the problem with the attention mechanism before 00:03:16.240 |
so here i have the slides from my previous video and we can see here the self-attention mechanism 00:03:22.080 |
with the self-attention we had before we had the matrix called the q k and v and the q was 00:03:27.440 |
basically the sentence which is a matrix of sequence length by d model d model is the size 00:03:33.440 |
of the vector of the representing the embedding of each word and when we do the multiplication 00:03:39.840 |
of query multiplied by the k or the transpose of the k to produce this matrix requires a number of 00:03:46.880 |
operations that is n to the power of 2 multiplied by the d model why because so it's n to the power 00:03:55.920 |
of 2 multiplied by the d model why is this the case well to produce for example this item here 00:04:02.000 |
in this the output of the softmax we need to do the dot product of this word so the word your with 00:04:07.760 |
the word your so the word with itself and the vector the dot product of two vectors that are d 00:04:16.400 |
d model long is d model and we need to do this for all the items in this matrix and the items 00:04:24.400 |
in this matrix are n to the power of 2 so the sequence to the power of 2 and this is the reason 00:04:30.320 |
why the complexity of the self-attention before but also of the attention of the cross attention 00:04:35.600 |
before was in the order of n to the power of 2 by d and this table is also this comparison is also 00:04:42.400 |
present in the paper here so the vanilla attention had a complexity of n to the power of 2 multiplied 00:04:48.000 |
by d but with this new model the long net we have an attention model an attention mechanism that is 00:04:55.360 |
in the order of n multiplied by so it grows linearly with the sequence length and we will see 00:05:01.920 |
how so here the introduction the authors claim that actually the sequence length is one of the 00:05:08.000 |
main problems with the language models and how scaling it is a priority and here they showed 00:05:15.600 |
how this work is better than the other basically because we reduce the number of computations to 00:05:20.720 |
calculate the attention and how they scale it to 1 billion we will see all of this in detail i will 00:05:26.640 |
also also show some visualizations of how this works the basic principle is attention allocation 00:05:33.200 |
decreases exponentially as the distance between the tokens grow now let's have a look at the 00:05:39.040 |
picture to see how this works before we had let's go to my previous slide before we had a matrix 00:05:47.520 |
like this so we calculated the attention between all tokens with all other tokens but with the log 00:05:53.840 |
net we don't do this imagine we have a sequence length of 16 and of course in the upper part of 00:06:00.240 |
this matrix we don't calculate the attention because we want the model the attention mechanism 00:06:04.240 |
to be causal so we don't want the token number one to be related to the token number eight but 00:06:09.040 |
of course we want the token number eight to be able to watch the token number one for example 00:06:13.680 |
and so the oldest part is empty and the second thing is instead of calculating all the attention 00:06:25.040 |
all the dot products of all the tokens with all other tokens what we do is we split the sequence 00:06:32.560 |
into small windows of different sizes so first we start with the size of four in this case the n is 00:06:39.680 |
the number of tokens of this sentence we split into four segments and here are called segments 00:06:46.800 |
each one of size four and we calculate the attention between all the words in this small 00:06:53.280 |
box with all the other words and we do it for all these small segments here here we also see another 00:07:00.880 |
parameter called deleted rate because we are not skipping any token so we are calculating all token 00:07:06.400 |
with all other token we do it again this time however by increasing the size of the window so 00:07:12.320 |
we don't use a window of size four we do use a window of size eight and we calculate the attention 00:07:18.240 |
between each word and every other word in this window so basically and we do it all for all the 00:07:27.200 |
windows until they cover all the sequence length again then we increase the we double the segment 00:07:35.200 |
length so the size of the window that we watch but we also double the deletion rate so how many 00:07:41.360 |
tokens we skip so we relate the token number zero with the token number four and then we skip three 00:07:49.040 |
and then we do again the dot product and we then we skip three and we do again the dot product so 00:07:53.600 |
we skip every three here we skip everyone here we skip zero why do we do this because we want 00:08:01.280 |
smaller window we want the attention mechanism to be more precise because if you for example when 00:08:06.720 |
you read a book you know when you read a paragraph the words in the paragraph should be very related 00:08:12.080 |
to each other because they're talking about where something very specific this one can be thought of 00:08:18.000 |
as a chapter so the in the chapter we don't want to relate all the words of all the chapters to 00:08:24.160 |
each other but maybe some parts of the chapters because basically in the same chapter the 00:08:30.960 |
paragraph more or less will talk about the same topics but it's not like we need the the dot 00:08:37.440 |
product between all the words in the chapter with all the other words in the same chapter 00:08:42.160 |
and then if we go to the book level we don't want the the dot product between every word of the book 00:08:49.120 |
with all the other words but we want some general idea so basically we want some words to so the 00:08:53.600 |
general theme of the book should be present but not every word with other words so this is the 00:08:58.960 |
idea that we use also for this attention mechanism here for small windows so words that are very 00:09:06.000 |
close to each other we do the dot product so words that are more far from each other we don't do all 00:09:11.760 |
the dot products and for very big windows we do even less another thing is that we the number of 00:09:19.760 |
dot products in each window no matter the size of the window or the relation rate is always the same 00:09:25.280 |
so here for example we have four plus three plus two plus one dot products in this window and it's 00:09:31.760 |
the same number of dot products that we have here and it's the same number of dot products that we 00:09:36.160 |
have this that we have here now you may be wondering well this is not relating the token number one to 00:09:46.160 |
the token number 16 for example right yeah but what if we overlapped all of them together we 00:09:53.120 |
obtain something like this and you can see here that still the token number one is not related 00:09:59.280 |
to the token number 16 but we can always find a connection a way of going from token number one 00:10:08.240 |
to token number 16 by using intermediate tokens and we will see later how this is possible i also 00:10:13.920 |
made a tool to visualize this and let's watch some details from the paper so first we start by 00:10:21.920 |
introducing the vanilla transformer so this is the basically the the attention mechanism as in the 00:10:29.280 |
paper attention is all you need and it's the same one that we saw here then basically here they 00:10:35.760 |
describe what is the deleted attention so in the deleted attention we choose a w and the r so a 00:10:42.800 |
window size and the deletion rate and we divide our sequence into n divided by w boxes like this 00:10:52.240 |
here n is 16 if the segment length is 4 we will have 4 boxes if the segment length is 8 we will 00:10:59.520 |
have 2 boxes etc and we also skip every r tokens actually r minus 1 tokens and all of this actually 00:11:10.720 |
as you can see they are independent because here the attention mechanism to be calculated in this 00:11:16.080 |
box for example you only need the to have available the embedding of the tokens that are in this box 00:11:23.520 |
because there is no interconnection between these two boxes so this one and this one can be calculated 00:11:30.240 |
in parallel okay and the next thing is that they calculate the softmax for each boxes so they 00:11:37.600 |
calculate the attention in each of these boxes here and then they combine them together basically 00:11:43.760 |
with the concatenation and the another important interesting thing the one we saw before is that 00:11:50.400 |
they don't use just one r or one w they use a sequence of r and w's and we will see here that 00:11:58.640 |
this sequence of r and w are geometric sequences with an alpha constant here so in this case the 00:12:05.280 |
alpha constant is 2 basically what they do is they start with a small window so w1 for example equal 00:12:11.520 |
to 4 then each time you multiply the previous window by 2 and also the dilation rate by 2 00:12:17.040 |
so from 4 we go to 8 from 8 we go to 16 until we reach the sequence length the same happens with 00:12:22.880 |
the dilation rate at the beginning we don't skip any word then we start skipping 1 then we multiply 00:12:28.800 |
it by 2 and we skip every 3 and they combine all of this together using these two equations 00:12:36.720 |
basically they calculate the denominator of the softmax for each of this attention so all of this 00:12:44.000 |
this and this and then they use it as a weight for a weighted average we can see it here 00:12:50.800 |
and how to transform this into a multi-head attention well basically 00:12:56.560 |
before we were start we for each of this combination of segment length and the 00:13:04.800 |
dilation rate suppose we have four heads the segment length is 8 and the dilation rate is 2 00:13:10.720 |
as you know with the dilation rate of 2 we need to skip every second token so we can calculate it 00:13:19.280 |
like this for the head number one we start from zero and we skip every other token so we calculate 00:13:24.880 |
from zero and then we skip the one and then we skip we arrive to the two and then three we skip 00:13:31.120 |
and four etc otherwise we can skip the zero and we start from the one so we pass we skip one and 00:13:38.880 |
then we do the other and we keep the same dilation the dilated rate for the head number three and the 00:13:44.960 |
head number four we do the same in this case the head number one and the head number three are the 00:13:49.920 |
same because actually the the stride is smaller than the number of heads if we had a stride that 00:13:56.720 |
was bigger than the number of heads or equal to the number of heads we would see four different 00:14:01.280 |
patterns here but the basic idea of the multi-head is this one and let's look at the computational 00:14:07.920 |
complexity of this model and how it is calculated well the computational complexity of this attention 00:14:14.000 |
mechanism is basically given by the dot product that we do to calculate the attention so the soft 00:14:20.240 |
max of the query multiplied by the key and with the vanilla transformer we had n to the power of 00:14:26.560 |
2 multiplied by d but here we have w divided by r to the power of 2 multiplied by d so w is our 00:14:34.800 |
segment size so let's go here r is the dilation rate and what we can see here is that the number 00:14:46.240 |
of dot products that we are doing is this one so this the size of the window is w divided by r 00:14:54.800 |
so w divided by r is 4 so this is the size of the window right and w divided by r is also the number 00:15:03.440 |
of tokens in this window for which we will calculate the dot product because you can see 00:15:09.600 |
that this matrix here even if the size is 8 the number of actual dot product that we will do is 00:15:16.000 |
actually not 8 by 8 but 4 by 4 because we are skipping every other token and even if the size 00:15:24.480 |
of this window is 16 by 16 we will not be calculating 16 by 16 dot products we will be 00:15:30.880 |
calculating 4 by 4 dot products because we are skipping three tokens and this is the idea behind 00:15:37.440 |
the calculation of the complexity the fact that we are not calculating the dot product between 00:15:43.280 |
all the window all the tokens in a window but only w divided by r multiplied by 2 is the number of 00:15:54.000 |
of the size of the of the dot products that we will do and each dot product is involves a vector 00:16:01.600 |
of dimension d so we also multiplied by d and divided by w is the number of boxes so for example 00:16:08.640 |
if we are here if when w is 4 the number of boxes is also 4 because n divided by w 4 and when the 00:16:16.960 |
number the w is 8 the number of boxes is 2 because 16 divided by 2 and when the sequence length is 00:16:23.520 |
16 so w is equal to 16 the number of boxes that we get is 16 by 16 so only one so we can see that 00:16:32.880 |
the the number of floating point operations that we are doing is proportional to n divided by 2 so 00:16:39.200 |
the number of boxes and in each box we will do w divided by r to the power of 2 multiplied by d 00:16:44.720 |
operations because of the dot product and you may be wondering that this this the window size is 00:16:52.720 |
still very big right so if you do it in numpy or in pytorch actually the the number of operations 00:16:59.520 |
you will do for example for this window of size 16 by 16 is still 16 by 16 but there are better 00:17:05.520 |
ways to represent what are called sparse matrices so this actually is a matrix that is sparse so if 00:17:11.600 |
you create a matrix multiplication algorithm that knows this and that can take into consideration 00:17:18.800 |
that this matrix is sparse then you can do many less operations first you can store less information 00:17:26.800 |
because you know that most of the matrix is zero and the second thing is that you you can perform 00:17:32.400 |
less operations so if you can just skip calculating the dot product for all the positions of this 00:17:38.080 |
matrix that are you know are zero then you do less operations and i think the on the authors 00:17:44.800 |
of the paper they created some custom kernel for cuda to do this the another thing the author shows 00:17:51.200 |
okay here this is the number of floating point operations for one window size so for one w but 00:17:58.560 |
we don't have one w we have many w's and we also know that these w's are according to a geometric 00:18:05.280 |
sequence as written in the paper written here we set w and r to geometric sequences geometric 00:18:12.080 |
sequence means that we take the previous w and to get the next w we multiply it by one alpha and 00:18:17.600 |
this alpha is fixed starting from w zero or here for example w zero is equal to four and the 00:18:25.360 |
dilation rate r zero is equal to one and every time they multiply by two and let's go back here 00:18:33.600 |
okay so they need to combine the floating point operations for all of this w and r's and they do 00:18:40.400 |
it here but considering that this w and r are actually the result of a geometric sequence 00:18:46.240 |
this becomes depending on alpha and w zero so the initial w that you choose and if we watch this 00:18:53.440 |
expression here we can see that it's the number of floating point operations that you need to do 00:18:59.520 |
calculate this combined attention here so the combination of all these w and r's here is 00:19:07.120 |
proportional to n and d it linear it grows linearly with n and d so just like it's written here 00:19:15.760 |
another interesting fact is that even if two words are not connected to each other directly 00:19:22.240 |
by a dot product we can calculate the information distance between them that is how many jumps you 00:19:29.920 |
need to make to go from one token to the next let me explain this better for example let's 00:19:36.640 |
watch my notebook that i made here this is a notebook that i made specifically for learning 00:19:42.160 |
this model and i wanted to actually test how it works so we imagine we have a sequence length of 00:19:48.480 |
16 and so here in the my representation is from 0 to 15 not from 1 to 16 but the idea is the same 00:19:55.840 |
and we know that we will be calculating for example the first attention will be calculated 00:20:01.920 |
for this box this box this box this box then another one that will be this one this one and 00:20:09.040 |
the last one that is this one and this is exactly the same is the combined attention that we see 00:20:16.960 |
here so the overlapping of this attention this attention this attention is exactly the same 00:20:22.480 |
just the colors are different now let's how let's look how the words for example the token number 00:20:29.440 |
0 and the token number 15 so the last token are related the idea is that we cannot go from token 00:20:35.920 |
number 0 to token number 15 directly because there is no dot product between 0 and 15 but we can find 00:20:42.320 |
a path to go there so from 0 to 15 we can go from 0 to 12 and from 12 to 15 let's see so from 0 to 00:20:50.640 |
12 there is a dot product right then there is a dot product between 12 and itself because in the 00:20:56.400 |
attention mechanism we are always making the dot product between every node every token and itself 00:21:01.440 |
and the token number 12 it's related to the token number 15 so there is a dot product between the 00:21:08.640 |
token number 12 and 15 so actually the token number 0 it is related to the token number 15 00:21:14.000 |
through the token number 12 and we can find this path for all the tokens and i can prove it i i 00:21:20.960 |
show it in this notebook that for example all the nodes are reachable from the node number zero so 00:21:27.680 |
from the token number zero we can reach all the tokens by using different uh tokens as intermediate 00:21:34.160 |
and in this paper let's go to the paper uh here they show that the maximum number of jumps that 00:21:42.080 |
you need to make to go from one token to any other token is uh less than is gross with the logarithm 00:21:51.520 |
of the sequence length that is if the sequence length is uh let's say 10 times bigger you don't 00:21:59.520 |
need 10 times you don't need to make 10 times bigger jumps to to go from one token to the next 00:22:05.680 |
this also so why do we are we talking about jumps because it means also how strong is the 00:22:11.840 |
relationship between two tokens because if we calculate the dot product between two tokens 00:22:16.880 |
then that means that the model will find that immediately that that dot product so the model 00:22:22.720 |
will learn to relate immediately that two tokens but if we have intermediate tokens the model will 00:22:28.960 |
make will take more iterations to find this connection between tokens so it will the 00:22:33.520 |
connection between those two tokens will be more weak and this is what the authors claimed 00:22:38.320 |
they claim that the attention mechanism is spread in such a way that the strength of the attention 00:22:45.360 |
mechanism becomes weaker exponentially with the by increasing the sequence length and or in other 00:22:54.560 |
words we can say that the number of jumps that you need to make grows with the logarithm of n 00:22:59.440 |
and we can do the same for example with other length of tokens for example here i use the token 00:23:06.160 |
sequence of sequence length of 16 but we can use 32 for example and visualize it 00:23:12.080 |
and let's see if it's visualizable yeah so basically our um our log net will do this 00:23:20.480 |
he will start with um small boxes of size four then he will also calculate the attention for 00:23:27.440 |
the box size eight then also for the box size 16 and also for the box size 32 here we can see the 00:23:36.720 |
overlap attention maxi so all the different sizes but also all the single groups so for example the 00:23:43.520 |
the all the tokens that are directly connected to each other with different color so the token 00:23:49.600 |
number zero is directly connected to the known to the token number three and and also the token 00:23:56.320 |
number four is directly connected to the token number five because they are part of the same box 00:24:00.320 |
when they are calculated but the other tokens they have to be inferred for example with a sequence 00:24:05.600 |
length of 32 we can see that still the token number zero is reachable from every other token 00:24:12.880 |
but by different number of steps for example to go from token number zero to token number 17 00:24:19.200 |
we need to pass from 16 let's see from 0 to 17 we cannot go directly because there is no dot product 00:24:26.880 |
here but we can go to 16 there is a dot product here and 16 is related to itself also and 16 is 00:24:33.760 |
also related to 17 so we actually can go from 0 to 17 by passing from 16 and this can be done for 00:24:41.440 |
all the nodes and i also made a graph here to visualize this so from 0 we cannot go directly 00:24:46.400 |
to 17 but we can go for to 16 and from 16 we can go to 17 and this is the idea of the long net we 00:24:56.880 |
let's go back okay we don't calculate all the dot products to each other with each other so all the 00:25:04.240 |
tokens with all the other tokens but we spread this attention mechanism in such a way that words 00:25:11.440 |
that are very close to each other are directly connected and words that are far from each other 00:25:16.800 |
are connected through other tokens and let's watch also in the paper they also show how the 00:25:28.240 |
model can be trained in a distributed way well we already saw it because all of these boxes 00:25:36.080 |
are actually independent from each other so to calculate for example the attention in this box 00:25:41.440 |
here you need only the embedding of the token number 0 1 2 and 3 and that's it to calculate 00:25:48.400 |
the attention of mechanism of this box here you need to have the embedding of the token number 00:25:54.720 |
0 2 and 4 and 6 but not of the other and to calculate this one the same etc and another 00:26:05.200 |
interesting thing is that the number of dot products in each box is always constant so if 00:26:10.800 |
we have we can choose the model in such a way that we each computer can hold at most that number of 00:26:20.240 |
dot products and so this this mechanism is quite parallelizable and it's really important okay 00:26:29.680 |
it's really important because it allows us to scale greatly with the because parallelization 00:26:35.920 |
is very important for us because we can compute the model on the cloud or on different gpus and 00:26:40.880 |
we take can take advantage of this parallelization another interesting thing is that the runtime 00:26:48.160 |
we can see here that with the sequence length increasing we grow linearly with the with the 00:26:54.080 |
runtime but not like exponentially with the vanilla transformer you can see here and then 00:26:59.280 |
in the rest of the paper they show how the how the model performs to other previous models 00:27:05.760 |
now my my point is also not to show actually the the results which you can look by yourself 00:27:11.840 |
my my my goal was to actually show the attention mechanism of this new long net and i hope it was 00:27:18.640 |
clear i hope also you will use my python notebook to experiment by yourself i show you how basically 00:27:26.240 |
it works here you define the sequence length that you want to visualize and the notebook will only 00:27:31.360 |
visualize short sequence length i think i set this to 32 so if it's bigger than 32 it will not be 00:27:37.040 |
visualized because it's not easy to visualize it also and basically to calculate the the the distance 00:27:45.680 |
between one token and the other token i just basically do a bfs breadth-first search it's a 00:27:51.760 |
very ugly one unoptimized one doesn't matter because i built this notebook in half hour 00:27:56.800 |
just for showing how the model works so you can you are invited to make it better if you want 00:28:03.520 |
last thing that i didn't show and that is very interesting is that we can see that the maximum 00:28:09.920 |
node distance from the node from zero to any other node is three and it's changing with the logarithm 00:28:16.640 |
of n as you can see so if we are for example if our n is equal to let's say 16 let's do it again 00:28:24.000 |
if the the sequence length is 16 we can see here that this is the path to go from the node number 00:28:32.080 |
zero to any or any other node and the maximum distance to go from node number zero to any 00:28:36.960 |
other node is two and it's just like the logarithm of n which is 16 you can also change which node 00:28:43.680 |
you want to go from so if you want to calculate other paths for example here we say that i want 00:28:48.480 |
to go from node number five to every other node and here we prove that the node all the nodes 00:28:55.120 |
are reachable from the node number five and here we display the paths okay and this is the maximum 00:29:01.360 |
distance from node number five to any other node and this is the graph i hope you like my video 00:29:08.160 |
guys and i hope it was more or less clear how this this mechanism works i didn't explain all 00:29:15.120 |
the equations i here i have a lot of i have a lot of comments reason on the sides because i like to 00:29:21.840 |
write take notes when i read the paper mostly because i want to understand also the maths 00:29:26.960 |
behind it so if you're interested in some parts just write in the comments and i will try to 00:29:30.960 |
explain it better but i think most people just want to understand the mechanism and they are 00:29:34.960 |
waiting for the official code to be released to actually watch how it works and i hope i didn't 00:29:40.080 |
make any mistakes because basically there is no information online about the long net so everything 00:29:45.040 |
i told you is all because of my research and i hope you enjoyed the video so please come back 00:29:52.320 |
to my channel for more videos about deep learning and machine learning and have a great day