back to indexTitans: Learning to Memorize at Test Time

00:00:00.000 |
Okay guys, so today we are going to talk about this paper, "Titans - Learning to Memorize 00:00:07.040 |
In this paper, we will be seeing first of all what is the problem we are trying to solve, 00:00:11.680 |
and then what is the solution proposed here, and then we will comment on what are the pros 00:00:18.000 |
The way I like to talk about papers is actually to give you the tools to understand the paper 00:00:23.400 |
yourself, so I don't like to just read the paper word by word, because that's something 00:00:26.680 |
you can do by yourself, so I like to talk about what is the background knowledge that 00:00:30.760 |
We work a little bit on there, then we look at the problem, and then we see the solution. 00:00:35.920 |
The problem we are talking about here is sequence modeling, and up to now, there are two main 00:00:40.800 |
ways in deep learning to do sequence modeling. 00:00:43.120 |
One is called the transformer, and the other is called the recurrent neural networks. 00:00:48.560 |
There are also hybrid variants, which combine a little bit of the attention mechanism with 00:00:54.480 |
So let's talk about how the sequence modeling is done in these two ways. 00:00:59.040 |
So open a new page, basically imagine you have a very long sequence, imagine let's talk 00:01:05.600 |
about language modeling, which is something we are all familiar with. 00:01:09.180 |
So imagine we want to train a language model. 00:01:11.280 |
How does the training of a language model work? 00:01:13.640 |
Usually we have a sequence, so we want to teach the language model to predict the next 00:01:19.040 |
So we have a sequence of tokens, so let's say this is our sequence of tokens. 00:01:24.560 |
So the first token is let's say "I", then the second token is "like", so I always pretend 00:01:32.600 |
like one token is a word and one word is a token, which is not actually the case, but 00:01:36.740 |
for simplicity we will think like it is "I like to eat", let's just say "pizza", okay. 00:01:48.600 |
Imagine we want to train a language model to generate this exact phrase. 00:01:53.240 |
What we do is basically we need a kind of model, which could be a transformer, but it 00:01:57.740 |
could be also a recurrent neural network, and we force it to predict the next token. 00:02:02.120 |
So the job of sequence modeling means that we have some input sequence, and we will call 00:02:07.920 |
it the input, and we are trying to map it to some output, which is this one, this one, 00:02:23.360 |
The language modeling that we do is usually, the model that we train usually is called 00:02:29.040 |
an autoregressive language model, which means that it is, when it makes its prediction, 00:02:34.920 |
it can use all the past words to choose, to predict what is the next word, which means 00:02:40.680 |
that the model, imagine the model should be able to predict exactly this, to generate 00:02:47.280 |
exactly this sentence, so the model, whenever it's fed, it's prompted with the word "I", 00:02:58.480 |
Whenever it's prompted with "I like", it should predict the word "to". 00:03:03.160 |
And whenever it's prompted with "I like to", it should predict "eat". 00:03:11.380 |
Whenever it's, etc, so as you can see a pattern here, right, and whenever it's prompted with 00:03:17.320 |
all the sentences, it should say "end of sentence", which means, okay, it's a special token that 00:03:22.640 |
says, okay, I'm done with the generation process. 00:03:27.540 |
So we take some sentence, which could be a document, which could be a web page, anything, 00:03:31.880 |
we shift the words by one position, and we force the language model to predict the next 00:03:37.040 |
And there are two principal models to do that. 00:03:40.720 |
So let's say that in between here, we have something called the transformer. 00:03:47.800 |
The transformer basically allows us to do this language modeling in such a way, through 00:03:52.360 |
the attention mechanism, such a way that this language modeling, the output of the language 00:03:56.720 |
model, which is used to compute the loss upon which the language model is trained, can be 00:04:04.280 |
Basically, this is also the reason most language models today are transformer based, because 00:04:09.840 |
we want to leverage the GPUs, so if we can parallelize some operations, it's better. 00:04:15.700 |
On the other hand, we also have recurrent neural networks. 00:04:19.640 |
Later we will see what are the problems with transformer and recurrent neural networks, 00:04:37.920 |
So let's call it, by the way, this is called the target sequence. 00:04:44.340 |
So this one, when you train a model, this is called the target sequence. 00:04:48.640 |
You compare what is the actual output of the transformer with what you want the transformer 00:04:53.320 |
to output, which is the target, and you compute the loss, and then you back propagate, based 00:05:00.560 |
on the gradient, you back propagate to update the parameters of the model. 00:05:03.720 |
So the model is forced to learn to generate the target given the input. 00:05:12.100 |
We can take this one and replace the transformer with the recurrent neural network. 00:05:17.600 |
And the problem with the recurrent neural network is that it's not parallelizable. 00:05:25.000 |
Recently there are recurrent neural networks that also have like, by exploiting the parallel 00:05:31.520 |
scan, they can actually be parallelized, but up to now they are not used in practice. 00:05:47.480 |
The transformer, I will not be talking about the attention mechanism, I suppose you already 00:05:51.480 |
know that, but it's not even important, you just need to remember that the transformer 00:05:55.040 |
is parallelizable, and the recurrent neural network in its basic form is not parallelizable. 00:06:00.040 |
So the recurrent neural networks work as follows. 00:06:03.440 |
You feed, when you want to train them, or even when you want to inference them, you, 00:06:08.120 |
because we are doing sequence modeling, and imagine we want to train a language model 00:06:11.320 |
to learn this exact sentence here, so I like to eat pizza, this sentence here, the way 00:06:18.600 |
we train them is as follows, so we take the first token, so the word I, we feed it to 00:06:25.600 |
the recurrent neural network, so let's call it recurrent RNN, the recurrent neural network 00:06:36.800 |
will produce an output, which is something we don't know, but we force it to learn the 00:06:44.040 |
target, so the target we want is, well, when it sees I, it should predict like, right? 00:06:52.640 |
So it should predict like, based on what it actually produces and what is the target, 00:07:01.640 |
Then we take the last output, the recurrent neural network not only produces the output 00:07:08.160 |
token, it also produces a state, which encapsulates information about all the input the model 00:07:16.040 |
This is called the hidden state of the recurrent neural network, or also the memory of the 00:07:23.980 |
So we use this state to, again, feed another token to the RNN, so let me put the input 00:07:32.120 |
below, actually, I think it's easier to visualize this way, so the input here was the word I. 00:07:38.680 |
Then this will produce a new hidden state, and this will, let's call it the hidden state 00:07:45.440 |
We feed it again to the recurrent neural network along with a new input, the next token is 00:07:52.440 |
like, and the recurrent neural network will predict something, but we force it to learn 00:07:58.240 |
to predict the word to, for the second time step. 00:08:04.260 |
How can it predict the word to, just given the word like, well, by leveraging the recurrent 00:08:10.000 |
state from the previous time step, which encapsulates all the information about the word I. 00:08:14.480 |
That's how it can predict to by, actually, the recurrent neural network is seeing I like, 00:08:21.320 |
like directly as the input, and I indirectly because it's in its hidden state. 00:08:27.480 |
Now, we can do it also for the third token, so the to, hidden state to, this is also another 00:08:36.120 |
recurrent neural network, and we feed it the token to, which will produce some output, 00:08:40.800 |
and we don't know what it is, but we force it to learn to predict the word it. 00:08:51.000 |
Because it can see that the input is to, but it can also see the history of the input so 00:09:00.480 |
When we use the transformer, the transformer can, to predict a particular token, for example, 00:09:06.280 |
the token pizza, it can leverage all the previous input because it's fed all at the same time 00:09:15.360 |
And this input during training is, are the keys and the values, and during inference, 00:09:25.720 |
So the transformer, in order to predict a particular token, can always see the entire 00:09:30.200 |
sequence, and that's why it's parallelizable. 00:09:32.240 |
So we feed the entire sequence to the transformer to predict each position, because we feed 00:09:36.680 |
the entire sequence, the transformer can see the entire sequence, and it can compute the 00:09:44.240 |
However, with the recurrent neural network, we cannot compute the output at each position 00:09:49.960 |
in parallel, so we have to do it one step at a time. 00:09:56.480 |
The advantage of the transformer is that it is parallelizable, so we can train massive 00:10:00.600 |
models by just increasing the number of GPUs. 00:10:03.520 |
The problem of the recurrent neural network, because it's not parallelizable, we are limited 00:10:07.360 |
because we have to do one, kind of a for loop to train them, so first we generate the first 00:10:12.760 |
one, and then generate the second one, and then the third one, et cetera, et cetera, 00:10:18.920 |
So and then there are other problems, like the vanishing gradients, et cetera, et cetera, 00:10:25.600 |
So the problem of the recurrent neural networks are two, actually. 00:10:29.160 |
First of all, it's not parallelizable, so this one is not parallelizable. 00:10:34.400 |
And the second problem is that we have this recurrent state, this is called also the hidden 00:10:40.240 |
state of the recurrent neural network, which is fixed in size. 00:10:45.320 |
So it can be as big as you want, it can be one megabyte, one gigabyte, whatever you like, 00:10:51.120 |
So once you have chosen your architecture, it's fixed. 00:10:53.680 |
On the other hand, when we use a transformer model, the size of the input that the language 00:11:04.000 |
Because when you use, for example, a prompt on chargeGPT, imagine you just feed the first 00:11:08.600 |
two, imagine chargeGPT was trained exactly on this sentence here, and suppose you only 00:11:13.680 |
feed the first token, I, what chargeGPT will do, it will predict the first token using 00:11:19.200 |
only I, then it will take the word like, put it back into the input, feed it again all 00:11:24.000 |
to the transformer, so I like, and the transformer will predict this next token. 00:11:29.280 |
And then it will take the word to put it back into the input, so I like to put all these 00:11:35.000 |
three tokens in the language model, to the transformer model, and then it will be able 00:11:39.000 |
to predict the transposition, et cetera, et cetera. 00:11:41.440 |
So the hidden state, so the memory of the transformer, so the stuff that we feed to 00:11:46.200 |
the transformer in order to predict the next token, is actually growing, and this is also 00:11:53.240 |
So when doing very long sequence modeling, we need two things. 00:11:57.360 |
First of all, we would like to be able for the language model to use all the input it 00:12:03.780 |
has seen so far, and that's something easily, that we can easily do with a transformer, 00:12:09.040 |
however the problem is that with the transformer we have a growing memory, because we need 00:12:12.480 |
to always put all the input in the transformer, all the tokens in the transformer, for it 00:12:19.720 |
Or if we have limited memory, we can use a recurrent neural network, but they are not 00:12:24.200 |
parallelizable during training, and the second problem is that they have a fixed memory. 00:12:29.720 |
The fixed memory also has another problem, because it's fixed, we cannot choose what 00:12:34.480 |
is inside, so sometimes the language model may see some information, and sometimes it 00:12:39.720 |
will not be able to see some information, it's like you take one person and you ask 00:12:46.480 |
the person to memorize 3,000 books, I don't think the person will be able to do it, because 00:12:51.520 |
our brain is fixed in size, and the same is the problem with recurrent neural networks. 00:12:58.480 |
Moreover, we have seen many architectures that are trying to improve this memorization 00:13:06.840 |
capability of the recurrent neural networks, for example Mamba, in which they use a particular 00:13:11.840 |
shape of the matrix called the hypometrix, that allows to memorize information in a more 00:13:17.440 |
effective way, however in practice they don't work as well as we think. 00:13:23.560 |
Now in this paper, they say, imagine, ok first of all, before we can talk about this paper, 00:13:32.560 |
So how do we train language models is as follows, I mean, now let's talk about the architecture 00:13:39.140 |
So usually we have some tokens, so let's say some, let's call them input, let me do it 00:13:46.680 |
vertically, I think it's easier, so we have some input tokens, we convert them into embeddings, 00:13:59.860 |
these embeddings are fed to a series of layers of transformers, so for example layer 1, layer 00:14:08.240 |
2, etc, etc, until they produce some output, these are called the logits, logits, now what 00:14:18.400 |
happens with the transformer and with the recurrent neural network is as follows, with 00:14:23.520 |
the transformer we have a growing memory, so we have this thing called the kvcache that 00:14:28.720 |
contains all the past tokens, so the transformer can always leverage all the past tokens to 00:14:36.560 |
In the recurrent neural network, we have a past memory that compresses all the past tokens 00:14:42.840 |
into a fixed size memory, that however has its own problem because sometimes the information 00:14:49.360 |
is lost because it's fixed and you're trying to squeeze in a lot of stuff, so we cannot 00:14:54.040 |
decide what is inside, we just hope that the network learns to keep the most important 00:15:00.520 |
information and forgets about the less important information. 00:15:07.560 |
The problem is when we train a language model, we feed it a lot of data, so for example we 00:15:13.360 |
train the language model on the entire wikipedia, we train it on the entire web, and a lot of 00:15:18.960 |
books, so the model has seen kind of all the possible data that exists in this world, we 00:15:28.600 |
hope that when we have, imagine we have a model, a hybrid model, so a transformer but 00:15:35.600 |
with also a recurrent neural network, so imagine that this, suppose that this one here is an 00:15:42.400 |
attention layer, so a transformer layer, let's call it attention, and this one is a recurrent 00:15:48.580 |
neural network, and suppose that this is one of the new fancy recurrent networks that can 00:15:53.480 |
be parallelized actually, there are new architectures actually that can be parallelized, but still 00:16:00.480 |
the problem is that this information here, the RNN, will produce a memory that is fixed 00:16:06.280 |
in size, so if you feed 1000 tokens, this one will contain, will output a memory that 00:16:12.520 |
will be leveraged by the attention that will not be 1000 tokens, it will be less, because 00:16:18.600 |
the goal of the RNN is to compress stuff into some fixed size memory that can be leveraged 00:16:26.140 |
by the transformer model, which is this layer here, attention layer here, the attention 00:16:31.720 |
layer here however is very good at leveraging the data it is being fed, but this data is 00:16:37.220 |
not all the sequence because we have compressed it with the recurrent neural network, and 00:16:43.760 |
we hope that the attention can leverage the information that was compressed by the recurrent 00:16:48.160 |
neural network to do its job of predicting the next token, if we do it this way, so imagine 00:16:55.040 |
we have this architecture which is a hybrid architecture of attention plus recurrent neural 00:16:59.240 |
network, the problem with this architecture is that when you train it, because we do it 00:17:06.840 |
with deep learning, we force the model to learn whatever target we have, it will be 00:17:12.000 |
forced to learn this recurrent neural network to compress the information in such a way 00:17:17.920 |
that the attention can use it, and the attention will be forced to extract whatever information 00:17:23.200 |
is in this compressed state made by the recurrent neural network, this is good, so when you train 00:17:30.080 |
it actually the loss decreases and you see that it performs quite well, however when 00:17:34.520 |
you use it in practice, the problem that you feed to the model may not be something that 00:17:40.280 |
the language model has seen in the past, so maybe we call this data out of distribution 00:17:47.720 |
so the model may not know how to compress it well, what to keep and what to not keep, 00:17:52.940 |
so in this case the recurrent neural network will fail at its task of compressing data, 00:17:58.640 |
and because the data necessary to predict the next token was not compressed well, the 00:18:03.600 |
attention layer will not be able to leverage this data to predict the next token, so at 00:18:08.400 |
training we see that this hybrid architecture works really fine, but at test time, so when 00:18:13.480 |
we use them, we actually see that they don't work quite well, and this is one of the reasons, 00:18:18.480 |
so they learn to compress the data, they have seen very well, so they know, ok, if I have 00:18:24.040 |
a long source code of Python, I should not concentrate on the, I don't know, some comments 00:18:33.400 |
that maybe are repetitive, but I should concentrate on the code, or maybe I should not, when I 00:18:37.480 |
see some C# code or C code, I should not concentrate on the, maybe the parentheses, because they 00:18:45.780 |
are just, how to say, redundant, but I should concentrate on the expressions, etc, etc, 00:18:53.420 |
so when it sees, so it actually learns to compress the information, but only the information 00:19:02.780 |
that it has seen at training time, now finally we can talk about the paper, so the paper 00:19:09.300 |
claim is, we have these models that need some kind of memory, in the transformer models 00:19:16.100 |
we have this KVCache, the problem with this KVCache, it's growing, so the problem with 00:19:21.700 |
the growing KVCache is that it requires a lot of memory, so actually most models are 00:19:28.540 |
not constrained, the fact that we cannot have a context window in the current models, very 00:19:35.660 |
big is because of the actually inference cost of this model, so they are really, really 00:19:41.900 |
expensive to inference, because we need to keep the KVCache, and the KVCache is one for 00:19:47.860 |
each layer, and the bigger models, they have a lot of layers, so you need to keep all the 00:19:51.460 |
tokens for each of the layers of the model, for each token that you need to predict, so 00:19:58.060 |
it's very expensive, and then the solution to have this infinite memory that keeps growing 00:20:03.300 |
is to have a compressed memory, but this compressed memory only works very well at training time, 00:20:08.100 |
so the claim is, can we have a memory module that is trained at test time, and that's why 00:20:15.860 |
we are talking about learning to memorize at test time, that is effective at retrieval, 00:20:22.900 |
because the goal of the memory is to retrieve the information that is salient, that is needed 00:20:26.620 |
by the model, that is effective in retrieving the information that is being fed exactly 00:20:33.780 |
at test time, not only the one that it has seen at the training time, this is the problem 00:20:38.900 |
that we are trying to solve with titans, now the way they do it is as follows, so they 00:20:45.500 |
say ok, imagine we have a module, imagine we have a module that we will call M, and 00:20:53.180 |
this module let's think of it as a layer in a module, so ok let me draw actually, I think 00:21:00.540 |
it's much easier if we can draw it, let's add a new paper, a new page, so ok, imagine 00:21:09.540 |
we have a very long sequence, we have seen that with the recurrent neural network the 00:21:13.460 |
job of the recurrent neural network is to compress this very long sequence so that the 00:21:17.500 |
transformer can use it, let's do with titans now, how does it differ, and then we will 00:21:25.100 |
check all the details, so we have this input, so let's go here again, so we have this input, 00:21:34.740 |
we transform into embeddings, then we, I will draw a little differently and then later I 00:21:43.380 |
will explain why, we have some, suppose we have a hybrid architecture again of transformer 00:21:48.460 |
and recurrent layers, but I will not draw the recurrent layers, so this is the first 00:21:53.100 |
layer of the, too big I think, let's call it L1, so the first layer with attention, 00:22:00.020 |
the second layer with attention, the third layer with attention, and then we have the 00:22:09.220 |
output which is the logits, ok, I think now it's more visible right, ok, so imagine we 00:22:19.140 |
have another module in this architecture that we will call the memory module, let's call 00:22:26.380 |
it neural memory because this is how they call it here, so let's call it neural memory, 00:22:31.300 |
and I will draw it as external module neural memory, now I want to show you how it would 00:22:45.520 |
work with the neural memory and then we check the detail on how it is actually trained, 00:22:51.580 |
so the way we usually train modules, so imagine, ok let's take a step back, how would we train 00:22:57.700 |
this module, we would feed it a sequence, imagine 1 million tokens, so imagine a very 00:23:04.020 |
big sequence, so let's say 1 million tokens, you convert this sequence of tokens, 1 million 00:23:09.780 |
tokens into embeddings, you run these embeddings in the neural networks, recurrent neural network 00:23:17.620 |
which will compress this 1 million tokens maybe in let's say 1000 tokens because its 00:23:22.060 |
goal is to compress stuff right, so the sequence that is fed to the attention because the goal 00:23:27.860 |
of the problem of the attention is that it's quadratic, so having a smaller input results 00:23:35.300 |
in a better computation, so we feed this 1000 compressed token to the attention and then 00:23:41.900 |
we force it to predict the next token only leveraging this 1000 compressed token, so 00:23:48.660 |
we feed 1 million token but we force the attention layer to predict the next token only leveraging 00:23:55.060 |
much less information, so we hope that the recurrent neural network is good at choosing 00:23:59.860 |
the right tokens to keep and discarding the one that it doesn't keep, actually ok it's 00:24:05.900 |
not really a token pruning mechanism, it's a token compression mechanism but ok you can 00:24:13.260 |
think of it as a token pruning, like it's being fed 1 million tokens and it just keeps 00:24:17.580 |
the top 1000 that are the most important for predicting the next token, and this is done 00:24:26.900 |
at training time, so we feed this 1 million token at training time, we compute the output, 00:24:33.320 |
we know what should be the next token because at training time we know what is the next 00:24:36.540 |
token, we force, we compute the loss with respect to what we think should be the next 00:24:41.660 |
token and then we back propagate to update the parameters of the model and we keep doing 00:24:45.940 |
it for all the sequences that we have, with the titans it would work differently, imagine 00:24:52.580 |
you have 1 million token again and what you do is you do 2 steps, the first thing that 00:25:02.260 |
we do, ok we have this input, we convert it into embeddings, the first thing we do is 00:25:09.260 |
in the training loop, so imagine we are training this titans architecture, we first train this 00:25:15.180 |
neural module to learn to memorize our 1 million tokens and then we ask it to retrieve the 00:25:25.000 |
information necessary for predicting the next token and feed it to the attention layer, 00:25:30.380 |
so this is, let's call it attention layer, so this is an attention layer, this is an 00:25:36.780 |
attention layer and this is an attention layer, so look at the difference here, before we 00:25:44.060 |
had an input, we predicted the output, we compute the loss and we back propagate and 00:25:50.660 |
we update all the parameters of the model, here we will do something different, we have 00:25:54.980 |
an input which is 1 million tokens, we convert them into embeddings, we train this module 00:26:02.380 |
here which is separate and in the paper they refer to it as the inner loop of the training, 00:26:09.340 |
we train this neural memory, and later we will see how we train it, with the sole purpose 00:26:15.180 |
for this neural memory to learn everything about this data so that it can easily retrieve 00:26:23.260 |
this data when we will need it, so we take this 1 million tokens, we convert them into 00:26:28.660 |
embeddings, we train this neural memory in an inner loop, then we take this neural memory 00:26:37.620 |
which has been trained to memorize this data and then we ask it to retrieve whatever information 00:26:45.620 |
is important from whatever it has seen, and use it as input for the attention layers here, 00:26:52.020 |
so that the attention layers can leverage this compressed memory to produce the output 00:26:57.740 |
and predict the next token, this not only at the training but also at test time, so 00:27:05.100 |
when we use the attention with the hybrid architectures, for example attention plus 00:27:10.040 |
recurrent neural networks at test time, so at inference time, what we have is usually 00:27:14.380 |
a prompt, imagine this prompt is huge because you are asking chargeBD for example to analyze 00:27:19.420 |
the entire github repository of a very big repository, what will happen is that this 00:27:26.820 |
1 million token will be fed to the recurrent neural network which is fixed now, so we are 00:27:32.020 |
using the model, so we are not changing its parameters anymore, the recurrent neural network 00:27:38.460 |
his job is to compress data, so it will compress these tokens into a smaller sequence that 00:27:43.740 |
we will fed to the attention layer and it will produce the output logits, however maybe 00:27:50.500 |
the information that we are feeding to this recurrent neural networks are kind of out 00:27:55.260 |
of distribution and the recurrent neural network has never seen something like this, and it 00:27:59.620 |
will do probably a very bad job at compressing this data, so because it will do a very bad 00:28:04.820 |
job at compressing this data, because it doesn't know what to keep and what not to keep, the 00:28:09.220 |
attention layer will not be able to leverage the most important information and then it 00:28:13.620 |
will not be able to predict the next token very well, so it will result in a bad output, 00:28:20.220 |
and with titans even at test time, so even at inference time, we are actually training 00:28:27.540 |
a model, and now I show you how, imagine now we have again a github repository, and it's 00:28:34.460 |
very big and it results in 1 million tokens that we want the language model to analyze, 00:28:39.660 |
we convert it into embeddings, then we take this 1 million tokens, we train on the fly 00:28:46.980 |
this neural memory, whose job will be to just learn as much information as possible about 00:28:53.660 |
this 1 million tokens, retrieve the most salient information, because the neural memory's job 00:28:58.860 |
is to compress information, so now after we have trained it in this inner loop, we retrieve 00:29:04.260 |
this information, we feed it to the attention layers, then the attention layers should be 00:29:09.260 |
able to leverage the information retrieved by the neural memory, so with titans basically 00:29:18.940 |
we don't just have a RNN, which is our memory that is trained at training time and then 00:29:27.660 |
never trained again, and every time it sees something that it has never seen, it just 00:29:31.340 |
goes crazy, we have a neural memory that can be trained at inference time, on the fly, 00:29:39.860 |
with the sole purpose of compressing stuff, and because we are training it at inference 00:29:44.740 |
time, we hope that it will perform better even on data it has never seen, now according 00:29:52.620 |
to the benchmark they published in the paper, but this actually happens in all papers, so 00:29:56.420 |
you never trust the benchmarks, it looks like it is doing a good job, now let's look at 00:30:02.500 |
the details, so I want to remind you, the problem we are solving is long context modeling, 00:30:07.780 |
long context modeling has one issue, which is with the transformer it is very expensive 00:30:12.860 |
to inference for long context, with RNNs we have the problem that we train them on some 00:30:19.780 |
data, but when you use them on something that they have never seen, they don't know how 00:30:23.260 |
to compress and what to keep and what to not keep, so they go crazy, and because they go 00:30:28.400 |
crazy they don't do this job very well, the attention layers cannot leverage this information, 00:30:33.820 |
so they just result in very bad output, with the neural network memory we want to train 00:30:39.660 |
on the fly a memory while inferencing the module, to just do the job of compressing 00:30:46.220 |
stuff on whatever data it is fed, now we can look at the details, ok, here they do some 00:30:54.860 |
preliminary, how to say, view of what is memory, what is linear attention, etc, etc, we don't 00:31:02.220 |
care about that for now, they say ok, imagine we have a memory module that only has two 00:31:07.860 |
operations, one is the write operation and one is the read operation, we want to write 00:31:14.700 |
and read at inference time and also at training time to this memory, how do we train this 00:31:21.820 |
memory, first of all this memory, neural memory, is a neural network by itself, meaning that 00:31:28.920 |
you can think of it as an external neural network that is separated from the rest of 00:31:35.540 |
the architecture, that will use this neural memory, so you need to think that you have 00:31:44.140 |
like a transformer module that is leveraging this neural memory, now how to train this 00:31:51.100 |
neural memory at inference time, because that's our problem, at training time we know how 00:31:55.540 |
to do it, we just put the input, compute the output, back propagate, and voila, how to 00:32:02.460 |
do that at inference time, it's what they see here, they say ok, imagine we have this 00:32:07.580 |
memory, first of all, how we want to update its information, they want to update its information, 00:32:15.820 |
ok, another step back, what we want this memory to do, we want this memory to learn, to extract 00:32:24.220 |
information about whatever they should memorize, and for that they use a very particular law, 00:32:30.980 |
which is kind of the reconstruction law, so imagine we have this memory, if we ask it 00:32:37.780 |
to memorize, ok imagine we have an input sequence, let's call it x, this xt here, we project 00:32:48.380 |
it with two linear projections called wk and wv, which are basically the same equivalent 00:32:54.040 |
of the one that we use in the attention mechanism, how can this memory do its job very well, 00:33:04.300 |
only if it learns to recreate the data it has seen, and this is the loss that you see 00:33:12.740 |
here, this is just the L2 loss that you can see here, which basically it learns to memorize 00:33:20.180 |
the mapping between a projection called key and a projection called v of the same data, 00:33:27.320 |
so it kind of learns to recreate the same data, this is the job of the memory, so if 00:33:34.460 |
I put some stuff I should be able to retrieve the same stuff, so I should be able to get 00:33:39.660 |
as much as possible from the stuff that I put inside, how to train it, how to train 00:33:47.420 |
it is they say ok I have this memory, I want to update this memory by using kind of a gradient 00:33:56.260 |
descent, so how gradient descent works, imagine we have an neural network, the basic version 00:34:03.780 |
of gradient descent work as follows, so we have a neural network with some parameters, 00:34:11.500 |
let's call them theta, so let's say theta, the parameters theta at time i, so at the 00:34:20.000 |
step i of the training, are updated with the previous parameters of the model, so at the 00:34:26.940 |
previous time, minus a learning rate that we will call gamma, multiplied by the gradient 00:34:34.740 |
of the loss with respect to the parameters of the model, the gradient tells us how we 00:34:43.180 |
should change the parameters in order to maximize a loss, but we move against the direction 00:34:51.940 |
of this gradient and that's why you see a sign minus, so we update the parameters in 00:34:58.420 |
the direction opposite to the one that would maximize the loss, so we update the parameters 00:35:05.300 |
to reduce the loss, and this is what we do here, we say we want to update our memory 00:35:12.260 |
in such a way such that we minimize this loss here, which is the memorization loss, which 00:35:19.940 |
is the reconstruction loss that we saw before, so a loss that tells if I ask the memory to 00:35:26.580 |
retrieve some information, which is the key projection of the data, it should recreate 00:35:32.540 |
this data, and this memory, in the paper, they model it as a linear layer, so a linear 00:35:43.660 |
layer is just a matrix multiplication with a weight matrix, so this memory module, so 00:35:48.260 |
m here, is nothing more than just a weight matrix of a linear layer, so we are modifying 00:35:57.820 |
this weight matrix, so the neural memory is just a matrix, w, we are modifying this w 00:36:06.900 |
in such a way that it reduces the reconstruction loss of the data, just the way we train a 00:36:17.540 |
neural network, so we train the neural network with parameters to reduce a loss, and these 00:36:23.620 |
parameters are calculated in such a way that they will result in the smallest loss possible, 00:36:29.980 |
in the same way we are updating this w matrix, which will be our memory, in such a way that 00:36:37.220 |
it will result in the minimum loss information possible, because that's the loss against 00:36:42.780 |
which we are optimizing it, which is the reconstruction loss. 00:36:48.860 |
And they call it the surprise, so this gradient of the w matrix, which is our memory, with 00:36:56.940 |
respect to the gradient of the loss, with respect to the w of this memory, they call 00:37:03.820 |
it the surprise, because the bigger the loss, the bigger difficulty the model had in reconstructing 00:37:12.380 |
its data, so it means that the model is surprised to see this data, so that's why they call 00:37:22.460 |
If you have ever studied how optimizers work, you will remember that in deep learning we 00:37:30.420 |
have this thing called momentum, so usually we don't update the parameters of the model 00:37:36.060 |
naively like this, because, for example, sometimes we want to retain the... we want to... first 00:37:44.420 |
of all, we don't want the... okay, first of all, the loss is computed with mini-batch 00:37:49.780 |
gradient descent, and it means that we don't compute it over all the input data set, but 00:37:58.420 |
over instances of data, so like a small batch of data, and the direction of this gradient 00:38:05.540 |
is actually stochastic, which means that it is not the true direction of the gradient, 00:38:11.380 |
which means that it oscillates from what it... it oscillates, so it is not indicating the 00:38:18.940 |
true direction, imagine the true direction of the gradient is here, but if we train it 00:38:22.840 |
on the first batch, maybe it's in this direction, maybe on the next batch it's in this direction, 00:38:28.780 |
maybe on the next batch on this direction, etc, on average it will point to the correct 00:38:32.500 |
direction of the gradient, but it will be noisy in each step, because we don't want 00:38:37.060 |
to take steps too confidently in each step of training, we add this momentum term, and 00:38:45.240 |
the momentum term basically kind of creates an exponentially moving average of all the 00:38:51.860 |
gradients, so that we also keep some information about the past gradient that we have computed 00:38:58.140 |
to smooth out the change of the weights, so that we don't take too much, so it's not like 00:39:04.700 |
we don't weight each step in the same way, and the idea for them to introduce the surprise 00:39:13.740 |
is as follows, they said ok, if I train my memory to recreate the data, then it can miss 00:39:28.100 |
this new data after it sees some novel data, so maybe there is some new data that the model 00:39:37.460 |
should memorize, but the gradient kind of disappears after a while, so the model will 00:39:42.980 |
miss it, so in order to avoid this mechanism, they use the momentum, just like we do when 00:39:49.060 |
doing model training, and they call it the past surprise, and this past surprise is nothing 00:39:57.020 |
more than the term past gradient in the optimizers that we use, for example the Adamo optimizer, 00:40:08.420 |
and then the momentary surprise, which is the gradient with respect to the current input, 00:40:13.420 |
so rehearse what we have said so far, we have this memory, which is just a w matrix, that 00:40:19.500 |
we want to optimize in such a way, so we want to change this w continuously, with every 00:40:28.700 |
token that we receive, in such a way that it encapsulates all the information that it 00:40:34.380 |
are in this input, and we can, how do we know it captures all the information in this input, 00:40:42.780 |
because we ask it to minimize the loss, the reconstruction loss of the input, now the 00:40:51.180 |
problem is we don't want to do this training of this novel model just during training, 00:40:58.420 |
but we also want to do it during inference time, because if we do it only during training, 00:41:02.860 |
what happens is that during inference time, every time it will see some new information 00:41:07.240 |
that it has never seen, probably it will do a bad job at compressing, so it will not work, 00:41:11.420 |
so how to do that at inference time, what we will do practically is as follows, so at 00:41:18.180 |
inference time, imagine we have inputs, so the first input, let me write all these formulas 00:41:26.100 |
actually so that we can refer to them, here, this one, and I paste it here, and then we 00:41:39.460 |
also copy the loss, this one, ok, let's learn how it would work at inference time, imagine 00:41:51.540 |
we have 1 million tokens, and ok, actually no, imagine we want to generate a lot of tokens 00:41:58.140 |
and we start with one token only, so the prompt is only one token, what will happen is we 00:42:03.100 |
have this one token, so let's call it one token, we feed it to the model as input, which 00:42:14.460 |
will be converted into embeddings, which will be only one embedding, and we want to train 00:42:19.580 |
our neural memory on this one single token, so it should learn to recreate this one single 00:42:25.580 |
token, how we will do that in practice, we take the memory, first of all we take this 00:42:34.120 |
one embedding and we project it into key and value by doing a matrix multiplication of 00:42:39.500 |
this single token with a matrix called WK and another called WB, then we compute this, 00:42:50.220 |
this is called the retrieval of the memory, and the retrieval, because the memory is modeled 00:42:55.060 |
only as a W matrix of a linear layer, the retrieval of the information from this memory 00:43:00.740 |
will just be W multiplied by the input, and the input actually they call it QT, so it's 00:43:07.900 |
another projection of the input to the WQ matrix, so this KT comes from WK multiplied 00:43:15.700 |
by X, and this VT comes from WB multiplied by X, and this QT comes from WQ multiplied 00:43:26.700 |
by X, this W here is the W of the memory, so this is the memory parameters, and this 00:43:33.780 |
is the memory, so it's the parameters of the memory, but it is also the memory itself, 00:43:39.340 |
we want to update this W, ok, so how to do that, so we project the information of a single 00:43:45.820 |
token with WV, we project it with WK, we compute this term here which is just W multiplied 00:43:54.820 |
by this term here, we compute this loss here, and we compute its gradient, the gradient 00:44:02.900 |
of this loss can be computed with the following formula, they actually specify, I can show 00:44:08.980 |
you also how to derive it actually, so there is a formula here for the gradient, this is 00:44:14.980 |
how we compute the gradient of the loss, how to compute this formula, well, how to derive 00:44:23.020 |
it, let's talk about it, but ok, ok, one second, so they compute the gradient of this loss 00:44:32.500 |
with respect to the parameters of the model, what are the parameters of the model? W, ok, 00:44:38.140 |
so they compute the gradient of the loss of this loss with respect to W, and then we need 00:44:43.300 |
to update W, how to update W? we need to compute this ST term here, this ST term results in 00:44:51.140 |
the pass surprise, but we don't have any pass surprise, so let's suppose this is 0 for now, 00:44:55.740 |
multiplied by a learning rate, multiplied by this theta, theta t is the learning rate, 00:45:05.780 |
multiplied by this gradient that we have computed, and then we update this W using this term 00:45:11.980 |
ST, now we have updated our memory, then we retrieve information from this memory, how 00:45:17.860 |
to retrieve information from this memory? we just take this W and we multiply it by 00:45:21.820 |
the, we take X, so our single token, we project it with another matrix called WQ, so that 00:45:29.700 |
it becomes a QT, we multiply it by W, and now we retrieve information, this information 00:45:35.940 |
is then sent to the first layer of the model, as compressed past information, and then to 00:45:43.460 |
the second, to the third, etc, etc, to predict the output, the model will produce the first 00:45:49.140 |
output token, then usually we put this output token back into the prompt, to generate the 00:45:55.620 |
next token, here, because we are not talking about just a transformer model, we are talking 00:46:00.580 |
about a hybrid architecture that has attention layers plus neural memory, we need to update 00:46:06.100 |
our neural memory with this new incoming token, so this new incoming token will again be used 00:46:13.220 |
to update the memory, the memory will be updated with the information of the new token, it 00:46:18.660 |
will not be replaced with only this new token, so we hope that the new memory will encapsulate 00:46:24.740 |
information about the first token that we fed before and the current token, what we 00:46:30.660 |
will do practically, we will take this new token that was output by the model, we will 00:46:35.540 |
project it through WV, and it will become VT, we will project it through WK and it will 00:46:40.860 |
become KT, we compute this loss term, we compute the gradient of this loss, and we update our 00:46:49.300 |
neural memory like before, but we have the past surprise this time, so, because we are 00:46:54.900 |
not just, and we also have the previous memory, so we are updating this W, and hopefully this 00:47:00.500 |
will contain information about the token number 2 and the token number 1 that we fed before, 00:47:06.580 |
now as you can see, because we are training the neural memory at test time, because now 00:47:11.540 |
we are inferencing the model, we hope that it will perform better than a neural memory 00:47:16.740 |
that has only been trained at training time, because at each step of this update, the neural 00:47:28.100 |
memory is actually trying to minimize the loss against this particular data, not only 00:47:33.380 |
the data that it has seen during training, but only exactly on this particular data that 00:47:37.700 |
is seen exactly in this moment, I know that I fed you with a lot of information, but I 00:47:43.500 |
hope now it should be a little more clear on practically what it means to have an inner 00:47:48.800 |
loop and an outer loop, so when we train the model, we update the parameters of this big 00:47:55.700 |
model to leverage whatever the memory creates, and the memory does not learn to compress 00:48:04.420 |
information only at training time, but also at inference time, exactly on the data that 00:48:09.780 |
you feed it at inference time. Now let's talk about the problems of this memory, so the 00:48:15.180 |
problem of this memory is that every time, as you can see, every time we need to run 00:48:19.060 |
a gradient descent on each single token, so this looks like it takes, you need to train 00:48:25.860 |
the model, you have a very big list of tokens and you want to train it as fast as possible, 00:48:32.900 |
but if you need to update the memory one token at a time, it's very slow, but fortunately 00:48:37.560 |
in the paper they also propose an algorithm to parallelize this training, and this training 00:48:45.980 |
can be parallelized actually not on the full sequence, but only chunk by chunk, which is 00:48:51.120 |
still better than doing one token at a time, so imagine you have one million tokens, if 00:48:56.020 |
we cannot parallelize it, it means ok, first take the first token, update the memory, then 00:49:00.880 |
take the second token, update the memory, then third token, update the memory, so we 00:49:03.980 |
need to do one million times this and we cannot exploit our GPUs because we have to do one 00:49:10.100 |
operation at a time, what they propose in the paper is a hybrid algorithm, so it's not 00:49:17.020 |
fully parallelizable on this entire sequence, but chunk by chunk, which is a good compromise, 00:49:22.060 |
it means that if you choose, imagine you have one million tokens and you choose a chunk 00:49:27.260 |
size of let's say 1,000, you can parallelize the first 1,000 tokens, then you take the 00:49:37.380 |
next 1,000 token and you parallelize this one, so in total you will compute 1,000 steps 00:49:42.720 |
and not one million steps, if you choose a chunk size of 1,000 over a sequence length 00:49:48.060 |
of one million, they also say ok how to leverage this neural memory module, you can use it 00:49:54.860 |
as a contextual memory, means that if you have a hybrid architecture in which you have 00:49:58.700 |
attention and this neural memory, so the one like the one we draw before, what we can do 00:50:06.820 |
is we take the sequence that is input by the user, because the neural memory, it's job 00:50:12.700 |
of the neural memory is just to compress information, we retrieve whatever is in the memory, we 00:50:18.660 |
append it to the sequence, prepend it to the sequence, along with some other persistent, 00:50:25.340 |
ok we can even not talk about the persistent memory tokens because I believe they just 00:50:29.260 |
overdid all this stuff, I mean this system could work even without the persistent memory 00:50:34.820 |
tokens, so we take our sequence, we prepend whatever information is in the memory, we 00:50:43.780 |
feed it to the attention module and we use the output of the attention to update the 00:50:48.300 |
memory and to produce the output, so let's go to our architecture, in this case basically 00:50:57.500 |
it would mean, imagine we have fed already 10 tokens to this memory and now we are trying 00:51:04.220 |
to predict the 11th token, what it would mean is that I would take this 11th token, I would 00:51:11.980 |
input, convert it into embeddings, I would retrieve whatever is inside the neural memory, 00:51:19.660 |
so imagine the neural memory gives me, because it's job is compressing right, even if I fed 00:51:24.100 |
it 10 tokens, it doesn't have to return me 10 tokens, it has to return me a compressed 00:51:27.660 |
version of these 10 tokens, suppose the ratio is like, suppose that the compressed state 00:51:33.620 |
is 5 tokens, so I would take these 5 tokens, prepend it to my single token, it would become 00:51:39.020 |
6 tokens, I fed it to the first attention layer, take the output of the attention, update 00:51:44.060 |
it and combine it with the output of the attention to get the output of this layer and feed it 00:51:50.020 |
to the next one, this is the neural memory as context usage, the other usage is memory 00:51:58.460 |
as gate, which is this architecture here, so in this case I have our 11th token, don't 00:52:09.220 |
think about persistent memory, I believe, it's just an overdoing, you don't have to 00:52:17.780 |
use persistent memory to make this mechanism work, they take this 11th token, they put 00:52:25.180 |
it in the memory, so now we update first the memory, and they also feed it to the attention, 00:52:32.260 |
and then they combine the output of the neural memory, which contains 11 tokens, but when 00:52:37.620 |
we retrieve it only gives us 5 tokens, and then the output of the attention, which we 00:52:41.580 |
only fed 1 token, and it's combined to produce the output, or you can only use the memory 00:52:47.260 |
as a module without any attention, which means that basically you skip all this part, so 00:52:53.460 |
you take your input, which could be 1 token, 1 million token, whatever, you update the 00:52:58.300 |
memory continuously, you take the compressed version of the memory, and you feed it directly 00:53:03.020 |
to the linear layer that will produce the logits, this is what they refer to as memory 00:53:08.780 |
as layer, honestly you can create 1 million variants of this architecture, the point is 00:53:15.660 |
not how you use it, the point is how it works, so I want to punctualize how it works, so 00:53:22.300 |
we are training a module at test time, which is different from what we do with recurrent 00:53:29.140 |
neural networks, so recurrent neural networks are trained at training time, and their job 00:53:33.860 |
is to compress data, but because they do very well the job of compressing the data they 00:53:41.500 |
have seen, they may not function very well during inference, because they may see some 00:53:47.240 |
data that they have never seen, however, by having a memory like this that you can train 00:53:51.900 |
at inference time, and with an algorithm that is supposedly parallelizable, we can avoid 00:53:59.180 |
hopefully this problem, because the only job of the memory is to be able to retrieve, so 00:54:04.860 |
I actually like this paper because I believe that it's a novel idea that I didn't think 00:54:10.580 |
about before, and I think it's ok, this is part of a bigger, actually ok, I've been researching 00:54:18.140 |
a little bit about this area for a while, it's called test time training, but this particular 00:54:25.020 |
architecture was a little bit innovative in this field, what else do we need to know to 00:54:34.380 |
read this paper, I think now you should have the information to read this paper, because 00:54:39.660 |
we have talked about how to update this memory, and what is this memory, this memory is just 00:54:44.180 |
a linear layer, in the paper they also say that ok, this memory doesn't have to be just 00:54:49.220 |
a linear layer, it can be a multilayer perceptron, so it can be for example two layers with an 00:54:54.500 |
activation in between, and it will work in the same way, and the algorithm that they 00:54:58.620 |
have devised that is parallelizable would work also with this multilayer memory, we 00:55:07.820 |
didn't talk about persistent memory, but the persistent memory are just tokens that 00:55:11.580 |
are prepended to the input, and they don't belong to the neural memory, they belong to 00:55:16.980 |
the outer loop as they call it here, the outer loop is just this model here, and this is 00:55:25.820 |
the inner loop, but ok, this system can work without persistent tokens, this is my claim, 00:55:34.260 |
if you look at the benchmark, it looks like that compared to the other architectures that 00:55:38.260 |
are like Mamba and the current neural networks, it performs better, if you check the average 00:55:46.140 |
score over all these benchmarks, I believe ok, this is a promising area of research, 00:55:52.100 |
I will probably be looking forward to the code which has not been released yet, but 00:55:57.940 |
thank you guys for spending time with me, I hope I gave you enough at least intuitions 00:56:02.760 |
into how it is happening, and I'm also really eager to look at the code, because I think 00:56:07.020 |
the best way to learn about a new architecture is actually to look at the code, so have a