back to indexGoogle Titans: Learning to Memorize at Test Time

00:00:00.000 |
Welcome, everyone. Let me share my window here. 00:00:21.080 |
Yeah. I was playing around with this new presentation tool, gamma.ai. And I can't go back to PowerPoint 00:00:31.320 |
That's huge praise. Yeah. A friend of mine actually quit to work for this company. And 00:00:37.760 |
I was like, another AI Slides company? You know, these never work. Shows what I know. 00:00:44.440 |
It's really impressive. So I'm not paid by the company in any way, but I do recommend 00:00:52.800 |
So this week, we're going to talk about a new paper out of Google Research, Titans. 00:01:01.840 |
The main thing that this new architecture presents is giving the model some memory, 00:01:11.240 |
especially at inference time, so that it can leverage that to make better inferences and, 00:01:25.600 |
And I'm generally not going to be looking at the chat while I present. There's a lot 00:01:31.920 |
to go through, but I'm going to leave plenty of time at the end for discussion where we 00:01:41.520 |
So yeah. So do feel free to drop posts in the chat, and we will get to them. 00:01:55.200 |
So let's see. Okay. I'm trying to figure out how to -- wait a minute. There we go. Okay. 00:02:15.420 |
So what are some of the problems that this Titans paper is trying to address? The first 00:02:22.760 |
of which is the quadratic complexity nature of attention. 00:02:30.600 |
And this has been one of the things that typically was, up until, say, the last year or so, really 00:02:39.680 |
restricting the context window size of models. 00:02:44.640 |
So if I remember correctly, when GPT-4 came out, it first had a size of 8K or 32K, something 00:02:54.080 |
like that context window, but nothing compared to the 1 million context window of Gemini. 00:03:02.400 |
And the reason for that was that for -- in typical attention, every token is compared 00:03:10.880 |
to every other token, so that if you, say, have a context window size of 1,000, then 00:03:18.120 |
the complexity is 1,000 squared. And if you have a, you know, window size of 10,000, then 00:03:30.040 |
And so in order to increase the size of your context window, it becomes, like, very computationally 00:03:39.200 |
expensive very quickly. And there are ways around this that are mentioned in the paper, 00:03:44.200 |
but that's one of the problems with, let's say, classical transformers, which leads right 00:03:51.200 |
into the limited context window, again, based on the quadratic aspect. 00:04:00.240 |
And then the other thing is that the long context window creates the bottleneck for 00:04:08.640 |
the very long sequences of data. So transformers are great, and, you know, there's been tons 00:04:18.800 |
of advances in them, since the attention is all you need paper, but, you know, there are 00:04:29.760 |
And so what does Titans do to try to address that at a very high level? The first thing 00:04:38.240 |
it does is to emulate human memory, meaning that human memory has different -- you can 00:04:48.200 |
think of them as modules or types of memory. There's short-term memory and long-term memory. 00:04:56.560 |
However, in the typical transformer model, as we'll see, those have, like, short-term 00:05:05.240 |
memory because they have access to the tokens that they're currently processing, but anything 00:05:11.320 |
that happened before that, they don't have memory of that. 00:05:17.240 |
And as you do AI engineering, you probably are well aware that every time you make an 00:05:25.480 |
API call to OpenAI or Cloud, it's, like, starting all over again. And then the second part is 00:05:36.520 |
the enhanced processing. So the memory allows very long, like, extremely long sequences 00:05:43.040 |
to be handled, so that -- and we'll see that as the tokens are processed, they -- the model 00:05:52.400 |
keeps a representation of just the most salient or important, let's say, facts. It's a representation 00:06:01.440 |
of facts, but let's say facts from the tokens it's seen so far so that it can use those 00:06:14.240 |
So there was some work by other groups, and this group as well. So this isn't, like, a 00:06:24.760 |
sudden thing. This paper, test-time training with supervision, goes back to 2020. So there's 00:06:36.120 |
a few things that this paper represented. The first is using that unlabeled test dataset 00:06:45.760 |
into a self-supervide learning problem so that it can update its own parameters before 00:06:55.280 |
making a prediction. And they noticed performance improvements with this. So that's one paper 00:07:03.960 |
that previously came out. And I didn't -- I'm going to just mention two papers. I didn't 00:07:12.400 |
read either of them in depth. I just kind of read the abstract and got the gist of them. 00:07:19.220 |
But if you're very interested in this, both of these are good sources. And then the other 00:07:25.320 |
one is this RNNs or learning to learn at test-time about RNNs with expressive hidden states. 00:07:34.720 |
So this is -- and we'll see this in the Titans paper as well, that it does bring an RNN or 00:07:45.960 |
recurring neural network technique back into transformers where there is a state inside 00:07:56.960 |
the -- well, it's the hidden state. And as the tokens are processed, that hidden state 00:08:04.360 |
is updated and then allowed to affect future outputs. So this is the paper that brought 00:08:17.560 |
that into the research community. You can also see the -- one of the main points from 00:08:25.440 |
this paper is the linear complexity. So using this to get around the quadratic complexity 00:08:32.080 |
of the typical transformer. So as I mentioned, Titans is inspired by human memory. There's 00:08:48.400 |
core or short-term memory. So in the Titans models, this is the equivalent of just a normal 00:08:59.320 |
attention mechanism where you have the query, keys, and values sets, you know, all interacting 00:09:11.100 |
in the transformer component. So this is the short-term memory. So they don't really introduce 00:09:21.240 |
anything about short-term memory. They just reuse what is already out there. Long-term 00:09:26.920 |
memory is where they learn to memorize historical context, encode that context into abstractions 00:09:37.600 |
so that it can improve future tokens. This is the main thing that they talk about in 00:09:43.760 |
the paper is this long-term memory module and how exactly it operates. And then the 00:09:52.040 |
persistent memory is not a, like, inference-dependent thing. So this is -- you can think of this 00:10:03.560 |
as, like, a set of knowledge or a set of rules that always exists in the model. It's hard-coded 00:10:12.840 |
in there. And it's always brought into the -- into the sequence. We'll see how that happens. 00:10:20.780 |
They don't talk a whole lot about persistent memory. So it's -- at least for me, it was 00:10:26.880 |
a little unclear where -- how exactly this is created. But it's also not the main point 00:10:34.520 |
of it. It's just kind of a set of parameters that are always fed into the sequence that's 00:10:43.080 |
going to be predicted against. Okay. So let's take some time digging into 00:10:53.940 |
this long-term memory module. So this is the -- kind of the main breakthrough that this 00:11:05.240 |
paper is presenting. There's a few different things you can see here as far as what this 00:11:12.960 |
module does. The first is recurrent processing. So like I mentioned, it maintains a state 00:11:22.660 |
as it's doing inference. And then that part of the output is then fed back into that state 00:11:33.220 |
so that it can continue to capture new information as it goes along. This helps it to kind of 00:11:41.700 |
learn the task that it's doing. And also to do those needle in a haystack things where 00:11:49.740 |
maybe some information early on in the sequence is relevant to something much later. So it 00:11:56.720 |
can keep that in memory. It's also a memory component. So this is the 00:12:03.300 |
thing that is responsible for generating representations for, you know, the information that is coming 00:12:13.800 |
through it at inference time. And then let's see. The last bullet point about sequential 00:12:23.220 |
data handling. Yeah. So as I've been mentioning, like, we're assuming this is all happening 00:12:33.300 |
in a sequence of tokens in the case of a large language model where it's processing one token 00:12:41.820 |
after another. It's not all at once. That's pretty -- pretty taken for granted for any 00:12:51.680 |
large language model. So let's keep going. So we'll dig into a few different aspects 00:13:02.240 |
now of the long-term memory module. So here you can see some points. Let me pull out a 00:13:15.960 |
couple of them. One is this weight adjustments, which is, I think, one of the most interesting 00:13:21.440 |
things about the Titan's architecture. So for, like, almost all LLMs, or at least the 00:13:29.500 |
ones that I'm aware of, like a GPT-4, a Sonnet, Claude Sonnet, whatever, they do a lot of 00:13:39.040 |
pre-training. They do, you know, instruction tuning, all of that. But once they're done 00:13:45.680 |
with that, then the weights are just the weights. And even if you think of, like, the deep 00:13:51.320 |
seek release last week, like, it was just a bunch of weights. So these don't -- these 00:14:00.360 |
don't adjust in memory, no matter what you put through the model. Whereas this Titan's 00:14:07.160 |
model, at least in the context of a single inference, can adjust weights on the long-term 00:14:14.080 |
memory module, which I think makes it, like, a very interesting new approach. 00:14:23.600 |
And for the continuous adaptation, so that kind of plays off the weight adjustments. 00:14:32.320 |
So that's something that, you know, as it's going through a long sequence of perhaps even 00:14:38.440 |
millions of tokens, it can continue to actually learn about, like, the data that it's processing. 00:14:51.240 |
And that's another thing that they really emphasize in the paper, is creating a model 00:14:57.160 |
of -- when I say model, I don't mean, like, a machine learning model. I mean, like, an 00:15:01.960 |
abstraction for learning. And what does that mean for a model to be able to learn? 00:15:12.400 |
So we've talked about how the model updates its weights based on the data that comes through. 00:15:23.460 |
And so how does the model know, like, what is interesting, like, what is worth keeping, 00:15:31.700 |
and what isn't? Because it doesn't have enough parameters for long sequences to capture everything 00:15:39.040 |
that comes into it. So it's going to have to compress whatever data is coming through. 00:15:46.540 |
So how does it decide how, like, what to remember and what not to? So this surprise mechanism 00:15:54.100 |
is the main way that it does. So information that's considered surprising is for humans 00:16:06.520 |
more memorable. And so they took the leap that for a model, it's also going to be more 00:16:13.700 |
memorable. Or more important to remember. I'm just looking through the slides. So, okay. 00:16:25.300 |
So the point is, well, how do we know if information is important or not? The main way they decide 00:16:34.260 |
that is the gradient of the neural network with respect to the input data. So you can 00:16:40.380 |
see that -- hopefully you can see my mouse. But you can see that down here. The gradient 00:16:45.840 |
of this incoming input data with respect to the memory from the previous time step gives 00:16:56.140 |
us the amount of surprise. And it -- you can see it gives preferential treatment to this 00:17:04.620 |
surprising information. The more surprising it is, the more likely it is to be stored 00:17:10.140 |
in memory. So you can imagine for, like, a long sequence of maybe documents, maybe there's, 00:17:18.700 |
I don't know, 100 documents that are all legal briefings. And then it comes to something 00:17:27.780 |
that is maybe a bill of lading or, you know, a product description or something like that. 00:17:36.740 |
So then it's going to be like, oh, this is completely different. I should remember something 00:17:40.700 |
about this. So it's going to prioritize keeping what it considers surprising or information 00:17:50.640 |
that it hasn't seen before. And then it stores this in key value pairs in that long-term 00:18:01.180 |
memory module. So in addition to surprise, there's also forgetting. 00:18:14.820 |
So I should have mentioned in the last slide that this data parameter here controls the 00:18:26.340 |
amount of, like, kind of the impact of surprise. So it's a tunable parameter that can be increased 00:18:37.220 |
or decreased to either remember more surprising information or forget it. The momentum and 00:18:46.860 |
dynamic forgetting means that it can forget things so that it does discard irrelevant 00:18:55.220 |
information as it goes along. So there's this, you can see this one over alpha where it is 00:19:05.060 |
slowly degrading the older memory. So as things get older and older, they fade out. I heard 00:19:16.860 |
someone come off mute. Is there a question? >> Yeah. That MT parameter, is that the output 00:19:25.020 |
of the memory module or the model itself? >> That is essentially the state of the memory 00:19:33.960 |
module. >> Okay. And so the surprise has the same dimensions, 00:19:41.940 |
the same form as the memory itself? >> Why don't we look in the paper after the 00:19:52.260 |
presentation and we can, or if you want to look at it. I can't answer that right off 00:20:00.220 |
the bat. >> I would assume so because it's just a simple 00:20:04.060 |
addition to it. And so kind of curious, I guess the origin of this is that, you know, 00:20:12.080 |
they show different structures for how the memory was included and it's, like, added 00:20:17.440 |
in different parts of the architecture stack, if you will. I'm kind of curious if that would 00:20:25.920 |
stay the same across their form of the architecture or if it changes depending on where the memory 00:20:35.200 |
is included at. >> Yeah. That's an interesting question. I'm 00:20:42.640 |
going to go through the different architectures they introduce. So hopefully, I don't know 00:20:50.860 |
if that will answer the question, but at least it will give us enough context for discussion. 00:20:55.920 |
I'm going to leave plenty of time for discussion, so this would be a good thing to talk about. 00:21:08.560 |
And then it also has this momentum that, where it combines past surprise with a decay factor. 00:21:17.300 |
So that's this eta t and I guess the previous surprise. So somehow this provides a momentum 00:21:26.500 |
mechanism to the model. Okay. So those are kind of the components that 00:21:39.920 |
go into a Titan's model. Now, let's look at the different ways they combine them because 00:21:47.080 |
there was this open question of, like, okay, we have this module that we think would be 00:21:54.600 |
useful for the model. Like, let's see what's the best way to then incorporate that. And 00:22:06.040 |
they come up with three different options. Memory as context, memory as layer, and memory 00:22:14.380 |
as gated branch. And each one of these has their strengths and weaknesses. So let's first 00:22:23.600 |
look at the memory as context. And I'll just walk through this diagram a little bit. So 00:22:32.900 |
we can start looking over here at the -- this is the input sequence. So this is essentially 00:22:39.060 |
your prompt or what you are sending into the model. This is then used as a key or a query 00:22:48.200 |
-- I guess a query into the memory module that then returns, like, the most relevant 00:23:01.360 |
information from the memory as a sequence of tokens. And then at the same time step, 00:23:13.120 |
you can see this is the persistent memory. So this is kind of the hard-coded information 00:23:17.860 |
that just always has the same token that adds to the sequence. So you can see that before 00:23:25.960 |
we go to attention, we're adding these extra sequences. So in my previous example of, like, 00:23:36.800 |
legal documents and maybe a product description, the retrieval could say, like, okay, this 00:23:46.920 |
sequence looks like -- or this next part of the sequence looks like a legal document. 00:23:52.400 |
So let's grab any information we have about the legal documents, put it here. And then 00:23:59.160 |
this persistent memory might just be about how to generally handle incoming requests, 00:24:05.960 |
something like that. It goes through attention. So this is the quadratic thing where every 00:24:12.160 |
token pays attention to every other token. It then, based on the output, sends an update 00:24:20.000 |
to the memory, so updates the actual weights here with whatever relevant information came 00:24:27.640 |
out. Then I'm not sure what this is. Some kind of -- maybe an XOR combines these and 00:24:38.520 |
then gives you the final output. And so this is what the attention masks look 00:24:51.160 |
like for memory as context. You can see that each one of these is treated as, like, a small 00:25:01.440 |
context window. So it's, like, kind of little batches of tokens being paid attention to 00:25:14.720 |
together. Then we add the long-term memory onto that. And then finally the persistent 00:25:22.840 |
memory. And so this kind of gives you an idea of, like, inside the attention mechanism, 00:25:31.040 |
what exactly it's seeing. And next, memory as gate. So this is a different 00:25:44.560 |
approach. It does not, like, query the memory before sending it into the attention mechanism. 00:25:57.200 |
So it sends a query to the memory at the same time as going through the attention mechanism 00:26:06.280 |
and then combines them afterwards. And so the memory -- like, whatever is stored in 00:26:13.880 |
the memory does not go through the attention mechanism at all. 00:26:24.080 |
Yeah. So this is -- this uses the memory as a gating mechanism. I guess right here. Which 00:26:38.080 |
I guess would allow -- I'm assuming allow certain tokens through and some tokens not. 00:26:45.960 |
We can discuss this in the discussion about exactly how this works. 00:26:55.080 |
And then here's the attention masks for the memory as gate. You can see it's quite different. 00:27:04.320 |
Here this is a sliding window that, you know, each token pays attention to a certain number 00:27:12.980 |
of tokens on each side of it. You can see the long-term memory, how next to the input 00:27:21.600 |
sequence the memory is strong. And then it gradually fades out as the tokens are, you 00:27:33.720 |
know, further in the future, if you will. And then finally, the persistent memory just, 00:27:40.800 |
like, adds to the very front of the sequence. This is what the attention masks look like 00:27:48.080 |
in the memory as gate. And then finally, memory as layer. So this 00:27:57.720 |
is in some ways similar to what some transformer models do, where they have the attention mechanism 00:28:07.360 |
and then there's a feedforward mechanism as well. So this is, like, the closest thing 00:28:14.120 |
to that, where it just goes right into the memory, whatever comes out of the memory goes 00:28:22.880 |
into the attention. Yeah, and it says the, can I take advantage 00:28:32.100 |
of the complementary data processing of intention in the neural memory module so there's no 00:28:39.760 |
combining after either before or after. It's just, like, one more step in the architecture. 00:28:54.640 |
And then finally, the, they also, in their experimentation, look at memory without attention. 00:29:02.480 |
So this is essentially just a long-term memory module by itself. There's no attention mechanism. 00:29:11.000 |
And they just look at, like, how does this perform? Just purely a long-term memory module 00:29:17.040 |
without attention or anything. Okay. So now we're getting into the last part 00:29:26.920 |
of the paper, the experimental setup and results. So they test all four of these variants and 00:29:36.680 |
they test them at different sizes. And I thought one thing that was interesting is that these 00:29:42.400 |
are pretty small sizes, at least compared to, you know, your current, your modern LLM. 00:29:50.840 |
So, like, a LLAMA7B is considered a smaller model. And if you look at, like, a, you know, 00:30:04.480 |
like a GPT4O that has probably hundreds of billions of parameters. So this size is pretty 00:30:13.320 |
tiny compared to, like, the models we use day to day. 00:30:24.400 |
So they gave a big table of their results for language modeling. And they also threw 00:30:32.440 |
some common sense reasoning benchmarks into here. So you can see all the benchmark names 00:30:39.240 |
across the top. And then the best performing ones are highlighted. The tan highlights are 00:30:48.820 |
for hybrids. And then the blue highlights are for pure or just, like, normal. Or yeah. 00:31:00.560 |
So you can see that basically Titans wins at everything here versus Mamba, Deltanet. 00:31:11.440 |
This is test at the time of -- or no. Anyway, this is one of the previous works of memory 00:31:19.000 |
testing. So you can see that it wins at everything. This is for the 340 million parameter model. 00:31:30.200 |
And you can also see the number of tokens they train on is pretty small. Modern LLMs 00:31:38.360 |
train on, like, low trillions of tokens. So this is not much data at all. 00:31:46.800 |
Then -- so you can see, like, language modeling, it does quite well. And then here's their 00:31:58.240 |
LLNA Haystack. So for this test, they had some information early on in a sequence. And 00:32:06.200 |
then a bunch of, like, filler tokens. And then some sequence that needed those very 00:32:14.760 |
early tokens. Like, to understand what was happening. And so you can -- the Titans is 00:32:23.560 |
the red stars here. So you can see that they are maintaining their performance quite well. 00:32:31.400 |
Even out to -- here, if we go to the fine-tuning setup, 10 to the 7th. So even out to 10 million 00:32:39.000 |
tokens. I mean, they did take a performance hit. But still doing much better than every 00:32:47.000 |
other model. And so I think this is, for me, one of the most interesting charts that just 00:32:54.160 |
shows that as this becomes more productionalized, there's going to be the opportunity to have 00:33:01.600 |
longer and longer context windows where maybe you can feed in, like, a bunch of YouTube 00:33:07.960 |
videos, like, hundreds of pages of PDFs, and all this stuff, and still have it be -- give 00:33:16.520 |
you, like, relevant output. So with that, that is the end of my presentation. 00:33:26.720 |
So let me pop open the chat here. See what's going on. And if anyone wants to speak up, 00:33:41.960 |
make comments, make any corrections to what I said, like, I'm not an expert in this, and 00:33:49.540 |
I'm also not an AI researcher. So if you have insights, would love to hear those. And also, 00:33:57.880 |
if anyone has answers to questions, feel free to chime in. Again, like, I'm not the expert 00:34:05.320 |
on this paper. Like, I read it and understand it. But I know there's a bunch of very smart 00:34:11.060 |
people on this call. So with that, I'll open up the floor. 00:34:18.860 |
>> I can also help with questions, but I'm struggling with the chat window. 00:34:33.720 |
>> Yeah, I mean, I want to validate Cosmin's frustration with this paper. Yeah, I mean, 00:34:40.680 |
look, like, I think they did try to illustrate the memory mechanisms somewhat, but not super 00:34:47.640 |
-- it's not super clear. And I always wish that these things came with code. I really 00:34:53.520 |
like the name of papers with code, because this one needed code. And, you know, maybe 00:34:59.360 |
they released it, but I didn't see it in the inside of the paper. 00:35:03.360 |
>> They said at the very bottom of the paper that they're planning to release code soon. 00:35:07.840 |
>> How Chinese of them? >> Whatever that means. 00:35:11.960 |
>> Yeah, I didn't understand if the diagrams refer to one step or multiple steps. Like, 00:35:18.560 |
I think -- and also, the diagrams are 3D, which makes it a bit more confusing. Like, 00:35:25.440 |
they say when you update the memory, you do memory of query, and then you get, like, vector 00:35:31.200 |
as output. What does that mean? Is that k-nearest neighbor lookup? Is it some attention? Maybe 00:35:38.520 |
we can go to the first slide, where they update the memory equation. I don't know, Eugene, 00:35:44.320 |
if you got time to -- got any time to look, but I would be interested in just one of those 00:35:50.920 |
operations, how does it actually happen? Like, I understand at the high level, we have a 00:35:55.840 |
memory module, we update, it's nice to forget, they somehow figured it out, but, like, what 00:36:01.600 |
layers or what do they actually do? Like, even lower, you have some -- if you go a bit 00:36:08.940 |
lower. Yeah, you see -- >> Retrieving memory. 00:36:15.400 |
>> Yeah, what does that -- maybe I didn't read or, like, even this figure one, I didn't 00:36:22.480 |
understand at all what's going on. >> Yeah, I also skipped this, because I didn't 00:36:28.520 |
understand what was going on. >> I can explain the in-context learning one 00:36:38.360 |
with -- and I can -- with the parallels of what happens on RWKB7 and Titan, based on 00:36:46.520 |
what I understood from the paper. So, think of it as this way. 00:36:51.840 |
>> Which figure? >> This whole segment, the long-term memory 00:36:57.040 |
training and the surprise formula. It's actually a lot easier if you explain it using simplified 00:37:05.280 |
-- >> Do you mind saying which? So, like, we read 00:37:07.440 |
while you explain. No, this is great, like, having an expert explain it, but which figure 00:37:12.120 |
is it or which formula? >> So, you scroll up. I'm trying to, like, 00:37:18.440 |
figure out the page numbers as well. So, we are talking about, very specifically, the 00:37:24.520 |
segment tree 3.1 and the surprise metric there, downwards, that whole memory architecture. 00:37:32.280 |
So, one way to view it, right, is that -- let's just say a standard problem. Let's say the 00:37:39.480 |
quick brown fox, correct? And this is -- this is a piece of text that exists in so many 00:37:44.480 |
training corpus that all LLMs will probably memorize this phrase itself, the quick brown 00:37:48.960 |
fox. There is no surprise there. So, because the surprise score is zero, there is no -- there's 00:37:56.440 |
no back propagation required, per se, to update the memories. 00:38:01.560 |
If you view the memories as, let's just say, a 4,000 -- I don't know what the dimensions 00:38:07.720 |
are here, because they never disclose, but let's just say a 4,000 by 4,000, 4096 floating 00:38:13.120 |
point value, in our case, it's BF16. If, let's say, we said the quick brown deer, then the 00:38:24.480 |
model is like, hey, I wasn't expecting deer, I'm expecting fox. There's a difference there. 00:38:30.000 |
That difference there -- I'm oversimplifying the math, because this is not accurate -- can 00:38:34.480 |
be converted to a loss score that you can back propagate on. So, it's when you see that 00:38:39.720 |
when the models see differences, do you update this memory, this memory that's being shared 00:38:45.200 |
between tokens? Now, where it differs for the Google paper, 00:38:50.520 |
Python and RWKB, which we are testing as well, is that the -- 00:38:54.200 |
>> Eugene, so you have a separate key value store that you attend at the same time as 00:39:01.400 |
you attend the current token. So, that's your budget, right? And you attend the whole thing, 00:39:08.980 |
okay? And does that help or doesn't it help? And when it's surprising, you need to update. 00:39:15.920 |
So I wonder how they send the updates to the key value store. Like, what actually happens? 00:39:21.640 |
Like, what's the loss of the key value store and the current token? Go ahead. 00:39:26.040 |
>> So if you want to view it as simplified code, which is not how they implement it, 00:39:31.320 |
because -- is that you can view your model weights during the forward pass, right, as 00:39:37.200 |
frozen, just view as frozen. And the -- so, the quick brown, let's say the quick brown, 00:39:44.200 |
those tokens, right, generate a state. You take that state, and then let's say instead 00:39:50.040 |
of -- let's say we see fox, and then we say deer instead, right? There's a difference 00:39:55.080 |
in expectation. The model expected fox, it got deer instead. So because the model weights 00:40:02.760 |
are frozen, if you do the forward pass and then you want to correct the model's thinking 00:40:06.880 |
and you do the backwards pass, the only way to update it is this state value. So you do 00:40:12.200 |
the backwards pass, you update the state values, then you take that state and you go -- you 00:40:17.440 |
process the next token. So you continue your sentence completion. And that essentially 00:40:22.640 |
is what the surprise mechanic is about. It's about, hey, it didn't give the output we required, 00:40:31.760 |
and then we take that difference and then we convert it into a score. Where this differs 00:40:39.360 |
from RWKB is that we don't use a surprise mechanism. We are currently using more closer 00:40:47.400 |
to the standard gradient descent. So the difference here is that in a surprise mechanism, so if 00:40:54.840 |
let's say you expected fox and you get fox, for example, there's no -- the loss is essentially 00:41:01.920 |
zero. There's no backprop. But in RWKB's case, right, if let's say that you -- it expected 00:41:10.640 |
fox and it actually got the fox, and since the way logics work, there's always a zero 00:41:15.800 |
point -- let's just say a zero point something percent difference, there's still a loss score 00:41:19.960 |
being calculated there. So even though it was not a surprise, we still do the backpropagation 00:41:24.200 |
process. In practice, is this better or worse? I have absolutely no idea. This is something 00:41:28.760 |
we need to test and evaluate on. But that's the key difference on how we handle the memories 00:41:34.040 |
segment. It's all about, like, every token you forward, you backprop and then you update 00:41:41.200 |
the weights. This is -- >> You just -- instead of you propagate the 00:41:48.240 |
signal through the frozen weights and just update the keys and values. One question that 00:41:53.200 |
I got was how do they manage a fixed size key value store? Basically, how do they decide 00:42:02.080 |
what to drop and stuff? And they have this -- I didn't understand their gating in here. 00:42:09.440 |
But basically, yeah, over time, you'll see lots and lots and lots of things. So you kind 00:42:14.480 |
of need to figure out, like, if you're efficiently using your memory, then you solve the problem, 00:42:21.160 |
basically. >> So this goes to the AI black box. We decide 00:42:26.080 |
the key value store to a specific size. That's part of the architecture design. This is the 00:42:30.600 |
same thing as the model dimensions as to how the model decides what to keep and drop, right? 00:42:37.000 |
That is specifically decided by the model. So I think one -- there's another segment 00:42:44.840 |
where it highlights the decay, right? Let me find that segment. 00:42:48.080 |
>> Could you scroll a bit down? >> Sorry if I'm not looking at the screen. 00:42:52.800 |
I'm looking at the paper to just find the decay. I think it was mentioned -- 00:42:58.080 |
>> Yeah, it's a forgetting mechanism. Sorry. It's equation, yeah, 13. 00:43:07.960 |
>> Yeah, so the idea behind the decay mechanism, right, and this part is consistent with other 00:43:14.600 |
theories, is that by default, every token you move forward, you are slowly forgetting. 00:43:23.400 |
And the forget rate is something that's trained in the model. So let's just say the value 00:43:29.520 |
is by default, everything you will forget in 32K tokens. So if you forward 32K, you 00:43:36.880 |
should forget it. This makes it sound similar to sliding window attention in that sense, 00:43:42.280 |
but the decay mechanism is supposed to work together with the -- with, let's say, the 00:43:48.240 |
surprise mechanism or basically the model's -- this is a bit more gray, but basically 00:43:56.080 |
as the -- by default, you decay. You let the model compute against the state itself. And 00:44:01.720 |
the model may just decide, hey, this is important to memorize, so I'm going to reinforce that 00:44:07.040 |
number. So as every step it takes, right, in a way, internally the model just, like, 00:44:13.000 |
keeps backpropagating against the state and think, hey, do I need to reinforce this floating 00:44:18.120 |
point value? If I stop reinforcing, it will eventually decay, but if I want to reinforce 00:44:23.640 |
it, I can just, like, keep increasing the value and then keeping it within bounds. And 00:44:29.080 |
that's how it keeps -- kind of, like, by default, it will slowly forget, but if it thinks it's 00:44:34.280 |
important, it will keep trying to remember it over larger context time. To be clear, 00:44:38.400 |
this is -- this part is really theorycrafting, because even for RWKB, the highest we ever 00:44:44.600 |
push a model to is 32K, and we are now experimenting at 64K. This theorycrafting is supposed to, 00:44:50.400 |
like, extend to, like, 1 million token per se. So it's something that we definitely need 00:44:58.440 |
to test. But the idea behind decay is so that by default, things will expire, and so the 00:45:04.320 |
model is able to clear up space to memorize new things, and it's able to also decide for 00:45:09.320 |
itself to keep things in memory. Not much different from how they describe it in Python. 00:45:19.640 |
>> Eugene, I have a question. So when you're referring to by default it decays, so in terms 00:45:26.760 |
of the surprise here, can we assume that the surprise by default is low for most of the 00:45:36.200 |
>> I think view it as -- let's just say -- let's just view it as a wreck scenario. A wreck 00:45:43.840 |
scenario. Let's just say your company is the most generic company on Earth, and then I 00:45:49.080 |
just put your company document there. There is no surprise. Then it just moves on. But 00:45:54.960 |
let's just say your company has some very proprietary, never-heard-before stuff. That 00:45:59.720 |
surprise will then be what is stored. So it's about the difference in information in the 00:46:08.200 |
fixed model weights compared to this floating-point state. Does that make sense? 00:46:17.600 |
>> Is there another -- I'm trying to read the questions from the chat. 00:46:34.480 |
>> Yeah. There's another one from Cosmin earlier about our needle in the haystack's interesting 00:46:42.720 |
tasks. I mean, I think they could be in real use cases. There's probably a lot where they're 00:46:50.800 |
not. But depending on your use case, I could see how if you just want to throw a bunch 00:46:58.560 |
of tokens into the model and not worry about the order or anything, that it would be useful 00:47:07.560 |
>> I think it's interesting also academically to just understand the maximum limit the model 00:47:15.840 |
can memorize in worst-case scenario. That's the way I view needle in the haystack. In 00:47:21.320 |
practical scenarios, that was one of the challenges about RWKB benchmarking this as well, and 00:47:29.800 |
the same thing for Titan, is that if, let's say, we train on Harry Potter book and you 00:47:35.440 |
just put the whole Harry Potter book as the right context, there's no surprise there. 00:47:43.640 |
And essentially, it can pass the right test with, what, 300k context length. But that's 00:47:50.160 |
not a correct test, per se. So needle in the haystack is meant to represent the worst-case 00:47:54.000 |
scenario. That's how I feel. And a lot of companies are a lot more generic than they 00:47:59.960 |
think they are. Also, I think in NeurIPS, the joint presentation that we did with Dan 00:48:08.120 |
Fu, both him and I agree that with other techniques, such as repeating twice, that works well for 00:48:18.020 |
both Mamba and RWKB, is that we may want to re-evaluate how we benchmark these things, 00:48:23.520 |
per se, when it comes to practical, right situations. Because if, let's say, the inference 00:48:29.040 |
cost for Mamba and RWKB is 100,000x cheaper, and it's right performance triples or quadruples 00:48:39.280 |
just by repeating the context twice, so then we just repeat the context twice. That was 00:48:46.200 |
one of the arguments. And I kind of agree, but it's also a very different test. 00:48:52.360 |
I see there's another question, Eugene, that you'd be good at answering is, what is the 00:49:11.800 |
I don't know. Aditya, do you want to clarify? 00:49:15.840 |
Yeah, this sentence is on the paper right in front of us. It says Dow and Gu, 2024, 00:49:22.120 |
or Vietor et al., 2023. It's a comparison of the weight decay mechanism. So I just have 00:49:30.640 |
no confidence that RNNs normally do. That's relatable. 00:49:35.440 |
Oh, okay. Later in the section, we show that the weight decay mechanism is closely related 00:49:52.800 |
to the gating mechanism in RNNs Dow and Gu, citing, I presume, the Mamba star? Yeah, okay, 00:50:02.240 |
the state space model, yeah. It's, I think, the state space model. This would be similar 00:50:08.080 |
to the weight decay that I explained earlier, which is basically, by default, it's for gating. 00:50:19.080 |
My understanding of Mamba is that they run these RNNs without gates, so that they can 00:50:25.400 |
run something like Fast Fourier Transform or some parallel scan on GPUs and run it fast. 00:50:31.200 |
And then they have some multiplicative process on each token. And when you do multiplications, 00:50:37.840 |
basically, you can use them as gates to how much of the signal you propagate. So they 00:50:42.680 |
add the nonlinearity and gating at token level, while old school LSTM, they had memory gate 00:50:51.160 |
for get gate, input gate at each step, and they were running it one by one. So it might 00:50:58.240 |
be, that might be one thing, like this token level multiplicative things that people have 00:51:05.960 |
in state space models. Yeah, if it's about that, then it's really more about how we restructure. 00:51:13.020 |
So both Mamba and RWBKB and Titan, is that with the way we restructure the formulation, 00:51:18.880 |
we do not need to wait for one token compared to another, unlike the old LSTM. So yeah, 00:51:27.160 |
I think that makes sense. Yeah, actually I think your explanation makes more sense. Basically, 00:51:31.440 |
if you contrast it with the old LSTM RNNs, all the newer gates, even though we are designed 00:51:36.200 |
differently, is designed in a way that doesn't have this bottleneck, where you need to wait 00:51:41.880 |
for one token after another. Whether is it through math hacks, which is what state space 00:51:50.040 |
model did, which is really impressive, honestly, that kind of math. But that's what they are 00:51:56.760 |
good at. And in our case, it's really more of like how we optimize things in the QDAR 00:52:01.240 |
forward. It achieves the same result. It's able to train in parallel, unlike old RNNs. 00:52:10.720 |
And I'm quite sure Google has their own optimization when it comes to training as well. 00:52:17.680 |
I have another question, Eugene, for you. I don't know how to compare baselines with 00:52:24.840 |
this model. Are you impressed with one point perplexity win on wiki or the other wins? 00:52:31.640 |
Or it's kind of any new model usually shows that kind of win. So to me, the win seem large. 00:52:40.440 |
So I think, is this something super strong or it's OK, not great? 00:52:48.480 |
To be honest, I classify this in promising, but we need to test further, because even 00:52:54.120 |
in our experience for RWKB, anything below 1.5B may not necessarily hold until 7B situation. 00:53:09.420 |
So we had reverted experimental changes where we made on 0.5B models, which is kind of what's 00:53:15.480 |
being tested here. Here is 340 to 760 million parents, where the perplexity loss was great. 00:53:21.680 |
It dropped much lower. And then when we scale it to 1.5B, it sucked. 00:53:29.940 |
So it's promising, but I want to test to find out more, because I think the true test is 00:53:37.880 |
testing on 1.5B and then 7B, which, to be honest, I'm quite sure the Google folks have 00:53:45.240 |
done it. They are not compute bound. It takes more effort for them to write this paper than 00:53:52.000 |
Yeah, I'm a bit skeptical, because there's some gossip that Googlers aren't allowed to 00:53:57.580 |
publish super impactful stuff. So it's interesting. 00:54:03.520 |
Yeah. I also would like to know what was your results for the larger models? 00:54:10.200 |
Yeah, I mean, I guess that's one interesting thing, is that back in 2018, the attention 00:54:17.480 |
is all you need days, Google researchers could publish anything they wanted, because there 00:54:22.640 |
was really no competitive advantage they were giving away. But these days, you wonder, if 00:54:30.360 |
they have a really big breakthrough, are they going to publish that, or are they just going 00:54:38.600 |
I have another question. Sorry, if someone could explain. The memory module is described 00:54:44.140 |
as a meta-in-context model. In quasi-layman's terms, what would a meta-in-context model 00:54:49.920 |
mean? There's like a sort of a parallel small model running on that, as if it was trained 00:54:58.100 |
So that's the part where I explained every time the tokens look different from what you 00:55:03.840 |
expected, it does the back propagation to the memory modules. So you can think of it 00:55:08.360 |
as a-- I'm oversimplifying the math calling back propagation. You can think of it as training 00:55:14.880 |
the memory modules. It's inefficient to do it that way, because we use matrix multiplication 00:55:21.680 |
math dedicated for this. But in theory, you could implement it as your standard backprop 00:55:28.520 |
gradient descent, at least for RWQ case. I mainly double-check for Google's case. 00:55:35.780 |
But the important thing is it runs at inference time. So it's like it always backpropagates 00:55:42.760 |
something, which is very different from other models that are trained in a big pre-training 00:55:48.360 |
run, and then nothing changes. You just run inference. So it's kind of meta-learning that 00:55:54.280 |
at inference time, it still does some additional update. 00:55:59.280 |
Yeah, the long-term hope and goal, if we can get this process to be stable, and the memory 00:56:08.320 |
module is, let's just say, a gigabyte in size in memory, this is what will represent short- 00:56:15.840 |
to mid-term memories for a AGI kind of model. The issue for any super long-context training 00:56:26.320 |
we're talking about in AGI scale is that we don't really have the means to really figure 00:56:35.380 |
out how to train this memory module in a structured, guided way. And right now, the hope is that 00:56:43.600 |
if we train it, let's just say, at 4, 8k or 32k, it generalizes to 1 to 8. And if, let's 00:56:51.520 |
say, we train at 1 mu, it generalizes to 10 mu. So if we train it at 10 mu, it generalizes 00:56:57.800 |
to a larger context length and a much longer context length. 00:57:01.000 |
The problem with this approach is, even for us right now, and this is something that maybe 00:57:06.480 |
Titan may have more tests on, is that when we train on 512, it generalizes up to 4k, 00:57:12.760 |
and then it dies out there. Then, if we train up to 4k, it generalizes up to 16k. So the 00:57:21.560 |
generalization doesn't go on to infinitum, unlike humans, arguably. But then again, maybe 00:57:28.700 |
that's why we go senile at the age of 100. Maybe that's the reason. 00:57:38.880 |
So I see we're at time. I don't know, Swix, I'll turn the floor back over to you. Any 00:57:51.040 |
comments or thoughts about next week or announcements? 00:57:56.120 |
Ishan is doing DQ2 on his spreadsheet. Is that Ishan? I don't know. He likes spreadsheets. 00:58:05.760 |
That's all I know. So that will be next week. Okay, well, yeah, I mean. Okay, I think that 00:58:15.120 |
was a yes. Cool. Yeah, we have run out of our context length for this session. Thank 00:58:21.700 |
you, Eric, for a great presentation. Yeah, very topical paper. We are, I've just been 00:58:27.480 |
chatting with Vibhu, and we're basically kind of thinking about, you know, doing the second 00:58:32.880 |
paper club and sometimes somewhat splitting between timeless papers or timeless survey 00:58:41.440 |
papers and then hot individual papers is kind of like the split that we're thinking about. 00:58:47.040 |
And then also maybe doing it at a different time. So it's not like during the day, during 00:58:51.040 |
the workday for most people in the US. So yeah, people are interested in timeless papers 00:58:55.760 |
and then also hot papers. So I think those are the two spiky things that maybe we can 00:58:59.840 |
have a different vibe for that as well. Yeah, let's discuss in Discord, but otherwise have