back to index

Google Titans: Learning to Memorize at Test Time


Whisper Transcript | Transcript Only Page

00:00:00.000 | Welcome, everyone. Let me share my window here.
00:00:19.080 | So pretty, by the way.
00:00:21.080 | Yeah. I was playing around with this new presentation tool, gamma.ai. And I can't go back to PowerPoint
00:00:29.920 | ever again.
00:00:31.320 | That's huge praise. Yeah. A friend of mine actually quit to work for this company. And
00:00:37.760 | I was like, another AI Slides company? You know, these never work. Shows what I know.
00:00:44.440 | It's really impressive. So I'm not paid by the company in any way, but I do recommend
00:00:49.600 | people try.
00:00:52.800 | So this week, we're going to talk about a new paper out of Google Research, Titans.
00:01:01.840 | The main thing that this new architecture presents is giving the model some memory,
00:01:11.240 | especially at inference time, so that it can leverage that to make better inferences and,
00:01:22.120 | you know, improve benchmark scores.
00:01:25.600 | And I'm generally not going to be looking at the chat while I present. There's a lot
00:01:31.920 | to go through, but I'm going to leave plenty of time at the end for discussion where we
00:01:36.560 | can go back over questions.
00:01:41.520 | So yeah. So do feel free to drop posts in the chat, and we will get to them.
00:01:55.200 | So let's see. Okay. I'm trying to figure out how to -- wait a minute. There we go. Okay.
00:02:15.420 | So what are some of the problems that this Titans paper is trying to address? The first
00:02:22.760 | of which is the quadratic complexity nature of attention.
00:02:30.600 | And this has been one of the things that typically was, up until, say, the last year or so, really
00:02:39.680 | restricting the context window size of models.
00:02:44.640 | So if I remember correctly, when GPT-4 came out, it first had a size of 8K or 32K, something
00:02:54.080 | like that context window, but nothing compared to the 1 million context window of Gemini.
00:03:02.400 | And the reason for that was that for -- in typical attention, every token is compared
00:03:10.880 | to every other token, so that if you, say, have a context window size of 1,000, then
00:03:18.120 | the complexity is 1,000 squared. And if you have a, you know, window size of 10,000, then
00:03:28.800 | it's 10,000 squared.
00:03:30.040 | And so in order to increase the size of your context window, it becomes, like, very computationally
00:03:39.200 | expensive very quickly. And there are ways around this that are mentioned in the paper,
00:03:44.200 | but that's one of the problems with, let's say, classical transformers, which leads right
00:03:51.200 | into the limited context window, again, based on the quadratic aspect.
00:04:00.240 | And then the other thing is that the long context window creates the bottleneck for
00:04:08.640 | the very long sequences of data. So transformers are great, and, you know, there's been tons
00:04:18.800 | of advances in them, since the attention is all you need paper, but, you know, there are
00:04:27.400 | limitations in the architecture.
00:04:29.760 | And so what does Titans do to try to address that at a very high level? The first thing
00:04:38.240 | it does is to emulate human memory, meaning that human memory has different -- you can
00:04:48.200 | think of them as modules or types of memory. There's short-term memory and long-term memory.
00:04:56.560 | However, in the typical transformer model, as we'll see, those have, like, short-term
00:05:05.240 | memory because they have access to the tokens that they're currently processing, but anything
00:05:11.320 | that happened before that, they don't have memory of that.
00:05:17.240 | And as you do AI engineering, you probably are well aware that every time you make an
00:05:25.480 | API call to OpenAI or Cloud, it's, like, starting all over again. And then the second part is
00:05:36.520 | the enhanced processing. So the memory allows very long, like, extremely long sequences
00:05:43.040 | to be handled, so that -- and we'll see that as the tokens are processed, they -- the model
00:05:52.400 | keeps a representation of just the most salient or important, let's say, facts. It's a representation
00:06:01.440 | of facts, but let's say facts from the tokens it's seen so far so that it can use those
00:06:07.280 | to improve predictions on future tokens.
00:06:14.240 | So there was some work by other groups, and this group as well. So this isn't, like, a
00:06:24.760 | sudden thing. This paper, test-time training with supervision, goes back to 2020. So there's
00:06:36.120 | a few things that this paper represented. The first is using that unlabeled test dataset
00:06:45.760 | into a self-supervide learning problem so that it can update its own parameters before
00:06:55.280 | making a prediction. And they noticed performance improvements with this. So that's one paper
00:07:03.960 | that previously came out. And I didn't -- I'm going to just mention two papers. I didn't
00:07:12.400 | read either of them in depth. I just kind of read the abstract and got the gist of them.
00:07:19.220 | But if you're very interested in this, both of these are good sources. And then the other
00:07:25.320 | one is this RNNs or learning to learn at test-time about RNNs with expressive hidden states.
00:07:34.720 | So this is -- and we'll see this in the Titans paper as well, that it does bring an RNN or
00:07:45.960 | recurring neural network technique back into transformers where there is a state inside
00:07:56.960 | the -- well, it's the hidden state. And as the tokens are processed, that hidden state
00:08:04.360 | is updated and then allowed to affect future outputs. So this is the paper that brought
00:08:17.560 | that into the research community. You can also see the -- one of the main points from
00:08:25.440 | this paper is the linear complexity. So using this to get around the quadratic complexity
00:08:32.080 | of the typical transformer. So as I mentioned, Titans is inspired by human memory. There's
00:08:48.400 | core or short-term memory. So in the Titans models, this is the equivalent of just a normal
00:08:59.320 | attention mechanism where you have the query, keys, and values sets, you know, all interacting
00:09:11.100 | in the transformer component. So this is the short-term memory. So they don't really introduce
00:09:21.240 | anything about short-term memory. They just reuse what is already out there. Long-term
00:09:26.920 | memory is where they learn to memorize historical context, encode that context into abstractions
00:09:37.600 | so that it can improve future tokens. This is the main thing that they talk about in
00:09:43.760 | the paper is this long-term memory module and how exactly it operates. And then the
00:09:52.040 | persistent memory is not a, like, inference-dependent thing. So this is -- you can think of this
00:10:03.560 | as, like, a set of knowledge or a set of rules that always exists in the model. It's hard-coded
00:10:12.840 | in there. And it's always brought into the -- into the sequence. We'll see how that happens.
00:10:20.780 | They don't talk a whole lot about persistent memory. So it's -- at least for me, it was
00:10:26.880 | a little unclear where -- how exactly this is created. But it's also not the main point
00:10:34.520 | of it. It's just kind of a set of parameters that are always fed into the sequence that's
00:10:43.080 | going to be predicted against. Okay. So let's take some time digging into
00:10:53.940 | this long-term memory module. So this is the -- kind of the main breakthrough that this
00:11:05.240 | paper is presenting. There's a few different things you can see here as far as what this
00:11:12.960 | module does. The first is recurrent processing. So like I mentioned, it maintains a state
00:11:22.660 | as it's doing inference. And then that part of the output is then fed back into that state
00:11:33.220 | so that it can continue to capture new information as it goes along. This helps it to kind of
00:11:41.700 | learn the task that it's doing. And also to do those needle in a haystack things where
00:11:49.740 | maybe some information early on in the sequence is relevant to something much later. So it
00:11:56.720 | can keep that in memory. It's also a memory component. So this is the
00:12:03.300 | thing that is responsible for generating representations for, you know, the information that is coming
00:12:13.800 | through it at inference time. And then let's see. The last bullet point about sequential
00:12:23.220 | data handling. Yeah. So as I've been mentioning, like, we're assuming this is all happening
00:12:33.300 | in a sequence of tokens in the case of a large language model where it's processing one token
00:12:41.820 | after another. It's not all at once. That's pretty -- pretty taken for granted for any
00:12:51.680 | large language model. So let's keep going. So we'll dig into a few different aspects
00:13:02.240 | now of the long-term memory module. So here you can see some points. Let me pull out a
00:13:15.960 | couple of them. One is this weight adjustments, which is, I think, one of the most interesting
00:13:21.440 | things about the Titan's architecture. So for, like, almost all LLMs, or at least the
00:13:29.500 | ones that I'm aware of, like a GPT-4, a Sonnet, Claude Sonnet, whatever, they do a lot of
00:13:39.040 | pre-training. They do, you know, instruction tuning, all of that. But once they're done
00:13:45.680 | with that, then the weights are just the weights. And even if you think of, like, the deep
00:13:51.320 | seek release last week, like, it was just a bunch of weights. So these don't -- these
00:14:00.360 | don't adjust in memory, no matter what you put through the model. Whereas this Titan's
00:14:07.160 | model, at least in the context of a single inference, can adjust weights on the long-term
00:14:14.080 | memory module, which I think makes it, like, a very interesting new approach.
00:14:23.600 | And for the continuous adaptation, so that kind of plays off the weight adjustments.
00:14:32.320 | So that's something that, you know, as it's going through a long sequence of perhaps even
00:14:38.440 | millions of tokens, it can continue to actually learn about, like, the data that it's processing.
00:14:51.240 | And that's another thing that they really emphasize in the paper, is creating a model
00:14:57.160 | of -- when I say model, I don't mean, like, a machine learning model. I mean, like, an
00:15:01.960 | abstraction for learning. And what does that mean for a model to be able to learn?
00:15:12.400 | So we've talked about how the model updates its weights based on the data that comes through.
00:15:23.460 | And so how does the model know, like, what is interesting, like, what is worth keeping,
00:15:31.700 | and what isn't? Because it doesn't have enough parameters for long sequences to capture everything
00:15:39.040 | that comes into it. So it's going to have to compress whatever data is coming through.
00:15:46.540 | So how does it decide how, like, what to remember and what not to? So this surprise mechanism
00:15:54.100 | is the main way that it does. So information that's considered surprising is for humans
00:16:06.520 | more memorable. And so they took the leap that for a model, it's also going to be more
00:16:13.700 | memorable. Or more important to remember. I'm just looking through the slides. So, okay.
00:16:25.300 | So the point is, well, how do we know if information is important or not? The main way they decide
00:16:34.260 | that is the gradient of the neural network with respect to the input data. So you can
00:16:40.380 | see that -- hopefully you can see my mouse. But you can see that down here. The gradient
00:16:45.840 | of this incoming input data with respect to the memory from the previous time step gives
00:16:56.140 | us the amount of surprise. And it -- you can see it gives preferential treatment to this
00:17:04.620 | surprising information. The more surprising it is, the more likely it is to be stored
00:17:10.140 | in memory. So you can imagine for, like, a long sequence of maybe documents, maybe there's,
00:17:18.700 | I don't know, 100 documents that are all legal briefings. And then it comes to something
00:17:27.780 | that is maybe a bill of lading or, you know, a product description or something like that.
00:17:36.740 | So then it's going to be like, oh, this is completely different. I should remember something
00:17:40.700 | about this. So it's going to prioritize keeping what it considers surprising or information
00:17:50.640 | that it hasn't seen before. And then it stores this in key value pairs in that long-term
00:18:01.180 | memory module. So in addition to surprise, there's also forgetting.
00:18:14.820 | So I should have mentioned in the last slide that this data parameter here controls the
00:18:26.340 | amount of, like, kind of the impact of surprise. So it's a tunable parameter that can be increased
00:18:37.220 | or decreased to either remember more surprising information or forget it. The momentum and
00:18:46.860 | dynamic forgetting means that it can forget things so that it does discard irrelevant
00:18:55.220 | information as it goes along. So there's this, you can see this one over alpha where it is
00:19:05.060 | slowly degrading the older memory. So as things get older and older, they fade out. I heard
00:19:16.860 | someone come off mute. Is there a question? >> Yeah. That MT parameter, is that the output
00:19:25.020 | of the memory module or the model itself? >> That is essentially the state of the memory
00:19:33.960 | module. >> Okay. And so the surprise has the same dimensions,
00:19:41.940 | the same form as the memory itself? >> Why don't we look in the paper after the
00:19:52.260 | presentation and we can, or if you want to look at it. I can't answer that right off
00:20:00.220 | the bat. >> I would assume so because it's just a simple
00:20:04.060 | addition to it. And so kind of curious, I guess the origin of this is that, you know,
00:20:12.080 | they show different structures for how the memory was included and it's, like, added
00:20:17.440 | in different parts of the architecture stack, if you will. I'm kind of curious if that would
00:20:25.920 | stay the same across their form of the architecture or if it changes depending on where the memory
00:20:35.200 | is included at. >> Yeah. That's an interesting question. I'm
00:20:42.640 | going to go through the different architectures they introduce. So hopefully, I don't know
00:20:50.860 | if that will answer the question, but at least it will give us enough context for discussion.
00:20:55.920 | I'm going to leave plenty of time for discussion, so this would be a good thing to talk about.
00:21:08.560 | And then it also has this momentum that, where it combines past surprise with a decay factor.
00:21:17.300 | So that's this eta t and I guess the previous surprise. So somehow this provides a momentum
00:21:26.500 | mechanism to the model. Okay. So those are kind of the components that
00:21:39.920 | go into a Titan's model. Now, let's look at the different ways they combine them because
00:21:47.080 | there was this open question of, like, okay, we have this module that we think would be
00:21:54.600 | useful for the model. Like, let's see what's the best way to then incorporate that. And
00:22:06.040 | they come up with three different options. Memory as context, memory as layer, and memory
00:22:14.380 | as gated branch. And each one of these has their strengths and weaknesses. So let's first
00:22:23.600 | look at the memory as context. And I'll just walk through this diagram a little bit. So
00:22:32.900 | we can start looking over here at the -- this is the input sequence. So this is essentially
00:22:39.060 | your prompt or what you are sending into the model. This is then used as a key or a query
00:22:48.200 | -- I guess a query into the memory module that then returns, like, the most relevant
00:23:01.360 | information from the memory as a sequence of tokens. And then at the same time step,
00:23:13.120 | you can see this is the persistent memory. So this is kind of the hard-coded information
00:23:17.860 | that just always has the same token that adds to the sequence. So you can see that before
00:23:25.960 | we go to attention, we're adding these extra sequences. So in my previous example of, like,
00:23:36.800 | legal documents and maybe a product description, the retrieval could say, like, okay, this
00:23:46.920 | sequence looks like -- or this next part of the sequence looks like a legal document.
00:23:52.400 | So let's grab any information we have about the legal documents, put it here. And then
00:23:59.160 | this persistent memory might just be about how to generally handle incoming requests,
00:24:05.960 | something like that. It goes through attention. So this is the quadratic thing where every
00:24:12.160 | token pays attention to every other token. It then, based on the output, sends an update
00:24:20.000 | to the memory, so updates the actual weights here with whatever relevant information came
00:24:27.640 | out. Then I'm not sure what this is. Some kind of -- maybe an XOR combines these and
00:24:38.520 | then gives you the final output. And so this is what the attention masks look
00:24:51.160 | like for memory as context. You can see that each one of these is treated as, like, a small
00:25:01.440 | context window. So it's, like, kind of little batches of tokens being paid attention to
00:25:14.720 | together. Then we add the long-term memory onto that. And then finally the persistent
00:25:22.840 | memory. And so this kind of gives you an idea of, like, inside the attention mechanism,
00:25:31.040 | what exactly it's seeing. And next, memory as gate. So this is a different
00:25:44.560 | approach. It does not, like, query the memory before sending it into the attention mechanism.
00:25:57.200 | So it sends a query to the memory at the same time as going through the attention mechanism
00:26:06.280 | and then combines them afterwards. And so the memory -- like, whatever is stored in
00:26:13.880 | the memory does not go through the attention mechanism at all.
00:26:24.080 | Yeah. So this is -- this uses the memory as a gating mechanism. I guess right here. Which
00:26:38.080 | I guess would allow -- I'm assuming allow certain tokens through and some tokens not.
00:26:45.960 | We can discuss this in the discussion about exactly how this works.
00:26:55.080 | And then here's the attention masks for the memory as gate. You can see it's quite different.
00:27:04.320 | Here this is a sliding window that, you know, each token pays attention to a certain number
00:27:12.980 | of tokens on each side of it. You can see the long-term memory, how next to the input
00:27:21.600 | sequence the memory is strong. And then it gradually fades out as the tokens are, you
00:27:33.720 | know, further in the future, if you will. And then finally, the persistent memory just,
00:27:40.800 | like, adds to the very front of the sequence. This is what the attention masks look like
00:27:48.080 | in the memory as gate. And then finally, memory as layer. So this
00:27:57.720 | is in some ways similar to what some transformer models do, where they have the attention mechanism
00:28:07.360 | and then there's a feedforward mechanism as well. So this is, like, the closest thing
00:28:14.120 | to that, where it just goes right into the memory, whatever comes out of the memory goes
00:28:22.880 | into the attention. Yeah, and it says the, can I take advantage
00:28:32.100 | of the complementary data processing of intention in the neural memory module so there's no
00:28:39.760 | combining after either before or after. It's just, like, one more step in the architecture.
00:28:54.640 | And then finally, the, they also, in their experimentation, look at memory without attention.
00:29:02.480 | So this is essentially just a long-term memory module by itself. There's no attention mechanism.
00:29:11.000 | And they just look at, like, how does this perform? Just purely a long-term memory module
00:29:17.040 | without attention or anything. Okay. So now we're getting into the last part
00:29:26.920 | of the paper, the experimental setup and results. So they test all four of these variants and
00:29:36.680 | they test them at different sizes. And I thought one thing that was interesting is that these
00:29:42.400 | are pretty small sizes, at least compared to, you know, your current, your modern LLM.
00:29:50.840 | So, like, a LLAMA7B is considered a smaller model. And if you look at, like, a, you know,
00:30:04.480 | like a GPT4O that has probably hundreds of billions of parameters. So this size is pretty
00:30:13.320 | tiny compared to, like, the models we use day to day.
00:30:24.400 | So they gave a big table of their results for language modeling. And they also threw
00:30:32.440 | some common sense reasoning benchmarks into here. So you can see all the benchmark names
00:30:39.240 | across the top. And then the best performing ones are highlighted. The tan highlights are
00:30:48.820 | for hybrids. And then the blue highlights are for pure or just, like, normal. Or yeah.
00:31:00.560 | So you can see that basically Titans wins at everything here versus Mamba, Deltanet.
00:31:11.440 | This is test at the time of -- or no. Anyway, this is one of the previous works of memory
00:31:19.000 | testing. So you can see that it wins at everything. This is for the 340 million parameter model.
00:31:30.200 | And you can also see the number of tokens they train on is pretty small. Modern LLMs
00:31:38.360 | train on, like, low trillions of tokens. So this is not much data at all.
00:31:46.800 | Then -- so you can see, like, language modeling, it does quite well. And then here's their
00:31:58.240 | LLNA Haystack. So for this test, they had some information early on in a sequence. And
00:32:06.200 | then a bunch of, like, filler tokens. And then some sequence that needed those very
00:32:14.760 | early tokens. Like, to understand what was happening. And so you can -- the Titans is
00:32:23.560 | the red stars here. So you can see that they are maintaining their performance quite well.
00:32:31.400 | Even out to -- here, if we go to the fine-tuning setup, 10 to the 7th. So even out to 10 million
00:32:39.000 | tokens. I mean, they did take a performance hit. But still doing much better than every
00:32:47.000 | other model. And so I think this is, for me, one of the most interesting charts that just
00:32:54.160 | shows that as this becomes more productionalized, there's going to be the opportunity to have
00:33:01.600 | longer and longer context windows where maybe you can feed in, like, a bunch of YouTube
00:33:07.960 | videos, like, hundreds of pages of PDFs, and all this stuff, and still have it be -- give
00:33:16.520 | you, like, relevant output. So with that, that is the end of my presentation.
00:33:26.720 | So let me pop open the chat here. See what's going on. And if anyone wants to speak up,
00:33:41.960 | make comments, make any corrections to what I said, like, I'm not an expert in this, and
00:33:49.540 | I'm also not an AI researcher. So if you have insights, would love to hear those. And also,
00:33:57.880 | if anyone has answers to questions, feel free to chime in. Again, like, I'm not the expert
00:34:05.320 | on this paper. Like, I read it and understand it. But I know there's a bunch of very smart
00:34:11.060 | people on this call. So with that, I'll open up the floor.
00:34:18.860 | >> I can also help with questions, but I'm struggling with the chat window.
00:34:33.720 | >> Yeah, I mean, I want to validate Cosmin's frustration with this paper. Yeah, I mean,
00:34:40.680 | look, like, I think they did try to illustrate the memory mechanisms somewhat, but not super
00:34:47.640 | -- it's not super clear. And I always wish that these things came with code. I really
00:34:53.520 | like the name of papers with code, because this one needed code. And, you know, maybe
00:34:59.360 | they released it, but I didn't see it in the inside of the paper.
00:35:03.360 | >> They said at the very bottom of the paper that they're planning to release code soon.
00:35:07.840 | >> How Chinese of them? >> Whatever that means.
00:35:11.960 | >> Yeah, I didn't understand if the diagrams refer to one step or multiple steps. Like,
00:35:18.560 | I think -- and also, the diagrams are 3D, which makes it a bit more confusing. Like,
00:35:25.440 | they say when you update the memory, you do memory of query, and then you get, like, vector
00:35:31.200 | as output. What does that mean? Is that k-nearest neighbor lookup? Is it some attention? Maybe
00:35:38.520 | we can go to the first slide, where they update the memory equation. I don't know, Eugene,
00:35:44.320 | if you got time to -- got any time to look, but I would be interested in just one of those
00:35:50.920 | operations, how does it actually happen? Like, I understand at the high level, we have a
00:35:55.840 | memory module, we update, it's nice to forget, they somehow figured it out, but, like, what
00:36:01.600 | layers or what do they actually do? Like, even lower, you have some -- if you go a bit
00:36:08.940 | lower. Yeah, you see -- >> Retrieving memory.
00:36:15.400 | >> Yeah, what does that -- maybe I didn't read or, like, even this figure one, I didn't
00:36:22.480 | understand at all what's going on. >> Yeah, I also skipped this, because I didn't
00:36:28.520 | understand what was going on. >> I can explain the in-context learning one
00:36:38.360 | with -- and I can -- with the parallels of what happens on RWKB7 and Titan, based on
00:36:46.520 | what I understood from the paper. So, think of it as this way.
00:36:51.840 | >> Which figure? >> This whole segment, the long-term memory
00:36:57.040 | training and the surprise formula. It's actually a lot easier if you explain it using simplified
00:37:05.280 | -- >> Do you mind saying which? So, like, we read
00:37:07.440 | while you explain. No, this is great, like, having an expert explain it, but which figure
00:37:12.120 | is it or which formula? >> So, you scroll up. I'm trying to, like,
00:37:18.440 | figure out the page numbers as well. So, we are talking about, very specifically, the
00:37:24.520 | segment tree 3.1 and the surprise metric there, downwards, that whole memory architecture.
00:37:32.280 | So, one way to view it, right, is that -- let's just say a standard problem. Let's say the
00:37:39.480 | quick brown fox, correct? And this is -- this is a piece of text that exists in so many
00:37:44.480 | training corpus that all LLMs will probably memorize this phrase itself, the quick brown
00:37:48.960 | fox. There is no surprise there. So, because the surprise score is zero, there is no -- there's
00:37:56.440 | no back propagation required, per se, to update the memories.
00:38:01.560 | If you view the memories as, let's just say, a 4,000 -- I don't know what the dimensions
00:38:07.720 | are here, because they never disclose, but let's just say a 4,000 by 4,000, 4096 floating
00:38:13.120 | point value, in our case, it's BF16. If, let's say, we said the quick brown deer, then the
00:38:24.480 | model is like, hey, I wasn't expecting deer, I'm expecting fox. There's a difference there.
00:38:30.000 | That difference there -- I'm oversimplifying the math, because this is not accurate -- can
00:38:34.480 | be converted to a loss score that you can back propagate on. So, it's when you see that
00:38:39.720 | when the models see differences, do you update this memory, this memory that's being shared
00:38:45.200 | between tokens? Now, where it differs for the Google paper,
00:38:50.520 | Python and RWKB, which we are testing as well, is that the --
00:38:54.200 | >> Eugene, so you have a separate key value store that you attend at the same time as
00:39:01.400 | you attend the current token. So, that's your budget, right? And you attend the whole thing,
00:39:08.980 | okay? And does that help or doesn't it help? And when it's surprising, you need to update.
00:39:15.920 | So I wonder how they send the updates to the key value store. Like, what actually happens?
00:39:21.640 | Like, what's the loss of the key value store and the current token? Go ahead.
00:39:26.040 | >> So if you want to view it as simplified code, which is not how they implement it,
00:39:31.320 | because -- is that you can view your model weights during the forward pass, right, as
00:39:37.200 | frozen, just view as frozen. And the -- so, the quick brown, let's say the quick brown,
00:39:44.200 | those tokens, right, generate a state. You take that state, and then let's say instead
00:39:50.040 | of -- let's say we see fox, and then we say deer instead, right? There's a difference
00:39:55.080 | in expectation. The model expected fox, it got deer instead. So because the model weights
00:40:02.760 | are frozen, if you do the forward pass and then you want to correct the model's thinking
00:40:06.880 | and you do the backwards pass, the only way to update it is this state value. So you do
00:40:12.200 | the backwards pass, you update the state values, then you take that state and you go -- you
00:40:17.440 | process the next token. So you continue your sentence completion. And that essentially
00:40:22.640 | is what the surprise mechanic is about. It's about, hey, it didn't give the output we required,
00:40:31.760 | and then we take that difference and then we convert it into a score. Where this differs
00:40:39.360 | from RWKB is that we don't use a surprise mechanism. We are currently using more closer
00:40:47.400 | to the standard gradient descent. So the difference here is that in a surprise mechanism, so if
00:40:54.840 | let's say you expected fox and you get fox, for example, there's no -- the loss is essentially
00:41:01.920 | zero. There's no backprop. But in RWKB's case, right, if let's say that you -- it expected
00:41:10.640 | fox and it actually got the fox, and since the way logics work, there's always a zero
00:41:15.800 | point -- let's just say a zero point something percent difference, there's still a loss score
00:41:19.960 | being calculated there. So even though it was not a surprise, we still do the backpropagation
00:41:24.200 | process. In practice, is this better or worse? I have absolutely no idea. This is something
00:41:28.760 | we need to test and evaluate on. But that's the key difference on how we handle the memories
00:41:34.040 | segment. It's all about, like, every token you forward, you backprop and then you update
00:41:41.200 | the weights. This is -- >> You just -- instead of you propagate the
00:41:48.240 | signal through the frozen weights and just update the keys and values. One question that
00:41:53.200 | I got was how do they manage a fixed size key value store? Basically, how do they decide
00:42:02.080 | what to drop and stuff? And they have this -- I didn't understand their gating in here.
00:42:09.440 | But basically, yeah, over time, you'll see lots and lots and lots of things. So you kind
00:42:14.480 | of need to figure out, like, if you're efficiently using your memory, then you solve the problem,
00:42:21.160 | basically. >> So this goes to the AI black box. We decide
00:42:26.080 | the key value store to a specific size. That's part of the architecture design. This is the
00:42:30.600 | same thing as the model dimensions as to how the model decides what to keep and drop, right?
00:42:37.000 | That is specifically decided by the model. So I think one -- there's another segment
00:42:44.840 | where it highlights the decay, right? Let me find that segment.
00:42:48.080 | >> Could you scroll a bit down? >> Sorry if I'm not looking at the screen.
00:42:52.800 | I'm looking at the paper to just find the decay. I think it was mentioned --
00:42:58.080 | >> Yeah, it's a forgetting mechanism. Sorry. It's equation, yeah, 13.
00:43:07.960 | >> Yeah, so the idea behind the decay mechanism, right, and this part is consistent with other
00:43:14.600 | theories, is that by default, every token you move forward, you are slowly forgetting.
00:43:23.400 | And the forget rate is something that's trained in the model. So let's just say the value
00:43:29.520 | is by default, everything you will forget in 32K tokens. So if you forward 32K, you
00:43:36.880 | should forget it. This makes it sound similar to sliding window attention in that sense,
00:43:42.280 | but the decay mechanism is supposed to work together with the -- with, let's say, the
00:43:48.240 | surprise mechanism or basically the model's -- this is a bit more gray, but basically
00:43:56.080 | as the -- by default, you decay. You let the model compute against the state itself. And
00:44:01.720 | the model may just decide, hey, this is important to memorize, so I'm going to reinforce that
00:44:07.040 | number. So as every step it takes, right, in a way, internally the model just, like,
00:44:13.000 | keeps backpropagating against the state and think, hey, do I need to reinforce this floating
00:44:18.120 | point value? If I stop reinforcing, it will eventually decay, but if I want to reinforce
00:44:23.640 | it, I can just, like, keep increasing the value and then keeping it within bounds. And
00:44:29.080 | that's how it keeps -- kind of, like, by default, it will slowly forget, but if it thinks it's
00:44:34.280 | important, it will keep trying to remember it over larger context time. To be clear,
00:44:38.400 | this is -- this part is really theorycrafting, because even for RWKB, the highest we ever
00:44:44.600 | push a model to is 32K, and we are now experimenting at 64K. This theorycrafting is supposed to,
00:44:50.400 | like, extend to, like, 1 million token per se. So it's something that we definitely need
00:44:58.440 | to test. But the idea behind decay is so that by default, things will expire, and so the
00:45:04.320 | model is able to clear up space to memorize new things, and it's able to also decide for
00:45:09.320 | itself to keep things in memory. Not much different from how they describe it in Python.
00:45:14.800 | Yeah.
00:45:15.800 | >> Thanks a lot. This helped me quite a bit.
00:45:19.640 | >> Eugene, I have a question. So when you're referring to by default it decays, so in terms
00:45:26.760 | of the surprise here, can we assume that the surprise by default is low for most of the
00:45:32.040 | tokens?
00:45:36.200 | >> I think view it as -- let's just say -- let's just view it as a wreck scenario. A wreck
00:45:43.840 | scenario. Let's just say your company is the most generic company on Earth, and then I
00:45:49.080 | just put your company document there. There is no surprise. Then it just moves on. But
00:45:54.960 | let's just say your company has some very proprietary, never-heard-before stuff. That
00:45:59.720 | surprise will then be what is stored. So it's about the difference in information in the
00:46:08.200 | fixed model weights compared to this floating-point state. Does that make sense?
00:46:17.600 | >> Is there another -- I'm trying to read the questions from the chat.
00:46:34.480 | >> Yeah. There's another one from Cosmin earlier about our needle in the haystack's interesting
00:46:42.720 | tasks. I mean, I think they could be in real use cases. There's probably a lot where they're
00:46:50.800 | not. But depending on your use case, I could see how if you just want to throw a bunch
00:46:58.560 | of tokens into the model and not worry about the order or anything, that it would be useful
00:47:06.560 | to have.
00:47:07.560 | >> I think it's interesting also academically to just understand the maximum limit the model
00:47:15.840 | can memorize in worst-case scenario. That's the way I view needle in the haystack. In
00:47:21.320 | practical scenarios, that was one of the challenges about RWKB benchmarking this as well, and
00:47:29.800 | the same thing for Titan, is that if, let's say, we train on Harry Potter book and you
00:47:35.440 | just put the whole Harry Potter book as the right context, there's no surprise there.
00:47:43.640 | And essentially, it can pass the right test with, what, 300k context length. But that's
00:47:50.160 | not a correct test, per se. So needle in the haystack is meant to represent the worst-case
00:47:54.000 | scenario. That's how I feel. And a lot of companies are a lot more generic than they
00:47:59.960 | think they are. Also, I think in NeurIPS, the joint presentation that we did with Dan
00:48:08.120 | Fu, both him and I agree that with other techniques, such as repeating twice, that works well for
00:48:18.020 | both Mamba and RWKB, is that we may want to re-evaluate how we benchmark these things,
00:48:23.520 | per se, when it comes to practical, right situations. Because if, let's say, the inference
00:48:29.040 | cost for Mamba and RWKB is 100,000x cheaper, and it's right performance triples or quadruples
00:48:39.280 | just by repeating the context twice, so then we just repeat the context twice. That was
00:48:46.200 | one of the arguments. And I kind of agree, but it's also a very different test.
00:48:52.360 | I see there's another question, Eugene, that you'd be good at answering is, what is the
00:49:04.360 | gating mechanism in modern RNNs?
00:49:07.120 | What do you mean by modern RNNs?
00:49:11.800 | I don't know. Aditya, do you want to clarify?
00:49:15.840 | Yeah, this sentence is on the paper right in front of us. It says Dow and Gu, 2024,
00:49:22.120 | or Vietor et al., 2023. It's a comparison of the weight decay mechanism. So I just have
00:49:30.640 | no confidence that RNNs normally do. That's relatable.
00:49:35.440 | Oh, okay. Later in the section, we show that the weight decay mechanism is closely related
00:49:52.800 | to the gating mechanism in RNNs Dow and Gu, citing, I presume, the Mamba star? Yeah, okay,
00:50:02.240 | the state space model, yeah. It's, I think, the state space model. This would be similar
00:50:08.080 | to the weight decay that I explained earlier, which is basically, by default, it's for gating.
00:50:19.080 | My understanding of Mamba is that they run these RNNs without gates, so that they can
00:50:25.400 | run something like Fast Fourier Transform or some parallel scan on GPUs and run it fast.
00:50:31.200 | And then they have some multiplicative process on each token. And when you do multiplications,
00:50:37.840 | basically, you can use them as gates to how much of the signal you propagate. So they
00:50:42.680 | add the nonlinearity and gating at token level, while old school LSTM, they had memory gate
00:50:51.160 | for get gate, input gate at each step, and they were running it one by one. So it might
00:50:58.240 | be, that might be one thing, like this token level multiplicative things that people have
00:51:05.960 | in state space models. Yeah, if it's about that, then it's really more about how we restructure.
00:51:13.020 | So both Mamba and RWBKB and Titan, is that with the way we restructure the formulation,
00:51:18.880 | we do not need to wait for one token compared to another, unlike the old LSTM. So yeah,
00:51:27.160 | I think that makes sense. Yeah, actually I think your explanation makes more sense. Basically,
00:51:31.440 | if you contrast it with the old LSTM RNNs, all the newer gates, even though we are designed
00:51:36.200 | differently, is designed in a way that doesn't have this bottleneck, where you need to wait
00:51:41.880 | for one token after another. Whether is it through math hacks, which is what state space
00:51:50.040 | model did, which is really impressive, honestly, that kind of math. But that's what they are
00:51:56.760 | good at. And in our case, it's really more of like how we optimize things in the QDAR
00:52:01.240 | forward. It achieves the same result. It's able to train in parallel, unlike old RNNs.
00:52:10.720 | And I'm quite sure Google has their own optimization when it comes to training as well.
00:52:17.680 | I have another question, Eugene, for you. I don't know how to compare baselines with
00:52:24.840 | this model. Are you impressed with one point perplexity win on wiki or the other wins?
00:52:31.640 | Or it's kind of any new model usually shows that kind of win. So to me, the win seem large.
00:52:40.440 | So I think, is this something super strong or it's OK, not great?
00:52:48.480 | To be honest, I classify this in promising, but we need to test further, because even
00:52:54.120 | in our experience for RWKB, anything below 1.5B may not necessarily hold until 7B situation.
00:53:09.420 | So we had reverted experimental changes where we made on 0.5B models, which is kind of what's
00:53:15.480 | being tested here. Here is 340 to 760 million parents, where the perplexity loss was great.
00:53:21.680 | It dropped much lower. And then when we scale it to 1.5B, it sucked.
00:53:29.940 | So it's promising, but I want to test to find out more, because I think the true test is
00:53:37.880 | testing on 1.5B and then 7B, which, to be honest, I'm quite sure the Google folks have
00:53:45.240 | done it. They are not compute bound. It takes more effort for them to write this paper than
00:53:49.440 | to train that 1.5B and 7B model.
00:53:52.000 | Yeah, I'm a bit skeptical, because there's some gossip that Googlers aren't allowed to
00:53:57.580 | publish super impactful stuff. So it's interesting.
00:54:03.520 | Yeah. I also would like to know what was your results for the larger models?
00:54:10.200 | Yeah, I mean, I guess that's one interesting thing, is that back in 2018, the attention
00:54:17.480 | is all you need days, Google researchers could publish anything they wanted, because there
00:54:22.640 | was really no competitive advantage they were giving away. But these days, you wonder, if
00:54:30.360 | they have a really big breakthrough, are they going to publish that, or are they just going
00:54:34.900 | to keep it to themselves?
00:54:38.600 | I have another question. Sorry, if someone could explain. The memory module is described
00:54:44.140 | as a meta-in-context model. In quasi-layman's terms, what would a meta-in-context model
00:54:49.920 | mean? There's like a sort of a parallel small model running on that, as if it was trained
00:54:55.920 | very quickly?
00:54:58.100 | So that's the part where I explained every time the tokens look different from what you
00:55:03.840 | expected, it does the back propagation to the memory modules. So you can think of it
00:55:08.360 | as a-- I'm oversimplifying the math calling back propagation. You can think of it as training
00:55:14.880 | the memory modules. It's inefficient to do it that way, because we use matrix multiplication
00:55:21.680 | math dedicated for this. But in theory, you could implement it as your standard backprop
00:55:28.520 | gradient descent, at least for RWQ case. I mainly double-check for Google's case.
00:55:35.780 | But the important thing is it runs at inference time. So it's like it always backpropagates
00:55:42.760 | something, which is very different from other models that are trained in a big pre-training
00:55:48.360 | run, and then nothing changes. You just run inference. So it's kind of meta-learning that
00:55:54.280 | at inference time, it still does some additional update.
00:55:59.280 | Yeah, the long-term hope and goal, if we can get this process to be stable, and the memory
00:56:08.320 | module is, let's just say, a gigabyte in size in memory, this is what will represent short-
00:56:15.840 | to mid-term memories for a AGI kind of model. The issue for any super long-context training
00:56:26.320 | we're talking about in AGI scale is that we don't really have the means to really figure
00:56:35.380 | out how to train this memory module in a structured, guided way. And right now, the hope is that
00:56:43.600 | if we train it, let's just say, at 4, 8k or 32k, it generalizes to 1 to 8. And if, let's
00:56:51.520 | say, we train at 1 mu, it generalizes to 10 mu. So if we train it at 10 mu, it generalizes
00:56:57.800 | to a larger context length and a much longer context length.
00:57:01.000 | The problem with this approach is, even for us right now, and this is something that maybe
00:57:06.480 | Titan may have more tests on, is that when we train on 512, it generalizes up to 4k,
00:57:12.760 | and then it dies out there. Then, if we train up to 4k, it generalizes up to 16k. So the
00:57:21.560 | generalization doesn't go on to infinitum, unlike humans, arguably. But then again, maybe
00:57:28.700 | that's why we go senile at the age of 100. Maybe that's the reason.
00:57:36.320 | That's our context length.
00:57:38.880 | So I see we're at time. I don't know, Swix, I'll turn the floor back over to you. Any
00:57:51.040 | comments or thoughts about next week or announcements?
00:57:56.120 | Ishan is doing DQ2 on his spreadsheet. Is that Ishan? I don't know. He likes spreadsheets.
00:58:05.760 | That's all I know. So that will be next week. Okay, well, yeah, I mean. Okay, I think that
00:58:15.120 | was a yes. Cool. Yeah, we have run out of our context length for this session. Thank
00:58:21.700 | you, Eric, for a great presentation. Yeah, very topical paper. We are, I've just been
00:58:27.480 | chatting with Vibhu, and we're basically kind of thinking about, you know, doing the second
00:58:32.880 | paper club and sometimes somewhat splitting between timeless papers or timeless survey
00:58:41.440 | papers and then hot individual papers is kind of like the split that we're thinking about.
00:58:47.040 | And then also maybe doing it at a different time. So it's not like during the day, during
00:58:51.040 | the workday for most people in the US. So yeah, people are interested in timeless papers
00:58:55.760 | and then also hot papers. So I think those are the two spiky things that maybe we can
00:58:59.840 | have a different vibe for that as well. Yeah, let's discuss in Discord, but otherwise have
00:59:05.000 | a wonderful day. Bye.