back to indexGoing beyond RAG: Extended Mind Transformers - Phoebe Klett

Chapters
0:0 Introduction
1:3 Long Context vs RAG
3:32 Extended Mind Attention
6:16 Evaluations
7:27 Results
8:58 Citations
10:27 Reduce hallucinations
11:52 Parameters
00:00:00.000 |
I'm Phoebe, I'm a machine learning engineer at Normal Computing, and I'm really excited to tell 00:00:17.520 |
you guys about some of our recent research, and in particular, extended mind transformers. 00:00:21.600 |
All right, so just to briefly cover what we're going to go over in today's talk, 00:00:26.160 |
we'll introduce the problem, which I think will be quite familiar, given the amazing talk 00:00:29.960 |
which came before mine, and then dive right into the methods, so what is the retrieval mechanism 00:00:35.320 |
that extended mind transformers implement, and then we'll dive into some experiments which 00:00:39.480 |
give us confidence that these methods are actually performant. 00:00:42.280 |
After that, we'll get into two of my favorite and I think most compelling features that extended 00:00:46.280 |
mind transformers enable. This is a new kind of citation, as well as a new kind of generation 00:00:51.240 |
paradigm, which is active learning inspired, and then we'll go over the most important parameters 00:00:56.280 |
to tune when implementing EMTs in your applications, and generally how to use them. 00:01:01.640 |
All right, so we pre-train language models so that they have general knowledge, but as we've 00:01:09.880 |
been discussing all this conference, that's not enough. We need a lot of application-specific 00:01:14.840 |
information and a topical description of the world in order to make these things useful. 00:01:19.240 |
I'm not going to belabor the two most popular methods which try to load this description into 00:01:26.120 |
the language model, those being longcontext and RAG, as I think, yeah, we've heard a lot about those 00:01:31.560 |
great methods already, but I'd like to point out that they solve the problem in different ways, 00:01:36.840 |
and thus suffer from different downsides. So longcontext seeks to extend the context window of the 00:01:43.480 |
transformer model. So we train language models, we train them on sequences of a fixed length, and then we're 00:01:49.000 |
trying to say, well, can we extend that so we can include more in the context, more in the prompt 00:01:54.920 |
during inference time? Fine-tuning is usually how this is done, and that's awfully expensive, 00:01:59.800 |
and more so than that, including all of that context in your prompt can confuse the model with a lot of 00:02:06.200 |
irrelevant information. And kind of beyond that, just conceptually speaking, it seems a little 00:02:11.640 |
like wasteful, right? Like if we're trying to do question answering over a big code base, our query 00:02:17.000 |
is most usually, does not need to reference like all of those different function definitions, but just 00:02:21.640 |
needs some subset of them to answer the query correctly. Okay, so this is what RAG tries to do, 00:02:27.080 |
right? Let's try to subset that information down and just include the most relevant context in our prompt. 00:02:33.560 |
So what are the issues here? Well, these mechanisms are external to the transformer, kind of like 00:02:41.080 |
necessarily limited by being external to the model. So we make this choice of what's relevant once and 00:02:47.080 |
up front before the generation starts, and we're also making this choice about what's relevant using 00:02:52.680 |
kind of the least granular representation of that data and often ones that are disjoint from the way 00:02:58.200 |
that the model will reason about that data. Kind of also just conceptually, neither of these methods 00:03:05.640 |
make a difference or make a distinction between things that should go in memory and things that should be 00:03:10.920 |
included along with your inference query. And this is more than just aesthetics, it's actually going to enable 00:03:15.400 |
us to, oh, it's going to enable us to have these like more granular causal citations and allow the model 00:03:24.760 |
to retrieve more information when we can tell it's uncertain, kind of actively within the generation. 00:03:29.400 |
All right, so how do we do this? Extended mind attention is a very simple edit to the attention 00:03:37.320 |
mechanism of the transformer. I'm not going to get too much into the math because we don't have a ton 00:03:41.240 |
of time today, but would love for anyone to check out the paper and let me know what you think. 00:03:44.840 |
So, but I'll just go over kind of, yeah, from a qualitative perspective, how this works. 00:03:50.920 |
So the model represents data within each decoder layer. Most of the transformers that we're using 00:03:56.680 |
today are decoder-only transformers. And within each of those decoder layers, the model will represent 00:04:02.040 |
that data as a key value pair. So it actually already has this retrieval mechanism built into the 00:04:07.480 |
transformer. All we have to do is kind of hack around it. And so we pass all of the memory tokens 00:04:13.960 |
through the model and save off those key value representations. And then during generation time, 00:04:19.640 |
we allow each query token, just like RAG, using cosine similarity, to go retrieve a particular number 00:04:26.440 |
of those memory tokens and attend to them. So in this picture, these kind of red tokens, red highlighted 00:04:33.320 |
tokens are meant to represent those retrieved tokens. Again, this actually ends up being a very simple 00:04:40.440 |
change to the transformer model. What's difficult is figuring out how to assign position 00:04:46.120 |
information to those tokens. So this work is based on research from a couple of years ago, 00:04:51.880 |
but they needed to fine tune their model in order to kind of teach the model how to leverage these 00:04:56.760 |
retrieved tokens. And that's in large part due to the absolute position embeddings that were popular 00:05:01.800 |
during that time. So because transformer models are position agnostic, we have to figure out 00:05:07.480 |
how to kind of tell them, okay, this token is position zero, this one is position one, et cetera, et cetera. 00:05:13.000 |
But due to today's more kind of like their softer position embeddings, this allows us to really 00:05:20.520 |
leverage this method without any further fine tuning. So in particular, these relative position embeddings 00:05:26.200 |
that have become popular, and I'll talk about two different methods that we've tested and implemented 00:05:30.680 |
this on, really enable the model to kind of generalize to these retrieved tokens. The first one 00:05:38.760 |
that we tested on is present in all of the Lama models. These are the rotary position embeddings, 00:05:43.800 |
and this generalizes the principle of using kind of like an angle between two vectors as a distance metric. 00:05:50.120 |
So we kind of take the whole embedding and we rotate kind of two positions at a time. 00:05:54.200 |
The other one that we implemented this method into is the alibi, linear biases. So these actually 00:06:01.960 |
aren't positioning embeddings at all, it just kind of linearly damps down information which is further 00:06:08.920 |
away. And these are the way that all of the mosaics MPT models are trained. 00:06:13.480 |
Okay, so let's talk about some evaluations. We also just open sourced a new counterfactual retrieval 00:06:22.360 |
benchmark, and I'm just going to briefly describe what that benchmark looks like. 00:06:25.480 |
So this is a long context benchmark, so our input context is our query answer pairs. And the context 00:06:33.000 |
to answer those questions range from about 2,000 tokens to all the way up to 16,000 tokens. And again, 00:06:39.800 |
these are like queries, so like the question might be who wrote the song, these shoes were made for walking, 00:06:44.440 |
and then the corresponding Wikipedia snippet. We wanted to control for facts memorized during pre-training, 00:06:51.080 |
though. And actually any fine tuning also. So what we did was we looked up, for instance, 00:06:56.360 |
in this case, the answer is Lee Hazelwood. We did a little bit of research. We figured out, okay, 00:07:00.440 |
well, Terry Allen is a similar songwriter. This is a plausible answer, but it's wrong. We went in and 00:07:06.200 |
we replaced all the instances of Lee Hazelwood with Terry Allen, and now we asked the model to retrieve this 00:07:11.720 |
Terry Allen answer. All right, so how do extended-mine transformers stack up? Here we're comparing it with fine-tuned models, as well as the base Lama model, with interpolated position embeddings. 00:07:25.000 |
So we can see here in the green that the base model does a pretty good job extrapolating even 00:07:30.280 |
like many times more, so this is the model trained up to 2048 tokens during pre-training. 00:07:35.560 |
You can see even up to 8K it's doing okay. 16K it really falls off. The position embeddings can't extrapolate that far. 00:07:43.560 |
The fine-tuned models you can see actually perform worse than the extended-mine model on these shorter inputs, and this is another data point that suggests that fine-tuning on super-long context actually degrades the quality of attention that you get on shorter inputs. 00:08:00.840 |
And, you can see that the fine-tuned models actually perform worse than the extended-mined model on these shorter inputs, and this is another data point that suggests that fine-tuning on super-long context actually degrades the quality of attention that you get on shorter inputs. 00:08:06.840 |
And, this is another data point that suggests that fine-tuning on super-long context actually degrades the quality of attention that you get on shorter inputs. 00:08:12.840 |
And, extended-mined transformers continue to be competitive with those fine-tuned models, all the way up to 16K. Again, our models are not fine-tuned at all. 00:08:21.160 |
And, in this particular experiment, so what the extended-mined model sees in context is the query only. So it only sees the, like, who wrote the song, these users made for walking, and relies heavily on that internal retrieval mechanism to go look up that new information. 00:08:38.920 |
In this second experiment, we seed it with a little bit more information in context using RAG, but again, mostly relying on that internal mechanism still. 00:08:47.240 |
And, you can see we're outperforming GPT-4 here now when we combine it with that more information in context as well. 00:08:53.240 |
Okay. Now we're going to talk about citations. So, I think this will be a topic that lots of you here can empathize with. As AI engineers, I think this is one of the most important things to provide in an application such that people can learn to trust the model outputs. In fact, you might actually use RAG just to get citations. 00:09:15.240 |
So, with RAG, though, the citations that you get are a little bit kind of like post-hoc rationalization. So maybe if, like, the date appears in the output and we knew it was also in the input to the language model, we feel pretty confident that that date is not hallucinated. 00:09:31.240 |
But again, this is not really, like, causally related to what information the model used during the generation. 00:09:37.240 |
Now, with extended mind transformers, we can look up exactly which tokens were retrieved from those memories and used during generation. 00:09:45.240 |
So, in this example, on the top left here, we have the memories. This is a snippet from Wikipedia about one of my favorite mathematicians, Alexander Grothendijk. And the query is, when did he get his French citizenship? 00:09:57.240 |
And then, in the bottom, you can see the completion with the correct date. I think he got it in 1971. So, the blue highlighted tokens here, importantly, the 1971, as well as some of the Alexander Grothendijk tokens, those are the ones that the model retrieved and attended to when generating 00:10:15.200 |
that 1971 correct token. And so, being able to report that gives a lot of confidence and also just insight into how the model is using those retrieved tokens. 00:10:28.200 |
Okay. We can also use extended mind transformers to reduce hallucinations. So, how do we do this? So, right now, we have access to, in the, like, simplest case, just kind of token-level entropy over that output distribution. 00:10:41.200 |
And if you want to get fancier, we're also doing some Bayesian fine-tuning of language models at normal. But you can use any uncertainty metric to determine kind of how certain the model is about a generated token. And if we kind of can detect that the model is uncertain about that token, we can regenerate that step using more information from these memories. 00:11:01.200 |
Okay. So, in the top right here, we just set, like, a baseline default number of memories that each query token is allowed to retrieve and attend to. And you can see it wasn't quite enough information to get this query right. 00:11:14.200 |
So, if you remember from the previous slide, the correct answer here is 1971. And you can see we've got 1993 here. So, it wasn't enough. We didn't attend to that memory quite enough to get this question right. 00:11:25.200 |
And in the bottom example, we allow it to regenerate some subset of those tokens using more information from the cache when we can tell the model was uncertain. 00:11:35.200 |
And again, we got this right. So, it's kind of like, kind of a nice intuition for when the model's uncertain and then, okay, if it's really uncertain, let's go use more information and also can be more efficient, kind of depending on how the math works out. 00:11:50.200 |
All right. So, now I'm going to tell you guys about the most important parameters to set when using extended mind transformers. 00:11:59.200 |
So, you may have heard of something called stride length before. And this is a parameter that comes up a lot, even just kind of in regular perplexity computations. 00:12:08.200 |
So, when we compute the memories that we're going to attend to, we pass them through the model and then, again, save off these key value representations that the model saves internally. 00:12:18.200 |
But, again, the models that we're using are trained on this fixed context length. 00:12:23.200 |
So, we need to kind of pass over them with some stride such that each of those tokens has an appropriate amount of context to generate the representation. 00:12:33.200 |
So, if the stride is smaller, you're going to get more high-quality representations, but also will require more computations. 00:12:42.200 |
So, you can kind of tune this, and there are some graphs in the paper as well that kind of represent this trade-off. 00:12:48.200 |
But this is an important parameter to set when generating the memories themselves. 00:12:53.200 |
Top K is probably the most important parameter to think about. 00:12:56.200 |
So, this is the number of key value pairs or memories that each query token is allowed to retrieve and attend to. 00:13:03.200 |
When your memory is quite long, kind of the more the better. 00:13:07.200 |
But, again, yeah, this should be dynamically set based on how long your memory is. 00:13:15.200 |
So, lastly, we want to retrieve as much information as we can from the memory without confusing the model. 00:13:23.200 |
It's making an analogy back to kind of putting everything into context. 00:13:27.200 |
We don't want to just throw everything in there because that will be confusing to the model. 00:13:31.200 |
So, we have two different regularization techniques that we implement that we have found to be especially effective. 00:13:40.200 |
So, again, we retrieve these tokens based on similarity with our query token and the key that we are retrieving from. 00:13:48.200 |
And so, we might say, like, well, if we don't hit some similarity threshold, like, we'll retrieve a lot of them. 00:13:54.200 |
But then if they, you know, if they're not at least, like, 0.25 similar, then we'll just throw them out. 00:13:58.200 |
So, we can retrieve and then just mask the ones that end up being less important. 00:14:02.200 |
Another important regularization technique, in particular for models that are trained using rope, 00:14:09.200 |
is to eliminate tokens from the memory that correspond to unknown tokens. 00:14:14.200 |
So, especially if your data is super messy, a lot of the Wikipedia-based benchmarks are, like, really way more messy than I even knew before I started working on this stuff. 00:14:22.200 |
They have a lot of, like, just unknown tokens. 00:14:25.200 |
And so, they're kind of, like, poorly represented by the models often because they're unknown. 00:14:29.200 |
They end up having a lot of matches with your query tokens, but then they're not actually containing a lot of useful information. 00:14:35.200 |
So, we just eliminate those from the memory before we allow it to start retrieving. 00:14:40.200 |
So, we have a whole collection of these models on Hugging Face. 00:14:45.200 |
All of the code is on GitHub as well as that data set. 00:14:49.200 |
And I encourage you all to read the paper if you're curious about more of the technical details. 00:14:53.200 |
As I hope you can see here, it's actually pretty easy to use these things. 00:14:56.200 |
So, it's as simple as passing those memories in as inputs, as tokens, into the model during instantiation. 00:15:04.200 |
You can dynamically change them after that as well, but it's the easiest way to do it. 00:15:08.200 |
And then making sure your config is set up correctly. 00:15:12.200 |
So, just to conclude here, I hope you all will take away that these new kinds of models achieve impressive performance on retrieval tasks. 00:15:22.200 |
They enable these great new kinds of citations. 00:15:25.200 |
They also enable this new kind of hallucination reduction technique, which is inspired by active learning. 00:15:31.200 |
They do not require fine tuning, unlike kind of long context methods. 00:15:36.200 |
And they can be easily run using our open source models and code.