Okay guys, so today we are going to talk about this paper, "Titans - Learning to Memorize at Test Time". In this paper, we will be seeing first of all what is the problem we are trying to solve, and then what is the solution proposed here, and then we will comment on what are the pros and the cons.
The way I like to talk about papers is actually to give you the tools to understand the paper yourself, so I don't like to just read the paper word by word, because that's something you can do by yourself, so I like to talk about what is the background knowledge that you need.
We work a little bit on there, then we look at the problem, and then we see the solution. So let's talk about the problem. The problem we are talking about here is sequence modeling, and up to now, there are two main ways in deep learning to do sequence modeling.
One is called the transformer, and the other is called the recurrent neural networks. There are also hybrid variants, which combine a little bit of the attention mechanism with the recurrent neural networks, etc. So let's talk about how the sequence modeling is done in these two ways. So open a new page, basically imagine you have a very long sequence, imagine let's talk about language modeling, which is something we are all familiar with.
So imagine we want to train a language model. How does the training of a language model work? Usually we have a sequence, so we want to teach the language model to predict the next token. So we have a sequence of tokens, so let's say this is our sequence of tokens.
So the first token is let's say "I", then the second token is "like", so I always pretend like one token is a word and one word is a token, which is not actually the case, but for simplicity we will think like it is "I like to eat", let's just say "pizza", okay.
Imagine we want to train a language model to generate this exact phrase. What we do is basically we need a kind of model, which could be a transformer, but it could be also a recurrent neural network, and we force it to predict the next token. So the job of sequence modeling means that we have some input sequence, and we will call it the input, and we are trying to map it to some output, which is this one, this one, this one, this one.
The language modeling that we do is usually, the model that we train usually is called an autoregressive language model, which means that it is, when it makes its prediction, it can use all the past words to choose, to predict what is the next word, which means that the model, imagine the model should be able to predict exactly this, to generate exactly this sentence, so the model, whenever it's fed, it's prompted with the word "I", it should output the word "like".
Whenever it's prompted with "I like", it should predict the word "to". And whenever it's prompted with "I like to", it should predict "eat". Whenever it's, etc, so as you can see a pattern here, right, and whenever it's prompted with all the sentences, it should say "end of sentence", which means, okay, it's a special token that says, okay, I'm done with the generation process.
This is how we train language models. So we take some sentence, which could be a document, which could be a web page, anything, we shift the words by one position, and we force the language model to predict the next token. And there are two principal models to do that.
One is called the transformer. So let's say that in between here, we have something called the transformer. The transformer basically allows us to do this language modeling in such a way, through the attention mechanism, such a way that this language modeling, the output of the language model, which is used to compute the loss upon which the language model is trained, can be done in parallel.
Basically, this is also the reason most language models today are transformer based, because we want to leverage the GPUs, so if we can parallelize some operations, it's better. On the other hand, we also have recurrent neural networks. Later we will see what are the problems with transformer and recurrent neural networks, so for now we just look at this one.
So the transformer can be parallelized. So this one is parallelizable. And then we have another paradigm. So let's call it, by the way, this is called the target sequence. So this one, when you train a model, this is called the target sequence. You compare what is the actual output of the transformer with what you want the transformer to output, which is the target, and you compute the loss, and then you back propagate, based on the gradient, you back propagate to update the parameters of the model.
So the model is forced to learn to generate the target given the input. This is how we train models. We can take this one and replace the transformer with the recurrent neural network. And the problem with the recurrent neural network is that it's not parallelizable. At least not in its simple form.
Recently there are recurrent neural networks that also have like, by exploiting the parallel scan, they can actually be parallelized, but up to now they are not used in practice. So recurrent neural networks. So how do recurrent neural networks work? The transformer, I will not be talking about the attention mechanism, I suppose you already know that, but it's not even important, you just need to remember that the transformer is parallelizable, and the recurrent neural network in its basic form is not parallelizable.
So the recurrent neural networks work as follows. You feed, when you want to train them, or even when you want to inference them, you, because we are doing sequence modeling, and imagine we want to train a language model to learn this exact sentence here, so I like to eat pizza, this sentence here, the way we train them is as follows, so we take the first token, so the word I, we feed it to the recurrent neural network, so let's call it recurrent RNN, the recurrent neural network will produce an output, which is something we don't know, but we force it to learn the target, so the target we want is, well, when it sees I, it should predict like, right?
So it should predict like, based on what it actually produces and what is the target, we compute the loss, and we backpropagate. Then we take the last output, the recurrent neural network not only produces the output token, it also produces a state, which encapsulates information about all the input the model has seen so far.
This is called the hidden state of the recurrent neural network, or also the memory of the recurrent neural network. So we use this state to, again, feed another token to the RNN, so let me put the input below, actually, I think it's easier to visualize this way, so the input here was the word I.
Then this will produce a new hidden state, and this will, let's call it the hidden state at time step one. We feed it again to the recurrent neural network along with a new input, the next token is like, and the recurrent neural network will predict something, but we force it to learn to predict the word to, for the second time step.
How can it predict the word to, just given the word like, well, by leveraging the recurrent state from the previous time step, which encapsulates all the information about the word I. That's how it can predict to by, actually, the recurrent neural network is seeing I like, like directly as the input, and I indirectly because it's in its hidden state.
Now, we can do it also for the third token, so the to, hidden state to, this is also another recurrent neural network, and we feed it the token to, which will produce some output, and we don't know what it is, but we force it to learn to predict the word it.
And how can a model learn to predict it? Because it can see that the input is to, but it can also see the history of the input so far through the hidden state h2. Now what is the problem here? When we use the transformer, the transformer can, to predict a particular token, for example, the token pizza, it can leverage all the previous input because it's fed all at the same time to the transformer.
And this input during training is, are the keys and the values, and during inference, this is called the KVCache. So the transformer, in order to predict a particular token, can always see the entire sequence, and that's why it's parallelizable. So we feed the entire sequence to the transformer to predict each position, because we feed the entire sequence, the transformer can see the entire sequence, and it can compute the output at each position in parallel.
However, with the recurrent neural network, we cannot compute the output at each position in parallel, so we have to do it one step at a time. So it is not parallelizable. The advantage of the transformer is that it is parallelizable, so we can train massive models by just increasing the number of GPUs.
The problem of the recurrent neural network, because it's not parallelizable, we are limited because we have to do one, kind of a for loop to train them, so first we generate the first one, and then generate the second one, and then the third one, et cetera, et cetera, and then we backpropagate it.
So and then there are other problems, like the vanishing gradients, et cetera, et cetera, but that's not the main point today. So the problem of the recurrent neural networks are two, actually. First of all, it's not parallelizable, so this one is not parallelizable. And the second problem is that we have this recurrent state, this is called also the hidden state of the recurrent neural network, which is fixed in size.
So it can be as big as you want, it can be one megabyte, one gigabyte, whatever you like, but it's fixed. So once you have chosen your architecture, it's fixed. On the other hand, when we use a transformer model, the size of the input that the language model sees is growing, why?
Because when you use, for example, a prompt on chargeGPT, imagine you just feed the first two, imagine chargeGPT was trained exactly on this sentence here, and suppose you only feed the first token, I, what chargeGPT will do, it will predict the first token using only I, then it will take the word like, put it back into the input, feed it again all to the transformer, so I like, and the transformer will predict this next token.
And then it will take the word to put it back into the input, so I like to put all these three tokens in the language model, to the transformer model, and then it will be able to predict the transposition, et cetera, et cetera. So the hidden state, so the memory of the transformer, so the stuff that we feed to the transformer in order to predict the next token, is actually growing, and this is also another problem.
So when doing very long sequence modeling, we need two things. First of all, we would like to be able for the language model to use all the input it has seen so far, and that's something easily, that we can easily do with a transformer, however the problem is that with the transformer we have a growing memory, because we need to always put all the input in the transformer, all the tokens in the transformer, for it to see all the input.
Or if we have limited memory, we can use a recurrent neural network, but they are not parallelizable during training, and the second problem is that they have a fixed memory. The fixed memory also has another problem, because it's fixed, we cannot choose what is inside, so sometimes the language model may see some information, and sometimes it will not be able to see some information, it's like you take one person and you ask the person to memorize 3,000 books, I don't think the person will be able to do it, because our brain is fixed in size, and the same is the problem with recurrent neural networks.
Moreover, we have seen many architectures that are trying to improve this memorization capability of the recurrent neural networks, for example Mamba, in which they use a particular shape of the matrix called the hypometrix, that allows to memorize information in a more effective way, however in practice they don't work as well as we think.
Now in this paper, they say, imagine, ok first of all, before we can talk about this paper, how do we train language models? So how do we train language models is as follows, I mean, now let's talk about the architecture level. So usually we have some tokens, so let's say some, let's call them input, let me do it vertically, I think it's easier, so we have some input tokens, we convert them into embeddings, these embeddings are fed to a series of layers of transformers, so for example layer 1, layer 2, etc, etc, until they produce some output, these are called the logits, logits, now what happens with the transformer and with the recurrent neural network is as follows, with the transformer we have a growing memory, so we have this thing called the kvcache that contains all the past tokens, so the transformer can always leverage all the past tokens to predict its next token.
In the recurrent neural network, we have a past memory that compresses all the past tokens into a fixed size memory, that however has its own problem because sometimes the information is lost because it's fixed and you're trying to squeeze in a lot of stuff, so we cannot decide what is inside, we just hope that the network learns to keep the most important information and forgets about the less important information.
The problem is when we train a language model, we feed it a lot of data, so for example we train the language model on the entire wikipedia, we train it on the entire web, and a lot of books, so the model has seen kind of all the possible data that exists in this world, we hope that when we have, imagine we have a model, a hybrid model, so a transformer but with also a recurrent neural network, so imagine that this, suppose that this one here is an attention layer, so a transformer layer, let's call it attention, and this one is a recurrent neural network, and suppose that this is one of the new fancy recurrent networks that can be parallelized actually, there are new architectures actually that can be parallelized, but still the problem is that this information here, the RNN, will produce a memory that is fixed in size, so if you feed 1000 tokens, this one will contain, will output a memory that will be leveraged by the attention that will not be 1000 tokens, it will be less, because the goal of the RNN is to compress stuff into some fixed size memory that can be leveraged by the transformer model, which is this layer here, attention layer here, the attention layer here however is very good at leveraging the data it is being fed, but this data is not all the sequence because we have compressed it with the recurrent neural network, and we hope that the attention can leverage the information that was compressed by the recurrent neural network to do its job of predicting the next token, if we do it this way, so imagine we have this architecture which is a hybrid architecture of attention plus recurrent neural network, the problem with this architecture is that when you train it, because we do it with deep learning, we force the model to learn whatever target we have, it will be forced to learn this recurrent neural network to compress the information in such a way that the attention can use it, and the attention will be forced to extract whatever information is in this compressed state made by the recurrent neural network, this is good, so when you train it actually the loss decreases and you see that it performs quite well, however when you use it in practice, the problem that you feed to the model may not be something that the language model has seen in the past, so maybe we call this data out of distribution so the model may not know how to compress it well, what to keep and what to not keep, so in this case the recurrent neural network will fail at its task of compressing data, and because the data necessary to predict the next token was not compressed well, the attention layer will not be able to leverage this data to predict the next token, so at training we see that this hybrid architecture works really fine, but at test time, so when we use them, we actually see that they don't work quite well, and this is one of the reasons, so they learn to compress the data, they have seen very well, so they know, ok, if I have a long source code of Python, I should not concentrate on the, I don't know, some comments that maybe are repetitive, but I should concentrate on the code, or maybe I should not, when I see some C# code or C code, I should not concentrate on the, maybe the parentheses, because they are just, how to say, redundant, but I should concentrate on the expressions, etc, etc, so when it sees, so it actually learns to compress the information, but only the information that it has seen at training time, now finally we can talk about the paper, so the paper claim is, we have these models that need some kind of memory, in the transformer models we have this KVCache, the problem with this KVCache, it's growing, so the problem with the growing KVCache is that it requires a lot of memory, so actually most models are not constrained, the fact that we cannot have a context window in the current models, very big is because of the actually inference cost of this model, so they are really, really expensive to inference, because we need to keep the KVCache, and the KVCache is one for each layer, and the bigger models, they have a lot of layers, so you need to keep all the tokens for each of the layers of the model, for each token that you need to predict, so it's very expensive, and then the solution to have this infinite memory that keeps growing is to have a compressed memory, but this compressed memory only works very well at training time, so the claim is, can we have a memory module that is trained at test time, and that's why we are talking about learning to memorize at test time, that is effective at retrieval, because the goal of the memory is to retrieve the information that is salient, that is needed by the model, that is effective in retrieving the information that is being fed exactly at test time, not only the one that it has seen at the training time, this is the problem that we are trying to solve with titans, now the way they do it is as follows, so they say ok, imagine we have a module, imagine we have a module that we will call M, and this module let's think of it as a layer in a module, so ok let me draw actually, I think it's much easier if we can draw it, let's add a new paper, a new page, so ok, imagine we have a very long sequence, we have seen that with the recurrent neural network the job of the recurrent neural network is to compress this very long sequence so that the transformer can use it, let's do with titans now, how does it differ, and then we will check all the details, so we have this input, so let's go here again, so we have this input, we transform into embeddings, then we, I will draw a little differently and then later I will explain why, we have some, suppose we have a hybrid architecture again of transformer and recurrent layers, but I will not draw the recurrent layers, so this is the first layer of the, too big I think, let's call it L1, so the first layer with attention, the second layer with attention, the third layer with attention, and then we have the output which is the logits, ok, I think now it's more visible right, ok, so imagine we have another module in this architecture that we will call the memory module, let's call it neural memory because this is how they call it here, so let's call it neural memory, and I will draw it as external module neural memory, now I want to show you how it would work with the neural memory and then we check the detail on how it is actually trained, so the way we usually train modules, so imagine, ok let's take a step back, how would we train this module, we would feed it a sequence, imagine 1 million tokens, so imagine a very big sequence, so let's say 1 million tokens, you convert this sequence of tokens, 1 million tokens into embeddings, you run these embeddings in the neural networks, recurrent neural network which will compress this 1 million tokens maybe in let's say 1000 tokens because its goal is to compress stuff right, so the sequence that is fed to the attention because the goal of the problem of the attention is that it's quadratic, so having a smaller input results in a better computation, so we feed this 1000 compressed token to the attention and then we force it to predict the next token only leveraging this 1000 compressed token, so we feed 1 million token but we force the attention layer to predict the next token only leveraging much less information, so we hope that the recurrent neural network is good at choosing the right tokens to keep and discarding the one that it doesn't keep, actually ok it's not really a token pruning mechanism, it's a token compression mechanism but ok you can think of it as a token pruning, like it's being fed 1 million tokens and it just keeps the top 1000 that are the most important for predicting the next token, and this is done at training time, so we feed this 1 million token at training time, we compute the output, we know what should be the next token because at training time we know what is the next token, we force, we compute the loss with respect to what we think should be the next token and then we back propagate to update the parameters of the model and we keep doing it for all the sequences that we have, with the titans it would work differently, imagine you have 1 million token again and what you do is you do 2 steps, the first thing that we do, ok we have this input, we convert it into embeddings, the first thing we do is in the training loop, so imagine we are training this titans architecture, we first train this neural module to learn to memorize our 1 million tokens and then we ask it to retrieve the information necessary for predicting the next token and feed it to the attention layer, so this is, let's call it attention layer, so this is an attention layer, this is an attention layer and this is an attention layer, so look at the difference here, before we had an input, we predicted the output, we compute the loss and we back propagate and we update all the parameters of the model, here we will do something different, we have an input which is 1 million tokens, we convert them into embeddings, we train this module here which is separate and in the paper they refer to it as the inner loop of the training, we train this neural memory, and later we will see how we train it, with the sole purpose for this neural memory to learn everything about this data so that it can easily retrieve this data when we will need it, so we take this 1 million tokens, we convert them into embeddings, we train this neural memory in an inner loop, then we take this neural memory which has been trained to memorize this data and then we ask it to retrieve whatever information is important from whatever it has seen, and use it as input for the attention layers here, so that the attention layers can leverage this compressed memory to produce the output and predict the next token, this not only at the training but also at test time, so when we use the attention with the hybrid architectures, for example attention plus recurrent neural networks at test time, so at inference time, what we have is usually a prompt, imagine this prompt is huge because you are asking chargeBD for example to analyze the entire github repository of a very big repository, what will happen is that this 1 million token will be fed to the recurrent neural network which is fixed now, so we are using the model, so we are not changing its parameters anymore, the recurrent neural network his job is to compress data, so it will compress these tokens into a smaller sequence that we will fed to the attention layer and it will produce the output logits, however maybe the information that we are feeding to this recurrent neural networks are kind of out of distribution and the recurrent neural network has never seen something like this, and it will do probably a very bad job at compressing this data, so because it will do a very bad job at compressing this data, because it doesn't know what to keep and what not to keep, the attention layer will not be able to leverage the most important information and then it will not be able to predict the next token very well, so it will result in a bad output, and with titans even at test time, so even at inference time, we are actually training a model, and now I show you how, imagine now we have again a github repository, and it's very big and it results in 1 million tokens that we want the language model to analyze, we convert it into embeddings, then we take this 1 million tokens, we train on the fly this neural memory, whose job will be to just learn as much information as possible about this 1 million tokens, retrieve the most salient information, because the neural memory's job is to compress information, so now after we have trained it in this inner loop, we retrieve this information, we feed it to the attention layers, then the attention layers should be able to leverage the information retrieved by the neural memory, so with titans basically we don't just have a RNN, which is our memory that is trained at training time and then never trained again, and every time it sees something that it has never seen, it just goes crazy, we have a neural memory that can be trained at inference time, on the fly, with the sole purpose of compressing stuff, and because we are training it at inference time, we hope that it will perform better even on data it has never seen, now according to the benchmark they published in the paper, but this actually happens in all papers, so you never trust the benchmarks, it looks like it is doing a good job, now let's look at the details, so I want to remind you, the problem we are solving is long context modeling, long context modeling has one issue, which is with the transformer it is very expensive to inference for long context, with RNNs we have the problem that we train them on some data, but when you use them on something that they have never seen, they don't know how to compress and what to keep and what to not keep, so they go crazy, and because they go crazy they don't do this job very well, the attention layers cannot leverage this information, so they just result in very bad output, with the neural network memory we want to train on the fly a memory while inferencing the module, to just do the job of compressing stuff on whatever data it is fed, now we can look at the details, ok, here they do some preliminary, how to say, view of what is memory, what is linear attention, etc, etc, we don't care about that for now, they say ok, imagine we have a memory module that only has two operations, one is the write operation and one is the read operation, we want to write and read at inference time and also at training time to this memory, how do we train this memory, first of all this memory, neural memory, is a neural network by itself, meaning that you can think of it as an external neural network that is separated from the rest of the architecture, that will use this neural memory, so you need to think that you have like a transformer module that is leveraging this neural memory, now how to train this neural memory at inference time, because that's our problem, at training time we know how to do it, we just put the input, compute the output, back propagate, and voila, how to do that at inference time, it's what they see here, they say ok, imagine we have this memory, first of all, how we want to update its information, they want to update its information, ok, another step back, what we want this memory to do, we want this memory to learn, to extract information about whatever they should memorize, and for that they use a very particular law, which is kind of the reconstruction law, so imagine we have this memory, if we ask it to memorize, ok imagine we have an input sequence, let's call it x, this xt here, we project it with two linear projections called wk and wv, which are basically the same equivalent of the one that we use in the attention mechanism, how can this memory do its job very well, only if it learns to recreate the data it has seen, and this is the loss that you see here, this is just the L2 loss that you can see here, which basically it learns to memorize the mapping between a projection called key and a projection called v of the same data, so it kind of learns to recreate the same data, this is the job of the memory, so if I put some stuff I should be able to retrieve the same stuff, so I should be able to get as much as possible from the stuff that I put inside, how to train it, how to train it is they say ok I have this memory, I want to update this memory by using kind of a gradient descent, so how gradient descent works, imagine we have an neural network, the basic version of gradient descent work as follows, so we have a neural network with some parameters, let's call them theta, so let's say theta, the parameters theta at time i, so at the step i of the training, are updated with the previous parameters of the model, so at the previous time, minus a learning rate that we will call gamma, multiplied by the gradient of the loss with respect to the parameters of the model, the gradient tells us how we should change the parameters in order to maximize a loss, but we move against the direction of this gradient and that's why you see a sign minus, so we update the parameters in the direction opposite to the one that would maximize the loss, so we update the parameters to reduce the loss, and this is what we do here, we say we want to update our memory in such a way such that we minimize this loss here, which is the memorization loss, which is the reconstruction loss that we saw before, so a loss that tells if I ask the memory to retrieve some information, which is the key projection of the data, it should recreate this data, and this memory, in the paper, they model it as a linear layer, so a linear layer is just a matrix multiplication with a weight matrix, so this memory module, so m here, is nothing more than just a weight matrix of a linear layer, so we are modifying this weight matrix, so the neural memory is just a matrix, w, we are modifying this w in such a way that it reduces the reconstruction loss of the data, just the way we train a neural network, so we train the neural network with parameters to reduce a loss, and these parameters are calculated in such a way that they will result in the smallest loss possible, in the same way we are updating this w matrix, which will be our memory, in such a way that it will result in the minimum loss information possible, because that's the loss against which we are optimizing it, which is the reconstruction loss.
And they call it the surprise, so this gradient of the w matrix, which is our memory, with respect to the gradient of the loss, with respect to the w of this memory, they call it the surprise, because the bigger the loss, the bigger difficulty the model had in reconstructing its data, so it means that the model is surprised to see this data, so that's why they call it surprise.
If you have ever studied how optimizers work, you will remember that in deep learning we have this thing called momentum, so usually we don't update the parameters of the model naively like this, because, for example, sometimes we want to retain the... we want to... first of all, we don't want the...
okay, first of all, the loss is computed with mini-batch gradient descent, and it means that we don't compute it over all the input data set, but over instances of data, so like a small batch of data, and the direction of this gradient is actually stochastic, which means that it is not the true direction of the gradient, which means that it oscillates from what it...
it oscillates, so it is not indicating the true direction, imagine the true direction of the gradient is here, but if we train it on the first batch, maybe it's in this direction, maybe on the next batch it's in this direction, maybe on the next batch on this direction, etc, on average it will point to the correct direction of the gradient, but it will be noisy in each step, because we don't want to take steps too confidently in each step of training, we add this momentum term, and the momentum term basically kind of creates an exponentially moving average of all the gradients, so that we also keep some information about the past gradient that we have computed to smooth out the change of the weights, so that we don't take too much, so it's not like we don't weight each step in the same way, and the idea for them to introduce the surprise is as follows, they said ok, if I train my memory to recreate the data, then it can miss this new data after it sees some novel data, so maybe there is some new data that the model should memorize, but the gradient kind of disappears after a while, so the model will miss it, so in order to avoid this mechanism, they use the momentum, just like we do when doing model training, and they call it the past surprise, and this past surprise is nothing more than the term past gradient in the optimizers that we use, for example the Adamo optimizer, and then the momentary surprise, which is the gradient with respect to the current input, so rehearse what we have said so far, we have this memory, which is just a w matrix, that we want to optimize in such a way, so we want to change this w continuously, with every token that we receive, in such a way that it encapsulates all the information that it are in this input, and we can, how do we know it captures all the information in this input, because we ask it to minimize the loss, the reconstruction loss of the input, now the problem is we don't want to do this training of this novel model just during training, but we also want to do it during inference time, because if we do it only during training, what happens is that during inference time, every time it will see some new information that it has never seen, probably it will do a bad job at compressing, so it will not work, so how to do that at inference time, what we will do practically is as follows, so at inference time, imagine we have inputs, so the first input, let me write all these formulas actually so that we can refer to them, here, this one, and I paste it here, and then we also copy the loss, this one, ok, let's learn how it would work at inference time, imagine we have 1 million tokens, and ok, actually no, imagine we want to generate a lot of tokens and we start with one token only, so the prompt is only one token, what will happen is we have this one token, so let's call it one token, we feed it to the model as input, which will be converted into embeddings, which will be only one embedding, and we want to train our neural memory on this one single token, so it should learn to recreate this one single token, how we will do that in practice, we take the memory, first of all we take this one embedding and we project it into key and value by doing a matrix multiplication of this single token with a matrix called WK and another called WB, then we compute this, this is called the retrieval of the memory, and the retrieval, because the memory is modeled only as a W matrix of a linear layer, the retrieval of the information from this memory will just be W multiplied by the input, and the input actually they call it QT, so it's another projection of the input to the WQ matrix, so this KT comes from WK multiplied by X, and this VT comes from WB multiplied by X, and this QT comes from WQ multiplied by X, this W here is the W of the memory, so this is the memory parameters, and this is the memory, so it's the parameters of the memory, but it is also the memory itself, we want to update this W, ok, so how to do that, so we project the information of a single token with WV, we project it with WK, we compute this term here which is just W multiplied by this term here, we compute this loss here, and we compute its gradient, the gradient of this loss can be computed with the following formula, they actually specify, I can show you also how to derive it actually, so there is a formula here for the gradient, this is how we compute the gradient of the loss, how to compute this formula, well, how to derive it, let's talk about it, but ok, ok, one second, so they compute the gradient of this loss with respect to the parameters of the model, what are the parameters of the model?
W, ok, so they compute the gradient of the loss of this loss with respect to W, and then we need to update W, how to update W? we need to compute this ST term here, this ST term results in the pass surprise, but we don't have any pass surprise, so let's suppose this is 0 for now, multiplied by a learning rate, multiplied by this theta, theta t is the learning rate, multiplied by this gradient that we have computed, and then we update this W using this term ST, now we have updated our memory, then we retrieve information from this memory, how to retrieve information from this memory?
we just take this W and we multiply it by the, we take X, so our single token, we project it with another matrix called WQ, so that it becomes a QT, we multiply it by W, and now we retrieve information, this information is then sent to the first layer of the model, as compressed past information, and then to the second, to the third, etc, etc, to predict the output, the model will produce the first output token, then usually we put this output token back into the prompt, to generate the next token, here, because we are not talking about just a transformer model, we are talking about a hybrid architecture that has attention layers plus neural memory, we need to update our neural memory with this new incoming token, so this new incoming token will again be used to update the memory, the memory will be updated with the information of the new token, it will not be replaced with only this new token, so we hope that the new memory will encapsulate information about the first token that we fed before and the current token, what we will do practically, we will take this new token that was output by the model, we will project it through WV, and it will become VT, we will project it through WK and it will become KT, we compute this loss term, we compute the gradient of this loss, and we update our neural memory like before, but we have the past surprise this time, so, because we are not just, and we also have the previous memory, so we are updating this W, and hopefully this will contain information about the token number 2 and the token number 1 that we fed before, now as you can see, because we are training the neural memory at test time, because now we are inferencing the model, we hope that it will perform better than a neural memory that has only been trained at training time, because at each step of this update, the neural memory is actually trying to minimize the loss against this particular data, not only the data that it has seen during training, but only exactly on this particular data that is seen exactly in this moment, I know that I fed you with a lot of information, but I hope now it should be a little more clear on practically what it means to have an inner loop and an outer loop, so when we train the model, we update the parameters of this big model to leverage whatever the memory creates, and the memory does not learn to compress information only at training time, but also at inference time, exactly on the data that you feed it at inference time.
Now let's talk about the problems of this memory, so the problem of this memory is that every time, as you can see, every time we need to run a gradient descent on each single token, so this looks like it takes, you need to train the model, you have a very big list of tokens and you want to train it as fast as possible, but if you need to update the memory one token at a time, it's very slow, but fortunately in the paper they also propose an algorithm to parallelize this training, and this training can be parallelized actually not on the full sequence, but only chunk by chunk, which is still better than doing one token at a time, so imagine you have one million tokens, if we cannot parallelize it, it means ok, first take the first token, update the memory, then take the second token, update the memory, then third token, update the memory, so we need to do one million times this and we cannot exploit our GPUs because we have to do one operation at a time, what they propose in the paper is a hybrid algorithm, so it's not fully parallelizable on this entire sequence, but chunk by chunk, which is a good compromise, it means that if you choose, imagine you have one million tokens and you choose a chunk size of let's say 1,000, you can parallelize the first 1,000 tokens, then you take the next 1,000 token and you parallelize this one, so in total you will compute 1,000 steps and not one million steps, if you choose a chunk size of 1,000 over a sequence length of one million, they also say ok how to leverage this neural memory module, you can use it as a contextual memory, means that if you have a hybrid architecture in which you have attention and this neural memory, so the one like the one we draw before, what we can do is we take the sequence that is input by the user, because the neural memory, it's job of the neural memory is just to compress information, we retrieve whatever is in the memory, we append it to the sequence, prepend it to the sequence, along with some other persistent, ok we can even not talk about the persistent memory tokens because I believe they just overdid all this stuff, I mean this system could work even without the persistent memory tokens, so we take our sequence, we prepend whatever information is in the memory, we feed it to the attention module and we use the output of the attention to update the memory and to produce the output, so let's go to our architecture, in this case basically it would mean, imagine we have fed already 10 tokens to this memory and now we are trying to predict the 11th token, what it would mean is that I would take this 11th token, I would input, convert it into embeddings, I would retrieve whatever is inside the neural memory, so imagine the neural memory gives me, because it's job is compressing right, even if I fed it 10 tokens, it doesn't have to return me 10 tokens, it has to return me a compressed version of these 10 tokens, suppose the ratio is like, suppose that the compressed state is 5 tokens, so I would take these 5 tokens, prepend it to my single token, it would become 6 tokens, I fed it to the first attention layer, take the output of the attention, update it and combine it with the output of the attention to get the output of this layer and feed it to the next one, this is the neural memory as context usage, the other usage is memory as gate, which is this architecture here, so in this case I have our 11th token, don't think about persistent memory, I believe, it's just an overdoing, you don't have to use persistent memory to make this mechanism work, they take this 11th token, they put it in the memory, so now we update first the memory, and they also feed it to the attention, and then they combine the output of the neural memory, which contains 11 tokens, but when we retrieve it only gives us 5 tokens, and then the output of the attention, which we only fed 1 token, and it's combined to produce the output, or you can only use the memory as a module without any attention, which means that basically you skip all this part, so you take your input, which could be 1 token, 1 million token, whatever, you update the memory continuously, you take the compressed version of the memory, and you feed it directly to the linear layer that will produce the logits, this is what they refer to as memory as layer, honestly you can create 1 million variants of this architecture, the point is not how you use it, the point is how it works, so I want to punctualize how it works, so we are training a module at test time, which is different from what we do with recurrent neural networks, so recurrent neural networks are trained at training time, and their job is to compress data, but because they do very well the job of compressing the data they have seen, they may not function very well during inference, because they may see some data that they have never seen, however, by having a memory like this that you can train at inference time, and with an algorithm that is supposedly parallelizable, we can avoid hopefully this problem, because the only job of the memory is to be able to retrieve, so I actually like this paper because I believe that it's a novel idea that I didn't think about before, and I think it's ok, this is part of a bigger, actually ok, I've been researching a little bit about this area for a while, it's called test time training, but this particular architecture was a little bit innovative in this field, what else do we need to know to read this paper, I think now you should have the information to read this paper, because we have talked about how to update this memory, and what is this memory, this memory is just a linear layer, in the paper they also say that ok, this memory doesn't have to be just a linear layer, it can be a multilayer perceptron, so it can be for example two layers with an activation in between, and it will work in the same way, and the algorithm that they have devised that is parallelizable would work also with this multilayer memory, we didn't talk about persistent memory, but the persistent memory are just tokens that are prepended to the input, and they don't belong to the neural memory, they belong to the outer loop as they call it here, the outer loop is just this model here, and this is the inner loop, but ok, this system can work without persistent tokens, this is my claim, if you look at the benchmark, it looks like that compared to the other architectures that are like Mamba and the current neural networks, it performs better, if you check the average score over all these benchmarks, I believe ok, this is a promising area of research, I will probably be looking forward to the code which has not been released yet, but thank you guys for spending time with me, I hope I gave you enough at least intuitions into how it is happening, and I'm also really eager to look at the code, because I think the best way to learn about a new architecture is actually to look at the code, so have a good night!