back to index

Stanford CS25: V1 I Transformer Circuits, Induction Heads, In-Context Learning


Chapters

0:0
0:26 People mean lots of different things by "interpretability". Mechanistic interpretability aims to map neural network parameters to human understandable algorithms.
14:13 What is going on???
44:14 The Induction Pattern

Whisper Transcript | Transcript Only Page

00:00:00.000 | Thank you all for having me. It's exciting to be here. One of my favorite things is talking
00:00:10.120 | about what is going on inside neural networks, or at least what we're trying to figure out
00:00:15.120 | is going on inside neural networks. So it's always fun to chat about that. Oh, gosh, I
00:00:21.040 | have to figure out how to do things. Okay. What? I won't. Okay, there we go. Now we are
00:00:28.200 | advancing slides, that seems promising. So I think interoperability means lots of different
00:00:33.320 | things to different people. It's a very broad term and people mean all sorts of different
00:00:38.540 | things by it. And so I wanted to talk just briefly about the kind of interoperability
00:00:43.480 | that I spend my time thinking about, which is what I'd call mechanistic interoperability.
00:00:48.640 | So most of my work actually has not been on language models or on RNNs or transformers,
00:00:54.920 | but on understanding vision confinates and trying to understand how do the parameters
00:01:00.720 | in those models actually map to algorithms. So you can like think of the parameters of
00:01:06.360 | a neural network as being like a compiled computer program. And the neurons are kind
00:01:11.100 | of like variables or registers. And somehow there are these complex computer programs
00:01:16.720 | that are embedded in those weights. And we'd like to turn them back in to computer programs
00:01:21.120 | that humans can understand. It's a kind of reverse engineering problem. And so this is
00:01:27.440 | kind of a fun example that we found where there was a car neuron and you could actually
00:01:31.000 | see that we have the car neuron and it's constructed from like a wheel neuron. And it looks for,
00:01:38.880 | in the case of the wheel neuron, it's looking for the wheels on the bottom. Those are positive
00:01:42.880 | weights and it doesn't want to see them on top. So it has negative weights there. And
00:01:46.160 | there's also a window neuron. It's looking for the windows on the top and not on the
00:01:49.800 | bottom. And so what we're actually seeing there, right, is it's an algorithm. It's an
00:01:53.660 | algorithm that goes and turns, you know, it's just, it's, you know, saying, you know, well,
00:01:59.000 | a car has wheels on the bottom and windows on the top and chrome in the middle. And that's
00:02:03.720 | actually like just the strongest neurons for that. And so we're actually seeing a meaningful
00:02:07.840 | algorithm and that's not an exception. That's sort of the general story that if you're willing
00:02:12.740 | to go and look at neural network weights and you're willing to invest a lot of energy in
00:02:16.560 | trying to first engineer them, there's meaningful algorithms written in the weights waiting
00:02:21.040 | for you to find them. And there's a bunch of reasons. I think that's an interesting
00:02:25.480 | thing to think about. One is, you know, just no one knows how to go and do the things that
00:02:29.160 | neural networks can do. Like no one knows how to write a computer program that can accurately
00:02:32.520 | classify image net, let alone, you know, the language modeling tasks that we're doing.
00:02:36.160 | No one knows how to like directly write a computer program that can do the things that
00:02:39.400 | GPT-3 does. And yet somehow breaking descent is able to go and discover a way to do this.
00:02:43.620 | And I want to know what's going on. I want to know, you know, how, what has it discovered
00:02:48.300 | that it can do in these systems? There's another reason why I think this is important, which
00:02:53.760 | is, uh, is safety. So, you know, if we, if we want to go and use these systems in places
00:02:59.060 | where they have big effect on the world, um, I think a question we need to ask ourselves
00:03:03.420 | is, you know, what, what happens when these models have, have unanticipated failure modes,
00:03:08.620 | failure modes we didn't know to go and test for, to look for, to check for, how can we,
00:03:12.820 | how can we discover those things, especially if they're, if they're really pathological
00:03:15.700 | failure modes.
00:03:16.700 | So the models in some sense, deliberately doing something that we don't want. Well,
00:03:20.180 | the only way that I really see that we, we can do that is if we can get to a point where
00:03:23.340 | we really understand what's going on inside these systems. Um, so that's another reason
00:03:28.020 | that I'm interested in this. Now, uh, actually doing interpretedly on language models and
00:03:33.740 | transformers it's new, new to me. I, um, before this year, I spent like eight years working
00:03:38.020 | on trying to reverse engineer continents and vision models. Um, and so the ideas in this
00:03:42.980 | talk, um, are, are new things that I've been thinking about with my collaborators. Um,
00:03:48.140 | and we're still probably a month or two out, maybe, maybe longer from publishing them.
00:03:51.820 | Um, and this was also the first public talk that I've given on it. So, uh, you know, the
00:03:55.460 | things that I'm going to talk about, um, they met there, there's, I think, honestly, it's
00:03:58.540 | still a little bit confused for me, um, and definitely are going to be confused in my
00:04:01.820 | articulation of them. So if I, if I say things that are confusing, um, you know, please feel
00:04:05.420 | free to ask me questions. There might be some points for me to go quickly because there's
00:04:08.180 | a lot of content. Um, but definitely at the end, I will be available for a while to chat
00:04:12.500 | about this stuff. Um, and, uh, yeah, also I apologize. Um, if, uh, if I'm unfamiliar
00:04:19.020 | with zoom and make, make mistakes. Um, but, uh, yeah. So, um, with that said, uh, let's
00:04:25.180 | dive in. Um, so I've wanted to start with a mystery, um, before we go and try to actually
00:04:33.980 | dig into, you know, what's going on inside these models. Um, I wanted to motivate it
00:04:38.020 | by a really strange piece of discovery of a behavior that we discovered and, and wanted
00:04:43.020 | to understand. Um, uh, and by the way, I should say all this work is, um, uh, you know, is,
00:04:49.860 | is done with my, my colleagues and Anthropic and especially my colleagues, Catherine and
00:04:53.260 | Nelson. Um, okay. So onto the mystery. Um, I think probably the, the most interesting
00:04:59.140 | and most exciting thing about, um, about transformers is their ability to do in context learning,
00:05:06.140 | or sometimes people will call it meta learning. Um, you know, the GP three paper, uh, goes
00:05:10.580 | and, and describes things as, um, you know, uh, language models are few shot learners.
00:05:15.060 | Like there's lots of impressive things about GPT three, but they choose to focus on that.
00:05:17.820 | And, you know, now everyone's talking about prompt engineering. Um, and, um, Andre McCarthy
00:05:22.660 | was, was joking about how, you know, software 3.0 is designing the prompt. Um, and so the
00:05:27.100 | ability of language models of these, these large transformers to respond to their context
00:05:31.900 | and learn from their context and change their behavior and response to their context. Um,
00:05:35.820 | you know, it really seems like probably the most surprising and striking and remarkable
00:05:39.260 | thing about that. Um, and, uh, some of my, my colleagues previously published a paper
00:05:46.420 | that has a trick in it that I, I really love, which is, so we're, we're all used to looking
00:05:50.420 | at learning curves. You train your model and you, you know, as your model trains, the loss
00:05:53.740 | goes down. Sometimes it's a little bit discontinuous, but it goes down. Another thing that you can
00:06:01.300 | do is you can go and take a fully trained model and you can go and ask, you know, as
00:06:04.940 | we go through the context, you know, as we go and we predict the first token and then
00:06:08.260 | the second token and the third token, we get better at predicting each token because we
00:06:12.220 | have more information to go and predict it on.
00:06:14.380 | So, you know, the first, the first con token, the loss should be the entropy of the unigrams
00:06:19.180 | and then the next token should be the entry of the bigrams and it falls, but it keeps
00:06:23.420 | falling and it keeps getting better. And in some sense, that's our, that's the model's
00:06:28.940 | ability to go and predict, to, to go and do in context learning the ability to go and
00:06:34.620 | predict, um, you know, to be better at predicting later tokens than you are at predicting early
00:06:38.500 | tokens. That is, that is in some sense, a mathematical definition of what it means to
00:06:42.020 | be good at this magical in context, learning or meta learning that, that these models can
00:06:46.020 | do. And so that's kind of cool because that gives us a way to go and look at whether models
00:06:51.020 | are good at, at in context learning.
00:06:52.740 | Chris, uh, if I could just ask a question, like a clarification question, when you say
00:06:58.380 | learning, there are no actual parameter updates that is the remarkable thing about in context
00:07:04.060 | learning, right? So yeah, indeed, we traditionally think about neural networks as learning over
00:07:08.180 | the course of training by going and modifying their parameters, but somehow models appear
00:07:12.340 | to also be able to learn in some sense. Um, if you give them a couple of examples in their
00:07:16.000 | context, they can then go and do that later in their context, even though no parameters
00:07:19.900 | changed. And so it's, it's some kind of quite different, different notion of learning as
00:07:23.780 | you're, as you're, as you're gesturing out.
00:07:25.540 | Uh, okay. I think that's making more sense. So, I mean, could you also just describe in
00:07:31.340 | context learning in this case as conditioning, as in like conditioning on the first five
00:07:35.540 | tokens of a 10 token sentence or the next five tokens?
00:07:39.100 | Yeah. I think the reason that people sometimes think about this as in context learning or
00:07:42.420 | meta learning is that you can do things where you like actually take a training set and
00:07:47.060 | you embed the training set in your context. Like if you just two or three examples, and
00:07:50.900 | then suddenly your model can go and do, do this task. And so you can do few shot learning
00:07:55.340 | by embedding things in the context. Um, yeah, the formal setup is that you're, you're just
00:08:00.320 | conditioning on, on, on this context. And it's just that somehow this, this ability,
00:08:05.460 | like this thing, like there's, there's some sense, you know, for a long time, people were,
00:08:08.980 | were, were, I mean, I, I guess really the history of this is, uh, we started to get
00:08:14.280 | good at, at neural networks learning, right. Um, and we could, we could go and train language,
00:08:18.180 | uh, train vision models and language models that could do all these remarkable things.
00:08:20.780 | But then people started to be like, well, you know, these systems are, they take so
00:08:24.300 | many more examples than humans do to go and learn. How can we go and fix this? And we
00:08:28.780 | had all these ideas of meta learning develop where we wanted to go and, and train models
00:08:32.860 | explicitly to be able to learn from a few examples and people develop all these complicated
00:08:37.060 | schemes. And then the like, truly like absurd thing about, about transformer language models
00:08:41.180 | is without any effort at all, we get this for free that you can go and just give them
00:08:45.820 | a couple of examples in their context and they can learn in their context to go and
00:08:49.060 | do new things. Um, I think that was like, like that was in some sense, the like most
00:08:53.380 | striking thing about the GPT-3 paper. Um, and so, uh, this, this, yeah, this, this ability
00:08:59.460 | to go and have the just conditioning on a context going, give you, you know, new abilities
00:09:04.460 | for free and, and the ability to generalize to new things is, is in some sense the, the
00:09:08.300 | most, yeah. And to me, the most striking and shocking thing about, about transformer language
00:09:12.220 | models.
00:09:13.220 | That makes sense. I mean, I guess from my perspective, I'm trying to square like the
00:09:21.340 | notion of learning in this case with, you know, if you were, I were given a prompt of
00:09:26.340 | like one plus one equals two, two plus three equals five as the sort of few shot setup.
00:09:32.500 | And then somebody else put, you know, like five plus three equals, and we had to fill
00:09:36.740 | it out. In that case, I wouldn't say that we've learned arithmetic because we already
00:09:40.540 | sort of knew it, but rather we're just sort of conditioning on the prompt to know what
00:09:45.260 | it is that we should then generate. Right. Uh, but, but it seems to me like that's yeah.
00:09:49.740 | I think that's on a spectrum though, because you can, you can also go and give like completely
00:09:54.980 | nonsensical problems where the model would never have seen, um, see it like mimic this
00:10:00.020 | function and give a couple of examples of the function and the models never seen it
00:10:02.580 | before. And it can go and do that later in the context. Um, and I think, I think what
00:10:06.620 | you did learn, um, in a lot of these cases, so you might not have, you might have, um,
00:10:11.500 | you might not have learned arithmetic, like you might've had some innate faculty for arithmetic
00:10:14.540 | that you're using, but you might've learned, Oh, okay. Right now we're, we're doing arithmetic
00:10:18.020 | problems. Um, in any case, this is, I agree that there's like an element of semantics
00:10:22.900 | here. Um, yeah, no, this is helpful though, just to clarify exactly sort of what the,
00:10:27.060 | what you remember in context learning. Thank you for walking through it. Of course.
00:10:33.460 | So something that's, I think, really striking about all of this, um, is well, okay. So we,
00:10:38.380 | we've talked about how we can, we can sort of look at the learning curve and we can also
00:10:41.820 | look at this, this in-context learning curve, but really those are just two slices of a
00:10:45.860 | two-dimensional, uh, space. So like the, the, in some sense, the more fundamental thing
00:10:50.460 | is how good are we at producing the nth token at a different given point in training and
00:10:55.500 | something that you'll notice if you, if you look at this. Um, so when we were, when we
00:10:58.940 | talk about the loss curve, we're, we're just talking about, if you average over this dimension,
00:11:02.860 | if you, if you like average like this and, and project onto the, the training step, that's,
00:11:07.620 | that's your loss curve. Um, and, uh, if you, the thing that we are calling the in-context
00:11:12.100 | learning curve is just this line. Um, uh, yeah, this, this line, uh, down the, the end
00:11:18.340 | here. Um, and something that's, that's kind of striking is there's, there's this discontinuity
00:11:25.220 | in it. Um, like there's this point where, where, you know, the model seems to get radically
00:11:30.460 | better in a very, very short timestamp span and going in predicting late tokens. So it
00:11:35.820 | doesn't, it's not that different in early time steps, but in late time steps, suddenly
00:11:39.300 | you get better. And a way that you can make this more striking is you can, you can take
00:11:46.980 | the difference in, in your ability to predict the 50th token and your ability to predict
00:11:51.580 | the 500th token. You can subtract from the, the, the 500th token, the 50th token loss.
00:11:57.140 | And what you see, um, is that over the course of training, you know, you're, you're, you're
00:12:02.660 | not very good at this and you got a little bit better. And then suddenly you have this
00:12:05.900 | cliff and then you never get better. The difference between these at least never gets better.
00:12:10.540 | So the model gets better at predicting things, but it's ability to go and predict late tokens
00:12:15.220 | over early tokens never gets better. And so there's in the span of just a few hundred
00:12:20.260 | steps in training, the model has gotten radically better at its ability to go and, and do this
00:12:25.900 | kind of in-context learning. And so you might ask, you know, what's going on at that point?
00:12:31.580 | Um, and this is just one model, but, um, well, so first of all, it's worth noting, this isn't
00:12:35.820 | a small, a small change. Um, so, um, the, you can, we don't think about this very often,
00:12:41.940 | but you know, often we just look at law schools and we're like, did the model do better than
00:12:44.540 | another model or worse than another model. But, um, you can, you can think about this
00:12:47.580 | as in terms of Nats and that are, are, you know, it's just the information theoretic
00:12:51.540 | quantity in that. Um, and you can convert that into, to bits. And so like one, one way
00:12:56.540 | you can interpret this as it's, it's something roughly like, you know, the model 0.4 Nats
00:13:01.220 | is about 0.5 bits is about, uh, like every other token, the model gets to go and sample
00:13:05.380 | twice, um, and pick the better one. It's actually, it's even stronger than that, but that's a
00:13:09.380 | sort of an underestimate of how big a deal going and getting better by 0.4 Nats. So this
00:13:13.980 | is like a real big difference in the models ability to go and predict late tokens. Um,
00:13:20.980 | and we can visualize this in different ways. We can, we can also go and ask, you know,
00:13:23.780 | how much better are we getting at going and predicting later tokens and look at the derivative.
00:13:28.460 | And then we, we can see very clearly that there's, there's some kind of discontinuity
00:13:31.740 | in that derivative at this point. And we can take the second derivative then, and we can,
00:13:36.260 | um, with, well, derivative with respect to training. And now we see that there's like,
00:13:40.780 | there's very, very clearly this, this, this line here. So something in just the span of
00:13:45.140 | a few steps, a few hundred steps is, is causing some big change. Um, we have some kind of
00:13:50.060 | phase change going on. Um, and this is true across model sizes. Um, uh, you can, you can
00:13:56.260 | actually see it a little bit in the loss curve and there's this little bump here, and that
00:13:59.820 | corresponds to the point where you have this, you have this change. We, we actually could
00:14:03.540 | have seen in the loss curve earlier too. It's, it's this bump here, excuse me. So, so we
00:14:09.940 | have this phase change going on and there's a, I think a really tempting theory to have,
00:14:13.940 | which is that somehow, whatever, you know, there, there's some, this, this, this change
00:14:17.860 | in the model's output and its behaviors and it's in a, in a, in a, in a sort of outward
00:14:21.820 | facing properties corresponds presumably to some kind of change in the algorithms that
00:14:26.460 | are running inside the model. So if we observe this big phase change, especially in a very
00:14:30.700 | small window, um, in, in the model's behavior, presumably there's some change in the circuits
00:14:35.620 | inside the model that is driving. At least that's a, you know, a natural hypothesis.
00:14:40.340 | So, um, if we want to ask that though, we need to go and be able to understand, you
00:14:44.340 | know, what are the algorithms that's running inside the model? How can we turn the parameters
00:14:47.620 | in the model back into this algorithm? So that's going to be our goal. Um, now it's
00:14:51.940 | going to recover us, require us to cover a lot of ground, um, in a relatively short amount
00:14:55.940 | of time. So I'm going to go a little bit quickly through the next section and I will highlight
00:15:00.380 | sort of the, the key takeaways, and then I will be very happy, um, to go and, uh, you
00:15:05.980 | know, explore any of this in as much depth. I'm free for another hour after this call.
00:15:10.220 | Um, and just happy to talk in as much depth as people want about the details of this.
00:15:14.940 | So, um, it turns out the space change doesn't happen in a one-layer attentionally transformer,
00:15:20.620 | and it does happen in a two-layer attentionally transformer. So if we could understand a one-layer
00:15:24.460 | attentionally transformer and a two-layer only attention, attentionally transformer,
00:15:28.360 | that might give us a pretty big clue as to what's going on. Um, so we're attention only.
00:15:35.780 | We're also going to leave out layer norm and biases to simplify things. So, you know, you,
00:15:39.580 | one way you could describe a attention only transformer is we're going to embed our tokens
00:15:45.460 | and then we're going to apply a bunch of attention heads and add them into the residual stream
00:15:49.580 | and then apply our unembedding and that'll give us our logits. And we could go and write
00:15:54.140 | that out as equations if we want, multiply it by an embedding matrix, apply attention
00:15:58.160 | heads, and then compute the logits for the unembedding. Um, and the part here that's
00:16:05.960 | a little tricky is, is understanding the attention heads. And this might be a somewhat conventional
00:16:10.360 | way of describing attention heads. And it actually kind of obscures a lot of the structure
00:16:14.700 | of attention heads. Um, I think that oftentimes it's, we, we make attention heads more, more
00:16:19.400 | complex than they are. We sort of hide the interesting structure. So what is this saying?
00:16:23.480 | Well, it's saying, you know, for every token, compute a value vector, a value vector, and
00:16:27.480 | then go and mix the value vectors according to the attention matrix and then project them
00:16:31.480 | with the output matrix back into the residual stream. Um, so there's, there's another notation
00:16:37.960 | which you could think of this as a, as using tensor products or using, um, using, uh, well,
00:16:44.040 | I guess there's a few, a few left and right multiplying. There's, there's a few ways you
00:16:46.680 | can interpret this, but, um, I'll, I'll just sort of try to explain what this notation
00:16:51.240 | means. Um, so this means for every, you know, X, our, our residual stream, we have a vector
00:16:58.240 | for every single token, and this means go and multiply independently the vector for
00:17:04.280 | each token by WV. So compute the value vector for every token. This one on the other hand
00:17:10.440 | means notice that it's now on the, A is on the left-hand side. It means go and go and
00:17:14.600 | multiply the, uh, attention matrix or go and go into linear combinations of values, value
00:17:21.000 | vectors. So don't, don't change the value vectors, you know, point-wise, but go and
00:17:24.280 | mix them together according to the attention pattern, create a weighted sum. And then again,
00:17:29.400 | independently for every position, go and apply the output matrix. And you can apply the distributive
00:17:34.600 | property to this, and it just reveals that actually didn't matter that you did the attention
00:17:38.520 | sort of in the middle. You could have done the attention at the beginning, you could
00:17:40.840 | have done it at the end. Um, that's, that's independent. Um, and the thing that actually
00:17:44.920 | matters is there's this WVWO matrix that describes what it's really saying is, you know, WVWO
00:17:51.560 | describes what information the attention head reads from each position and how it writes
00:17:55.160 | it to its destination. Whereas A describes which tokens we read from and write to. Um,
00:18:01.320 | and that's, that's kind of getting more of the fundamental structure and attention. And
00:18:04.360 | attention head goes and moves information from one position to another and the process
00:18:09.320 | of, of which position gets moved from and to is independent from what information gets moved.
00:18:13.800 | And if you rewrite your transformer that way, um, well, first we can go and write, uh, the
00:18:22.760 | sum of attention heads just as, as in this form. Um, and then we can, uh, go in and write
00:18:29.160 | that as the, the entire layer by going and adding in an identity. And if we go and plug
00:18:35.240 | that all in to our transformer and go and expand, um, we, we have to go in and multiply
00:18:43.640 | everything through. We get this interesting equation. And so we get this one term, this
00:18:48.200 | corresponds to just the path directly through the residual stream. And it's going to want to store,
00:18:52.840 | uh, bigram statistics. It's just, you know, all it gets is the previous token and tries
00:18:56.520 | to predict the next token. And so it gets to try and predict, uh, try to store bigram statistics.
00:19:01.720 | And then for every attention head, we get this matrix that says, okay, well, for,
00:19:05.240 | we have the attention pattern. So it looks, that describes which token looks at which token.
00:19:09.160 | And we have this matrix here, which describes how for every possible token you could attend to,
00:19:13.240 | how it affects the logics. And that's just a table that you can look at. It just says,
00:19:17.960 | you know, for, for this attention head, if it looks at this token, it's going to increase
00:19:21.080 | the probability of these tokens in a one-layer attention only transformer. That's all there is.
00:19:28.920 | Yeah, so this is just, just the interpretation I was describing.
00:19:32.040 | And another thing that's worth noting is this, the, according to this,
00:19:38.280 | the attention only transformer is linear if you fix the attention pattern. Now, of course it's,
00:19:43.560 | the attention pattern isn't fixed, but whenever you have the opportunity to go and make something
00:19:47.160 | linear, linear functions are really easy to understand. And so if you can fix a small
00:19:50.600 | number of things and make something linear, that's actually, it's a lot of leverage. Okay.
00:19:57.560 | And yeah, we can talk about how the attention pattern is computed as well.
00:20:00.520 | Um, you, if you expand it out, you'll get an equation like this
00:20:04.040 | and, uh, notice, well, I think, I think it'll be easier. Okay.
00:20:10.200 | The, I think the core story there to take away from all of these is we have these two matrices
00:20:17.080 | that actually look kind of similar. So this one here tells you if you attend to a token,
00:20:22.680 | how are the logics affected? And it's, you can just think of it as a giant matrix of,
00:20:26.920 | for every possible token input token, how, how is the logic, how are the logics affected
00:20:31.160 | by that token? Are they made more likely or less likely? And we have this one, which sort of says,
00:20:36.360 | how much does every token want to attend to every other token?
00:20:39.320 | Um, one way that you can, you can picture this is, uh, okay, that's really, there's really three
00:20:48.280 | tokens involved when, when we're thinking about an attention head, we have the token that we're
00:20:54.440 | going to move information to, and that's attending backwards. We have the source token that's going
00:20:59.960 | to get attended to, and we have the output token whose logics are going to be affected.
00:21:03.720 | And you can just trace through this. So you can ask what happens. Um, how, how does the,
00:21:09.560 | the attending to this token affect the output? Well, first we embed the token.
00:21:13.720 | Then we multiply by WV to get the value vector. The information gets moved by the attention pattern.
00:21:20.920 | We multiply by W O to add it back into the residual stream, but yet hit by the unembedding
00:21:24.920 | and we affect the logics. And that's where that one matrix comes from. And we can also ask, you
00:21:28.920 | know, what decides, you know, whether a token gets a high score when we're, when we're computing the
00:21:33.560 | attention pattern. And it just says, you know, embed, embed the token, turn it into a query,
00:21:40.280 | embed the other token, turn it into a key and dot product them and see you that's where those,
00:21:45.160 | those two matrices come from. So I know that I'm going quite quickly. Um, maybe I'll just
00:21:53.000 | briefly pause here. And if anyone wants to ask for clarifications, um, this would be a good time.
00:21:57.800 | And then we'll actually go in reverse engineer and say, you know, everything that's going on
00:22:01.880 | in a one-layer attentionally transformer is now in the palm of our hands. It's a very toy model.
00:22:06.840 | No one actually uses one-layer attentionally transformers, but we'll be able to understand
00:22:11.720 | the one-layer attentionally transformer. So just to be clear, so you're saying that
00:22:17.880 | the quite key circuit is learning the attention baits and like essentially is responsible for
00:22:24.440 | running the sort of like the attention between different tokens. Yeah. Yeah. So, so this,
00:22:29.800 | this matrix, when it, yeah. And, you know, all, all three of those parts are, are learned, but
00:22:34.280 | that's, that's what expresses whether a attention pattern. Yeah. That's what generates the attention
00:22:39.560 | patterns gets run for every pair of tokens. And you can, you can, you can think of values in that
00:22:43.720 | matrix as just being how much every token wants to attend to every other token. If it was in the
00:22:48.040 | context, um, we're, we're doing positional learnings here. So there's a little bit that
00:22:51.480 | we're sort of aligning over there as well, but sort of in, in a global sense, how much does
00:22:55.240 | every token want to attend to every other token? Right. Yeah. And the other circuit, like the
00:22:58.840 | optimal circuit is using the attention that's calculated to yes. Um, like affect the final
00:23:06.920 | outputs. It's sort of saying, if, if the attention had assumed that the attention
00:23:10.600 | had attends to some token. So let's set aside the question of how that gets computed. Just
00:23:14.040 | assume that it tends to some token, how would it affect the outputs if it attended to that token?
00:23:18.200 | And you can just, you can just calculate that. Um, it's just a, a big table of values that says,
00:23:23.400 | you know, for this token, it's going to make this token more likely, this token will make this token
00:23:27.320 | less likely. Okay. Okay. And it's completely independent. Like it's just two separate
00:23:33.400 | matrices. They're, they're not, you know, the, the formulas that might make them seem entangled,
00:23:38.200 | but they're actually separate. Right. To me, it seems like the, um, like just supervision is
00:23:43.880 | coming from the output value circuit and the query key circuit seems to be more like unsupervised
00:23:47.960 | kind of thing. Cause there's no. Hmm. I mean, there are just, I think in the sense that every,
00:23:54.440 | in, in, in a model, like every, every neuron is in some sense, you know, like a signal is, is,
00:24:00.760 | is somehow downstream from the ultimate, the ultimate signal. And so, you know, the output
00:24:05.240 | value signal, the output value circuit is getting more, more direct is perhaps getting more direct
00:24:09.720 | signal. Correct. Um, but yeah. We will be able to dig into this in lots of detail in as much detail
00:24:19.800 | as you want, uh, in a little bit. So we can, um, maybe I'll push forward. And I think also actually
00:24:24.920 | an example of how to use this to reverse engineer a one layer model will maybe make it a little bit
00:24:28.840 | more, more motivated. Okay. So, um, just, just to emphasize this, there's three different tokens
00:24:36.840 | that we can talk about. There's a token that gets attended to there's the token that does the
00:24:40.680 | attention to call it the destination. And then there's the token that gets affected, get it,
00:24:44.360 | get the next token, which it's probabilities are affected. Um, and so something we can do
00:24:50.840 | is notice that the, the only token that connects to both of these is the token that gets attended
00:24:55.400 | to. So these two are sort of, they're, they're bridged by their, their interaction with the
00:25:01.080 | source token. So something that's kind of natural is to ask for a given source token, you know,
00:25:06.120 | how does it interact with both of these? So let's, let's take, for instance, the token perfect,
00:25:11.880 | which tokens, one thing we can ask is which tokens want to attend to perfect. Well,
00:25:19.480 | apparently the tokens that most want to attend to perfect are, are, and looks and is, and provides.
00:25:26.200 | Um, so R is the most looks as the next most and so on. And then when we attend to perfect,
00:25:30.760 | and this is with one, one single attention, and so, you know, it'd be different if we did a
00:25:33.960 | different attention, attention, and it wants to really increase the probability of perfect.
00:25:37.720 | And then to a lesser extent, super and absolute and cure. And we can ask, you know, what, what
00:25:44.120 | sequences of tokens are made more likely by this, this particular, um, set of, you know, this,
00:25:51.080 | this particular set of things wanting to attend to each other and, and becoming more likely.
00:25:54.440 | Well, things are the form we have our, our token that we attended back to, and we have some,
00:25:59.880 | some skip of some number of tokens. They don't have to be adjacent, but then later on, we see
00:26:04.200 | the token R and it tends back to perfect and increases the probability of perfect. So you
00:26:09.560 | can, you can think of these as being like, we're, we're sort of creating, changing the probability
00:26:13.160 | of what we might call, might call skip trigrams, where we have, you know, we skip over a bunch of
00:26:17.320 | tokens in the middle, but we're, we're affecting the probability really of, of trigrams. So perfect
00:26:22.760 | are perfect, perfect, look super. Um, we can look at another one. So we have the token large,
00:26:28.120 | um, these tokens contains using specify, want to go and look back to it and an increase of
00:26:32.840 | probability of large and small. And the skip trigrams that are affected are things like
00:26:37.240 | large using large, large contains small, um, and things like this.
00:26:42.680 | Um, if we see the number two, um, we increase the probability of other numbers
00:26:48.760 | and we affect probably tokens or skip diagrams, like two, one, two, two has three.
00:26:55.320 | Um, now you're, you're all in, uh, in a technical field. So you'll probably recognize this one. We
00:27:02.040 | have, uh, have Lambda and then we see backslash and then we want to increase the probability of
00:27:08.040 | Lambda and sorted and Lambda and operator. So it's all, it's all latex. Um, it wants to,
00:27:13.000 | um, it's, if it sees Lambda, it thinks that, you know, maybe next time I use a backslash,
00:27:16.840 | I should go and put in some latex, uh, math symbol. Um, also same thing for HTML. We see
00:27:24.680 | NBSP for non-breaking space. And then we see an ampersand. We want to go and make that more
00:27:28.680 | likely. The takeaway from all of this is that a one-layer attention only transformer is totally
00:27:33.240 | acting on these skip trigrams. Um, every, everything that it does, I mean, I guess it
00:27:38.920 | also has this pathway by which it affects by grams, but mostly it's just affecting these
00:27:41.720 | skip trigrams. Um, and there's lots of them. It's just like these giant tables of skip
00:27:46.280 | trigrams that are made more or less likely. Um, there's lots of other fun things that does.
00:27:52.520 | Sometimes the tokenization will split up a word in multiple ways. So, um, like we have indie,
00:27:57.000 | uh, well, that's, that's not a good example. We have like the word pike, and then we, we see the
00:28:02.280 | token P and then we predict Ike, um, and we predict spikes and stuff like that. Um, or, uh,
00:28:09.480 | these, these ones are kind of fun. Maybe they're actually worth talking about for a second. So
00:28:12.600 | we see the token Lloyd, and then we see an L and maybe we predict Lloyd, um, or R and we predict
00:28:20.600 | Ralph, um, C Catherine. Um, but we'll see in a second that, well, yeah, we'll, we'll come back
00:28:26.280 | to that in a sec. So we, we increased the probability of things like Lloyd, Lloyd,
00:28:29.560 | and Lloyd Catherine or picks map. Um, if anyone's worked with QT, um, it's, we see picks map and we
00:28:35.720 | increased the probability of, um, P X map again, but also Q canvas. Um, yeah, but of course there's
00:28:47.880 | a problem with this, which is, um, it doesn't get to pick which one of these goes with which one.
00:28:52.280 | So if you want to go and make picks map, picks map, and picks map Q canvas more probable,
00:28:59.560 | you also have to go and create, make picks map, picks map, P canvas more probable. And if you
00:29:04.920 | want to make Lloyd, Lloyd, and Lloyd Catherine more probable, you also have to make Lloyd,
00:29:10.120 | Cloyd, and Lloyd Lathren more probable. And so there's actually like bugs that transformers have,
00:29:16.840 | like weird, at least in, you know, in these, these really tiny one-layer attention only
00:29:20.360 | transformers there, there's these bugs that, you know, they seem weird until you realize
00:29:24.120 | that it's this giant table of skip trigrams that's, that's operating. Um, and the, the
00:29:29.640 | nature of that is that you're going to be, um, uh, yeah, you, you, it sort of forces you,
00:29:35.560 | if you want to go and do this to go in and also make some weird predictions.
00:29:38.280 | Is there a reason why the source tokens here have a space before the first character?
00:29:46.200 | Yes. Um, that's just the, I was giving examples where the tokenization breaks in a particular
00:29:52.440 | way and because spaces get included in the tokenization, um, when there's a space in front
00:29:59.560 | of something and then there's an example where the space isn't in front of it, they can get
00:30:03.080 | tokenized in different ways. Got it. Cool. Thanks. Yeah. Great question.
00:30:07.400 | Um, okay. So some, just to abstract away some common patterns that we're seeing, I think,
00:30:15.240 | um, one pretty common thing is what you might describe as like B
00:30:19.640 | AB. So you're, you go and you, you see some token and then you'll see another token that might
00:30:24.360 | proceed that token. And then you're like, ah, probably the token that I saw earlier is going
00:30:27.480 | to occur again. Um, or sometimes you, you predict a slightly different token. So like me, maybe an
00:30:33.880 | example of the first one is two, one, two, but you could also do two has three. And so three,
00:30:40.440 | isn't the same as two, but it's kind of similar. So that's, that's one thing. Another one is this,
00:30:44.040 | this example where you have a token that something that's tokenized together one
00:30:46.680 | time and that's split apart. So you see the token and then you see something that might
00:30:51.000 | be the first part of the token and then you predict the second part. Um, I think the thing
00:30:57.560 | that's really striking about this is these are all in some ways are really crude, kind of in
00:31:02.920 | context learning. Um, and in particular, these models get about 0.1 Nats rather than about 0.4
00:31:10.040 | Nats up in context learning, and they never go through the phase change. So they're doing some
00:31:13.960 | kind of really crude in context learning. And also they're dedicating almost all their attention
00:31:18.280 | heads to this kind of crude in context learning. So they're not very good at it, but they're,
00:31:21.880 | they're, they're dedicating their, um, their capacity to it. Uh, I'm noticing that it's 1037.
00:31:28.200 | Um, I, I want to just check how long I can go. Cause I, maybe I should like super accelerate.
00:31:32.920 | Of course. Uh, I think it's fine because like students are also asking questions in between.
00:31:38.840 | So you should be good. Okay. So maybe my plan will be, but I'll talk until like 1055 or 11.
00:31:44.920 | And then if you want, I can go and answer questions for a while after, after that.
00:31:49.880 | Yeah, it works. Fantastic. So you can see this as a very crude, kind of in context learning.
00:31:56.200 | Like basically what we're saying is it's sort of all of this labor of, okay, well,
00:31:59.480 | I saw this token, probably these other tokens, the same token or similar tokens are more likely
00:32:03.560 | to go and occur later. And look, this is an opportunity that sort of looks like I could
00:32:07.160 | inject the token that I saw earlier. I'm going to inject it here and say that it's more likely
00:32:10.680 | that's like, that's basically what it's doing. And it's dedicating almost all of its capacity
00:32:15.320 | to that. So, you know, these, it's sort of the opposite of what we thought with RNNs in the past,
00:32:18.840 | like used to be that everyone was like, oh, you know, RNNs it's so hard to care about long distance
00:32:23.720 | contacts. You know, maybe we need to go and like use dams or something. No, if you, if you train a
00:32:28.360 | transformer, it dedicates and you give it a long, a long enough context, it's dedicating almost all
00:32:32.280 | of its capacity to this type of stuff. Just kind of interesting. There are some attentionants which
00:32:40.120 | are more primarily positional. Usually we, you know, the model that I've been training that has
00:32:45.320 | two layer or it's only a one layer model has 12 attentionants and usually around two or three of
00:32:49.640 | those will become these more positional sort of shorter term things that do something more like,
00:32:53.640 | like local trigram statistics and then everything else becomes these skip trigrams.
00:33:01.480 | Yeah, so some takeaways from this. Yeah, you can, you can understand
00:33:05.800 | one layer attentionally transformers in terms of these OV and QK circuits.
00:33:09.400 | Transformers desperately want to do in context learning. They desperately, desperately, desperately
00:33:15.400 | want to go and look at these long distance contacts and go and predict things. That's
00:33:19.480 | just so much, so much entropy that they can go and reduce out of that. The constraints of a one
00:33:24.760 | layer attentionally transformer force it to make certain bugs if it wants to do the right thing.
00:33:29.400 | And if you freeze the attention patterns, these models are linear.
00:33:32.440 | Okay. A quick aside, because so far this type of work has required us to do a lot of very manual
00:33:40.760 | inspection. Like we're looking through these giant matrices, but there's a way that we can escape
00:33:44.440 | that. We don't have to use, look at these giant matrices if we don't want to. We can use eigen
00:33:49.080 | values and eigen vectors. So recall that an eigen value and an eigen vector just means that if you,
00:33:54.920 | if you multiply that vector by the matrix, it's equivalent to just scaling. And often in my
00:34:02.600 | experience, this hasn't been very useful for interpretability because we're, we're usually
00:34:05.240 | mapping between different spaces. But if you're mapping onto the same space, eigen values and
00:34:08.920 | eigen vectors are a beautiful way to think about this. So we're going to draw them on a, a radial
00:34:16.440 | plot. And we're going to have a log radial scale because they're going to vary, their magnitude's
00:34:21.880 | going to vary by many orders of magnitude. Okay. So we can just go and, you know, our,
00:34:27.720 | our OB circuit maps from tokens to tokens. That's the same vector space on the input and the output.
00:34:32.280 | And we can ask, you know, what does it mean if we see eigen values of a particular kind? Well,
00:34:36.760 | positive eigen values, and this is really the most important part, mean copying. So if you have a
00:34:40.840 | positive eigen value, it means that there's some set of, of tokens where if you, if you see them,
00:34:45.880 | you increase their probability. And if you have a lot of positive eigen values, you're doing a lot
00:34:50.840 | of copying. If you only have positive eigen values, everything you do is copying. Now,
00:34:55.480 | imaginary eigen values mean that you see a token and then you want to go and increase the probability
00:34:59.160 | of unrelated tokens. And finally, negative eigen values are anti-copying. They're like,
00:35:03.000 | if you see this token, you make it less probable in the future.
00:35:05.240 | Well, that's really nice because now we don't have to go and dig through these giant matrices
00:35:10.440 | that are vocab size by vocab size. We can just look at the eigen values. And so these are the
00:35:15.560 | eigen values for our one-layer attentionally transformer. And we can see that, you know,
00:35:20.280 | for many of these, they're almost entirely positive. These ones are, are sort of entirely
00:35:26.680 | positive. These ones are almost entirely positive. And then really these ones are even almost
00:35:30.840 | entirely positive. And there's only two that have a significant number of imaginary and negative
00:35:36.120 | eigen values. And so what this is telling us is it's, it's just in one picture, we can see,
00:35:40.680 | you know, OK, they're really, you know, 10 out of 12 of these, of these attention heads are just
00:35:46.440 | doing copying. They just, they just are doing this long distance, you know, well, I saw a token,
00:35:50.280 | probably it's going to occur again, type stuff. That's kind of cool. We can, we can summarize it
00:35:54.280 | really quickly. OK. Now, the other thing that you can, yeah, so this is, this is for a second,
00:36:01.400 | we're going to look at a two-layer model in a second. And we'll, we'll see that also a lot of
00:36:04.520 | its heads are doing this kind of copying-ish stuff. They have large positive eigenvalues.
00:36:08.440 | You can do a histogram, like, you know, one, one thing that's cool is you can just add up
00:36:13.960 | the eigenvalues and divide them by their absolute values. And you've got a number between zero and
00:36:17.800 | one, which is like how copying, how copying is just the head, or between negative one and one,
00:36:21.400 | how copying is just the head. And you can just do a histogram and you can see, oh yeah, almost
00:36:25.080 | all of the heads are doing, doing lots of copying. You know, it's nice to be able to go and summarize
00:36:29.640 | your model in a, and I think this is sort of like we've gone for a very bottom-up way. And we didn't
00:36:35.160 | start with assumptions about what the model was doing. We tried to understand its structure. And
00:36:38.040 | then we were able to summarize it in useful ways. And now we're able to go and say something about
00:36:41.160 | it. Now, another thing you might ask is what do the eigenvalues of a QK circuit mean? And in our
00:36:48.440 | example so far, they haven't been that, they wouldn't have been that interesting. But in a
00:36:52.280 | minute, they will be. And so I'll briefly describe what they mean. A positive eigenvalue would mean
00:36:55.640 | you want to attend to the same tokens. An imaginary eigenvalue, and this is what you would mostly see
00:37:00.760 | in the models that we've seen so far, means you want to go in and attend to a unrelated or different
00:37:05.560 | token. And a negative eigenvalue would mean you want to avoid attending to the same token.
00:37:09.960 | So that will be relevant in a second. Yeah, so those are going to mostly be useful to think
00:37:17.000 | about in multilayer attentional learning transformers when we can have chains of attention
00:37:20.280 | heads. And so we can ask, well, I'll get to that in a second. Yeah, so that's a table summarizing
00:37:25.160 | that. Unfortunately, this approach completely breaks down once you have MLP layers. MLP layers,
00:37:30.840 | you know, now you have these non-linearity since you don't get this property where your model is
00:37:34.120 | mostly linear and you can just look at a matrix. But if you're working with only attentionally
00:37:38.280 | transformers, this is a very nice way to think about things. OK, so recall that one-layer
00:37:42.920 | attentionally transformers don't undergo this phase change that we talked about in the beginning.
00:37:45.880 | Like right now, we're on a hunt. We're trying to go and answer this mystery of what the hell is
00:37:50.200 | going on in that phase change where models suddenly get good at in-context learning.
00:37:53.800 | We want to answer that. And one-layer attentionally transformers don't undergo that phase change,
00:37:58.040 | but two-layer attentionally transformers do. So we'd like to know what's different about
00:38:01.640 | two-layer attentionally transformers. OK, well, so in our previous-- when we were dealing with
00:38:10.280 | one-layer attentionally transformers, we were able to go and rewrite them in this form. And it gave
00:38:14.680 | us a lot of ability to go and understand the model because we could go and say, well, you know,
00:38:19.080 | this is bigrams. And then each one of these is looking somewhere. And we had this matrix that
00:38:23.160 | describes how it affects things. And yeah, so that gave us a lot of ability to think about
00:38:29.160 | these things. And we can also just write in this factored form where we have the embedding,
00:38:33.880 | and then we have the attention heads, and then we have the unembedding.
00:38:36.040 | OK, well-- oh, and for simplicity, we often go and write W-O-V for W-O-W-V because they always
00:38:45.400 | come together. It's always the case. It's, in some sense, an illusion that W-O and W-V are
00:38:49.720 | different matrices. They're just one low-rank matrix. They're never-- they're always used
00:38:53.160 | together. And similarly, W-Q and W-K, it's sort of an illusion that they're different matrices.
00:38:57.400 | They're always just used together. And keys and queries are just sort of-- they're just an
00:39:01.960 | artifact of these low-rank matrices. So in any case, it's useful to go and write those together.
00:39:06.920 | OK, great. So a two-layer attentionally transformer, what we do is we go through the
00:39:12.760 | embedding matrix. Then we go through the layer 1 attention heads. Then we go through the layer 2
00:39:17.800 | attention heads. And then we go through the unembedding. And for the attention heads,
00:39:21.960 | we always have this identity as well, which corresponds just going down the residual stream.
00:39:25.960 | So we can go down the residual stream, or we can go through an attention head.
00:39:29.400 | Next up, we can also go down the residual stream, or we can go through an attention head.
00:39:33.960 | And there's this useful identity, the mixed product identity that any tensor product or
00:39:44.040 | other ways of interpreting this obey, which is that if you have an attention head and we have,
00:39:49.800 | say, we have the weights and the attention pattern and the W-O-V matrix and the attention pattern,
00:39:54.360 | the attention patterns multiply together, and the O-V circuits multiply together,
00:39:58.360 | and they behave nicely. OK, great. So we can just expand out that equation. We can just take that
00:40:05.000 | big product we had at the beginning, and we can just expand it out. And we get three different
00:40:08.040 | kinds of terms. So one thing we do is we get this path that just goes directly through the residual
00:40:12.840 | stream where we embed and un-embed, and that's going to want to represent some bigram statistics.
00:40:16.920 | Then we get things that look like the attention head terms that we had previously.
00:40:23.720 | And finally, we get these terms that correspond to going through two attention heads.
00:40:32.360 | Now, it's worth noting that these terms are not actually the same as--
00:40:40.520 | they're-- because the attention head, the attention patterns in the second layer can be computed from
00:40:45.320 | the outputs of the first layer, those are also going to be more expressive. But at a high level,
00:40:49.480 | you can think of there as being these three different kinds of terms. And we sometimes
00:40:52.600 | call these terms virtual attention heads because they don't exist in the sense-- like, they aren't
00:40:57.000 | explicitly represented in the model, but they, in fact, they have an attention pattern. They have
00:41:02.040 | no E-circuit. They're in almost all functional ways like a tiny little attention head, and there's
00:41:06.680 | exponentially many of them. Turns out they're not going to be that important in this model,
00:41:10.920 | but in other models, they can be important. Right, so one thing that I said is it allows
00:41:16.600 | us to think about attention heads in a really principled way. We don't have to go and think
00:41:19.880 | about-- I think there's-- people look at attention patterns all the time, and I think a concern you
00:41:27.240 | have is, well, there's multiple attention patterns. The information that's being moved by one attention
00:41:31.880 | head, it might have been moved there by another attention head and not originated there. It might
00:41:35.560 | still be moved somewhere else. But in fact, this gives us a way to avoid all those concerns and
00:41:40.120 | just think about things in a single principled way. OK, in any case, an important question to ask is,
00:41:46.680 | how important are these different terms? Like, we could study all of them. How important are they?
00:41:50.680 | And it turns out you can just-- there's an algorithm you can use where you knock out
00:41:56.840 | attention-- knock out these terms, and you go and you ask, how important are they?
00:42:01.320 | And it turns out by far the most important thing is these individual attention head terms. In this
00:42:06.440 | model, by far the most important thing, the virtual attention heads basically don't matter
00:42:11.400 | that much. They only have an effective 0.3 nats using to the above ones, and the bigrams are still
00:42:16.680 | pretty useful. So if we want to try to understand this model, we should probably go and focus our
00:42:20.120 | attention on-- the virtual attention heads are not going to be the best way to go and focus our
00:42:26.680 | attention, especially since there's a lot of them. There's 124 of them for 0.3 nats. It's very little
00:42:31.720 | that you would understand for studying one of those terms. So the thing that we probably want
00:42:36.440 | to do-- we know that these are bigram statistics. So what we really want to do is we want to
00:42:39.800 | understand the individual attention head terms. This is the algorithm. I'm going to skip over it
00:42:48.520 | for time. We can ignore that term because it's small. And it turns out also that the layer 2
00:42:54.520 | attention heads are doing way more than layer 1 attention heads. And that's not that surprising.
00:42:58.840 | Layer 2 attention heads are more expressive because they can use the layer 1 attention
00:43:02.280 | heads to construct their attention patterns. So if we could just go and understand the layer
00:43:08.120 | 2 attention heads, we'd probably understand a lot of what's going on in this model.
00:43:11.720 | And the trick is that the attention heads are now constructed from the previous layer rather
00:43:18.200 | than just from the tokens. So this is still the same, but the attention pattern is more complex.
00:43:24.120 | And if you write it out, you get this complex equation that says, you embed the tokens,
00:43:28.440 | and you're going to shuffle things around using the attention heads for the keys. Then you multiply
00:43:31.880 | by WQK. Then you shuffle things around again for the queries. And then you go and multiply by the
00:43:37.080 | embedding again because they were embedded. And then you get back to the tokens. But let's
00:43:44.200 | actually look at them. So one thing that's-- remember that when we see positive eigenvalues
00:43:49.000 | in the OB circuit, we're doing copying. So one thing we can say is, well, 7 out of 12-- and in
00:43:53.800 | fact, the ones with the largest eigenvalues are doing copying. So we still have a lot of attention
00:43:58.680 | heads that are doing copying. And yeah, the QK circuit-- so one thing you could do is you could
00:44:07.480 | try to understand things in terms of this more complex QK equation. You could also just try to
00:44:10.600 | understand what the attention patterns are doing empirically. So let's look at one of these copying
00:44:14.520 | ones. I've given it the first paragraph of Harry Potter, and we can just look at word attempts.
00:44:23.240 | And something really interesting happens. So almost all the time, we just attend back to
00:44:27.880 | the first token. We have this special token at the beginning of the sequence.
00:44:31.480 | And we usually think of that as just being a null attention operation. It's a way for it to not do
00:44:36.440 | anything. In fact, if you look, the value vector is basically 0. It's just not copying any information
00:44:40.600 | from that. But whenever we see repeated text, something interesting happens. So when we get
00:44:47.800 | to "Mr."-- tries to look at "and." It's a little bit weak. Then we get to "D," and it attends to
00:44:55.400 | "ers." That's interesting. And then we get to "ers," and it attends to "ly." And so it's not
00:45:04.920 | attending to the same token. It's attending to the same token, shifted one forward. Well, that's
00:45:12.840 | really interesting. And there's actually a lot of attention heads that are doing this. So here we
00:45:16.760 | have one where now we hit the potter's pot, and we attend to "ters." Maybe that's the same attention
00:45:21.880 | head I don't remember when I was constructing this example. It turns out this is a super common
00:45:25.880 | thing. So you go and you look for the previous example, you shift one forward, and you're like,
00:45:30.120 | OK, well, last time I saw this, this is what happened. Probably the same thing is going to
00:45:33.080 | happen. And we can go and look at the effect that the attention head has on the logits. Most of the
00:45:42.360 | time, it's not affecting things. But in these cases, it's able to go and predict when it's
00:45:46.200 | doing this thing of going and looking one forward. It's able to go and predict the next token.
00:45:49.160 | So we call this an induction head. An induction head looks for the previous copy,
00:45:54.920 | looks forward, and says, ah, probably the same thing that happened last time is going to happen.
00:45:58.680 | You can think of this as being a nearest neighbors. It's like an in-context nearest
00:46:02.280 | neighbors algorithm. It's going and searching through your context, finding similar things,
00:46:06.120 | and then predicting that's what's going to happen next.
00:46:11.480 | The way that these actually work is-- I mean, there's actually two ways. But in a model that
00:46:17.400 | uses rotary attention or something like this, you only have one. You shift your key. First,
00:46:24.280 | an earlier attention head shifts your key forward one. So you take the value of the previous token,
00:46:29.960 | and you embed it in your present token. And then you have your query in your key,
00:46:33.960 | go and look at-- yeah, try to go and match. So you look for the same thing.
00:46:40.600 | And then you go and you predict that whatever you saw is going to be the next token. So that's the
00:46:45.480 | high-level algorithm. Sometimes you can do clever things where actually it'll care about multiple
00:46:50.200 | earlier tokens, and it'll look for short phrases and so on. So induction heads can really vary in
00:46:54.680 | how much of the previous context they care about or what aspects of the previous context they care
00:46:58.200 | about. But this general trick of looking for the same thing, shift forward, predict that,
00:47:03.320 | is what induction heads will do. Lots of examples of this. And the cool thing is you can now
00:47:10.760 | you can use the QK eigenvalues to characterize this. You can say, well, we're looking for the
00:47:15.560 | same thing, shifted by one, but looking for the same thing. If you expand through the attention
00:47:19.160 | nodes in the right way, that'll work out. And we're copying. And so an induction head is one
00:47:23.640 | which has both positive OV eigenvalues and also positive QK eigenvalues.
00:47:28.520 | And so you can just put that on a plot, and you have your induction heads in the corner.
00:47:37.000 | So your OV eigenvalues, your QK eigenvalues, and I think actually OV is this axis, QK is this one
00:47:42.760 | axis, doesn't matter. And in the corner, you have your eigenvalues or your induction heads.
00:47:49.400 | And so this seems to be-- well, OK, we now have an actual hypothesis. The hypothesis is the way
00:47:57.160 | that that phase change we were seeing, the phase change is the discovery of these induction heads.
00:48:01.240 | That would be the hypothesis. And these are way more effective than this first algorithm we had,
00:48:08.120 | which was just blindly copy things wherever it could be plausible. Now we can go and actually
00:48:12.520 | recognize patterns and look at what happened and predict that similar things are going to
00:48:15.560 | happen again. That's a way better algorithm. Yeah, so there's other attention heads that
00:48:22.280 | are doing more local things. I'm going to go and skip over that and return to our mystery,
00:48:25.800 | because I am running out of time. I have five more minutes. OK, so what is going on with this
00:48:30.200 | in-context learning? Well, now we have a hypothesis. Let's check it. So we think it
00:48:34.040 | might be induction heads. And there's a few reasons we believe this. So one thing is going
00:48:40.600 | to be that induction heads-- well, OK, I'll just go over to the end. So one thing you can do is you
00:48:46.520 | can just ablate the attention heads. And it turns out you can color-- here we have attention heads
00:48:52.200 | colored by how much they are an induction head. And this is the start of the bump. This is the
00:48:58.040 | end of the bump here. And we can see that they-- first of all, induction heads are forming. Like
00:49:02.600 | previously, we didn't have induction heads here. Now they're just starting to form here. And then
00:49:06.920 | we have really intense induction heads here and here. And the attention heads, where if you ablate
00:49:13.080 | them, you get a loss. And so we're looking not at the loss, but at this meta learning score,
00:49:20.760 | the difference between-- or in-context learning score, the difference between the 500th token
00:49:25.000 | and the 50th token. And that's all explained by induction heads. Now, we actually have one
00:49:31.480 | induction head that doesn't contribute to it. Actually, it does the opposite. So that's kind
00:49:34.840 | of interesting. Maybe it's doing something shorter distance. And there's also this interesting thing
00:49:39.800 | where they all rush to be induction heads. And then they discover only a few went out in the end.
00:49:44.680 | So there's some interesting dynamics going on there. But it really seems like in these small
00:49:48.280 | models, all of in-context learning is explained by these induction heads. OK. What about large
00:49:55.880 | models? Well, in large models, it's going to be harder to go and ask this. But one thing you can
00:49:59.560 | do is you can ask, OK, we can look at our in-context learning score over time. We get this
00:50:06.600 | sharp phase change. Oh, look. Induction heads form at exactly the same point in time. So that's only
00:50:13.400 | correlational evidence. But it's pretty suggestive correlational evidence, especially given that we
00:50:17.320 | have an obvious-- the obvious effect that induction heads should have is this. I guess it could be
00:50:22.920 | that there's other mechanisms being discovered at the same time in large models. But it has to
00:50:26.360 | be in a very small window. So I really suggest the thing that's driving that change is in-context
00:50:32.760 | learning. OK. So obviously, induction heads can go and copy text. But a question you might ask is,
00:50:41.080 | can they do translation? There's all these amazing things that models can do that it's not obvious
00:50:47.000 | in-context learning or this sort of copying mechanism could do. So I just want to very
00:50:51.240 | quickly look at a few fun examples. So here we have an attention pattern. Oh, yeah. I guess
00:51:02.120 | I need to open Lexiscope. Let me try doing that again. Sorry. I should have thought this through
00:51:12.200 | a bit more before this talk. Chris, could you zoom in a little, please? Yeah, yeah. Thank you.
00:51:38.840 | OK. I'm not-- my French isn't that great. But my name is Christopher. I'm from Canada.
00:51:44.360 | What we can do here is we can look at where this attention head attends as we go and we do this.
00:51:50.680 | And it'll become especially clear on the second sentence. So here, we're on the period,
00:51:55.400 | and we attend to "je." Now we're on-- and "je" is "I" in French. OK. Now we're on the "I,"
00:52:03.640 | and we attend to "sui." Now we're on the "am," and we attend to "do," which is "from,"
00:52:09.720 | and then "from" to "Canada." And so we're doing a cross-lingual induction head, which we can use
00:52:16.040 | for translation. And indeed, if you look at examples, this is where it seems to-- it seems
00:52:21.960 | to be a major driving force in the model's ability to go and correctly do translation.
00:52:28.680 | Another fun example is-- I think maybe the most impressive thing about in-context learning to me
00:52:35.240 | has been the model's ability to go and learn arbitrary functions. Like,
00:52:38.040 | you can just show the model a function. It can start mimicking that function. Well, OK.
00:52:41.800 | I have a question.
00:52:44.680 | So do these induction heads only do kind of a look-ahead copy? Or can they also do some sort
00:52:50.120 | of complex structure recognition? Yeah, yeah. So they can both use a larger context-- previous
00:52:58.520 | context-- and they can copy more abstract things. So the translation one is showing you that they
00:53:02.600 | can copy, rather than the literal token, a translated version. So it's what we might call
00:53:06.520 | soft induction head. And yeah, you can have them copy similar words. You can have them
00:53:11.800 | look at longer contexts. You can look for more structural things. The way that we usually
00:53:16.440 | characterize them is whether-- in large models, just whether they empirically behave like an
00:53:20.600 | induction head. So the definition gets a little bit blurry when you try to encompass these more--
00:53:25.480 | there's sort of a blurry boundary. But yeah, there seem to be a lot of attention heads that
00:53:29.480 | are doing sort of more and more abstract versions. And yeah, my favorite version is this one that I'm
00:53:36.040 | about to show you, which is used-- let's isolate a single one of these-- which can do pattern
00:53:42.200 | recognition. So it can learn functions in the context and learn how to do it. So I've just
00:53:46.200 | made up a nonsense function here. We're going to encode one binary variable with the choice of
00:53:52.120 | whether to do a color or a month as the first word. Then we're going to-- so we have green or
00:53:58.440 | June here. Let's zoom in more. So we have color or month, and animal or fruit. And then we have
00:54:07.400 | to map it to either true or false. So that's our goal. And it's going to be an XOR. So we have
00:54:11.480 | the binary variable represented in this way. We do an XOR. I'm pretty confident this was never
00:54:16.840 | in the training set, because I just made it up, and it seems like a nonsense problem.
00:54:21.080 | OK, so then we can go and ask, can the model go and predict that? Well, it can, and it uses
00:54:25.880 | induction heads to do it. And what we can do is we can look at the-- so we look at a colon where
00:54:30.120 | it's going to go and try and predict the next word. And for instance here, we have April dog.
00:54:35.880 | So it's a month and then an animal, and it should be true. And what it does is it looks for a
00:54:40.920 | previous-- previous cases where there was an animal-- a month and then an animal, especially
00:54:45.400 | one where the month was the same-- and goes and looks and says that it's true. And so the model
00:54:50.120 | can go and learn-- learn a function, a completely arbitrary function, by going and doing this kind
00:54:55.240 | of pattern recognition induction head. And so this, to me, made it a lot more plausible that
00:55:01.240 | these models actually can do-- can do in-context learning. Like, the generality of all these
00:55:08.600 | amazing things we see these large language models do can be explained by induction heads. We don't
00:55:14.440 | know that. It could be that there's other things going on. It's very possible that there's lots
00:55:17.800 | of other things going on. But it seems a lot more plausible to me than it did when we started.
00:55:22.680 | I'm conscious that I am actually, over time-- I'm going to just quickly go through these last
00:55:27.800 | few slides. So I think thinking of this as an in-context nearest neighbors, I think,
00:55:31.400 | is a really useful way to think about this. Other things could absolutely be contributing.
00:55:35.800 | This might explain why transformers do in-context learning over long-context better than LSTMs.
00:55:44.040 | And LSTM can't do this, because it's not linear in the amount of compute it needs. It's, like,
00:55:48.200 | quadratic or n log n if it was really clever. So transformers-- or LSTM's impossible to do this.
00:55:53.640 | Transformers do do this. And actually, they diverge at the same point. But if you look-- well,
00:55:58.840 | I can go into this in more detail after, if you want. There's a really nice paper by Marcus Hutter
00:56:05.000 | trying to predict and explain why we observe scaling laws in models. It's worth noting that
00:56:09.080 | the arguments in this paper go exactly through to this example, this theory. In fact, they work
00:56:15.160 | better for the case of thinking about this in-context learning with, essentially, a nearest
00:56:20.120 | neighbors algorithm than they do in the regular case. So yeah, I'm happy to answer questions. I
00:56:27.160 | can go into as much detail as people want about any of this. And I can also, if you send me an
00:56:31.640 | email, send me more information about all this. And yeah, again, this work is not yet published.
00:56:38.440 | You don't have to keep it secret. But just if you could be thoughtful about the fact that it's
00:56:42.680 | unpublished work and probably is a month or two away from coming out, I'd be really grateful for
00:56:46.360 | that. Thank you so much for your time. Yeah, thanks a lot, Chris. This was a great talk.
00:56:50.680 | So I'll just open it to some general questions. And then we can do a round of questions from the
00:56:57.880 | students. So I was very excited to know, so what is the line of work that you're currently working
00:57:03.080 | on? Is it extending this? So what do you think is the next things you try to do to make it more
00:57:08.360 | interpretable? What are the next? Yeah. I mean, I want to just reverse engineer language models. I
00:57:13.320 | want to figure out the entirety of what's going on in these language models. And one thing that we
00:57:21.960 | totally don't understand is MLP layers. We understand some things about them, but we don't
00:57:29.240 | really understand MLP layers very well. There's a lot of stuff going on in large models that we
00:57:33.800 | don't understand. I want to know how models do arithmetic. I want to know-- another thing that
00:57:38.360 | I'm very interested in is what's going on when you have multiple speakers. The model can clearly
00:57:41.960 | represent-- it has a basic theory of mind, multiple speakers in a dialogue. I want to understand
00:57:46.200 | what's going on with that. But honestly, there's just so much we don't understand. It's sort of
00:57:51.560 | hard to answer the question because there's just so much to figure out. And we have a lot of
00:57:56.760 | different threads of research in doing this. But yeah, the interpretability team at Anthropic
00:58:02.600 | is just sort of-- has a bunch of threads trying to go and figure out what's going on inside these
00:58:06.760 | models. And sort of a similar flavor to this of just trying to figure out, how do the parameters
00:58:11.240 | actually encode algorithms? And can we reverse engineer those into meaningful computer programs
00:58:16.440 | that we can understand? Got it. Another question I had is, so you were talking about how the
00:58:22.840 | transformers are trying to do meta-learning inherently. So it's like-- and you spent a lot
00:58:27.320 | of time talking about the induction heads, and that was very interesting. But can you formalize
00:58:31.720 | the sort of meta-learning algorithm they might be learning? Is it possible to say, oh, maybe this is
00:58:35.880 | a sort of internal algorithm that's going that's making them good meta-learners or something like
00:58:40.680 | that? I don't know. Yeah, I mean, I think that there's roughly two algorithms. One is this
00:58:46.040 | algorithm we saw in the one-layer model. And we see it in other models, too, especially early on,
00:58:49.480 | which is just try to copy-- you saw a word, probably a similar word is going to happen
00:58:53.880 | later. Look for places that it might fit in and increase the probability. So that's one thing that
00:58:59.000 | we see. And the other thing we see is induction heads, which you can just summarize as in-context
00:59:05.000 | nearest neighbors, basically. And it seems-- possibly there's other things, but it seems like
00:59:09.640 | those two algorithms and the specific instantiations that we are looking at seem to be what's driving
00:59:15.640 | in-context learning. That would be my present theory. Yeah, it sounds very interesting. Yeah.
00:59:21.160 | OK, so let's open-- make a round of questions. So yeah, feel free to go ahead for questions.
00:59:28.200 | Thank you.
00:59:28.700 | [BLANK_AUDIO]