back to indexStanford CS25: V1 I Transformer Circuits, Induction Heads, In-Context Learning
Chapters
0:0
0:26 People mean lots of different things by "interpretability". Mechanistic interpretability aims to map neural network parameters to human understandable algorithms.
14:13 What is going on???
44:14 The Induction Pattern
00:00:00.000 |
Thank you all for having me. It's exciting to be here. One of my favorite things is talking 00:00:10.120 |
about what is going on inside neural networks, or at least what we're trying to figure out 00:00:15.120 |
is going on inside neural networks. So it's always fun to chat about that. Oh, gosh, I 00:00:21.040 |
have to figure out how to do things. Okay. What? I won't. Okay, there we go. Now we are 00:00:28.200 |
advancing slides, that seems promising. So I think interoperability means lots of different 00:00:33.320 |
things to different people. It's a very broad term and people mean all sorts of different 00:00:38.540 |
things by it. And so I wanted to talk just briefly about the kind of interoperability 00:00:43.480 |
that I spend my time thinking about, which is what I'd call mechanistic interoperability. 00:00:48.640 |
So most of my work actually has not been on language models or on RNNs or transformers, 00:00:54.920 |
but on understanding vision confinates and trying to understand how do the parameters 00:01:00.720 |
in those models actually map to algorithms. So you can like think of the parameters of 00:01:06.360 |
a neural network as being like a compiled computer program. And the neurons are kind 00:01:11.100 |
of like variables or registers. And somehow there are these complex computer programs 00:01:16.720 |
that are embedded in those weights. And we'd like to turn them back in to computer programs 00:01:21.120 |
that humans can understand. It's a kind of reverse engineering problem. And so this is 00:01:27.440 |
kind of a fun example that we found where there was a car neuron and you could actually 00:01:31.000 |
see that we have the car neuron and it's constructed from like a wheel neuron. And it looks for, 00:01:38.880 |
in the case of the wheel neuron, it's looking for the wheels on the bottom. Those are positive 00:01:42.880 |
weights and it doesn't want to see them on top. So it has negative weights there. And 00:01:46.160 |
there's also a window neuron. It's looking for the windows on the top and not on the 00:01:49.800 |
bottom. And so what we're actually seeing there, right, is it's an algorithm. It's an 00:01:53.660 |
algorithm that goes and turns, you know, it's just, it's, you know, saying, you know, well, 00:01:59.000 |
a car has wheels on the bottom and windows on the top and chrome in the middle. And that's 00:02:03.720 |
actually like just the strongest neurons for that. And so we're actually seeing a meaningful 00:02:07.840 |
algorithm and that's not an exception. That's sort of the general story that if you're willing 00:02:12.740 |
to go and look at neural network weights and you're willing to invest a lot of energy in 00:02:16.560 |
trying to first engineer them, there's meaningful algorithms written in the weights waiting 00:02:21.040 |
for you to find them. And there's a bunch of reasons. I think that's an interesting 00:02:25.480 |
thing to think about. One is, you know, just no one knows how to go and do the things that 00:02:29.160 |
neural networks can do. Like no one knows how to write a computer program that can accurately 00:02:32.520 |
classify image net, let alone, you know, the language modeling tasks that we're doing. 00:02:36.160 |
No one knows how to like directly write a computer program that can do the things that 00:02:39.400 |
GPT-3 does. And yet somehow breaking descent is able to go and discover a way to do this. 00:02:43.620 |
And I want to know what's going on. I want to know, you know, how, what has it discovered 00:02:48.300 |
that it can do in these systems? There's another reason why I think this is important, which 00:02:53.760 |
is, uh, is safety. So, you know, if we, if we want to go and use these systems in places 00:02:59.060 |
where they have big effect on the world, um, I think a question we need to ask ourselves 00:03:03.420 |
is, you know, what, what happens when these models have, have unanticipated failure modes, 00:03:08.620 |
failure modes we didn't know to go and test for, to look for, to check for, how can we, 00:03:12.820 |
how can we discover those things, especially if they're, if they're really pathological 00:03:16.700 |
So the models in some sense, deliberately doing something that we don't want. Well, 00:03:20.180 |
the only way that I really see that we, we can do that is if we can get to a point where 00:03:23.340 |
we really understand what's going on inside these systems. Um, so that's another reason 00:03:28.020 |
that I'm interested in this. Now, uh, actually doing interpretedly on language models and 00:03:33.740 |
transformers it's new, new to me. I, um, before this year, I spent like eight years working 00:03:38.020 |
on trying to reverse engineer continents and vision models. Um, and so the ideas in this 00:03:42.980 |
talk, um, are, are new things that I've been thinking about with my collaborators. Um, 00:03:48.140 |
and we're still probably a month or two out, maybe, maybe longer from publishing them. 00:03:51.820 |
Um, and this was also the first public talk that I've given on it. So, uh, you know, the 00:03:55.460 |
things that I'm going to talk about, um, they met there, there's, I think, honestly, it's 00:03:58.540 |
still a little bit confused for me, um, and definitely are going to be confused in my 00:04:01.820 |
articulation of them. So if I, if I say things that are confusing, um, you know, please feel 00:04:05.420 |
free to ask me questions. There might be some points for me to go quickly because there's 00:04:08.180 |
a lot of content. Um, but definitely at the end, I will be available for a while to chat 00:04:12.500 |
about this stuff. Um, and, uh, yeah, also I apologize. Um, if, uh, if I'm unfamiliar 00:04:19.020 |
with zoom and make, make mistakes. Um, but, uh, yeah. So, um, with that said, uh, let's 00:04:25.180 |
dive in. Um, so I've wanted to start with a mystery, um, before we go and try to actually 00:04:33.980 |
dig into, you know, what's going on inside these models. Um, I wanted to motivate it 00:04:38.020 |
by a really strange piece of discovery of a behavior that we discovered and, and wanted 00:04:43.020 |
to understand. Um, uh, and by the way, I should say all this work is, um, uh, you know, is, 00:04:49.860 |
is done with my, my colleagues and Anthropic and especially my colleagues, Catherine and 00:04:53.260 |
Nelson. Um, okay. So onto the mystery. Um, I think probably the, the most interesting 00:04:59.140 |
and most exciting thing about, um, about transformers is their ability to do in context learning, 00:05:06.140 |
or sometimes people will call it meta learning. Um, you know, the GP three paper, uh, goes 00:05:10.580 |
and, and describes things as, um, you know, uh, language models are few shot learners. 00:05:15.060 |
Like there's lots of impressive things about GPT three, but they choose to focus on that. 00:05:17.820 |
And, you know, now everyone's talking about prompt engineering. Um, and, um, Andre McCarthy 00:05:22.660 |
was, was joking about how, you know, software 3.0 is designing the prompt. Um, and so the 00:05:27.100 |
ability of language models of these, these large transformers to respond to their context 00:05:31.900 |
and learn from their context and change their behavior and response to their context. Um, 00:05:35.820 |
you know, it really seems like probably the most surprising and striking and remarkable 00:05:39.260 |
thing about that. Um, and, uh, some of my, my colleagues previously published a paper 00:05:46.420 |
that has a trick in it that I, I really love, which is, so we're, we're all used to looking 00:05:50.420 |
at learning curves. You train your model and you, you know, as your model trains, the loss 00:05:53.740 |
goes down. Sometimes it's a little bit discontinuous, but it goes down. Another thing that you can 00:06:01.300 |
do is you can go and take a fully trained model and you can go and ask, you know, as 00:06:04.940 |
we go through the context, you know, as we go and we predict the first token and then 00:06:08.260 |
the second token and the third token, we get better at predicting each token because we 00:06:12.220 |
have more information to go and predict it on. 00:06:14.380 |
So, you know, the first, the first con token, the loss should be the entropy of the unigrams 00:06:19.180 |
and then the next token should be the entry of the bigrams and it falls, but it keeps 00:06:23.420 |
falling and it keeps getting better. And in some sense, that's our, that's the model's 00:06:28.940 |
ability to go and predict, to, to go and do in context learning the ability to go and 00:06:34.620 |
predict, um, you know, to be better at predicting later tokens than you are at predicting early 00:06:38.500 |
tokens. That is, that is in some sense, a mathematical definition of what it means to 00:06:42.020 |
be good at this magical in context, learning or meta learning that, that these models can 00:06:46.020 |
do. And so that's kind of cool because that gives us a way to go and look at whether models 00:06:52.740 |
Chris, uh, if I could just ask a question, like a clarification question, when you say 00:06:58.380 |
learning, there are no actual parameter updates that is the remarkable thing about in context 00:07:04.060 |
learning, right? So yeah, indeed, we traditionally think about neural networks as learning over 00:07:08.180 |
the course of training by going and modifying their parameters, but somehow models appear 00:07:12.340 |
to also be able to learn in some sense. Um, if you give them a couple of examples in their 00:07:16.000 |
context, they can then go and do that later in their context, even though no parameters 00:07:19.900 |
changed. And so it's, it's some kind of quite different, different notion of learning as 00:07:25.540 |
Uh, okay. I think that's making more sense. So, I mean, could you also just describe in 00:07:31.340 |
context learning in this case as conditioning, as in like conditioning on the first five 00:07:35.540 |
tokens of a 10 token sentence or the next five tokens? 00:07:39.100 |
Yeah. I think the reason that people sometimes think about this as in context learning or 00:07:42.420 |
meta learning is that you can do things where you like actually take a training set and 00:07:47.060 |
you embed the training set in your context. Like if you just two or three examples, and 00:07:50.900 |
then suddenly your model can go and do, do this task. And so you can do few shot learning 00:07:55.340 |
by embedding things in the context. Um, yeah, the formal setup is that you're, you're just 00:08:00.320 |
conditioning on, on, on this context. And it's just that somehow this, this ability, 00:08:05.460 |
like this thing, like there's, there's some sense, you know, for a long time, people were, 00:08:08.980 |
were, were, I mean, I, I guess really the history of this is, uh, we started to get 00:08:14.280 |
good at, at neural networks learning, right. Um, and we could, we could go and train language, 00:08:18.180 |
uh, train vision models and language models that could do all these remarkable things. 00:08:20.780 |
But then people started to be like, well, you know, these systems are, they take so 00:08:24.300 |
many more examples than humans do to go and learn. How can we go and fix this? And we 00:08:28.780 |
had all these ideas of meta learning develop where we wanted to go and, and train models 00:08:32.860 |
explicitly to be able to learn from a few examples and people develop all these complicated 00:08:37.060 |
schemes. And then the like, truly like absurd thing about, about transformer language models 00:08:41.180 |
is without any effort at all, we get this for free that you can go and just give them 00:08:45.820 |
a couple of examples in their context and they can learn in their context to go and 00:08:49.060 |
do new things. Um, I think that was like, like that was in some sense, the like most 00:08:53.380 |
striking thing about the GPT-3 paper. Um, and so, uh, this, this, yeah, this, this ability 00:08:59.460 |
to go and have the just conditioning on a context going, give you, you know, new abilities 00:09:04.460 |
for free and, and the ability to generalize to new things is, is in some sense the, the 00:09:08.300 |
most, yeah. And to me, the most striking and shocking thing about, about transformer language 00:09:13.220 |
That makes sense. I mean, I guess from my perspective, I'm trying to square like the 00:09:21.340 |
notion of learning in this case with, you know, if you were, I were given a prompt of 00:09:26.340 |
like one plus one equals two, two plus three equals five as the sort of few shot setup. 00:09:32.500 |
And then somebody else put, you know, like five plus three equals, and we had to fill 00:09:36.740 |
it out. In that case, I wouldn't say that we've learned arithmetic because we already 00:09:40.540 |
sort of knew it, but rather we're just sort of conditioning on the prompt to know what 00:09:45.260 |
it is that we should then generate. Right. Uh, but, but it seems to me like that's yeah. 00:09:49.740 |
I think that's on a spectrum though, because you can, you can also go and give like completely 00:09:54.980 |
nonsensical problems where the model would never have seen, um, see it like mimic this 00:10:00.020 |
function and give a couple of examples of the function and the models never seen it 00:10:02.580 |
before. And it can go and do that later in the context. Um, and I think, I think what 00:10:06.620 |
you did learn, um, in a lot of these cases, so you might not have, you might have, um, 00:10:11.500 |
you might not have learned arithmetic, like you might've had some innate faculty for arithmetic 00:10:14.540 |
that you're using, but you might've learned, Oh, okay. Right now we're, we're doing arithmetic 00:10:18.020 |
problems. Um, in any case, this is, I agree that there's like an element of semantics 00:10:22.900 |
here. Um, yeah, no, this is helpful though, just to clarify exactly sort of what the, 00:10:27.060 |
what you remember in context learning. Thank you for walking through it. Of course. 00:10:33.460 |
So something that's, I think, really striking about all of this, um, is well, okay. So we, 00:10:38.380 |
we've talked about how we can, we can sort of look at the learning curve and we can also 00:10:41.820 |
look at this, this in-context learning curve, but really those are just two slices of a 00:10:45.860 |
two-dimensional, uh, space. So like the, the, in some sense, the more fundamental thing 00:10:50.460 |
is how good are we at producing the nth token at a different given point in training and 00:10:55.500 |
something that you'll notice if you, if you look at this. Um, so when we were, when we 00:10:58.940 |
talk about the loss curve, we're, we're just talking about, if you average over this dimension, 00:11:02.860 |
if you, if you like average like this and, and project onto the, the training step, that's, 00:11:07.620 |
that's your loss curve. Um, and, uh, if you, the thing that we are calling the in-context 00:11:12.100 |
learning curve is just this line. Um, uh, yeah, this, this line, uh, down the, the end 00:11:18.340 |
here. Um, and something that's, that's kind of striking is there's, there's this discontinuity 00:11:25.220 |
in it. Um, like there's this point where, where, you know, the model seems to get radically 00:11:30.460 |
better in a very, very short timestamp span and going in predicting late tokens. So it 00:11:35.820 |
doesn't, it's not that different in early time steps, but in late time steps, suddenly 00:11:39.300 |
you get better. And a way that you can make this more striking is you can, you can take 00:11:46.980 |
the difference in, in your ability to predict the 50th token and your ability to predict 00:11:51.580 |
the 500th token. You can subtract from the, the, the 500th token, the 50th token loss. 00:11:57.140 |
And what you see, um, is that over the course of training, you know, you're, you're, you're 00:12:02.660 |
not very good at this and you got a little bit better. And then suddenly you have this 00:12:05.900 |
cliff and then you never get better. The difference between these at least never gets better. 00:12:10.540 |
So the model gets better at predicting things, but it's ability to go and predict late tokens 00:12:15.220 |
over early tokens never gets better. And so there's in the span of just a few hundred 00:12:20.260 |
steps in training, the model has gotten radically better at its ability to go and, and do this 00:12:25.900 |
kind of in-context learning. And so you might ask, you know, what's going on at that point? 00:12:31.580 |
Um, and this is just one model, but, um, well, so first of all, it's worth noting, this isn't 00:12:35.820 |
a small, a small change. Um, so, um, the, you can, we don't think about this very often, 00:12:41.940 |
but you know, often we just look at law schools and we're like, did the model do better than 00:12:44.540 |
another model or worse than another model. But, um, you can, you can think about this 00:12:47.580 |
as in terms of Nats and that are, are, you know, it's just the information theoretic 00:12:51.540 |
quantity in that. Um, and you can convert that into, to bits. And so like one, one way 00:12:56.540 |
you can interpret this as it's, it's something roughly like, you know, the model 0.4 Nats 00:13:01.220 |
is about 0.5 bits is about, uh, like every other token, the model gets to go and sample 00:13:05.380 |
twice, um, and pick the better one. It's actually, it's even stronger than that, but that's a 00:13:09.380 |
sort of an underestimate of how big a deal going and getting better by 0.4 Nats. So this 00:13:13.980 |
is like a real big difference in the models ability to go and predict late tokens. Um, 00:13:20.980 |
and we can visualize this in different ways. We can, we can also go and ask, you know, 00:13:23.780 |
how much better are we getting at going and predicting later tokens and look at the derivative. 00:13:28.460 |
And then we, we can see very clearly that there's, there's some kind of discontinuity 00:13:31.740 |
in that derivative at this point. And we can take the second derivative then, and we can, 00:13:36.260 |
um, with, well, derivative with respect to training. And now we see that there's like, 00:13:40.780 |
there's very, very clearly this, this, this line here. So something in just the span of 00:13:45.140 |
a few steps, a few hundred steps is, is causing some big change. Um, we have some kind of 00:13:50.060 |
phase change going on. Um, and this is true across model sizes. Um, uh, you can, you can 00:13:56.260 |
actually see it a little bit in the loss curve and there's this little bump here, and that 00:13:59.820 |
corresponds to the point where you have this, you have this change. We, we actually could 00:14:03.540 |
have seen in the loss curve earlier too. It's, it's this bump here, excuse me. So, so we 00:14:09.940 |
have this phase change going on and there's a, I think a really tempting theory to have, 00:14:13.940 |
which is that somehow, whatever, you know, there, there's some, this, this, this change 00:14:17.860 |
in the model's output and its behaviors and it's in a, in a, in a, in a sort of outward 00:14:21.820 |
facing properties corresponds presumably to some kind of change in the algorithms that 00:14:26.460 |
are running inside the model. So if we observe this big phase change, especially in a very 00:14:30.700 |
small window, um, in, in the model's behavior, presumably there's some change in the circuits 00:14:35.620 |
inside the model that is driving. At least that's a, you know, a natural hypothesis. 00:14:40.340 |
So, um, if we want to ask that though, we need to go and be able to understand, you 00:14:44.340 |
know, what are the algorithms that's running inside the model? How can we turn the parameters 00:14:47.620 |
in the model back into this algorithm? So that's going to be our goal. Um, now it's 00:14:51.940 |
going to recover us, require us to cover a lot of ground, um, in a relatively short amount 00:14:55.940 |
of time. So I'm going to go a little bit quickly through the next section and I will highlight 00:15:00.380 |
sort of the, the key takeaways, and then I will be very happy, um, to go and, uh, you 00:15:05.980 |
know, explore any of this in as much depth. I'm free for another hour after this call. 00:15:10.220 |
Um, and just happy to talk in as much depth as people want about the details of this. 00:15:14.940 |
So, um, it turns out the space change doesn't happen in a one-layer attentionally transformer, 00:15:20.620 |
and it does happen in a two-layer attentionally transformer. So if we could understand a one-layer 00:15:24.460 |
attentionally transformer and a two-layer only attention, attentionally transformer, 00:15:28.360 |
that might give us a pretty big clue as to what's going on. Um, so we're attention only. 00:15:35.780 |
We're also going to leave out layer norm and biases to simplify things. So, you know, you, 00:15:39.580 |
one way you could describe a attention only transformer is we're going to embed our tokens 00:15:45.460 |
and then we're going to apply a bunch of attention heads and add them into the residual stream 00:15:49.580 |
and then apply our unembedding and that'll give us our logits. And we could go and write 00:15:54.140 |
that out as equations if we want, multiply it by an embedding matrix, apply attention 00:15:58.160 |
heads, and then compute the logits for the unembedding. Um, and the part here that's 00:16:05.960 |
a little tricky is, is understanding the attention heads. And this might be a somewhat conventional 00:16:10.360 |
way of describing attention heads. And it actually kind of obscures a lot of the structure 00:16:14.700 |
of attention heads. Um, I think that oftentimes it's, we, we make attention heads more, more 00:16:19.400 |
complex than they are. We sort of hide the interesting structure. So what is this saying? 00:16:23.480 |
Well, it's saying, you know, for every token, compute a value vector, a value vector, and 00:16:27.480 |
then go and mix the value vectors according to the attention matrix and then project them 00:16:31.480 |
with the output matrix back into the residual stream. Um, so there's, there's another notation 00:16:37.960 |
which you could think of this as a, as using tensor products or using, um, using, uh, well, 00:16:44.040 |
I guess there's a few, a few left and right multiplying. There's, there's a few ways you 00:16:46.680 |
can interpret this, but, um, I'll, I'll just sort of try to explain what this notation 00:16:51.240 |
means. Um, so this means for every, you know, X, our, our residual stream, we have a vector 00:16:58.240 |
for every single token, and this means go and multiply independently the vector for 00:17:04.280 |
each token by WV. So compute the value vector for every token. This one on the other hand 00:17:10.440 |
means notice that it's now on the, A is on the left-hand side. It means go and go and 00:17:14.600 |
multiply the, uh, attention matrix or go and go into linear combinations of values, value 00:17:21.000 |
vectors. So don't, don't change the value vectors, you know, point-wise, but go and 00:17:24.280 |
mix them together according to the attention pattern, create a weighted sum. And then again, 00:17:29.400 |
independently for every position, go and apply the output matrix. And you can apply the distributive 00:17:34.600 |
property to this, and it just reveals that actually didn't matter that you did the attention 00:17:38.520 |
sort of in the middle. You could have done the attention at the beginning, you could 00:17:40.840 |
have done it at the end. Um, that's, that's independent. Um, and the thing that actually 00:17:44.920 |
matters is there's this WVWO matrix that describes what it's really saying is, you know, WVWO 00:17:51.560 |
describes what information the attention head reads from each position and how it writes 00:17:55.160 |
it to its destination. Whereas A describes which tokens we read from and write to. Um, 00:18:01.320 |
and that's, that's kind of getting more of the fundamental structure and attention. And 00:18:04.360 |
attention head goes and moves information from one position to another and the process 00:18:09.320 |
of, of which position gets moved from and to is independent from what information gets moved. 00:18:13.800 |
And if you rewrite your transformer that way, um, well, first we can go and write, uh, the 00:18:22.760 |
sum of attention heads just as, as in this form. Um, and then we can, uh, go in and write 00:18:29.160 |
that as the, the entire layer by going and adding in an identity. And if we go and plug 00:18:35.240 |
that all in to our transformer and go and expand, um, we, we have to go in and multiply 00:18:43.640 |
everything through. We get this interesting equation. And so we get this one term, this 00:18:48.200 |
corresponds to just the path directly through the residual stream. And it's going to want to store, 00:18:52.840 |
uh, bigram statistics. It's just, you know, all it gets is the previous token and tries 00:18:56.520 |
to predict the next token. And so it gets to try and predict, uh, try to store bigram statistics. 00:19:01.720 |
And then for every attention head, we get this matrix that says, okay, well, for, 00:19:05.240 |
we have the attention pattern. So it looks, that describes which token looks at which token. 00:19:09.160 |
And we have this matrix here, which describes how for every possible token you could attend to, 00:19:13.240 |
how it affects the logics. And that's just a table that you can look at. It just says, 00:19:17.960 |
you know, for, for this attention head, if it looks at this token, it's going to increase 00:19:21.080 |
the probability of these tokens in a one-layer attention only transformer. That's all there is. 00:19:28.920 |
Yeah, so this is just, just the interpretation I was describing. 00:19:32.040 |
And another thing that's worth noting is this, the, according to this, 00:19:38.280 |
the attention only transformer is linear if you fix the attention pattern. Now, of course it's, 00:19:43.560 |
the attention pattern isn't fixed, but whenever you have the opportunity to go and make something 00:19:47.160 |
linear, linear functions are really easy to understand. And so if you can fix a small 00:19:50.600 |
number of things and make something linear, that's actually, it's a lot of leverage. Okay. 00:19:57.560 |
And yeah, we can talk about how the attention pattern is computed as well. 00:20:00.520 |
Um, you, if you expand it out, you'll get an equation like this 00:20:04.040 |
and, uh, notice, well, I think, I think it'll be easier. Okay. 00:20:10.200 |
The, I think the core story there to take away from all of these is we have these two matrices 00:20:17.080 |
that actually look kind of similar. So this one here tells you if you attend to a token, 00:20:22.680 |
how are the logics affected? And it's, you can just think of it as a giant matrix of, 00:20:26.920 |
for every possible token input token, how, how is the logic, how are the logics affected 00:20:31.160 |
by that token? Are they made more likely or less likely? And we have this one, which sort of says, 00:20:36.360 |
how much does every token want to attend to every other token? 00:20:39.320 |
Um, one way that you can, you can picture this is, uh, okay, that's really, there's really three 00:20:48.280 |
tokens involved when, when we're thinking about an attention head, we have the token that we're 00:20:54.440 |
going to move information to, and that's attending backwards. We have the source token that's going 00:20:59.960 |
to get attended to, and we have the output token whose logics are going to be affected. 00:21:03.720 |
And you can just trace through this. So you can ask what happens. Um, how, how does the, 00:21:09.560 |
the attending to this token affect the output? Well, first we embed the token. 00:21:13.720 |
Then we multiply by WV to get the value vector. The information gets moved by the attention pattern. 00:21:20.920 |
We multiply by W O to add it back into the residual stream, but yet hit by the unembedding 00:21:24.920 |
and we affect the logics. And that's where that one matrix comes from. And we can also ask, you 00:21:28.920 |
know, what decides, you know, whether a token gets a high score when we're, when we're computing the 00:21:33.560 |
attention pattern. And it just says, you know, embed, embed the token, turn it into a query, 00:21:40.280 |
embed the other token, turn it into a key and dot product them and see you that's where those, 00:21:45.160 |
those two matrices come from. So I know that I'm going quite quickly. Um, maybe I'll just 00:21:53.000 |
briefly pause here. And if anyone wants to ask for clarifications, um, this would be a good time. 00:21:57.800 |
And then we'll actually go in reverse engineer and say, you know, everything that's going on 00:22:01.880 |
in a one-layer attentionally transformer is now in the palm of our hands. It's a very toy model. 00:22:06.840 |
No one actually uses one-layer attentionally transformers, but we'll be able to understand 00:22:11.720 |
the one-layer attentionally transformer. So just to be clear, so you're saying that 00:22:17.880 |
the quite key circuit is learning the attention baits and like essentially is responsible for 00:22:24.440 |
running the sort of like the attention between different tokens. Yeah. Yeah. So, so this, 00:22:29.800 |
this matrix, when it, yeah. And, you know, all, all three of those parts are, are learned, but 00:22:34.280 |
that's, that's what expresses whether a attention pattern. Yeah. That's what generates the attention 00:22:39.560 |
patterns gets run for every pair of tokens. And you can, you can, you can think of values in that 00:22:43.720 |
matrix as just being how much every token wants to attend to every other token. If it was in the 00:22:48.040 |
context, um, we're, we're doing positional learnings here. So there's a little bit that 00:22:51.480 |
we're sort of aligning over there as well, but sort of in, in a global sense, how much does 00:22:55.240 |
every token want to attend to every other token? Right. Yeah. And the other circuit, like the 00:22:58.840 |
optimal circuit is using the attention that's calculated to yes. Um, like affect the final 00:23:06.920 |
outputs. It's sort of saying, if, if the attention had assumed that the attention 00:23:10.600 |
had attends to some token. So let's set aside the question of how that gets computed. Just 00:23:14.040 |
assume that it tends to some token, how would it affect the outputs if it attended to that token? 00:23:18.200 |
And you can just, you can just calculate that. Um, it's just a, a big table of values that says, 00:23:23.400 |
you know, for this token, it's going to make this token more likely, this token will make this token 00:23:27.320 |
less likely. Okay. Okay. And it's completely independent. Like it's just two separate 00:23:33.400 |
matrices. They're, they're not, you know, the, the formulas that might make them seem entangled, 00:23:38.200 |
but they're actually separate. Right. To me, it seems like the, um, like just supervision is 00:23:43.880 |
coming from the output value circuit and the query key circuit seems to be more like unsupervised 00:23:47.960 |
kind of thing. Cause there's no. Hmm. I mean, there are just, I think in the sense that every, 00:23:54.440 |
in, in, in a model, like every, every neuron is in some sense, you know, like a signal is, is, 00:24:00.760 |
is somehow downstream from the ultimate, the ultimate signal. And so, you know, the output 00:24:05.240 |
value signal, the output value circuit is getting more, more direct is perhaps getting more direct 00:24:09.720 |
signal. Correct. Um, but yeah. We will be able to dig into this in lots of detail in as much detail 00:24:19.800 |
as you want, uh, in a little bit. So we can, um, maybe I'll push forward. And I think also actually 00:24:24.920 |
an example of how to use this to reverse engineer a one layer model will maybe make it a little bit 00:24:28.840 |
more, more motivated. Okay. So, um, just, just to emphasize this, there's three different tokens 00:24:36.840 |
that we can talk about. There's a token that gets attended to there's the token that does the 00:24:40.680 |
attention to call it the destination. And then there's the token that gets affected, get it, 00:24:44.360 |
get the next token, which it's probabilities are affected. Um, and so something we can do 00:24:50.840 |
is notice that the, the only token that connects to both of these is the token that gets attended 00:24:55.400 |
to. So these two are sort of, they're, they're bridged by their, their interaction with the 00:25:01.080 |
source token. So something that's kind of natural is to ask for a given source token, you know, 00:25:06.120 |
how does it interact with both of these? So let's, let's take, for instance, the token perfect, 00:25:11.880 |
which tokens, one thing we can ask is which tokens want to attend to perfect. Well, 00:25:19.480 |
apparently the tokens that most want to attend to perfect are, are, and looks and is, and provides. 00:25:26.200 |
Um, so R is the most looks as the next most and so on. And then when we attend to perfect, 00:25:30.760 |
and this is with one, one single attention, and so, you know, it'd be different if we did a 00:25:33.960 |
different attention, attention, and it wants to really increase the probability of perfect. 00:25:37.720 |
And then to a lesser extent, super and absolute and cure. And we can ask, you know, what, what 00:25:44.120 |
sequences of tokens are made more likely by this, this particular, um, set of, you know, this, 00:25:51.080 |
this particular set of things wanting to attend to each other and, and becoming more likely. 00:25:54.440 |
Well, things are the form we have our, our token that we attended back to, and we have some, 00:25:59.880 |
some skip of some number of tokens. They don't have to be adjacent, but then later on, we see 00:26:04.200 |
the token R and it tends back to perfect and increases the probability of perfect. So you 00:26:09.560 |
can, you can think of these as being like, we're, we're sort of creating, changing the probability 00:26:13.160 |
of what we might call, might call skip trigrams, where we have, you know, we skip over a bunch of 00:26:17.320 |
tokens in the middle, but we're, we're affecting the probability really of, of trigrams. So perfect 00:26:22.760 |
are perfect, perfect, look super. Um, we can look at another one. So we have the token large, 00:26:28.120 |
um, these tokens contains using specify, want to go and look back to it and an increase of 00:26:32.840 |
probability of large and small. And the skip trigrams that are affected are things like 00:26:37.240 |
large using large, large contains small, um, and things like this. 00:26:42.680 |
Um, if we see the number two, um, we increase the probability of other numbers 00:26:48.760 |
and we affect probably tokens or skip diagrams, like two, one, two, two has three. 00:26:55.320 |
Um, now you're, you're all in, uh, in a technical field. So you'll probably recognize this one. We 00:27:02.040 |
have, uh, have Lambda and then we see backslash and then we want to increase the probability of 00:27:08.040 |
Lambda and sorted and Lambda and operator. So it's all, it's all latex. Um, it wants to, 00:27:13.000 |
um, it's, if it sees Lambda, it thinks that, you know, maybe next time I use a backslash, 00:27:16.840 |
I should go and put in some latex, uh, math symbol. Um, also same thing for HTML. We see 00:27:24.680 |
NBSP for non-breaking space. And then we see an ampersand. We want to go and make that more 00:27:28.680 |
likely. The takeaway from all of this is that a one-layer attention only transformer is totally 00:27:33.240 |
acting on these skip trigrams. Um, every, everything that it does, I mean, I guess it 00:27:38.920 |
also has this pathway by which it affects by grams, but mostly it's just affecting these 00:27:41.720 |
skip trigrams. Um, and there's lots of them. It's just like these giant tables of skip 00:27:46.280 |
trigrams that are made more or less likely. Um, there's lots of other fun things that does. 00:27:52.520 |
Sometimes the tokenization will split up a word in multiple ways. So, um, like we have indie, 00:27:57.000 |
uh, well, that's, that's not a good example. We have like the word pike, and then we, we see the 00:28:02.280 |
token P and then we predict Ike, um, and we predict spikes and stuff like that. Um, or, uh, 00:28:09.480 |
these, these ones are kind of fun. Maybe they're actually worth talking about for a second. So 00:28:12.600 |
we see the token Lloyd, and then we see an L and maybe we predict Lloyd, um, or R and we predict 00:28:20.600 |
Ralph, um, C Catherine. Um, but we'll see in a second that, well, yeah, we'll, we'll come back 00:28:26.280 |
to that in a sec. So we, we increased the probability of things like Lloyd, Lloyd, 00:28:29.560 |
and Lloyd Catherine or picks map. Um, if anyone's worked with QT, um, it's, we see picks map and we 00:28:35.720 |
increased the probability of, um, P X map again, but also Q canvas. Um, yeah, but of course there's 00:28:47.880 |
a problem with this, which is, um, it doesn't get to pick which one of these goes with which one. 00:28:52.280 |
So if you want to go and make picks map, picks map, and picks map Q canvas more probable, 00:28:59.560 |
you also have to go and create, make picks map, picks map, P canvas more probable. And if you 00:29:04.920 |
want to make Lloyd, Lloyd, and Lloyd Catherine more probable, you also have to make Lloyd, 00:29:10.120 |
Cloyd, and Lloyd Lathren more probable. And so there's actually like bugs that transformers have, 00:29:16.840 |
like weird, at least in, you know, in these, these really tiny one-layer attention only 00:29:20.360 |
transformers there, there's these bugs that, you know, they seem weird until you realize 00:29:24.120 |
that it's this giant table of skip trigrams that's, that's operating. Um, and the, the 00:29:29.640 |
nature of that is that you're going to be, um, uh, yeah, you, you, it sort of forces you, 00:29:35.560 |
if you want to go and do this to go in and also make some weird predictions. 00:29:38.280 |
Is there a reason why the source tokens here have a space before the first character? 00:29:46.200 |
Yes. Um, that's just the, I was giving examples where the tokenization breaks in a particular 00:29:52.440 |
way and because spaces get included in the tokenization, um, when there's a space in front 00:29:59.560 |
of something and then there's an example where the space isn't in front of it, they can get 00:30:03.080 |
tokenized in different ways. Got it. Cool. Thanks. Yeah. Great question. 00:30:07.400 |
Um, okay. So some, just to abstract away some common patterns that we're seeing, I think, 00:30:15.240 |
um, one pretty common thing is what you might describe as like B 00:30:19.640 |
AB. So you're, you go and you, you see some token and then you'll see another token that might 00:30:24.360 |
proceed that token. And then you're like, ah, probably the token that I saw earlier is going 00:30:27.480 |
to occur again. Um, or sometimes you, you predict a slightly different token. So like me, maybe an 00:30:33.880 |
example of the first one is two, one, two, but you could also do two has three. And so three, 00:30:40.440 |
isn't the same as two, but it's kind of similar. So that's, that's one thing. Another one is this, 00:30:44.040 |
this example where you have a token that something that's tokenized together one 00:30:46.680 |
time and that's split apart. So you see the token and then you see something that might 00:30:51.000 |
be the first part of the token and then you predict the second part. Um, I think the thing 00:30:57.560 |
that's really striking about this is these are all in some ways are really crude, kind of in 00:31:02.920 |
context learning. Um, and in particular, these models get about 0.1 Nats rather than about 0.4 00:31:10.040 |
Nats up in context learning, and they never go through the phase change. So they're doing some 00:31:13.960 |
kind of really crude in context learning. And also they're dedicating almost all their attention 00:31:18.280 |
heads to this kind of crude in context learning. So they're not very good at it, but they're, 00:31:21.880 |
they're, they're dedicating their, um, their capacity to it. Uh, I'm noticing that it's 1037. 00:31:28.200 |
Um, I, I want to just check how long I can go. Cause I, maybe I should like super accelerate. 00:31:32.920 |
Of course. Uh, I think it's fine because like students are also asking questions in between. 00:31:38.840 |
So you should be good. Okay. So maybe my plan will be, but I'll talk until like 1055 or 11. 00:31:44.920 |
And then if you want, I can go and answer questions for a while after, after that. 00:31:49.880 |
Yeah, it works. Fantastic. So you can see this as a very crude, kind of in context learning. 00:31:56.200 |
Like basically what we're saying is it's sort of all of this labor of, okay, well, 00:31:59.480 |
I saw this token, probably these other tokens, the same token or similar tokens are more likely 00:32:03.560 |
to go and occur later. And look, this is an opportunity that sort of looks like I could 00:32:07.160 |
inject the token that I saw earlier. I'm going to inject it here and say that it's more likely 00:32:10.680 |
that's like, that's basically what it's doing. And it's dedicating almost all of its capacity 00:32:15.320 |
to that. So, you know, these, it's sort of the opposite of what we thought with RNNs in the past, 00:32:18.840 |
like used to be that everyone was like, oh, you know, RNNs it's so hard to care about long distance 00:32:23.720 |
contacts. You know, maybe we need to go and like use dams or something. No, if you, if you train a 00:32:28.360 |
transformer, it dedicates and you give it a long, a long enough context, it's dedicating almost all 00:32:32.280 |
of its capacity to this type of stuff. Just kind of interesting. There are some attentionants which 00:32:40.120 |
are more primarily positional. Usually we, you know, the model that I've been training that has 00:32:45.320 |
two layer or it's only a one layer model has 12 attentionants and usually around two or three of 00:32:49.640 |
those will become these more positional sort of shorter term things that do something more like, 00:32:53.640 |
like local trigram statistics and then everything else becomes these skip trigrams. 00:33:01.480 |
Yeah, so some takeaways from this. Yeah, you can, you can understand 00:33:05.800 |
one layer attentionally transformers in terms of these OV and QK circuits. 00:33:09.400 |
Transformers desperately want to do in context learning. They desperately, desperately, desperately 00:33:15.400 |
want to go and look at these long distance contacts and go and predict things. That's 00:33:19.480 |
just so much, so much entropy that they can go and reduce out of that. The constraints of a one 00:33:24.760 |
layer attentionally transformer force it to make certain bugs if it wants to do the right thing. 00:33:29.400 |
And if you freeze the attention patterns, these models are linear. 00:33:32.440 |
Okay. A quick aside, because so far this type of work has required us to do a lot of very manual 00:33:40.760 |
inspection. Like we're looking through these giant matrices, but there's a way that we can escape 00:33:44.440 |
that. We don't have to use, look at these giant matrices if we don't want to. We can use eigen 00:33:49.080 |
values and eigen vectors. So recall that an eigen value and an eigen vector just means that if you, 00:33:54.920 |
if you multiply that vector by the matrix, it's equivalent to just scaling. And often in my 00:34:02.600 |
experience, this hasn't been very useful for interpretability because we're, we're usually 00:34:05.240 |
mapping between different spaces. But if you're mapping onto the same space, eigen values and 00:34:08.920 |
eigen vectors are a beautiful way to think about this. So we're going to draw them on a, a radial 00:34:16.440 |
plot. And we're going to have a log radial scale because they're going to vary, their magnitude's 00:34:21.880 |
going to vary by many orders of magnitude. Okay. So we can just go and, you know, our, 00:34:27.720 |
our OB circuit maps from tokens to tokens. That's the same vector space on the input and the output. 00:34:32.280 |
And we can ask, you know, what does it mean if we see eigen values of a particular kind? Well, 00:34:36.760 |
positive eigen values, and this is really the most important part, mean copying. So if you have a 00:34:40.840 |
positive eigen value, it means that there's some set of, of tokens where if you, if you see them, 00:34:45.880 |
you increase their probability. And if you have a lot of positive eigen values, you're doing a lot 00:34:50.840 |
of copying. If you only have positive eigen values, everything you do is copying. Now, 00:34:55.480 |
imaginary eigen values mean that you see a token and then you want to go and increase the probability 00:34:59.160 |
of unrelated tokens. And finally, negative eigen values are anti-copying. They're like, 00:35:03.000 |
if you see this token, you make it less probable in the future. 00:35:05.240 |
Well, that's really nice because now we don't have to go and dig through these giant matrices 00:35:10.440 |
that are vocab size by vocab size. We can just look at the eigen values. And so these are the 00:35:15.560 |
eigen values for our one-layer attentionally transformer. And we can see that, you know, 00:35:20.280 |
for many of these, they're almost entirely positive. These ones are, are sort of entirely 00:35:26.680 |
positive. These ones are almost entirely positive. And then really these ones are even almost 00:35:30.840 |
entirely positive. And there's only two that have a significant number of imaginary and negative 00:35:36.120 |
eigen values. And so what this is telling us is it's, it's just in one picture, we can see, 00:35:40.680 |
you know, OK, they're really, you know, 10 out of 12 of these, of these attention heads are just 00:35:46.440 |
doing copying. They just, they just are doing this long distance, you know, well, I saw a token, 00:35:50.280 |
probably it's going to occur again, type stuff. That's kind of cool. We can, we can summarize it 00:35:54.280 |
really quickly. OK. Now, the other thing that you can, yeah, so this is, this is for a second, 00:36:01.400 |
we're going to look at a two-layer model in a second. And we'll, we'll see that also a lot of 00:36:04.520 |
its heads are doing this kind of copying-ish stuff. They have large positive eigenvalues. 00:36:08.440 |
You can do a histogram, like, you know, one, one thing that's cool is you can just add up 00:36:13.960 |
the eigenvalues and divide them by their absolute values. And you've got a number between zero and 00:36:17.800 |
one, which is like how copying, how copying is just the head, or between negative one and one, 00:36:21.400 |
how copying is just the head. And you can just do a histogram and you can see, oh yeah, almost 00:36:25.080 |
all of the heads are doing, doing lots of copying. You know, it's nice to be able to go and summarize 00:36:29.640 |
your model in a, and I think this is sort of like we've gone for a very bottom-up way. And we didn't 00:36:35.160 |
start with assumptions about what the model was doing. We tried to understand its structure. And 00:36:38.040 |
then we were able to summarize it in useful ways. And now we're able to go and say something about 00:36:41.160 |
it. Now, another thing you might ask is what do the eigenvalues of a QK circuit mean? And in our 00:36:48.440 |
example so far, they haven't been that, they wouldn't have been that interesting. But in a 00:36:52.280 |
minute, they will be. And so I'll briefly describe what they mean. A positive eigenvalue would mean 00:36:55.640 |
you want to attend to the same tokens. An imaginary eigenvalue, and this is what you would mostly see 00:37:00.760 |
in the models that we've seen so far, means you want to go in and attend to a unrelated or different 00:37:05.560 |
token. And a negative eigenvalue would mean you want to avoid attending to the same token. 00:37:09.960 |
So that will be relevant in a second. Yeah, so those are going to mostly be useful to think 00:37:17.000 |
about in multilayer attentional learning transformers when we can have chains of attention 00:37:20.280 |
heads. And so we can ask, well, I'll get to that in a second. Yeah, so that's a table summarizing 00:37:25.160 |
that. Unfortunately, this approach completely breaks down once you have MLP layers. MLP layers, 00:37:30.840 |
you know, now you have these non-linearity since you don't get this property where your model is 00:37:34.120 |
mostly linear and you can just look at a matrix. But if you're working with only attentionally 00:37:38.280 |
transformers, this is a very nice way to think about things. OK, so recall that one-layer 00:37:42.920 |
attentionally transformers don't undergo this phase change that we talked about in the beginning. 00:37:45.880 |
Like right now, we're on a hunt. We're trying to go and answer this mystery of what the hell is 00:37:50.200 |
going on in that phase change where models suddenly get good at in-context learning. 00:37:53.800 |
We want to answer that. And one-layer attentionally transformers don't undergo that phase change, 00:37:58.040 |
but two-layer attentionally transformers do. So we'd like to know what's different about 00:38:01.640 |
two-layer attentionally transformers. OK, well, so in our previous-- when we were dealing with 00:38:10.280 |
one-layer attentionally transformers, we were able to go and rewrite them in this form. And it gave 00:38:14.680 |
us a lot of ability to go and understand the model because we could go and say, well, you know, 00:38:19.080 |
this is bigrams. And then each one of these is looking somewhere. And we had this matrix that 00:38:23.160 |
describes how it affects things. And yeah, so that gave us a lot of ability to think about 00:38:29.160 |
these things. And we can also just write in this factored form where we have the embedding, 00:38:33.880 |
and then we have the attention heads, and then we have the unembedding. 00:38:36.040 |
OK, well-- oh, and for simplicity, we often go and write W-O-V for W-O-W-V because they always 00:38:45.400 |
come together. It's always the case. It's, in some sense, an illusion that W-O and W-V are 00:38:49.720 |
different matrices. They're just one low-rank matrix. They're never-- they're always used 00:38:53.160 |
together. And similarly, W-Q and W-K, it's sort of an illusion that they're different matrices. 00:38:57.400 |
They're always just used together. And keys and queries are just sort of-- they're just an 00:39:01.960 |
artifact of these low-rank matrices. So in any case, it's useful to go and write those together. 00:39:06.920 |
OK, great. So a two-layer attentionally transformer, what we do is we go through the 00:39:12.760 |
embedding matrix. Then we go through the layer 1 attention heads. Then we go through the layer 2 00:39:17.800 |
attention heads. And then we go through the unembedding. And for the attention heads, 00:39:21.960 |
we always have this identity as well, which corresponds just going down the residual stream. 00:39:25.960 |
So we can go down the residual stream, or we can go through an attention head. 00:39:29.400 |
Next up, we can also go down the residual stream, or we can go through an attention head. 00:39:33.960 |
And there's this useful identity, the mixed product identity that any tensor product or 00:39:44.040 |
other ways of interpreting this obey, which is that if you have an attention head and we have, 00:39:49.800 |
say, we have the weights and the attention pattern and the W-O-V matrix and the attention pattern, 00:39:54.360 |
the attention patterns multiply together, and the O-V circuits multiply together, 00:39:58.360 |
and they behave nicely. OK, great. So we can just expand out that equation. We can just take that 00:40:05.000 |
big product we had at the beginning, and we can just expand it out. And we get three different 00:40:08.040 |
kinds of terms. So one thing we do is we get this path that just goes directly through the residual 00:40:12.840 |
stream where we embed and un-embed, and that's going to want to represent some bigram statistics. 00:40:16.920 |
Then we get things that look like the attention head terms that we had previously. 00:40:23.720 |
And finally, we get these terms that correspond to going through two attention heads. 00:40:32.360 |
Now, it's worth noting that these terms are not actually the same as-- 00:40:40.520 |
they're-- because the attention head, the attention patterns in the second layer can be computed from 00:40:45.320 |
the outputs of the first layer, those are also going to be more expressive. But at a high level, 00:40:49.480 |
you can think of there as being these three different kinds of terms. And we sometimes 00:40:52.600 |
call these terms virtual attention heads because they don't exist in the sense-- like, they aren't 00:40:57.000 |
explicitly represented in the model, but they, in fact, they have an attention pattern. They have 00:41:02.040 |
no E-circuit. They're in almost all functional ways like a tiny little attention head, and there's 00:41:06.680 |
exponentially many of them. Turns out they're not going to be that important in this model, 00:41:10.920 |
but in other models, they can be important. Right, so one thing that I said is it allows 00:41:16.600 |
us to think about attention heads in a really principled way. We don't have to go and think 00:41:19.880 |
about-- I think there's-- people look at attention patterns all the time, and I think a concern you 00:41:27.240 |
have is, well, there's multiple attention patterns. The information that's being moved by one attention 00:41:31.880 |
head, it might have been moved there by another attention head and not originated there. It might 00:41:35.560 |
still be moved somewhere else. But in fact, this gives us a way to avoid all those concerns and 00:41:40.120 |
just think about things in a single principled way. OK, in any case, an important question to ask is, 00:41:46.680 |
how important are these different terms? Like, we could study all of them. How important are they? 00:41:50.680 |
And it turns out you can just-- there's an algorithm you can use where you knock out 00:41:56.840 |
attention-- knock out these terms, and you go and you ask, how important are they? 00:42:01.320 |
And it turns out by far the most important thing is these individual attention head terms. In this 00:42:06.440 |
model, by far the most important thing, the virtual attention heads basically don't matter 00:42:11.400 |
that much. They only have an effective 0.3 nats using to the above ones, and the bigrams are still 00:42:16.680 |
pretty useful. So if we want to try to understand this model, we should probably go and focus our 00:42:20.120 |
attention on-- the virtual attention heads are not going to be the best way to go and focus our 00:42:26.680 |
attention, especially since there's a lot of them. There's 124 of them for 0.3 nats. It's very little 00:42:31.720 |
that you would understand for studying one of those terms. So the thing that we probably want 00:42:36.440 |
to do-- we know that these are bigram statistics. So what we really want to do is we want to 00:42:39.800 |
understand the individual attention head terms. This is the algorithm. I'm going to skip over it 00:42:48.520 |
for time. We can ignore that term because it's small. And it turns out also that the layer 2 00:42:54.520 |
attention heads are doing way more than layer 1 attention heads. And that's not that surprising. 00:42:58.840 |
Layer 2 attention heads are more expressive because they can use the layer 1 attention 00:43:02.280 |
heads to construct their attention patterns. So if we could just go and understand the layer 00:43:08.120 |
2 attention heads, we'd probably understand a lot of what's going on in this model. 00:43:11.720 |
And the trick is that the attention heads are now constructed from the previous layer rather 00:43:18.200 |
than just from the tokens. So this is still the same, but the attention pattern is more complex. 00:43:24.120 |
And if you write it out, you get this complex equation that says, you embed the tokens, 00:43:28.440 |
and you're going to shuffle things around using the attention heads for the keys. Then you multiply 00:43:31.880 |
by WQK. Then you shuffle things around again for the queries. And then you go and multiply by the 00:43:37.080 |
embedding again because they were embedded. And then you get back to the tokens. But let's 00:43:44.200 |
actually look at them. So one thing that's-- remember that when we see positive eigenvalues 00:43:49.000 |
in the OB circuit, we're doing copying. So one thing we can say is, well, 7 out of 12-- and in 00:43:53.800 |
fact, the ones with the largest eigenvalues are doing copying. So we still have a lot of attention 00:43:58.680 |
heads that are doing copying. And yeah, the QK circuit-- so one thing you could do is you could 00:44:07.480 |
try to understand things in terms of this more complex QK equation. You could also just try to 00:44:10.600 |
understand what the attention patterns are doing empirically. So let's look at one of these copying 00:44:14.520 |
ones. I've given it the first paragraph of Harry Potter, and we can just look at word attempts. 00:44:23.240 |
And something really interesting happens. So almost all the time, we just attend back to 00:44:27.880 |
the first token. We have this special token at the beginning of the sequence. 00:44:31.480 |
And we usually think of that as just being a null attention operation. It's a way for it to not do 00:44:36.440 |
anything. In fact, if you look, the value vector is basically 0. It's just not copying any information 00:44:40.600 |
from that. But whenever we see repeated text, something interesting happens. So when we get 00:44:47.800 |
to "Mr."-- tries to look at "and." It's a little bit weak. Then we get to "D," and it attends to 00:44:55.400 |
"ers." That's interesting. And then we get to "ers," and it attends to "ly." And so it's not 00:45:04.920 |
attending to the same token. It's attending to the same token, shifted one forward. Well, that's 00:45:12.840 |
really interesting. And there's actually a lot of attention heads that are doing this. So here we 00:45:16.760 |
have one where now we hit the potter's pot, and we attend to "ters." Maybe that's the same attention 00:45:21.880 |
head I don't remember when I was constructing this example. It turns out this is a super common 00:45:25.880 |
thing. So you go and you look for the previous example, you shift one forward, and you're like, 00:45:30.120 |
OK, well, last time I saw this, this is what happened. Probably the same thing is going to 00:45:33.080 |
happen. And we can go and look at the effect that the attention head has on the logits. Most of the 00:45:42.360 |
time, it's not affecting things. But in these cases, it's able to go and predict when it's 00:45:46.200 |
doing this thing of going and looking one forward. It's able to go and predict the next token. 00:45:49.160 |
So we call this an induction head. An induction head looks for the previous copy, 00:45:54.920 |
looks forward, and says, ah, probably the same thing that happened last time is going to happen. 00:45:58.680 |
You can think of this as being a nearest neighbors. It's like an in-context nearest 00:46:02.280 |
neighbors algorithm. It's going and searching through your context, finding similar things, 00:46:06.120 |
and then predicting that's what's going to happen next. 00:46:11.480 |
The way that these actually work is-- I mean, there's actually two ways. But in a model that 00:46:17.400 |
uses rotary attention or something like this, you only have one. You shift your key. First, 00:46:24.280 |
an earlier attention head shifts your key forward one. So you take the value of the previous token, 00:46:29.960 |
and you embed it in your present token. And then you have your query in your key, 00:46:33.960 |
go and look at-- yeah, try to go and match. So you look for the same thing. 00:46:40.600 |
And then you go and you predict that whatever you saw is going to be the next token. So that's the 00:46:45.480 |
high-level algorithm. Sometimes you can do clever things where actually it'll care about multiple 00:46:50.200 |
earlier tokens, and it'll look for short phrases and so on. So induction heads can really vary in 00:46:54.680 |
how much of the previous context they care about or what aspects of the previous context they care 00:46:58.200 |
about. But this general trick of looking for the same thing, shift forward, predict that, 00:47:03.320 |
is what induction heads will do. Lots of examples of this. And the cool thing is you can now 00:47:10.760 |
you can use the QK eigenvalues to characterize this. You can say, well, we're looking for the 00:47:15.560 |
same thing, shifted by one, but looking for the same thing. If you expand through the attention 00:47:19.160 |
nodes in the right way, that'll work out. And we're copying. And so an induction head is one 00:47:23.640 |
which has both positive OV eigenvalues and also positive QK eigenvalues. 00:47:28.520 |
And so you can just put that on a plot, and you have your induction heads in the corner. 00:47:37.000 |
So your OV eigenvalues, your QK eigenvalues, and I think actually OV is this axis, QK is this one 00:47:42.760 |
axis, doesn't matter. And in the corner, you have your eigenvalues or your induction heads. 00:47:49.400 |
And so this seems to be-- well, OK, we now have an actual hypothesis. The hypothesis is the way 00:47:57.160 |
that that phase change we were seeing, the phase change is the discovery of these induction heads. 00:48:01.240 |
That would be the hypothesis. And these are way more effective than this first algorithm we had, 00:48:08.120 |
which was just blindly copy things wherever it could be plausible. Now we can go and actually 00:48:12.520 |
recognize patterns and look at what happened and predict that similar things are going to 00:48:15.560 |
happen again. That's a way better algorithm. Yeah, so there's other attention heads that 00:48:22.280 |
are doing more local things. I'm going to go and skip over that and return to our mystery, 00:48:25.800 |
because I am running out of time. I have five more minutes. OK, so what is going on with this 00:48:30.200 |
in-context learning? Well, now we have a hypothesis. Let's check it. So we think it 00:48:34.040 |
might be induction heads. And there's a few reasons we believe this. So one thing is going 00:48:40.600 |
to be that induction heads-- well, OK, I'll just go over to the end. So one thing you can do is you 00:48:46.520 |
can just ablate the attention heads. And it turns out you can color-- here we have attention heads 00:48:52.200 |
colored by how much they are an induction head. And this is the start of the bump. This is the 00:48:58.040 |
end of the bump here. And we can see that they-- first of all, induction heads are forming. Like 00:49:02.600 |
previously, we didn't have induction heads here. Now they're just starting to form here. And then 00:49:06.920 |
we have really intense induction heads here and here. And the attention heads, where if you ablate 00:49:13.080 |
them, you get a loss. And so we're looking not at the loss, but at this meta learning score, 00:49:20.760 |
the difference between-- or in-context learning score, the difference between the 500th token 00:49:25.000 |
and the 50th token. And that's all explained by induction heads. Now, we actually have one 00:49:31.480 |
induction head that doesn't contribute to it. Actually, it does the opposite. So that's kind 00:49:34.840 |
of interesting. Maybe it's doing something shorter distance. And there's also this interesting thing 00:49:39.800 |
where they all rush to be induction heads. And then they discover only a few went out in the end. 00:49:44.680 |
So there's some interesting dynamics going on there. But it really seems like in these small 00:49:48.280 |
models, all of in-context learning is explained by these induction heads. OK. What about large 00:49:55.880 |
models? Well, in large models, it's going to be harder to go and ask this. But one thing you can 00:49:59.560 |
do is you can ask, OK, we can look at our in-context learning score over time. We get this 00:50:06.600 |
sharp phase change. Oh, look. Induction heads form at exactly the same point in time. So that's only 00:50:13.400 |
correlational evidence. But it's pretty suggestive correlational evidence, especially given that we 00:50:17.320 |
have an obvious-- the obvious effect that induction heads should have is this. I guess it could be 00:50:22.920 |
that there's other mechanisms being discovered at the same time in large models. But it has to 00:50:26.360 |
be in a very small window. So I really suggest the thing that's driving that change is in-context 00:50:32.760 |
learning. OK. So obviously, induction heads can go and copy text. But a question you might ask is, 00:50:41.080 |
can they do translation? There's all these amazing things that models can do that it's not obvious 00:50:47.000 |
in-context learning or this sort of copying mechanism could do. So I just want to very 00:50:51.240 |
quickly look at a few fun examples. So here we have an attention pattern. Oh, yeah. I guess 00:51:02.120 |
I need to open Lexiscope. Let me try doing that again. Sorry. I should have thought this through 00:51:12.200 |
a bit more before this talk. Chris, could you zoom in a little, please? Yeah, yeah. Thank you. 00:51:38.840 |
OK. I'm not-- my French isn't that great. But my name is Christopher. I'm from Canada. 00:51:44.360 |
What we can do here is we can look at where this attention head attends as we go and we do this. 00:51:50.680 |
And it'll become especially clear on the second sentence. So here, we're on the period, 00:51:55.400 |
and we attend to "je." Now we're on-- and "je" is "I" in French. OK. Now we're on the "I," 00:52:03.640 |
and we attend to "sui." Now we're on the "am," and we attend to "do," which is "from," 00:52:09.720 |
and then "from" to "Canada." And so we're doing a cross-lingual induction head, which we can use 00:52:16.040 |
for translation. And indeed, if you look at examples, this is where it seems to-- it seems 00:52:21.960 |
to be a major driving force in the model's ability to go and correctly do translation. 00:52:28.680 |
Another fun example is-- I think maybe the most impressive thing about in-context learning to me 00:52:35.240 |
has been the model's ability to go and learn arbitrary functions. Like, 00:52:38.040 |
you can just show the model a function. It can start mimicking that function. Well, OK. 00:52:44.680 |
So do these induction heads only do kind of a look-ahead copy? Or can they also do some sort 00:52:50.120 |
of complex structure recognition? Yeah, yeah. So they can both use a larger context-- previous 00:52:58.520 |
context-- and they can copy more abstract things. So the translation one is showing you that they 00:53:02.600 |
can copy, rather than the literal token, a translated version. So it's what we might call 00:53:06.520 |
soft induction head. And yeah, you can have them copy similar words. You can have them 00:53:11.800 |
look at longer contexts. You can look for more structural things. The way that we usually 00:53:16.440 |
characterize them is whether-- in large models, just whether they empirically behave like an 00:53:20.600 |
induction head. So the definition gets a little bit blurry when you try to encompass these more-- 00:53:25.480 |
there's sort of a blurry boundary. But yeah, there seem to be a lot of attention heads that 00:53:29.480 |
are doing sort of more and more abstract versions. And yeah, my favorite version is this one that I'm 00:53:36.040 |
about to show you, which is used-- let's isolate a single one of these-- which can do pattern 00:53:42.200 |
recognition. So it can learn functions in the context and learn how to do it. So I've just 00:53:46.200 |
made up a nonsense function here. We're going to encode one binary variable with the choice of 00:53:52.120 |
whether to do a color or a month as the first word. Then we're going to-- so we have green or 00:53:58.440 |
June here. Let's zoom in more. So we have color or month, and animal or fruit. And then we have 00:54:07.400 |
to map it to either true or false. So that's our goal. And it's going to be an XOR. So we have 00:54:11.480 |
the binary variable represented in this way. We do an XOR. I'm pretty confident this was never 00:54:16.840 |
in the training set, because I just made it up, and it seems like a nonsense problem. 00:54:21.080 |
OK, so then we can go and ask, can the model go and predict that? Well, it can, and it uses 00:54:25.880 |
induction heads to do it. And what we can do is we can look at the-- so we look at a colon where 00:54:30.120 |
it's going to go and try and predict the next word. And for instance here, we have April dog. 00:54:35.880 |
So it's a month and then an animal, and it should be true. And what it does is it looks for a 00:54:40.920 |
previous-- previous cases where there was an animal-- a month and then an animal, especially 00:54:45.400 |
one where the month was the same-- and goes and looks and says that it's true. And so the model 00:54:50.120 |
can go and learn-- learn a function, a completely arbitrary function, by going and doing this kind 00:54:55.240 |
of pattern recognition induction head. And so this, to me, made it a lot more plausible that 00:55:01.240 |
these models actually can do-- can do in-context learning. Like, the generality of all these 00:55:08.600 |
amazing things we see these large language models do can be explained by induction heads. We don't 00:55:14.440 |
know that. It could be that there's other things going on. It's very possible that there's lots 00:55:17.800 |
of other things going on. But it seems a lot more plausible to me than it did when we started. 00:55:22.680 |
I'm conscious that I am actually, over time-- I'm going to just quickly go through these last 00:55:27.800 |
few slides. So I think thinking of this as an in-context nearest neighbors, I think, 00:55:31.400 |
is a really useful way to think about this. Other things could absolutely be contributing. 00:55:35.800 |
This might explain why transformers do in-context learning over long-context better than LSTMs. 00:55:44.040 |
And LSTM can't do this, because it's not linear in the amount of compute it needs. It's, like, 00:55:48.200 |
quadratic or n log n if it was really clever. So transformers-- or LSTM's impossible to do this. 00:55:53.640 |
Transformers do do this. And actually, they diverge at the same point. But if you look-- well, 00:55:58.840 |
I can go into this in more detail after, if you want. There's a really nice paper by Marcus Hutter 00:56:05.000 |
trying to predict and explain why we observe scaling laws in models. It's worth noting that 00:56:09.080 |
the arguments in this paper go exactly through to this example, this theory. In fact, they work 00:56:15.160 |
better for the case of thinking about this in-context learning with, essentially, a nearest 00:56:20.120 |
neighbors algorithm than they do in the regular case. So yeah, I'm happy to answer questions. I 00:56:27.160 |
can go into as much detail as people want about any of this. And I can also, if you send me an 00:56:31.640 |
email, send me more information about all this. And yeah, again, this work is not yet published. 00:56:38.440 |
You don't have to keep it secret. But just if you could be thoughtful about the fact that it's 00:56:42.680 |
unpublished work and probably is a month or two away from coming out, I'd be really grateful for 00:56:46.360 |
that. Thank you so much for your time. Yeah, thanks a lot, Chris. This was a great talk. 00:56:50.680 |
So I'll just open it to some general questions. And then we can do a round of questions from the 00:56:57.880 |
students. So I was very excited to know, so what is the line of work that you're currently working 00:57:03.080 |
on? Is it extending this? So what do you think is the next things you try to do to make it more 00:57:08.360 |
interpretable? What are the next? Yeah. I mean, I want to just reverse engineer language models. I 00:57:13.320 |
want to figure out the entirety of what's going on in these language models. And one thing that we 00:57:21.960 |
totally don't understand is MLP layers. We understand some things about them, but we don't 00:57:29.240 |
really understand MLP layers very well. There's a lot of stuff going on in large models that we 00:57:33.800 |
don't understand. I want to know how models do arithmetic. I want to know-- another thing that 00:57:38.360 |
I'm very interested in is what's going on when you have multiple speakers. The model can clearly 00:57:41.960 |
represent-- it has a basic theory of mind, multiple speakers in a dialogue. I want to understand 00:57:46.200 |
what's going on with that. But honestly, there's just so much we don't understand. It's sort of 00:57:51.560 |
hard to answer the question because there's just so much to figure out. And we have a lot of 00:57:56.760 |
different threads of research in doing this. But yeah, the interpretability team at Anthropic 00:58:02.600 |
is just sort of-- has a bunch of threads trying to go and figure out what's going on inside these 00:58:06.760 |
models. And sort of a similar flavor to this of just trying to figure out, how do the parameters 00:58:11.240 |
actually encode algorithms? And can we reverse engineer those into meaningful computer programs 00:58:16.440 |
that we can understand? Got it. Another question I had is, so you were talking about how the 00:58:22.840 |
transformers are trying to do meta-learning inherently. So it's like-- and you spent a lot 00:58:27.320 |
of time talking about the induction heads, and that was very interesting. But can you formalize 00:58:31.720 |
the sort of meta-learning algorithm they might be learning? Is it possible to say, oh, maybe this is 00:58:35.880 |
a sort of internal algorithm that's going that's making them good meta-learners or something like 00:58:40.680 |
that? I don't know. Yeah, I mean, I think that there's roughly two algorithms. One is this 00:58:46.040 |
algorithm we saw in the one-layer model. And we see it in other models, too, especially early on, 00:58:49.480 |
which is just try to copy-- you saw a word, probably a similar word is going to happen 00:58:53.880 |
later. Look for places that it might fit in and increase the probability. So that's one thing that 00:58:59.000 |
we see. And the other thing we see is induction heads, which you can just summarize as in-context 00:59:05.000 |
nearest neighbors, basically. And it seems-- possibly there's other things, but it seems like 00:59:09.640 |
those two algorithms and the specific instantiations that we are looking at seem to be what's driving 00:59:15.640 |
in-context learning. That would be my present theory. Yeah, it sounds very interesting. Yeah. 00:59:21.160 |
OK, so let's open-- make a round of questions. So yeah, feel free to go ahead for questions.