back to indexStanford CS224N NLP with Deep Learning | 2023 | Lecture 8 - Self-Attention and Transformers
00:00:09.280 |
We're about two minutes in, so let's get started. 00:00:12.880 |
So today, we've got what I think is quite an exciting lecture topic. 00:00:17.240 |
We're going to talk about self-attention and transformers. 00:00:21.760 |
So these are some ideas that are sort of the foundation 00:00:25.320 |
of most of the modern advances in natural language processing. 00:00:29.400 |
And actually, AI systems in a broad range of fields. 00:00:39.480 |
OK, before we get into that, we're going to have a couple of reminders. 00:01:02.460 |
with what I'll be talking about today, but go into considerably more detail. 00:01:16.540 |
Thankfully, our TAs especially has tested that this works on Colab, 00:01:25.100 |
and the amount of training is such that a Colab session will 00:01:29.040 |
allow you to train your machine translation system. 00:01:37.100 |
to more GPUs for assignment five in the final project. 00:01:41.900 |
We'll continue to update you as we're able to. 00:01:44.660 |
But the usual systems this year are no longer 00:01:49.100 |
holding because companies are changing their minds about things. 00:01:52.140 |
OK, so our final project proposal, you have a proposal 00:01:57.540 |
of what you want to work on for your final project. 00:02:00.460 |
We will give you feedback on whether we think it's a feasible idea 00:02:05.260 |
So this is very important because we want you to work on something 00:02:07.980 |
that we think has a good chance of success for the rest of the quarter. 00:02:12.460 |
We'll have an ad announcement when it is out. 00:02:15.660 |
And we want to get you feedback on that pretty quickly 00:02:19.020 |
because you'll be working on this after assignment five is done. 00:02:22.140 |
Really, the major core component of the course after that 00:02:33.860 |
So let's take a look back into what we've done so far in this course 00:02:41.900 |
and see what we were doing in natural language processing. 00:02:48.080 |
If you had a natural language processing problem 00:02:50.080 |
and you wanted to take your best effort attempt at it 00:02:53.060 |
without doing anything too fancy, you would have said, OK, 00:02:56.040 |
I'm going to have a bidirectional LSTM instead of a simple RNN. 00:03:01.180 |
I'm going to use an LSTM to encode my sentences, 00:03:06.140 |
And if I have an output that I'm trying to generate, 00:03:09.300 |
I'll have a unidirectional LSTM that I was going to generate one by one. 00:03:14.100 |
So you have a translation or a parse or whatever. 00:03:17.140 |
And so maybe I've encoded in a bidirectional LSTM the source sentence 00:03:20.500 |
and I'm sort of one by one decoding out the target 00:03:26.480 |
And then also, I was going to use something like attention 00:03:34.800 |
like I needed to do this sort of look back and see 00:03:39.040 |
And this was just working exceptionally well. 00:03:41.960 |
And we motivated attention through wanting to do machine translation. 00:03:48.140 |
want to have to encode the whole source sentence in a single vector. 00:03:55.000 |
So we're going to be looking at a lot of the same problems 00:03:58.520 |
But we're going to use different building blocks. 00:04:00.560 |
We're going to say, if 2014 to 2017-ish I was using recurrence 00:04:07.480 |
through lots of trial and error, years later, 00:04:10.400 |
it had these brand new building blocks that we 00:04:17.160 |
And they're going to allow for just a huge range of much more successful 00:04:23.320 |
And so what are the issues with the recurrent neural networks 00:04:29.680 |
And what are the new systems that we're going to use from this point moving 00:04:35.160 |
So one of the issues with a recurrent neural network 00:04:38.880 |
is what we're going to call linear interaction distance. 00:04:41.680 |
So as we know, RNNs are unrolled left to right or right to left, 00:04:49.840 |
But it encodes the notion of linear locality, which is useful. 00:04:53.080 |
Because if two words occur right next to each other, 00:04:59.680 |
And in the recurrent neural network, you encode tasty. 00:05:04.360 |
And then you walk one step, and you encode pizza. 00:05:08.720 |
So nearby words do often affect each other's meanings. 00:05:12.600 |
But you have this problem where very long distance dependencies 00:05:28.680 |
Like the chef who went to the stores and picked up the ingredients 00:05:40.440 |
of application of the recurrent weight matrix 00:05:43.000 |
and some element-wise nonlinearities once, twice, three times. 00:05:47.320 |
As many times as there is potentially the length of the sequence between chef 00:05:56.840 |
Should feel kind of related to the stuff that we did in dependency syntax. 00:06:01.120 |
But it's quite difficult to learn potentially 00:06:19.440 |
it can be difficult to learn the dependencies between them. 00:06:24.240 |
LSTMs do a lot better at modeling the gradients across long distances 00:06:36.680 |
isn't sort of the right way to think about sentences. 00:06:40.520 |
So if I wanted to learn that it's the chef who was, 00:06:46.920 |
then I might have a hard time doing it because the gradients have 00:06:56.440 |
between words that might be related in the sentence. 00:06:59.080 |
Or in a document even, if these are going to get much longer. 00:07:04.000 |
So this is this linear interaction distance problem. 00:07:08.400 |
to be able to interact with each other in the neural networks 00:07:11.000 |
computation graph more easily than being linearly far away 00:07:19.800 |
so that we can learn these long distance dependencies better. 00:07:23.000 |
And there's a related problem too that again comes back 00:07:25.560 |
to the recurrent neural networks dependence on the index. 00:07:28.640 |
On the index into the sequence, often called a dependence on time. 00:07:32.880 |
So in a recurrent neural network, the forward and backward passes 00:07:39.520 |
So that means just roughly sequence, in this case, 00:07:41.600 |
just sequence length many unparallelizable operations. 00:07:50.520 |
as long as there's no dependency between the operations in terms 00:07:54.360 |
You have to compute one and then compute the other. 00:07:59.800 |
can't actually compute the RNN hidden state for time step 5 00:08:03.800 |
before you compute the RNN hidden state for time step 4 00:08:08.560 |
And so you get this graph that looks very similar, 00:08:11.320 |
where if I want to compute this hidden state, 00:08:13.200 |
so I've got some word, I have zero operations 00:08:16.160 |
I need to do before I can compute this state. 00:08:18.600 |
I have one operation I can do before I can compute this state. 00:08:28.880 |
the state with the number 3, because I need to compute this 00:08:37.000 |
that I'm glomming all the matrix multiplies and stuff 00:08:42.480 |
And of course, this grows with the sequence length as well. 00:08:45.320 |
So down over here, as the sequence length grows, 00:08:53.600 |
with the matrix multiply to compute this state, 00:08:56.840 |
because I need to compute all the previous states beforehand. 00:09:08.800 |
Yeah, so I have a question on the linear interaction issues. 00:09:11.360 |
I thought that was the whole point of the attention network, 00:09:13.880 |
and then how maybe you want, during the training, 00:09:17.960 |
of the actual cells that depend more on each other. 00:09:26.200 |
So the question is, with the linear interaction distance, 00:09:31.720 |
Can't we use something with attention to help, 00:09:35.040 |
So it won't solve the parallelizability problem. 00:09:37.480 |
And in fact, everything we do in the rest of this lecture 00:10:00.040 |
And so we're going to get deep into attention today. 00:10:08.160 |
as a query to access and incorporate information 00:10:14.880 |
We were decoding out a translation of a sentence. 00:10:19.280 |
that we didn't have to store the entire representation 00:10:35.960 |
I'm writing out the number of unparallelizable operations 00:10:38.840 |
that you need to do before you can compute these. 00:10:43.600 |
compute its embedding without doing anything else previously, 00:10:46.920 |
because the embedding just depends on the word identity. 00:10:53.460 |
to build an attention representation of this word 00:10:55.580 |
by looking at all the other words in the sequence, 00:10:59.740 |
And I can do them in parallel for all the words. 00:11:06.100 |
I don't need to walk left to right like I did for an RNN. 00:11:16.220 |
problem and the non-parallelizability problem. 00:11:18.860 |
Because now, no matter how far away words are from each other, 00:11:23.980 |
I might just attend to you, even if you're very, very far away, 00:11:29.840 |
And I also don't need to sort of walk along the sequence 00:11:42.280 |
And it doesn't have this dependence on the sequence 00:11:44.440 |
index that keeps us from parallelizing operations. 00:12:00.180 |
One thing that you might think of with attention 00:12:02.660 |
is that it's sort of performing kind of a fuzzy lookup 00:12:07.180 |
So you have a bunch of keys, a bunch of values, 00:12:09.540 |
and it's going to help you sort of access that. 00:12:14.300 |
like a dictionary in Python, for example, very simple. 00:12:18.220 |
You have a table of keys that each key maps to a value. 00:12:39.500 |
And in attention, so just like we saw before, 00:13:00.020 |
The query, to different extents, is similar to each of the keys. 00:13:04.500 |
And you will sort of measure that similarity between 0 and 1 00:13:12.380 |
So you average them via the weights of the similarity 00:13:21.660 |
So it really is quite a bit like a lookup table, 00:13:24.780 |
but in this sort of soft vector space, mushy sort of sense. 00:13:32.140 |
into this information that's stored in the key value store. 00:13:35.820 |
But I'm sort of softly looking at all of the results. 00:13:48.620 |
So if I was trying to represent this sentence, 00:13:54.260 |
So I'm trying to build a representation of learned. 00:14:01.580 |
So this is this self-attention thing that we'll get into. 00:14:04.500 |
I have a key for each word, a value for each word. 00:14:08.380 |
And I've got these sort of tealish bars up top, 00:14:22.700 |
And then learned, maybe that's important to representing 00:14:25.900 |
So you sort of look across at the whole sentence 00:14:28.100 |
and build up this sort of soft accessing of information 00:14:31.020 |
across the sentence in order to represent learned in context. 00:14:40.460 |
So we're going to look at a sequence of words. 00:14:43.540 |
So that's w1 to n, a sequence of words in a vocabulary. 00:14:46.860 |
So this is like, you know, Zuko made his uncle tea. 00:14:57.780 |
goes from the vocabulary size to the dimensionality d. 00:15:07.500 |
And now I'm going to transform each word with one of three 00:15:12.500 |
So this is often called key query value self-attention. 00:15:19.980 |
So this maps xi, which is a vector of dimensionality d, 00:15:31.260 |
shuffles it around, stretches it, squishes it. 00:15:35.940 |
And now for a different learnable parameter, k-- 00:15:47.100 |
So I'm taking each of the non-contextual word embeddings, 00:15:49.640 |
each of these xi's, and I'm transforming each of them 00:15:53.300 |
to come up with my query for that word, my key for that word, 00:16:03.700 |
Next, I'm going to compute all pairs of similarities 00:16:10.500 |
computing the similarity between a single query for the word 00:16:13.260 |
learned and all of the keys for the entire sentence. 00:16:17.380 |
In this context, I'm computing all pairs of similarities 00:16:20.380 |
between all keys and all values because I want to represent 00:16:27.620 |
I'm just going to take the dot product between these two 00:17:06.020 |
and all of the possible j prime in the sequence. 00:17:08.220 |
And then my output is just the weighted sum of values. 00:17:18.420 |
And I'm representing it as the sum of these weights 00:17:39.380 |
Oh, Wi, you can either think of it as a symbol in vocab v. 00:17:44.900 |
So that's like, you could think of it as a one-hot vector. 00:17:47.660 |
And yeah, in this case, we are, I guess, thinking of it as-- 00:17:51.220 |
so one-hot vector in dimensionality size of vocab. 00:17:54.380 |
So in the matrix E, you see that it's r d by bars around v. 00:18:05.100 |
taking E, which is d by v, multiplying it by w, which is v, 00:18:10.540 |
and returning a vector that's dimensionality d. 00:18:20.700 |
like a column for every word in that sentence. 00:18:25.820 |
Yeah, usually, I guess we think of it as having a-- 00:18:28.460 |
I mean, if I'm putting the sequence length index first, 00:18:31.780 |
you might think of it as having a row for each word. 00:18:33.980 |
But similarly, yeah, it's n, which is the sequence length. 00:18:40.740 |
And then that gets mapped to this thing, which 00:18:46.380 |
Why do we learn two different matrices, q and k, 00:18:51.420 |
qi transpose kj is really just one matrix in the middle? 00:18:59.500 |
being a low-rank approximation to that matrix. 00:19:02.060 |
So it is for computational efficiency reasons. 00:19:11.300 |
is having a very low-rank approximation to qk transpose. 00:19:27.620 |
This eii, so the query of the word dotted with the key 00:19:40.500 |
So does eii, for j equal to i, so looking at itself, 00:20:07.460 |
would be sort of dot product with yourself, which is going 00:20:18.460 |
be sort of arbitrarily different from each other, 00:20:27.060 |
for example, so that you don't look at yourself. 00:20:35.860 |
whether you should be looking at yourself or not. 00:20:42.820 |
that wouldn't be there if I just used xis everywhere 00:21:10.260 |
And so what we're going to do for this portion of the lecture 00:21:14.740 |
that we need in order to use self-attention as sort 00:21:21.540 |
So we can't use it as it stands as I've presented it, 00:21:29.340 |
One of them is that there's no notion of sequence order 00:21:40.140 |
I'm going to move over here to the whiteboard briefly, 00:21:46.820 |
If I have a sentence like, Zuko made his uncle. 00:22:08.700 |
using its embedding matrix, the embedding matrix 00:22:23.900 |
and there's a lot more on this in the lecture notes that 00:22:32.260 |
will give you exactly the same representations 00:22:34.380 |
for this sequence, Zuko made his uncle, as for this sequence, 00:22:52.660 |
and nowhere does the exact position of the words 00:22:59.140 |
So we're going to encode the position of words 00:23:02.540 |
through the keys, queries, and values that we have. 00:23:06.100 |
So consider now representing each sequence index-- 00:23:13.380 |
So don't worry so far about how it's being made, 00:23:17.220 |
but you can imagine representing the number 1, 00:23:20.300 |
the position 1, the position 2, the position 3, 00:23:25.580 |
just like we're representing our keys, queries, and values. 00:23:33.060 |
If you were to want to incorporate the information 00:23:37.940 |
represented by these positions into our self-attention, 00:23:48.020 |
So if I have this xi embedding of a word, which 00:23:53.340 |
is the word at position i, but really just represents, 00:23:56.060 |
oh, the word zuko is here, now I can say, oh, it's the word 00:23:59.100 |
zuko, and it's at position 5, because this vector represents 00:24:12.860 |
So we can do it once at the very input to the network, 00:24:26.060 |
is look at these sinusoidal position representations. 00:24:29.500 |
So this looks a little bit like this, where you have-- 00:24:32.100 |
so this is a vector p i, which is in dimensionality d. 00:24:35.980 |
And each one of the dimensions, you take the value i, 00:24:40.500 |
you modify it by some constant, and you pass it 00:24:49.900 |
according to the period, differing periods depending 00:24:54.340 |
So I've got this sort of a representation of a matrix, 00:25:01.500 |
And you can see that there's sort of like, oh, as I walk 00:25:05.580 |
along, you see the period of the sine function going up and down, 00:25:08.300 |
and each of the dimensions d has a different period. 00:25:11.220 |
And so together, you can represent a bunch of different 00:25:15.140 |
And it gives this intuition that, oh, maybe sort 00:25:20.020 |
of the absolute position of a word isn't as important. 00:25:26.220 |
And maybe that allows you to extrapolate to longer sequences. 00:25:34.220 |
that is still sometimes used for how to represent position 00:25:45.180 |
You might think it's a little bit complicated, 00:25:50.300 |
Here's something that feels a little bit more deep learning. 00:26:05.500 |
And I'm going to learn it as a parameter, just like I 00:26:13.100 |
So you just sort of add this matrix to the xi's, 00:26:26.300 |
that's linear, sort of index-based that you want, 00:26:31.460 |
And the cons are that, well, you definitely now 00:26:33.900 |
can't represent anything that's longer than n words long, right? 00:26:37.980 |
No sequence longer than n you can handle because, well, 00:26:41.660 |
you only learned a matrix of this many positions. 00:26:47.660 |
if you pass a self-attention model, something longer 00:26:51.620 |
It will just sort of crash and say, I can't do this. 00:26:56.620 |
And so this is sort of what most systems nowadays use. 00:26:59.660 |
There are more flexible representations of position, 00:27:04.940 |
You might want to look at the relative linear position, 00:27:16.660 |
Because, oh, maybe words that are close in the dependency 00:27:19.100 |
parse tree should be the things that are sort of close 00:27:28.340 |
In practice, do we typically just make n large enough 00:27:32.420 |
that we don't run into the issue of having something 00:27:39.060 |
So the question is, in practice, do we just make n long enough 00:27:41.900 |
that we don't run into the problem where we're going 00:27:46.660 |
No, in practice, it's actually quite a problem. 00:27:49.420 |
Even today, even in the largest, biggest language models, 00:27:52.540 |
and can I fit this prompt into chat GPT or whatever? 00:28:01.980 |
And part of it is because the self-attention operation-- 00:28:04.980 |
and we'll get into this later in the lecture-- 00:28:06.900 |
it's quadratic complexity in the sequence length. 00:28:10.060 |
So you're going to spend n squared memory budget in order 00:28:15.740 |
So in practice, this might be on a large model, say, 4,000 or so. 00:28:21.420 |
n is 4,000, so you can fit 4,000 words, which feels like a lot, 00:28:29.740 |
And there are models that do longer sequences, for sure. 00:28:49.700 |
So how do you know that the p that you've learned, 00:28:51.580 |
this matrix that you've learned, is representing position 00:28:55.540 |
And the reason is the only thing it correlates is position. 00:28:58.700 |
So when I see these vectors, I'm adding this p matrix 00:29:21.260 |
This vector at index 1 is always at index 1 for every example, 00:29:34.900 |
But it definitely allows you to know, oh, this word 00:29:42.700 |
Just quickly, when you say quadratic constant in space, 00:29:47.100 |
is a sequence right now defined as a sequence? 00:29:51.180 |
Or I'm trying to figure out what unit is using it. 00:29:57.800 |
So the question is, when this is quadratic in the sequence, 00:30:12.100 |
maybe for an entire paragraph, or an entire document, 00:30:30.540 |
Another is that, based on the presentation of self-attention 00:30:34.060 |
that we've done, there's really no nonlinearities 00:30:38.940 |
We're just computing weighted averages of stuff. 00:30:43.420 |
So if I apply self-attention, and then apply self-attention 00:30:58.240 |
So you're computing averages of value vectors, 00:31:00.940 |
and it ends up looking like one big self-attention. 00:31:17.460 |
in this case, I'm calling it a multilayer perceptron MLP. 00:31:20.380 |
So this is a vector in Rd that's going to be-- 00:31:26.420 |
And you do the usual multilayer perceptron thing, 00:31:30.020 |
where you have the output, and you multiply it by a matrix, 00:31:32.460 |
pass it through a nonlinearity, multiply it by another matrix. 00:31:36.020 |
And so what this looks like in self-attention 00:31:38.300 |
is that I've got this sentence, the chef who-- 00:31:44.040 |
I pass it through this whole big self-attention block, which 00:31:55.300 |
So this embedding, that's the output of the self-attention 00:32:03.540 |
And you can think of it as combining together or processing 00:32:11.600 |
So there's a number of reasons why we do this. 00:32:16.020 |
stack a ton of computation into these feed-forward networks 00:32:27.280 |
can pass it through this position-wise feed-forward 00:32:34.000 |
by this feed-forward network to process the result. 00:32:40.360 |
So that's adding our classical deep learning nonlinearities 00:32:45.980 |
And that's an easy fix for this no nonlinearities 00:32:52.860 |
have our final minimal self-attention building block 00:33:02.100 |
been writing out all of these examples of self-attention, 00:33:07.700 |
And in practice, for some tasks, such as machine translation 00:33:13.940 |
want to define a probability distribution over a sequence, 00:33:18.820 |
So at every time step, I could define the set 00:33:24.860 |
of keys and queries and values to only include past words. 00:33:31.780 |
It's inefficient because you can't parallelize it so well. 00:33:34.660 |
So instead, we compute the entire n by n matrix, 00:33:37.940 |
just like I showed in the slide discussing self-attention. 00:33:45.500 |
and I computed eij for all n by n pairs of words-- 00:33:49.900 |
is equal to whatever it was before if the word that you're 00:33:54.820 |
looking at, index j, is an index that is less than or equal to 00:34:01.740 |
And it's equal to negative infinity-ish otherwise, 00:34:06.220 |
And when you softmax the eij, negative infinity 00:34:20.020 |
So in order to encode these words, the chef who-- 00:34:32.700 |
I negative infinity out the words I can't look at. 00:34:39.260 |
When encoding the, I can look at the start symbol and the. 00:34:43.220 |
When encoding chef, I can look at start the chef. 00:34:48.960 |
And so with this representation of chef that is only looking 00:34:53.780 |
at start the chef, I can define a probability distribution 00:34:57.940 |
using this vector that allows me to predict who 00:35:01.100 |
without having cheated by already looking ahead 00:35:26.700 |
So this is the distinction between a bidirectional LSTM 00:35:40.180 |
sentence of your machine translation problem, 00:35:48.800 |
because you have this autoregressive probability 00:35:51.620 |
of word one, probability of two given one, three given two 00:35:56.340 |
So traditionally, yes, in decoders, you will use it. 00:36:10.780 |
by having some notion of the probability of future words 00:36:19.020 |
or before they choose the words that they are currently 00:36:27.520 |
So the question is, isn't looking ahead a little bit 00:36:30.460 |
and predicting or getting an idea of the words 00:36:32.900 |
that you might say in the future sort of how humans generate 00:36:35.460 |
language instead of the strict constraint of not 00:36:51.180 |
I can't-- if I'm teaching it to try to predict the next word, 00:37:01.680 |
maybe it would be a good idea to make some guesses 00:37:03.980 |
far into the future or have a high-level plan or something. 00:37:07.700 |
But in training the network, I can't encode that intuition 00:37:17.020 |
directly, at least, because then it's just too easy. 00:37:22.380 |
But there might be interesting ideas about maybe giving 00:37:32.220 |
So I understand why we want to mask the future for stuff 00:37:50.020 |
In machine translation, I have a sentence like, 00:38:16.380 |
And so I want self-attention without masking, 00:38:22.820 |
because I want "I" to look at "like" and "I" to look at 00:38:40.600 |
And we'll talk about encoder-decoder architectures 00:38:46.120 |
But I want to be able to look at myself, none of the future, 00:38:50.020 |
And so what I'm talking about right now in this masking case 00:38:52.780 |
is masking out with negative infinity all of these words. 00:38:59.380 |
So that attention score from "Je" to everything else 00:39:19.220 |
and this is my personal opinion-- a minimal self-attention 00:39:22.900 |
You have self-attention, the basis of the method. 00:39:29.260 |
And maybe we had the inputs to the sequence here. 00:39:31.980 |
And then you embed it with that embedding matrix E. 00:39:38.580 |
using the key, the value, and the query that's 00:39:54.980 |
because otherwise you'd have no idea what order the words 00:39:59.260 |
You have the nonlinearities in sort of the TLFeedForward 00:40:01.940 |
network there to sort of provide that sort of squashing 00:40:20.180 |
so you have this thing-- maybe you repeat this sort of 00:40:24.460 |
So self-attention, feedforward, self-attention, feedforward, 00:40:31.140 |
And then maybe at the end of it, you predict something. 00:40:38.500 |
or you predict the sentiment, or you predict whatever. 00:40:40.740 |
So this is like a self-attention architecture. 00:40:44.640 |
OK, we're going to move on to the transformer next. 00:40:56.140 |
where I want to decode out a sequence where I have 00:41:02.380 |
to represent this word properly, I cannot have 00:41:20.660 |
I call a minimal self-attention architecture. 00:41:30.820 |
that was just up on the slide, the previous slide. 00:41:38.660 |
that we'll talk about now that goes into the transformer. 00:41:41.700 |
What I would hope, though, to have you take away from that 00:41:51.420 |
the end point of our search for better and better ways 00:41:54.580 |
of representing language, even though it's now ubiquitous 00:42:05.020 |
and maybe ways of fixing some of the issues with transformers. 00:42:08.940 |
OK, so a transformer decoder is how we'll build systems 00:42:15.740 |
It's like our decoder with our self-attention-only sort 00:42:31.660 |
with masking with masked multi-head self-attention. 00:42:39.820 |
between the transformer and this sort of minimal architecture 00:42:43.820 |
So let's come back to our toy example of attention, 00:42:46.740 |
where we've been trying to represent the word learned 00:42:49.060 |
in the context of the sequence, I went to Stanford CS224N 00:42:54.740 |
And I was sort of giving these teal bars to say, 00:42:57.580 |
oh, maybe intuitively you look at various things 00:43:09.660 |
to see varying sort of aspects of information 00:43:13.660 |
that I want to incorporate into my representation. 00:43:19.340 |
to look at Stanford CS224N, because, oh, it's like entities. 00:43:28.300 |
than you do at other courses or other universities or whatever. 00:43:31.940 |
And so maybe I want to look here for this reason. 00:43:43.380 |
And I want to see maybe syntactically relevant words. 00:43:55.020 |
ends up being maybe somewhat too difficult in a way 00:44:14.340 |
Yeah, so it should be an application of attention 00:44:19.140 |
So one independent define the keys, define the queries, 00:44:27.460 |
and then I do it again with different parameters, 00:44:31.620 |
being able to look at different things, et cetera. 00:44:36.220 |
How do we ensure that they look at different things? 00:44:40.060 |
if we have two separate sets of weights trying to learn, 00:44:41.980 |
say, to do this and to do that, how do we ensure that they 00:44:45.940 |
We do not ensure that they learn different things. 00:44:49.100 |
And in practice, they do, although not perfectly. 00:44:52.780 |
So it ends up being the case that you have some redundancy, 00:44:59.140 |
But we hope, just like we hope that different dimensions 00:45:02.300 |
in our feedforward layers will learn different things 00:45:06.740 |
that we hope that the heads will start to specialize. 00:45:09.380 |
And that will mean they'll specialize even more. 00:45:16.340 |
All right, so in order to discuss multi-head self 00:45:18.620 |
attention well, we really need to talk about the matrices, 00:45:22.100 |
how we're going to implement this in GPUs efficiently. 00:45:25.220 |
We're going to talk about the sequence-stacked form 00:45:29.260 |
So we've been talking about each word sort of individually 00:45:31.660 |
as a vector in dimensionality D. But really, we're 00:45:35.140 |
going to be working on these as big matrices that are stacked. 00:45:38.900 |
So I take all of my word embeddings, x1 to xn, 00:45:43.900 |
And now I have a big matrix that is in dimensionality Rn by D. 00:45:58.220 |
So x is Rn by D. K is Rd by D. So n by D times d by D 00:46:10.460 |
on my whole sequence to multiply each one of the words 00:46:13.340 |
of my key query and value matrices very efficiently. 00:46:35.260 |
So first, we're going to take the key query dot 00:47:20.700 |
So this is an n by n matrix multiplied by an n by D matrix. 00:47:26.900 |
Well, this is just doing the weighted average. 00:47:34.220 |
on the whole matrix, giving me my whole self-attention output 00:47:37.380 |
in Rn by D. So I've just restated identically 00:47:43.980 |
in terms of matrices so that you could do this efficiently 00:47:55.500 |
and it's going to be important to compute this 00:48:01.000 |
to look in multiple places at once for different reasons. 00:48:04.220 |
So for self-attention looks where this dot product here 00:48:15.380 |
But maybe we want to look in different places 00:48:19.300 |
So we actually define multiple query, key, and value matrices. 00:48:30.260 |
And for each head, I'm going to define an independent query, 00:48:41.300 |
So each one of these is doing projection down 00:48:45.340 |
This is going to be for computational efficiency. 00:48:53.300 |
So this equation here is identical to the one 00:48:58.540 |
except I've got these sort of l indices everywhere. 00:49:06.540 |
And then I do have my lower dimensional value vector 00:49:11.900 |
But really, you're doing exactly the same kind of operation. 00:49:19.700 |
So I've done sort of look in different places 00:49:22.140 |
with the different key, query, and value matrices. 00:49:33.540 |
And I concatenate them together and then sort of mix them 00:49:36.140 |
together with the final linear transformation. 00:49:39.820 |
And so each head gets to look at different things 00:49:43.040 |
and construct their value vectors differently. 00:49:45.420 |
And then I sort of combine the result all together at once. 00:49:55.820 |
It's actually not more costly to do this, really, 00:49:58.540 |
than it is to compute a single head of self-attention. 00:50:16.260 |
And then we can reshape it into rn, that's sequence length, 00:50:24.500 |
times the number of heads, times the model dimensionality 00:50:39.420 |
The third is this reduced model dimensionality. 00:50:45.940 |
And then I transpose so that I've got the head 00:50:50.740 |
And now I can compute all my other operations 00:51:05.100 |
model dimensionality d, I've got, in this case, 00:51:08.180 |
three xq matrices of model dimensionality d by 3, d by 3, 00:51:29.060 |
And the cost is that, well, each of my attention heads 00:51:33.420 |
has only a d by h vector to work with instead 00:51:47.100 |
And then I have three value matrices there as well, 00:51:52.660 |
And then finally, I get my three different output vectors. 00:52:04.620 |
is exactly what I gave in the toy example, which 00:52:10.240 |
look at different parts of a sequence for different reasons. 00:52:21.020 |
All of these attention heads are for a given transformer block. 00:52:23.780 |
A next block could also have three attention heads. 00:52:26.980 |
The question is, are all of these for a given block? 00:52:31.800 |
But this block was this sort of pair of self-attention 00:52:41.340 |
And the question is, are the parameters shared 00:52:46.180 |
You'll have independent parameters at every block, 00:52:55.380 |
that you have the same number of heads at each block? 00:52:58.820 |
Or do you vary the number of heads across blocks? 00:53:01.380 |
You have this-- you definitely could vary it. 00:53:05.540 |
so the question is, do you have different numbers of heads 00:53:09.340 |
Or do you have the same number of heads across all blocks? 00:53:14.540 |
it be the same everywhere, which is what people have done. 00:53:16.940 |
I haven't yet found a good reason to vary it, 00:53:21.860 |
It's definitely the case that after training these networks, 00:53:30.900 |
And I'd be curious to know if you could remove more or less, 00:53:35.700 |
depending on the layer index, which might then say, 00:53:43.700 |
So people tend to instead set the number of heads 00:53:46.740 |
to be roughly so that you have a reasonable number of dimensions 00:53:51.180 |
per head, given the total model dimensionality d that you want. 00:53:55.460 |
So for example, I might want at least 64 dimensions per head, 00:54:00.260 |
which if d is 128, that tells me how many heads 00:54:09.620 |
Yeah, with that xq, by slicing it into different columns, 00:54:15.820 |
you're reducing the rank of the final matrix, right? 00:54:19.820 |
But that doesn't really have any effect on the results. 00:54:23.180 |
So the question is, by having these reduced xq and xk 00:54:29.300 |
matrices, this is a very low rank approximation. 00:54:35.340 |
defining this whole big matrix, it's very low rank. 00:54:42.700 |
we limit the number of heads depending on the model 00:54:45.820 |
dimensionality, because you want intuitively at least 00:54:51.220 |
So 64 is sometimes done, 128, something like that. 00:54:55.820 |
But if you're not giving each head too much to do, 00:54:58.300 |
and it's got sort of a simple job, you've got a lot of heads, 00:55:16.300 |
to see if information in one of the sets of the attention 00:55:25.100 |
learns is consistent and related to each other, 00:55:32.380 |
So the question is, have there been studies to see 00:55:42.580 |
and interpretability and analysis of these models 00:55:44.980 |
to try to figure out what roles, what sort of mechanistic roles 00:55:55.140 |
learning to pick out the syntactic dependencies, 00:55:59.820 |
or maybe doing a global averaging of context. 00:56:09.620 |
it's unclear if you look at a word 10 layers deep 00:56:17.060 |
from everyone else, and it's a little bit unclear. 00:56:26.780 |
But yeah, if you want to talk more about it, I'm happy to. 00:56:31.420 |
So another sort of hack that I'm going to toss in here-- 00:56:36.300 |
but it's a nice little method to improve things. 00:56:42.060 |
So one of the issues with this sort of key query value 00:56:45.500 |
self-attention is that when the model dimensionality becomes 00:56:47.940 |
large, the dot products between vectors, even random vectors, 00:56:55.060 |
And when that happens, the inputs to the softmax function 00:56:58.060 |
can be very large, making the gradient small. 00:57:01.380 |
So intuitively, if you have two random vectors 00:57:03.300 |
and model dimensionality d, and you just dot product them 00:57:13.940 |
with everyone's attention being very uniform, very flat, 00:57:23.660 |
And so what you end up doing is you just sort of-- 00:57:26.300 |
for each of your heads, you just sort of divide all the scores 00:57:35.660 |
their dot products don't, at least at initialization time. 00:57:40.500 |
So this is sort of like a nice little important, 00:57:52.260 |
And so that's called scaled dot product attention. 00:57:55.500 |
From here on out, we'll just assume that we do this. 00:57:59.540 |
You just do a little division in all of your computations. 00:58:12.660 |
We have two big optimization tricks, or optimization 00:58:17.060 |
are quite important, that end up being very important. 00:58:20.100 |
We've got residual connections and layer normalization. 00:58:22.940 |
And in transformer diagrams that you see sort of around the web, 00:58:26.980 |
they're often written together as this add and norm box. 00:58:38.620 |
and then do this sort of optimization add a norm. 00:58:47.660 |
So let's go over these two individual components. 00:58:53.300 |
I mean, I think we've talked about residual connections 00:58:58.140 |
But it's really a good trick to help models train better. 00:59:04.660 |
instead of having this sort of-- you have a layer, layer i 00:59:16.380 |
I'm going to add the result of layer i to its input here. 00:59:23.060 |
So now I'm saying I'm just going to compute the layer, 00:59:25.320 |
and I'm going to add in the input to the layer 00:59:39.860 |
And you should think that the gradient is just 00:59:41.740 |
really great through the residual connection. 00:59:43.820 |
Like, ah, if I've got vanishing or exploding gradient-- 00:59:49.500 |
well, I can at least learn everything behind it 00:59:51.740 |
because I've got this residual connection where 01:00:03.980 |
looks a little bit like the identity function now, right? 01:00:08.920 |
is somewhat small because all of your weights are small, 01:00:15.620 |
like the identity, which might be a good sort of place 01:00:28.000 |
trying to traverse the mountains of the lost landscape. 01:00:38.060 |
and you can't sort of find your way to get out. 01:01:02.860 |
So layer norm is another thing to help your model train 01:01:08.380 |
And the intuitions around layer normalization 01:01:14.760 |
and sort of the empiricism of it working very well 01:01:17.040 |
maybe aren't perfectly, let's say, connected. 01:01:25.700 |
that we want to say this variation within each layer. 01:01:33.140 |
That's not actually informative because of variations 01:01:39.940 |
Or I've got sort of weird things going on in my layers 01:01:45.140 |
I haven't been able to sort of make everything behave sort 01:01:47.740 |
of nicely where everything stays roughly the same norm. 01:01:54.660 |
And I want to cut down on sort of uninformative variation 01:02:00.940 |
So I'm going to let x and rd be an individual word 01:02:05.380 |
So this is like I have a single index, one vector. 01:02:09.100 |
And what I'm going to try to do is just normalize it. 01:02:12.660 |
Normalize it in the sense of it's got a bunch of variation. 01:02:17.700 |
I'm going to normalize it to unit mean and standard 01:02:21.020 |
So I'm going to estimate the mean here across-- 01:02:40.220 |
I'm going to have my estimate of the standard deviation. 01:02:45.020 |
This is my simple estimate of the standard deviation 01:02:53.700 |
and then possibly, I guess I can have learned parameters 01:02:58.060 |
to try to scale back out in terms of multiplicatively 01:03:08.380 |
I'm going to take my vector x, subtract out the mean, 01:03:17.900 |
So I'm going to have this epsilon there that's close to 0. 01:03:21.700 |
So this part here, x minus mu over square root sigma 01:03:25.500 |
plus epsilon, is saying take all the variation 01:03:28.540 |
and normalize it to unit mean and standard deviation. 01:03:32.600 |
And then maybe I want to scale it, stretch it back out, 01:03:37.080 |
and then maybe add an offset beta that I've learned. 01:03:40.860 |
Although in practice, actually, this part-- and discuss this 01:03:44.580 |
in practice, this part maybe isn't actually that important. 01:03:47.940 |
But so layer normalization, yeah, you're sort of-- 01:03:51.220 |
you can think of this as when I get the output of layer 01:03:55.940 |
sort of look nice and look similar to the next layer 01:04:00.940 |
going to be unit mean and standard deviation. 01:04:06.260 |
OK, any questions for residual or layer norm? 01:04:13.220 |
What would it mean to subtract the scalar mu from the vector x? 01:04:18.780 |
When I subtract the scalar mu from the vector x, 01:04:21.580 |
I broadcast mu to dimensionality d and remove mu from all d. 01:04:43.620 |
In the fourth bullet point when you're calculating the mean, 01:05:00.760 |
So if you have five words in a sentence by their norm, 01:05:04.700 |
do you normalize based on the statistics of these five words 01:05:11.700 |
So the question is, if I have five words in the sequence, 01:05:14.500 |
do I normalize by aggregating the statistics to estimate mu 01:05:21.060 |
share their statistics, or do it independently for each word? 01:05:25.700 |
think in all the papers that discuss transformers 01:05:30.140 |
You do not share across the five words, which 01:05:35.500 |
So each of the five words is done completely independently. 01:05:54.360 |
For example, per batch or per output of the same position? 01:06:02.840 |
The question is, if you have a batch of sequences, 01:06:06.760 |
so just like we were doing batch-based training, 01:06:11.840 |
now, we don't share across the sequence index 01:06:13.760 |
for sharing the statistics, but do you share across the batch? 01:06:22.160 |
invented as a replacement for batch normalization, which 01:06:27.900 |
is that now your forward pass sort of depends in a way 01:06:30.960 |
that you don't like on examples that should be not 01:06:35.400 |
And so, yeah, you don't share statistics across the batch. 01:06:44.240 |
OK, so now we have our full transformer decoder, 01:06:50.520 |
So in this sort of slightly grayed out thing here 01:06:52.960 |
that says repeat for a number of decoder blocks, 01:07:08.400 |
I've got the layer normalization there, and then 01:07:10.780 |
a feed-forward layer, and then another add and norm. 01:07:18.040 |
I apply for some number of times, number of blocks. 01:07:21.720 |
So that whole thing is called a single block. 01:07:29.040 |
Cool, so that's a whole architecture right there. 01:07:34.040 |
We've solved things like needing to represent position. 01:07:41.760 |
We've solved a lot of different optimization problems. 01:07:52.800 |
With the dot product scaling with the square root 01:08:20.820 |
So you're going to maybe pad to a constant length. 01:08:24.740 |
And in order to not look at the future, the stuff that's 01:08:28.580 |
happening in the future, you can mask out the pad tokens, 01:08:32.720 |
just like the masking that we showed for not looking 01:08:36.520 |
You can just say, set all of the attention weights to 0, 01:08:48.280 |
So you can set everything to this maximum length. 01:08:53.820 |
do you set this length that you have everything 01:08:56.980 |
I mean, yes, often, although you can save computation 01:09:03.420 |
And everything-- the math all still works out. 01:09:06.140 |
You just have to code it properly so it can handle-- 01:09:12.260 |
is shorter than length 5, and you save a lot of computation. 01:09:15.260 |
All of the self-attention operations just work. 01:09:21.980 |
How many layers are in the feedforward normally? 01:09:25.340 |
There's one hidden layer in the feedforward usually. 01:09:30.060 |
We've got a couple more things and not very much time. 01:09:35.620 |
So in the encoder-- so the transformer encoder 01:09:48.620 |
And so it's that easy to make the model bidirectional. 01:09:58.100 |
And then finally, we've got the transformer encoder decoder, 01:10:01.900 |
which is actually how the transformer was originally 01:10:04.060 |
presented in this paper, "Attention is All You Need." 01:10:07.900 |
And this is when we want to have a bidirectional network. 01:10:17.580 |
And I have a decoder to decode out my sentence. 01:10:22.140 |
Now, but you'll see that this is slightly more complicated. 01:10:35.220 |
I am going to use my decoder vectors as my queries. 01:10:41.460 |
Then I'll take the output of the encoder as my keys and values. 01:10:50.580 |
in the output of all of the blocks of the encoder. 01:10:54.940 |
How do we get a key and value separated from the output? 01:11:04.340 |
Because didn't we collapse those into the single output? 01:11:12.780 |
Like, how do we-- because when we have the output, 01:11:15.020 |
didn't we collapse the keys and values into a single output? 01:11:27.900 |
is just this weighted average of the value vectors 01:11:33.020 |
And then from that output for the next layer, 01:11:35.780 |
we apply a new key, query, and value transformation 01:11:38.700 |
to each of them for the next layer of self-attention. 01:11:44.280 |
Yeah, you apply the key matrix, the query matrix, 01:11:59.820 |
I'm going to call them the output of the encoder. 01:12:03.420 |
And then I've got vectors that are the output of the decoder. 01:12:10.860 |
And then I simply define my keys and my values 01:12:19.460 |
So I take the h's, I apply a key matrix and a value matrix, 01:12:22.860 |
and then I define the queries from my decoder. 01:12:26.540 |
So my queries here-- so this is why two of the arrows 01:12:32.540 |
I've got my z's here, my queries, my keys and values 01:12:45.540 |
I want to discuss some of the results of transformers, 01:12:53.060 |
So really, the original results of transformers, 01:12:58.580 |
you can do way more computation because of parallelization. 01:13:02.340 |
They got great results in machine translation. 01:13:04.780 |
So you had-- you had transformers doing quite well, 01:13:13.060 |
although not astoundingly better than existing machine 01:13:20.020 |
But they were significantly more efficient to train. 01:13:22.260 |
Because you don't have this parallelization problem, 01:13:25.380 |
you could compute on much more data much faster, 01:13:27.620 |
and you could make use of faster GPUs much more. 01:13:32.100 |
After that, there were things like document generation, 01:13:35.060 |
where you had the old standard of sequence-to-sequence models 01:13:39.060 |
And eventually, everything became transformers 01:13:47.340 |
into pre-training, which we'll go over in next year, 01:13:54.660 |
allows you to compute on tons and tons of data. 01:13:58.340 |
And so after a certain point, on standard large benchmarks, 01:14:04.740 |
This ability to make use of lots and lots of data, 01:14:07.540 |
lots and lots of compute, just put transformers head 01:14:13.420 |
almost every modern advancement in natural language processing. 01:14:19.900 |
There are many drawbacks and variants to transformers. 01:14:25.820 |
tried to work on quite a bit is this quadratic compute 01:14:31.780 |
means that our total computation for each block 01:14:34.420 |
grows quadratically with the sequence length. 01:14:36.260 |
And in a student's question, we heard that, well, 01:14:41.580 |
if I want to process a whole Wikipedia article, 01:14:44.220 |
a whole novel, that becomes quite unfeasible. 01:14:48.100 |
And actually, that's a step backwards in some sense, 01:14:52.740 |
it only grew linearly with the sequence length. 01:15:02.060 |
is not really the best way maybe to represent 01:15:07.940 |
And just to give you an intuition of quadratic sequence 01:15:10.300 |
length, remember that we had this big matrix multiply here 01:15:24.220 |
And so if you think of the model dimensionality 01:15:26.260 |
as like 1,000, although today it gets much larger, 01:15:29.340 |
then for a short sequence of n is roughly 30, 01:15:32.460 |
maybe if you're computing n squared times d, 30 isn't so 01:15:42.020 |
then n squared becomes huge and sort of totally infeasible. 01:15:46.540 |
So people have tried to sort of map things down 01:15:59.700 |
most of the computation doesn't show up in the self-attention. 01:16:05.500 |
necessary to get rid of the self-attention operations 01:16:14.780 |
And then finally, there have been a ton of modifications 01:16:17.460 |
to the transformer over the last five, four-ish years. 01:16:21.940 |
And it turns out that the original transformer 01:16:27.820 |
is pretty much the best thing there is still. 01:16:33.940 |
Changing out the nonlinearities in the feedforward network 01:16:43.140 |
But I think it's ripe for people to come through and think 01:16:46.340 |
about how to sort of improve it in various ways. 01:16:53.260 |
And then we'll have the project proposal documents out tonight