back to index

Titans: Learning to Memorize at Test Time


Whisper Transcript | Transcript Only Page

00:00:00.000 | Okay guys, so today we are going to talk about this paper, "Titans - Learning to Memorize
00:00:06.040 | at Test Time".
00:00:07.040 | In this paper, we will be seeing first of all what is the problem we are trying to solve,
00:00:11.680 | and then what is the solution proposed here, and then we will comment on what are the pros
00:00:15.800 | and the cons.
00:00:18.000 | The way I like to talk about papers is actually to give you the tools to understand the paper
00:00:23.400 | yourself, so I don't like to just read the paper word by word, because that's something
00:00:26.680 | you can do by yourself, so I like to talk about what is the background knowledge that
00:00:29.760 | you need.
00:00:30.760 | We work a little bit on there, then we look at the problem, and then we see the solution.
00:00:34.360 | So let's talk about the problem.
00:00:35.920 | The problem we are talking about here is sequence modeling, and up to now, there are two main
00:00:40.800 | ways in deep learning to do sequence modeling.
00:00:43.120 | One is called the transformer, and the other is called the recurrent neural networks.
00:00:48.560 | There are also hybrid variants, which combine a little bit of the attention mechanism with
00:00:52.080 | the recurrent neural networks, etc.
00:00:54.480 | So let's talk about how the sequence modeling is done in these two ways.
00:00:59.040 | So open a new page, basically imagine you have a very long sequence, imagine let's talk
00:01:05.600 | about language modeling, which is something we are all familiar with.
00:01:09.180 | So imagine we want to train a language model.
00:01:11.280 | How does the training of a language model work?
00:01:13.640 | Usually we have a sequence, so we want to teach the language model to predict the next
00:01:17.840 | token.
00:01:19.040 | So we have a sequence of tokens, so let's say this is our sequence of tokens.
00:01:24.560 | So the first token is let's say "I", then the second token is "like", so I always pretend
00:01:32.600 | like one token is a word and one word is a token, which is not actually the case, but
00:01:36.740 | for simplicity we will think like it is "I like to eat", let's just say "pizza", okay.
00:01:48.600 | Imagine we want to train a language model to generate this exact phrase.
00:01:53.240 | What we do is basically we need a kind of model, which could be a transformer, but it
00:01:57.740 | could be also a recurrent neural network, and we force it to predict the next token.
00:02:02.120 | So the job of sequence modeling means that we have some input sequence, and we will call
00:02:07.920 | it the input, and we are trying to map it to some output, which is this one, this one,
00:02:19.680 | this one, this one.
00:02:23.360 | The language modeling that we do is usually, the model that we train usually is called
00:02:29.040 | an autoregressive language model, which means that it is, when it makes its prediction,
00:02:34.920 | it can use all the past words to choose, to predict what is the next word, which means
00:02:40.680 | that the model, imagine the model should be able to predict exactly this, to generate
00:02:47.280 | exactly this sentence, so the model, whenever it's fed, it's prompted with the word "I",
00:02:53.600 | it should output the word "like".
00:02:58.480 | Whenever it's prompted with "I like", it should predict the word "to".
00:03:03.160 | And whenever it's prompted with "I like to", it should predict "eat".
00:03:11.380 | Whenever it's, etc, so as you can see a pattern here, right, and whenever it's prompted with
00:03:17.320 | all the sentences, it should say "end of sentence", which means, okay, it's a special token that
00:03:22.640 | says, okay, I'm done with the generation process.
00:03:26.120 | This is how we train language models.
00:03:27.540 | So we take some sentence, which could be a document, which could be a web page, anything,
00:03:31.880 | we shift the words by one position, and we force the language model to predict the next
00:03:36.040 | token.
00:03:37.040 | And there are two principal models to do that.
00:03:39.600 | One is called the transformer.
00:03:40.720 | So let's say that in between here, we have something called the transformer.
00:03:47.800 | The transformer basically allows us to do this language modeling in such a way, through
00:03:52.360 | the attention mechanism, such a way that this language modeling, the output of the language
00:03:56.720 | model, which is used to compute the loss upon which the language model is trained, can be
00:04:01.280 | done in parallel.
00:04:04.280 | Basically, this is also the reason most language models today are transformer based, because
00:04:09.840 | we want to leverage the GPUs, so if we can parallelize some operations, it's better.
00:04:15.700 | On the other hand, we also have recurrent neural networks.
00:04:19.640 | Later we will see what are the problems with transformer and recurrent neural networks,
00:04:23.680 | so for now we just look at this one.
00:04:25.340 | So the transformer can be parallelized.
00:04:27.920 | So this one is parallelizable.
00:04:36.280 | And then we have another paradigm.
00:04:37.920 | So let's call it, by the way, this is called the target sequence.
00:04:44.340 | So this one, when you train a model, this is called the target sequence.
00:04:48.640 | You compare what is the actual output of the transformer with what you want the transformer
00:04:53.320 | to output, which is the target, and you compute the loss, and then you back propagate, based
00:05:00.560 | on the gradient, you back propagate to update the parameters of the model.
00:05:03.720 | So the model is forced to learn to generate the target given the input.
00:05:08.480 | This is how we train models.
00:05:12.100 | We can take this one and replace the transformer with the recurrent neural network.
00:05:17.600 | And the problem with the recurrent neural network is that it's not parallelizable.
00:05:22.360 | At least not in its simple form.
00:05:25.000 | Recently there are recurrent neural networks that also have like, by exploiting the parallel
00:05:31.520 | scan, they can actually be parallelized, but up to now they are not used in practice.
00:05:39.640 | So recurrent neural networks.
00:05:44.780 | So how do recurrent neural networks work?
00:05:47.480 | The transformer, I will not be talking about the attention mechanism, I suppose you already
00:05:51.480 | know that, but it's not even important, you just need to remember that the transformer
00:05:55.040 | is parallelizable, and the recurrent neural network in its basic form is not parallelizable.
00:06:00.040 | So the recurrent neural networks work as follows.
00:06:03.440 | You feed, when you want to train them, or even when you want to inference them, you,
00:06:08.120 | because we are doing sequence modeling, and imagine we want to train a language model
00:06:11.320 | to learn this exact sentence here, so I like to eat pizza, this sentence here, the way
00:06:18.600 | we train them is as follows, so we take the first token, so the word I, we feed it to
00:06:25.600 | the recurrent neural network, so let's call it recurrent RNN, the recurrent neural network
00:06:36.800 | will produce an output, which is something we don't know, but we force it to learn the
00:06:44.040 | target, so the target we want is, well, when it sees I, it should predict like, right?
00:06:52.640 | So it should predict like, based on what it actually produces and what is the target,
00:06:59.000 | we compute the loss, and we backpropagate.
00:07:01.640 | Then we take the last output, the recurrent neural network not only produces the output
00:07:08.160 | token, it also produces a state, which encapsulates information about all the input the model
00:07:14.920 | has seen so far.
00:07:16.040 | This is called the hidden state of the recurrent neural network, or also the memory of the
00:07:19.720 | recurrent neural network.
00:07:23.980 | So we use this state to, again, feed another token to the RNN, so let me put the input
00:07:32.120 | below, actually, I think it's easier to visualize this way, so the input here was the word I.
00:07:38.680 | Then this will produce a new hidden state, and this will, let's call it the hidden state
00:07:43.440 | at time step one.
00:07:45.440 | We feed it again to the recurrent neural network along with a new input, the next token is
00:07:52.440 | like, and the recurrent neural network will predict something, but we force it to learn
00:07:58.240 | to predict the word to, for the second time step.
00:08:04.260 | How can it predict the word to, just given the word like, well, by leveraging the recurrent
00:08:10.000 | state from the previous time step, which encapsulates all the information about the word I.
00:08:14.480 | That's how it can predict to by, actually, the recurrent neural network is seeing I like,
00:08:21.320 | like directly as the input, and I indirectly because it's in its hidden state.
00:08:27.480 | Now, we can do it also for the third token, so the to, hidden state to, this is also another
00:08:36.120 | recurrent neural network, and we feed it the token to, which will produce some output,
00:08:40.800 | and we don't know what it is, but we force it to learn to predict the word it.
00:08:48.400 | And how can a model learn to predict it?
00:08:51.000 | Because it can see that the input is to, but it can also see the history of the input so
00:08:55.080 | far through the hidden state h2.
00:08:58.760 | Now what is the problem here?
00:09:00.480 | When we use the transformer, the transformer can, to predict a particular token, for example,
00:09:06.280 | the token pizza, it can leverage all the previous input because it's fed all at the same time
00:09:13.200 | to the transformer.
00:09:15.360 | And this input during training is, are the keys and the values, and during inference,
00:09:21.160 | this is called the KVCache.
00:09:25.720 | So the transformer, in order to predict a particular token, can always see the entire
00:09:30.200 | sequence, and that's why it's parallelizable.
00:09:32.240 | So we feed the entire sequence to the transformer to predict each position, because we feed
00:09:36.680 | the entire sequence, the transformer can see the entire sequence, and it can compute the
00:09:41.640 | output at each position in parallel.
00:09:44.240 | However, with the recurrent neural network, we cannot compute the output at each position
00:09:49.960 | in parallel, so we have to do it one step at a time.
00:09:53.360 | So it is not parallelizable.
00:09:56.480 | The advantage of the transformer is that it is parallelizable, so we can train massive
00:10:00.600 | models by just increasing the number of GPUs.
00:10:03.520 | The problem of the recurrent neural network, because it's not parallelizable, we are limited
00:10:07.360 | because we have to do one, kind of a for loop to train them, so first we generate the first
00:10:12.760 | one, and then generate the second one, and then the third one, et cetera, et cetera,
00:10:16.180 | and then we backpropagate it.
00:10:18.920 | So and then there are other problems, like the vanishing gradients, et cetera, et cetera,
00:10:22.880 | but that's not the main point today.
00:10:25.600 | So the problem of the recurrent neural networks are two, actually.
00:10:29.160 | First of all, it's not parallelizable, so this one is not parallelizable.
00:10:34.400 | And the second problem is that we have this recurrent state, this is called also the hidden
00:10:40.240 | state of the recurrent neural network, which is fixed in size.
00:10:45.320 | So it can be as big as you want, it can be one megabyte, one gigabyte, whatever you like,
00:10:50.120 | but it's fixed.
00:10:51.120 | So once you have chosen your architecture, it's fixed.
00:10:53.680 | On the other hand, when we use a transformer model, the size of the input that the language
00:11:00.140 | model sees is growing, why?
00:11:04.000 | Because when you use, for example, a prompt on chargeGPT, imagine you just feed the first
00:11:08.600 | two, imagine chargeGPT was trained exactly on this sentence here, and suppose you only
00:11:13.680 | feed the first token, I, what chargeGPT will do, it will predict the first token using
00:11:19.200 | only I, then it will take the word like, put it back into the input, feed it again all
00:11:24.000 | to the transformer, so I like, and the transformer will predict this next token.
00:11:29.280 | And then it will take the word to put it back into the input, so I like to put all these
00:11:35.000 | three tokens in the language model, to the transformer model, and then it will be able
00:11:39.000 | to predict the transposition, et cetera, et cetera.
00:11:41.440 | So the hidden state, so the memory of the transformer, so the stuff that we feed to
00:11:46.200 | the transformer in order to predict the next token, is actually growing, and this is also
00:11:52.240 | another problem.
00:11:53.240 | So when doing very long sequence modeling, we need two things.
00:11:57.360 | First of all, we would like to be able for the language model to use all the input it
00:12:03.780 | has seen so far, and that's something easily, that we can easily do with a transformer,
00:12:09.040 | however the problem is that with the transformer we have a growing memory, because we need
00:12:12.480 | to always put all the input in the transformer, all the tokens in the transformer, for it
00:12:17.200 | to see all the input.
00:12:19.720 | Or if we have limited memory, we can use a recurrent neural network, but they are not
00:12:24.200 | parallelizable during training, and the second problem is that they have a fixed memory.
00:12:29.720 | The fixed memory also has another problem, because it's fixed, we cannot choose what
00:12:34.480 | is inside, so sometimes the language model may see some information, and sometimes it
00:12:39.720 | will not be able to see some information, it's like you take one person and you ask
00:12:46.480 | the person to memorize 3,000 books, I don't think the person will be able to do it, because
00:12:51.520 | our brain is fixed in size, and the same is the problem with recurrent neural networks.
00:12:58.480 | Moreover, we have seen many architectures that are trying to improve this memorization
00:13:06.840 | capability of the recurrent neural networks, for example Mamba, in which they use a particular
00:13:11.840 | shape of the matrix called the hypometrix, that allows to memorize information in a more
00:13:17.440 | effective way, however in practice they don't work as well as we think.
00:13:23.560 | Now in this paper, they say, imagine, ok first of all, before we can talk about this paper,
00:13:31.000 | how do we train language models?
00:13:32.560 | So how do we train language models is as follows, I mean, now let's talk about the architecture
00:13:37.040 | level.
00:13:39.140 | So usually we have some tokens, so let's say some, let's call them input, let me do it
00:13:46.680 | vertically, I think it's easier, so we have some input tokens, we convert them into embeddings,
00:13:59.860 | these embeddings are fed to a series of layers of transformers, so for example layer 1, layer
00:14:08.240 | 2, etc, etc, until they produce some output, these are called the logits, logits, now what
00:14:18.400 | happens with the transformer and with the recurrent neural network is as follows, with
00:14:23.520 | the transformer we have a growing memory, so we have this thing called the kvcache that
00:14:28.720 | contains all the past tokens, so the transformer can always leverage all the past tokens to
00:14:33.980 | predict its next token.
00:14:36.560 | In the recurrent neural network, we have a past memory that compresses all the past tokens
00:14:42.840 | into a fixed size memory, that however has its own problem because sometimes the information
00:14:49.360 | is lost because it's fixed and you're trying to squeeze in a lot of stuff, so we cannot
00:14:54.040 | decide what is inside, we just hope that the network learns to keep the most important
00:15:00.520 | information and forgets about the less important information.
00:15:07.560 | The problem is when we train a language model, we feed it a lot of data, so for example we
00:15:13.360 | train the language model on the entire wikipedia, we train it on the entire web, and a lot of
00:15:18.960 | books, so the model has seen kind of all the possible data that exists in this world, we
00:15:28.600 | hope that when we have, imagine we have a model, a hybrid model, so a transformer but
00:15:35.600 | with also a recurrent neural network, so imagine that this, suppose that this one here is an
00:15:42.400 | attention layer, so a transformer layer, let's call it attention, and this one is a recurrent
00:15:48.580 | neural network, and suppose that this is one of the new fancy recurrent networks that can
00:15:53.480 | be parallelized actually, there are new architectures actually that can be parallelized, but still
00:16:00.480 | the problem is that this information here, the RNN, will produce a memory that is fixed
00:16:06.280 | in size, so if you feed 1000 tokens, this one will contain, will output a memory that
00:16:12.520 | will be leveraged by the attention that will not be 1000 tokens, it will be less, because
00:16:18.600 | the goal of the RNN is to compress stuff into some fixed size memory that can be leveraged
00:16:26.140 | by the transformer model, which is this layer here, attention layer here, the attention
00:16:31.720 | layer here however is very good at leveraging the data it is being fed, but this data is
00:16:37.220 | not all the sequence because we have compressed it with the recurrent neural network, and
00:16:43.760 | we hope that the attention can leverage the information that was compressed by the recurrent
00:16:48.160 | neural network to do its job of predicting the next token, if we do it this way, so imagine
00:16:55.040 | we have this architecture which is a hybrid architecture of attention plus recurrent neural
00:16:59.240 | network, the problem with this architecture is that when you train it, because we do it
00:17:06.840 | with deep learning, we force the model to learn whatever target we have, it will be
00:17:12.000 | forced to learn this recurrent neural network to compress the information in such a way
00:17:17.920 | that the attention can use it, and the attention will be forced to extract whatever information
00:17:23.200 | is in this compressed state made by the recurrent neural network, this is good, so when you train
00:17:30.080 | it actually the loss decreases and you see that it performs quite well, however when
00:17:34.520 | you use it in practice, the problem that you feed to the model may not be something that
00:17:40.280 | the language model has seen in the past, so maybe we call this data out of distribution
00:17:47.720 | so the model may not know how to compress it well, what to keep and what to not keep,
00:17:52.940 | so in this case the recurrent neural network will fail at its task of compressing data,
00:17:58.640 | and because the data necessary to predict the next token was not compressed well, the
00:18:03.600 | attention layer will not be able to leverage this data to predict the next token, so at
00:18:08.400 | training we see that this hybrid architecture works really fine, but at test time, so when
00:18:13.480 | we use them, we actually see that they don't work quite well, and this is one of the reasons,
00:18:18.480 | so they learn to compress the data, they have seen very well, so they know, ok, if I have
00:18:24.040 | a long source code of Python, I should not concentrate on the, I don't know, some comments
00:18:33.400 | that maybe are repetitive, but I should concentrate on the code, or maybe I should not, when I
00:18:37.480 | see some C# code or C code, I should not concentrate on the, maybe the parentheses, because they
00:18:45.780 | are just, how to say, redundant, but I should concentrate on the expressions, etc, etc,
00:18:53.420 | so when it sees, so it actually learns to compress the information, but only the information
00:19:02.780 | that it has seen at training time, now finally we can talk about the paper, so the paper
00:19:09.300 | claim is, we have these models that need some kind of memory, in the transformer models
00:19:16.100 | we have this KVCache, the problem with this KVCache, it's growing, so the problem with
00:19:21.700 | the growing KVCache is that it requires a lot of memory, so actually most models are
00:19:28.540 | not constrained, the fact that we cannot have a context window in the current models, very
00:19:35.660 | big is because of the actually inference cost of this model, so they are really, really
00:19:41.900 | expensive to inference, because we need to keep the KVCache, and the KVCache is one for
00:19:47.860 | each layer, and the bigger models, they have a lot of layers, so you need to keep all the
00:19:51.460 | tokens for each of the layers of the model, for each token that you need to predict, so
00:19:58.060 | it's very expensive, and then the solution to have this infinite memory that keeps growing
00:20:03.300 | is to have a compressed memory, but this compressed memory only works very well at training time,
00:20:08.100 | so the claim is, can we have a memory module that is trained at test time, and that's why
00:20:15.860 | we are talking about learning to memorize at test time, that is effective at retrieval,
00:20:22.900 | because the goal of the memory is to retrieve the information that is salient, that is needed
00:20:26.620 | by the model, that is effective in retrieving the information that is being fed exactly
00:20:33.780 | at test time, not only the one that it has seen at the training time, this is the problem
00:20:38.900 | that we are trying to solve with titans, now the way they do it is as follows, so they
00:20:45.500 | say ok, imagine we have a module, imagine we have a module that we will call M, and
00:20:53.180 | this module let's think of it as a layer in a module, so ok let me draw actually, I think
00:21:00.540 | it's much easier if we can draw it, let's add a new paper, a new page, so ok, imagine
00:21:09.540 | we have a very long sequence, we have seen that with the recurrent neural network the
00:21:13.460 | job of the recurrent neural network is to compress this very long sequence so that the
00:21:17.500 | transformer can use it, let's do with titans now, how does it differ, and then we will
00:21:25.100 | check all the details, so we have this input, so let's go here again, so we have this input,
00:21:34.740 | we transform into embeddings, then we, I will draw a little differently and then later I
00:21:43.380 | will explain why, we have some, suppose we have a hybrid architecture again of transformer
00:21:48.460 | and recurrent layers, but I will not draw the recurrent layers, so this is the first
00:21:53.100 | layer of the, too big I think, let's call it L1, so the first layer with attention,
00:22:00.020 | the second layer with attention, the third layer with attention, and then we have the
00:22:09.220 | output which is the logits, ok, I think now it's more visible right, ok, so imagine we
00:22:19.140 | have another module in this architecture that we will call the memory module, let's call
00:22:26.380 | it neural memory because this is how they call it here, so let's call it neural memory,
00:22:31.300 | and I will draw it as external module neural memory, now I want to show you how it would
00:22:45.520 | work with the neural memory and then we check the detail on how it is actually trained,
00:22:51.580 | so the way we usually train modules, so imagine, ok let's take a step back, how would we train
00:22:57.700 | this module, we would feed it a sequence, imagine 1 million tokens, so imagine a very
00:23:04.020 | big sequence, so let's say 1 million tokens, you convert this sequence of tokens, 1 million
00:23:09.780 | tokens into embeddings, you run these embeddings in the neural networks, recurrent neural network
00:23:17.620 | which will compress this 1 million tokens maybe in let's say 1000 tokens because its
00:23:22.060 | goal is to compress stuff right, so the sequence that is fed to the attention because the goal
00:23:27.860 | of the problem of the attention is that it's quadratic, so having a smaller input results
00:23:35.300 | in a better computation, so we feed this 1000 compressed token to the attention and then
00:23:41.900 | we force it to predict the next token only leveraging this 1000 compressed token, so
00:23:48.660 | we feed 1 million token but we force the attention layer to predict the next token only leveraging
00:23:55.060 | much less information, so we hope that the recurrent neural network is good at choosing
00:23:59.860 | the right tokens to keep and discarding the one that it doesn't keep, actually ok it's
00:24:05.900 | not really a token pruning mechanism, it's a token compression mechanism but ok you can
00:24:13.260 | think of it as a token pruning, like it's being fed 1 million tokens and it just keeps
00:24:17.580 | the top 1000 that are the most important for predicting the next token, and this is done
00:24:26.900 | at training time, so we feed this 1 million token at training time, we compute the output,
00:24:33.320 | we know what should be the next token because at training time we know what is the next
00:24:36.540 | token, we force, we compute the loss with respect to what we think should be the next
00:24:41.660 | token and then we back propagate to update the parameters of the model and we keep doing
00:24:45.940 | it for all the sequences that we have, with the titans it would work differently, imagine
00:24:52.580 | you have 1 million token again and what you do is you do 2 steps, the first thing that
00:25:02.260 | we do, ok we have this input, we convert it into embeddings, the first thing we do is
00:25:09.260 | in the training loop, so imagine we are training this titans architecture, we first train this
00:25:15.180 | neural module to learn to memorize our 1 million tokens and then we ask it to retrieve the
00:25:25.000 | information necessary for predicting the next token and feed it to the attention layer,
00:25:30.380 | so this is, let's call it attention layer, so this is an attention layer, this is an
00:25:36.780 | attention layer and this is an attention layer, so look at the difference here, before we
00:25:44.060 | had an input, we predicted the output, we compute the loss and we back propagate and
00:25:50.660 | we update all the parameters of the model, here we will do something different, we have
00:25:54.980 | an input which is 1 million tokens, we convert them into embeddings, we train this module
00:26:02.380 | here which is separate and in the paper they refer to it as the inner loop of the training,
00:26:09.340 | we train this neural memory, and later we will see how we train it, with the sole purpose
00:26:15.180 | for this neural memory to learn everything about this data so that it can easily retrieve
00:26:23.260 | this data when we will need it, so we take this 1 million tokens, we convert them into
00:26:28.660 | embeddings, we train this neural memory in an inner loop, then we take this neural memory
00:26:37.620 | which has been trained to memorize this data and then we ask it to retrieve whatever information
00:26:45.620 | is important from whatever it has seen, and use it as input for the attention layers here,
00:26:52.020 | so that the attention layers can leverage this compressed memory to produce the output
00:26:57.740 | and predict the next token, this not only at the training but also at test time, so
00:27:05.100 | when we use the attention with the hybrid architectures, for example attention plus
00:27:10.040 | recurrent neural networks at test time, so at inference time, what we have is usually
00:27:14.380 | a prompt, imagine this prompt is huge because you are asking chargeBD for example to analyze
00:27:19.420 | the entire github repository of a very big repository, what will happen is that this
00:27:26.820 | 1 million token will be fed to the recurrent neural network which is fixed now, so we are
00:27:32.020 | using the model, so we are not changing its parameters anymore, the recurrent neural network
00:27:38.460 | his job is to compress data, so it will compress these tokens into a smaller sequence that
00:27:43.740 | we will fed to the attention layer and it will produce the output logits, however maybe
00:27:50.500 | the information that we are feeding to this recurrent neural networks are kind of out
00:27:55.260 | of distribution and the recurrent neural network has never seen something like this, and it
00:27:59.620 | will do probably a very bad job at compressing this data, so because it will do a very bad
00:28:04.820 | job at compressing this data, because it doesn't know what to keep and what not to keep, the
00:28:09.220 | attention layer will not be able to leverage the most important information and then it
00:28:13.620 | will not be able to predict the next token very well, so it will result in a bad output,
00:28:20.220 | and with titans even at test time, so even at inference time, we are actually training
00:28:27.540 | a model, and now I show you how, imagine now we have again a github repository, and it's
00:28:34.460 | very big and it results in 1 million tokens that we want the language model to analyze,
00:28:39.660 | we convert it into embeddings, then we take this 1 million tokens, we train on the fly
00:28:46.980 | this neural memory, whose job will be to just learn as much information as possible about
00:28:53.660 | this 1 million tokens, retrieve the most salient information, because the neural memory's job
00:28:58.860 | is to compress information, so now after we have trained it in this inner loop, we retrieve
00:29:04.260 | this information, we feed it to the attention layers, then the attention layers should be
00:29:09.260 | able to leverage the information retrieved by the neural memory, so with titans basically
00:29:18.940 | we don't just have a RNN, which is our memory that is trained at training time and then
00:29:27.660 | never trained again, and every time it sees something that it has never seen, it just
00:29:31.340 | goes crazy, we have a neural memory that can be trained at inference time, on the fly,
00:29:39.860 | with the sole purpose of compressing stuff, and because we are training it at inference
00:29:44.740 | time, we hope that it will perform better even on data it has never seen, now according
00:29:52.620 | to the benchmark they published in the paper, but this actually happens in all papers, so
00:29:56.420 | you never trust the benchmarks, it looks like it is doing a good job, now let's look at
00:30:02.500 | the details, so I want to remind you, the problem we are solving is long context modeling,
00:30:07.780 | long context modeling has one issue, which is with the transformer it is very expensive
00:30:12.860 | to inference for long context, with RNNs we have the problem that we train them on some
00:30:19.780 | data, but when you use them on something that they have never seen, they don't know how
00:30:23.260 | to compress and what to keep and what to not keep, so they go crazy, and because they go
00:30:28.400 | crazy they don't do this job very well, the attention layers cannot leverage this information,
00:30:33.820 | so they just result in very bad output, with the neural network memory we want to train
00:30:39.660 | on the fly a memory while inferencing the module, to just do the job of compressing
00:30:46.220 | stuff on whatever data it is fed, now we can look at the details, ok, here they do some
00:30:54.860 | preliminary, how to say, view of what is memory, what is linear attention, etc, etc, we don't
00:31:02.220 | care about that for now, they say ok, imagine we have a memory module that only has two
00:31:07.860 | operations, one is the write operation and one is the read operation, we want to write
00:31:14.700 | and read at inference time and also at training time to this memory, how do we train this
00:31:21.820 | memory, first of all this memory, neural memory, is a neural network by itself, meaning that
00:31:28.920 | you can think of it as an external neural network that is separated from the rest of
00:31:35.540 | the architecture, that will use this neural memory, so you need to think that you have
00:31:44.140 | like a transformer module that is leveraging this neural memory, now how to train this
00:31:51.100 | neural memory at inference time, because that's our problem, at training time we know how
00:31:55.540 | to do it, we just put the input, compute the output, back propagate, and voila, how to
00:32:02.460 | do that at inference time, it's what they see here, they say ok, imagine we have this
00:32:07.580 | memory, first of all, how we want to update its information, they want to update its information,
00:32:15.820 | ok, another step back, what we want this memory to do, we want this memory to learn, to extract
00:32:24.220 | information about whatever they should memorize, and for that they use a very particular law,
00:32:30.980 | which is kind of the reconstruction law, so imagine we have this memory, if we ask it
00:32:37.780 | to memorize, ok imagine we have an input sequence, let's call it x, this xt here, we project
00:32:48.380 | it with two linear projections called wk and wv, which are basically the same equivalent
00:32:54.040 | of the one that we use in the attention mechanism, how can this memory do its job very well,
00:33:04.300 | only if it learns to recreate the data it has seen, and this is the loss that you see
00:33:12.740 | here, this is just the L2 loss that you can see here, which basically it learns to memorize
00:33:20.180 | the mapping between a projection called key and a projection called v of the same data,
00:33:27.320 | so it kind of learns to recreate the same data, this is the job of the memory, so if
00:33:34.460 | I put some stuff I should be able to retrieve the same stuff, so I should be able to get
00:33:39.660 | as much as possible from the stuff that I put inside, how to train it, how to train
00:33:47.420 | it is they say ok I have this memory, I want to update this memory by using kind of a gradient
00:33:56.260 | descent, so how gradient descent works, imagine we have an neural network, the basic version
00:34:03.780 | of gradient descent work as follows, so we have a neural network with some parameters,
00:34:11.500 | let's call them theta, so let's say theta, the parameters theta at time i, so at the
00:34:20.000 | step i of the training, are updated with the previous parameters of the model, so at the
00:34:26.940 | previous time, minus a learning rate that we will call gamma, multiplied by the gradient
00:34:34.740 | of the loss with respect to the parameters of the model, the gradient tells us how we
00:34:43.180 | should change the parameters in order to maximize a loss, but we move against the direction
00:34:51.940 | of this gradient and that's why you see a sign minus, so we update the parameters in
00:34:58.420 | the direction opposite to the one that would maximize the loss, so we update the parameters
00:35:05.300 | to reduce the loss, and this is what we do here, we say we want to update our memory
00:35:12.260 | in such a way such that we minimize this loss here, which is the memorization loss, which
00:35:19.940 | is the reconstruction loss that we saw before, so a loss that tells if I ask the memory to
00:35:26.580 | retrieve some information, which is the key projection of the data, it should recreate
00:35:32.540 | this data, and this memory, in the paper, they model it as a linear layer, so a linear
00:35:43.660 | layer is just a matrix multiplication with a weight matrix, so this memory module, so
00:35:48.260 | m here, is nothing more than just a weight matrix of a linear layer, so we are modifying
00:35:57.820 | this weight matrix, so the neural memory is just a matrix, w, we are modifying this w
00:36:06.900 | in such a way that it reduces the reconstruction loss of the data, just the way we train a
00:36:17.540 | neural network, so we train the neural network with parameters to reduce a loss, and these
00:36:23.620 | parameters are calculated in such a way that they will result in the smallest loss possible,
00:36:29.980 | in the same way we are updating this w matrix, which will be our memory, in such a way that
00:36:37.220 | it will result in the minimum loss information possible, because that's the loss against
00:36:42.780 | which we are optimizing it, which is the reconstruction loss.
00:36:48.860 | And they call it the surprise, so this gradient of the w matrix, which is our memory, with
00:36:56.940 | respect to the gradient of the loss, with respect to the w of this memory, they call
00:37:03.820 | it the surprise, because the bigger the loss, the bigger difficulty the model had in reconstructing
00:37:12.380 | its data, so it means that the model is surprised to see this data, so that's why they call
00:37:19.900 | it surprise.
00:37:22.460 | If you have ever studied how optimizers work, you will remember that in deep learning we
00:37:30.420 | have this thing called momentum, so usually we don't update the parameters of the model
00:37:36.060 | naively like this, because, for example, sometimes we want to retain the... we want to... first
00:37:44.420 | of all, we don't want the... okay, first of all, the loss is computed with mini-batch
00:37:49.780 | gradient descent, and it means that we don't compute it over all the input data set, but
00:37:58.420 | over instances of data, so like a small batch of data, and the direction of this gradient
00:38:05.540 | is actually stochastic, which means that it is not the true direction of the gradient,
00:38:11.380 | which means that it oscillates from what it... it oscillates, so it is not indicating the
00:38:18.940 | true direction, imagine the true direction of the gradient is here, but if we train it
00:38:22.840 | on the first batch, maybe it's in this direction, maybe on the next batch it's in this direction,
00:38:28.780 | maybe on the next batch on this direction, etc, on average it will point to the correct
00:38:32.500 | direction of the gradient, but it will be noisy in each step, because we don't want
00:38:37.060 | to take steps too confidently in each step of training, we add this momentum term, and
00:38:45.240 | the momentum term basically kind of creates an exponentially moving average of all the
00:38:51.860 | gradients, so that we also keep some information about the past gradient that we have computed
00:38:58.140 | to smooth out the change of the weights, so that we don't take too much, so it's not like
00:39:04.700 | we don't weight each step in the same way, and the idea for them to introduce the surprise
00:39:13.740 | is as follows, they said ok, if I train my memory to recreate the data, then it can miss
00:39:28.100 | this new data after it sees some novel data, so maybe there is some new data that the model
00:39:37.460 | should memorize, but the gradient kind of disappears after a while, so the model will
00:39:42.980 | miss it, so in order to avoid this mechanism, they use the momentum, just like we do when
00:39:49.060 | doing model training, and they call it the past surprise, and this past surprise is nothing
00:39:57.020 | more than the term past gradient in the optimizers that we use, for example the Adamo optimizer,
00:40:08.420 | and then the momentary surprise, which is the gradient with respect to the current input,
00:40:13.420 | so rehearse what we have said so far, we have this memory, which is just a w matrix, that
00:40:19.500 | we want to optimize in such a way, so we want to change this w continuously, with every
00:40:28.700 | token that we receive, in such a way that it encapsulates all the information that it
00:40:34.380 | are in this input, and we can, how do we know it captures all the information in this input,
00:40:42.780 | because we ask it to minimize the loss, the reconstruction loss of the input, now the
00:40:51.180 | problem is we don't want to do this training of this novel model just during training,
00:40:58.420 | but we also want to do it during inference time, because if we do it only during training,
00:41:02.860 | what happens is that during inference time, every time it will see some new information
00:41:07.240 | that it has never seen, probably it will do a bad job at compressing, so it will not work,
00:41:11.420 | so how to do that at inference time, what we will do practically is as follows, so at
00:41:18.180 | inference time, imagine we have inputs, so the first input, let me write all these formulas
00:41:26.100 | actually so that we can refer to them, here, this one, and I paste it here, and then we
00:41:39.460 | also copy the loss, this one, ok, let's learn how it would work at inference time, imagine
00:41:51.540 | we have 1 million tokens, and ok, actually no, imagine we want to generate a lot of tokens
00:41:58.140 | and we start with one token only, so the prompt is only one token, what will happen is we
00:42:03.100 | have this one token, so let's call it one token, we feed it to the model as input, which
00:42:14.460 | will be converted into embeddings, which will be only one embedding, and we want to train
00:42:19.580 | our neural memory on this one single token, so it should learn to recreate this one single
00:42:25.580 | token, how we will do that in practice, we take the memory, first of all we take this
00:42:34.120 | one embedding and we project it into key and value by doing a matrix multiplication of
00:42:39.500 | this single token with a matrix called WK and another called WB, then we compute this,
00:42:50.220 | this is called the retrieval of the memory, and the retrieval, because the memory is modeled
00:42:55.060 | only as a W matrix of a linear layer, the retrieval of the information from this memory
00:43:00.740 | will just be W multiplied by the input, and the input actually they call it QT, so it's
00:43:07.900 | another projection of the input to the WQ matrix, so this KT comes from WK multiplied
00:43:15.700 | by X, and this VT comes from WB multiplied by X, and this QT comes from WQ multiplied
00:43:26.700 | by X, this W here is the W of the memory, so this is the memory parameters, and this
00:43:33.780 | is the memory, so it's the parameters of the memory, but it is also the memory itself,
00:43:39.340 | we want to update this W, ok, so how to do that, so we project the information of a single
00:43:45.820 | token with WV, we project it with WK, we compute this term here which is just W multiplied
00:43:54.820 | by this term here, we compute this loss here, and we compute its gradient, the gradient
00:44:02.900 | of this loss can be computed with the following formula, they actually specify, I can show
00:44:08.980 | you also how to derive it actually, so there is a formula here for the gradient, this is
00:44:14.980 | how we compute the gradient of the loss, how to compute this formula, well, how to derive
00:44:23.020 | it, let's talk about it, but ok, ok, one second, so they compute the gradient of this loss
00:44:32.500 | with respect to the parameters of the model, what are the parameters of the model? W, ok,
00:44:38.140 | so they compute the gradient of the loss of this loss with respect to W, and then we need
00:44:43.300 | to update W, how to update W? we need to compute this ST term here, this ST term results in
00:44:51.140 | the pass surprise, but we don't have any pass surprise, so let's suppose this is 0 for now,
00:44:55.740 | multiplied by a learning rate, multiplied by this theta, theta t is the learning rate,
00:45:05.780 | multiplied by this gradient that we have computed, and then we update this W using this term
00:45:11.980 | ST, now we have updated our memory, then we retrieve information from this memory, how
00:45:17.860 | to retrieve information from this memory? we just take this W and we multiply it by
00:45:21.820 | the, we take X, so our single token, we project it with another matrix called WQ, so that
00:45:29.700 | it becomes a QT, we multiply it by W, and now we retrieve information, this information
00:45:35.940 | is then sent to the first layer of the model, as compressed past information, and then to
00:45:43.460 | the second, to the third, etc, etc, to predict the output, the model will produce the first
00:45:49.140 | output token, then usually we put this output token back into the prompt, to generate the
00:45:55.620 | next token, here, because we are not talking about just a transformer model, we are talking
00:46:00.580 | about a hybrid architecture that has attention layers plus neural memory, we need to update
00:46:06.100 | our neural memory with this new incoming token, so this new incoming token will again be used
00:46:13.220 | to update the memory, the memory will be updated with the information of the new token, it
00:46:18.660 | will not be replaced with only this new token, so we hope that the new memory will encapsulate
00:46:24.740 | information about the first token that we fed before and the current token, what we
00:46:30.660 | will do practically, we will take this new token that was output by the model, we will
00:46:35.540 | project it through WV, and it will become VT, we will project it through WK and it will
00:46:40.860 | become KT, we compute this loss term, we compute the gradient of this loss, and we update our
00:46:49.300 | neural memory like before, but we have the past surprise this time, so, because we are
00:46:54.900 | not just, and we also have the previous memory, so we are updating this W, and hopefully this
00:47:00.500 | will contain information about the token number 2 and the token number 1 that we fed before,
00:47:06.580 | now as you can see, because we are training the neural memory at test time, because now
00:47:11.540 | we are inferencing the model, we hope that it will perform better than a neural memory
00:47:16.740 | that has only been trained at training time, because at each step of this update, the neural
00:47:28.100 | memory is actually trying to minimize the loss against this particular data, not only
00:47:33.380 | the data that it has seen during training, but only exactly on this particular data that
00:47:37.700 | is seen exactly in this moment, I know that I fed you with a lot of information, but I
00:47:43.500 | hope now it should be a little more clear on practically what it means to have an inner
00:47:48.800 | loop and an outer loop, so when we train the model, we update the parameters of this big
00:47:55.700 | model to leverage whatever the memory creates, and the memory does not learn to compress
00:48:04.420 | information only at training time, but also at inference time, exactly on the data that
00:48:09.780 | you feed it at inference time. Now let's talk about the problems of this memory, so the
00:48:15.180 | problem of this memory is that every time, as you can see, every time we need to run
00:48:19.060 | a gradient descent on each single token, so this looks like it takes, you need to train
00:48:25.860 | the model, you have a very big list of tokens and you want to train it as fast as possible,
00:48:32.900 | but if you need to update the memory one token at a time, it's very slow, but fortunately
00:48:37.560 | in the paper they also propose an algorithm to parallelize this training, and this training
00:48:45.980 | can be parallelized actually not on the full sequence, but only chunk by chunk, which is
00:48:51.120 | still better than doing one token at a time, so imagine you have one million tokens, if
00:48:56.020 | we cannot parallelize it, it means ok, first take the first token, update the memory, then
00:49:00.880 | take the second token, update the memory, then third token, update the memory, so we
00:49:03.980 | need to do one million times this and we cannot exploit our GPUs because we have to do one
00:49:10.100 | operation at a time, what they propose in the paper is a hybrid algorithm, so it's not
00:49:17.020 | fully parallelizable on this entire sequence, but chunk by chunk, which is a good compromise,
00:49:22.060 | it means that if you choose, imagine you have one million tokens and you choose a chunk
00:49:27.260 | size of let's say 1,000, you can parallelize the first 1,000 tokens, then you take the
00:49:37.380 | next 1,000 token and you parallelize this one, so in total you will compute 1,000 steps
00:49:42.720 | and not one million steps, if you choose a chunk size of 1,000 over a sequence length
00:49:48.060 | of one million, they also say ok how to leverage this neural memory module, you can use it
00:49:54.860 | as a contextual memory, means that if you have a hybrid architecture in which you have
00:49:58.700 | attention and this neural memory, so the one like the one we draw before, what we can do
00:50:06.820 | is we take the sequence that is input by the user, because the neural memory, it's job
00:50:12.700 | of the neural memory is just to compress information, we retrieve whatever is in the memory, we
00:50:18.660 | append it to the sequence, prepend it to the sequence, along with some other persistent,
00:50:25.340 | ok we can even not talk about the persistent memory tokens because I believe they just
00:50:29.260 | overdid all this stuff, I mean this system could work even without the persistent memory
00:50:34.820 | tokens, so we take our sequence, we prepend whatever information is in the memory, we
00:50:43.780 | feed it to the attention module and we use the output of the attention to update the
00:50:48.300 | memory and to produce the output, so let's go to our architecture, in this case basically
00:50:57.500 | it would mean, imagine we have fed already 10 tokens to this memory and now we are trying
00:51:04.220 | to predict the 11th token, what it would mean is that I would take this 11th token, I would
00:51:11.980 | input, convert it into embeddings, I would retrieve whatever is inside the neural memory,
00:51:19.660 | so imagine the neural memory gives me, because it's job is compressing right, even if I fed
00:51:24.100 | it 10 tokens, it doesn't have to return me 10 tokens, it has to return me a compressed
00:51:27.660 | version of these 10 tokens, suppose the ratio is like, suppose that the compressed state
00:51:33.620 | is 5 tokens, so I would take these 5 tokens, prepend it to my single token, it would become
00:51:39.020 | 6 tokens, I fed it to the first attention layer, take the output of the attention, update
00:51:44.060 | it and combine it with the output of the attention to get the output of this layer and feed it
00:51:50.020 | to the next one, this is the neural memory as context usage, the other usage is memory
00:51:58.460 | as gate, which is this architecture here, so in this case I have our 11th token, don't
00:52:09.220 | think about persistent memory, I believe, it's just an overdoing, you don't have to
00:52:17.780 | use persistent memory to make this mechanism work, they take this 11th token, they put
00:52:25.180 | it in the memory, so now we update first the memory, and they also feed it to the attention,
00:52:32.260 | and then they combine the output of the neural memory, which contains 11 tokens, but when
00:52:37.620 | we retrieve it only gives us 5 tokens, and then the output of the attention, which we
00:52:41.580 | only fed 1 token, and it's combined to produce the output, or you can only use the memory
00:52:47.260 | as a module without any attention, which means that basically you skip all this part, so
00:52:53.460 | you take your input, which could be 1 token, 1 million token, whatever, you update the
00:52:58.300 | memory continuously, you take the compressed version of the memory, and you feed it directly
00:53:03.020 | to the linear layer that will produce the logits, this is what they refer to as memory
00:53:08.780 | as layer, honestly you can create 1 million variants of this architecture, the point is
00:53:15.660 | not how you use it, the point is how it works, so I want to punctualize how it works, so
00:53:22.300 | we are training a module at test time, which is different from what we do with recurrent
00:53:29.140 | neural networks, so recurrent neural networks are trained at training time, and their job
00:53:33.860 | is to compress data, but because they do very well the job of compressing the data they
00:53:41.500 | have seen, they may not function very well during inference, because they may see some
00:53:47.240 | data that they have never seen, however, by having a memory like this that you can train
00:53:51.900 | at inference time, and with an algorithm that is supposedly parallelizable, we can avoid
00:53:59.180 | hopefully this problem, because the only job of the memory is to be able to retrieve, so
00:54:04.860 | I actually like this paper because I believe that it's a novel idea that I didn't think
00:54:10.580 | about before, and I think it's ok, this is part of a bigger, actually ok, I've been researching
00:54:18.140 | a little bit about this area for a while, it's called test time training, but this particular
00:54:25.020 | architecture was a little bit innovative in this field, what else do we need to know to
00:54:34.380 | read this paper, I think now you should have the information to read this paper, because
00:54:39.660 | we have talked about how to update this memory, and what is this memory, this memory is just
00:54:44.180 | a linear layer, in the paper they also say that ok, this memory doesn't have to be just
00:54:49.220 | a linear layer, it can be a multilayer perceptron, so it can be for example two layers with an
00:54:54.500 | activation in between, and it will work in the same way, and the algorithm that they
00:54:58.620 | have devised that is parallelizable would work also with this multilayer memory, we
00:55:07.820 | didn't talk about persistent memory, but the persistent memory are just tokens that
00:55:11.580 | are prepended to the input, and they don't belong to the neural memory, they belong to
00:55:16.980 | the outer loop as they call it here, the outer loop is just this model here, and this is
00:55:25.820 | the inner loop, but ok, this system can work without persistent tokens, this is my claim,
00:55:34.260 | if you look at the benchmark, it looks like that compared to the other architectures that
00:55:38.260 | are like Mamba and the current neural networks, it performs better, if you check the average
00:55:46.140 | score over all these benchmarks, I believe ok, this is a promising area of research,
00:55:52.100 | I will probably be looking forward to the code which has not been released yet, but
00:55:57.940 | thank you guys for spending time with me, I hope I gave you enough at least intuitions
00:56:02.760 | into how it is happening, and I'm also really eager to look at the code, because I think
00:56:07.020 | the best way to learn about a new architecture is actually to look at the code, so have a
00:56:11.740 | good night!
00:56:12.260 | [BLANK_AUDIO]