back to indexStanford XCS224U: NLU I Contextual Word Representations, Part 2: Transformer I Spring 2023
00:00:05.840 |
This is part two in our series on contextual representations. 00:00:13.840 |
I propose that we just dive into the core model structure. 00:00:17.040 |
I'm going to introduce that by way of a simple example. 00:00:20.200 |
I've got that at the bottom of the slide here. 00:00:24.220 |
and I've paired each one of those tokens with 00:00:26.440 |
a token representing its position in the string. 00:00:30.040 |
The first thing that we do in this model is look up 00:00:32.820 |
each one of those tokens in its own embedding space. 00:00:36.820 |
For word embeddings, we look those up and get things like x47, 00:00:40.360 |
which is a vector corresponding to the word the. 00:00:43.380 |
That representation is a static word representation 00:00:47.040 |
that's very similar conceptually to what we had in 00:00:49.600 |
the previous era with models like word2vec and GloVe. 00:00:55.220 |
these positional tokens here and get their vector representations. 00:00:59.300 |
Then to combine them, we simply add them together dimension-wise 00:01:03.460 |
to get the representations that I have in green here, 00:01:06.360 |
which you could think of as the first contextual 00:01:12.380 |
On the right here, I've depicted that calculation 00:01:18.860 |
That's a pattern that I'm going to continue all the way up as we 00:01:27.140 |
the calculations are entirely parallel for A and for B. 00:01:31.340 |
To get C input, we simply add together x34 with P3, 00:01:41.360 |
This is the part of the model that gives rise to 00:01:43.420 |
that famous paper title, attention is all you need. 00:01:46.860 |
The reason the paper has the title attention is all you need is 00:01:52.860 |
the previous era with recurrent neural networks 00:02:01.280 |
those recurrences to further connect everything to everything else. 00:02:06.940 |
you can get rid of those recurrent connections 00:02:15.720 |
the transformer has many other pieces as you'll see, 00:02:20.860 |
I believe, that you could drop the recurrent mechanisms. 00:02:24.780 |
The attention mechanism that the transformer uses is 00:02:34.440 |
It is a dot product-based approach to attention. 00:02:48.260 |
Here, I've got depicted each dot product is a dot, 00:02:55.720 |
the components that feed into that calculation. 00:02:58.500 |
This dot here corresponds to A input combined with 00:03:08.020 |
and then we do the same thing for the C step. 00:03:10.500 |
The two dots that are depicted here correspond 00:03:13.580 |
to the two dot products that are in this numerator. 00:03:16.780 |
One new thing that they did in the transformer paper is 00:03:19.780 |
normalize those dot products by the square root of DK. 00:03:28.220 |
all the representations that we have talked about so far. 00:03:31.380 |
That's a really important element of the transformer. 00:03:44.220 |
There is one exception to that which I will return to, 00:03:52.220 |
What the transformer authors found is that they got 00:03:55.740 |
better scaling for the dot products when they 00:04:03.500 |
Those normalized dot products give us a new vector, 00:04:12.340 |
which you could think of as attention scores. 00:04:20.640 |
we take each component of this vector alpha and 00:04:26.260 |
of the representations that we're attending to. 00:04:32.300 |
and then we sum those values together to get C attention. 00:04:39.100 |
dense connections for all of these different states. 00:04:41.460 |
I'm just showing you the calculations for C attention. 00:04:45.500 |
That's important because all those lines that are now on 00:04:48.600 |
the slide are really the only place at which we knit 00:04:54.380 |
otherwise be operating independently of each other. 00:04:57.320 |
This really gives us all the dense connections that we think 00:05:00.480 |
are so powerful for the transformer learning, 00:05:07.140 |
the representations that I have in orange are 00:05:09.620 |
attention representations but they're raw materials 00:05:12.940 |
because they're really just recording the similarity 00:05:20.580 |
To get an actual attention representation in the transformer, 00:05:30.940 |
and that gives us the representations in yellow, 00:05:36.180 |
full-fledged attention-based representations. 00:05:40.020 |
I've depicted the calculation over here and that 00:05:42.260 |
includes a nice reminder that we actually apply 00:05:44.300 |
dropout to the sum of the orange and the green. 00:05:48.980 |
Dropout is a simple regularization technique that will help 00:05:53.340 |
diverse representations as part of its training. 00:05:59.700 |
and this is simply going to help us with scaling the values. 00:06:17.260 |
These are the feedforward components in the transformer. 00:06:20.820 |
I have depicted them as a single representation in blue, 00:06:24.220 |
but it's really important to see that this is actually 00:06:32.920 |
and we feed that through a dense layer with parameters W1 and 00:06:49.860 |
This is important because many of the parameters for 00:06:57.260 |
In fact, this is the one place where we could 00:07:03.280 |
because CA norm here has dimensionality decay by design. 00:07:13.660 |
some larger dimensionality if we want as long 00:07:16.620 |
as the output of that goes back down to decay. 00:07:27.420 |
really wide internal layers in this feedforward step. 00:07:31.620 |
Then of course, you have to collapse back down, 00:07:38.640 |
But we collapse back down to decay for CFF here. 00:07:42.620 |
Then we have another addition of CA norm with CFF, 00:07:54.660 |
Then finally, we have a layer normalization step, 00:08:03.120 |
and therefore help the model learn more effectively. 00:08:06.780 |
That is the essence of the transformer architecture. 00:08:24.780 |
and then we have the feedforward layers up here. 00:08:32.600 |
but the essence of it is position sensitivity, 00:08:37.900 |
We are going to stack these blocks on top of each other, 00:08:43.740 |
but all the blocks will follow that same rhythm. 00:08:47.420 |
Since attention is so important for these models, 00:08:54.740 |
What I've shown you so far is the calculation 00:08:57.520 |
that I've given at the top of the slide here, 00:08:59.500 |
which shows piecewise how all of these dot products 00:09:11.980 |
that calculation is presented in this matrix format here. 00:09:16.040 |
And if you're like me, you might not immediately see 00:09:18.720 |
how these two calculations correspond to each other. 00:09:22.460 |
And so what I've done is just offer you some simple code 00:09:25.660 |
that you could get hands-on with to convince yourself 00:09:31.220 |
And that might help you bootstrap an understanding 00:09:36.100 |
and then you can go forth with that more efficient 00:09:48.440 |
The other major piece that I have so far not introduced 00:09:58.960 |
So let's dive into what it means to be multi-headed. 00:10:01.540 |
I'm gonna show you a worked example with three heads. 00:10:08.980 |
So let's try to do this by way of a simple example. 00:10:12.260 |
I've got our usual sequence at the bottom here, 00:10:17.220 |
three contextual representations given in green. 00:10:36.940 |
that we have introduced a bunch of new parameters 00:10:49.960 |
Those are depicted in orange in this calculation, 00:10:52.160 |
and I put them in orange to try to make it easy to see 00:10:55.120 |
that if we simply remove all of those learned parameters, 00:11:16.100 |
We do the same thing for our second attention head, 00:11:21.740 |
but now augmented with those new learned parameters. 00:11:35.400 |
again with parameters corresponding to that third head. 00:11:46.760 |
So here is the attention representation for A, 00:12:06.120 |
that was probably a multi-headed attention process. 00:12:13.680 |
Maybe the one big idea that's worth repeating 00:12:16.720 |
is that we typically stack transformer blocks 00:12:25.040 |
but C out could be the basis for a second transformer block 00:12:30.920 |
And then of course we could repeat that process. 00:12:41.640 |
And the other thing that's worth reminding yourself of 00:12:47.020 |
are probably not single-headed attention representations, 00:12:52.000 |
where we piece together a bunch of component pieces 00:12:55.360 |
that themselves correspond to a lot of learned parameters. 00:13:06.060 |
In addition to the fact that that's the one place 00:13:08.760 |
where all of these columns of representations 00:13:17.080 |
and why it's good to have lots of heads in there 00:13:34.160 |
to better understand the famous transformer diagram 00:13:38.500 |
that appears in the attention is all you need paper. 00:13:41.380 |
I will confess to you that I myself on first reading 00:13:50.620 |
they are dealing mainly with sequence to sequence problems 00:13:56.980 |
And so now we can see that on the encoder side here, 00:13:59.820 |
what they've depicted is repeated for every step 00:14:05.540 |
So every step in the sequence that we're processing. 00:14:22.260 |
more normalization and kind of adding together 00:14:27.020 |
That's that same rhythm that I pointed out before. 00:14:31.820 |
On the decoder side, things get a little more complicated. 00:14:36.900 |
but the important thing is that now we need to do 00:14:38.940 |
masked attention because as we think about decoding, 00:14:48.740 |
and look only into the past when we do those dot products. 00:14:53.740 |
but otherwise the decoder has the same exact structure 00:14:58.780 |
They do have additional parameters on top here 00:15:03.300 |
If we're doing something like machine translation 00:15:06.540 |
we'll have those heads on every single state in the decoder. 00:15:10.700 |
But if we're doing something like classification, 00:15:15.860 |
only on one of the output states, maybe the final one. 00:15:21.460 |
you can see the same pieces that I've discussed before 00:15:24.340 |
just presented in this encoder decoder phase. 00:15:27.740 |
So I hope that helps a little bit with the famous diagram. 00:15:31.540 |
The final thing I wanted to say under this heading 00:15:36.460 |
for how these models work by downloading them 00:15:47.700 |
You see a lot of the pieces that we've already discussed. 00:15:52.020 |
It's got an embedding layer, which has word embeddings. 00:15:55.060 |
And you can see that there are about 30,000 items 00:15:57.540 |
in the embedding space, each one dimensionality 768. 00:16:04.580 |
The positional embeddings, we have 512 positional embeddings. 00:16:20.340 |
but that's kind of like a positional embedding. 00:16:24.740 |
So that's kind of regularization of these values. 00:16:29.420 |
And what you can see on this slide is just the first layer. 00:16:31.860 |
It's the same structure repeated for all subsequent layers. 00:16:37.660 |
You see 768 all over the place because that's DK. 00:16:43.580 |
that we need to have that same dimensionality everywhere. 00:16:56.460 |
But then we have to go from 3072 back to 768 for the output 00:17:00.900 |
so that we can stack these components on top of each other. 00:17:07.900 |
and therefore a lot more representational power. 00:17:11.980 |
And as I said, this would continue for all the layers. 00:17:14.820 |
And that's pretty much a summary of the architecture. 00:17:17.860 |
And you can do this for lots of different models 00:17:25.260 |
They'll differ subtly in their kind of graphs, 00:17:28.180 |
but I expect that you'll see a lot of the core pieces 00:17:31.220 |
repeated in various flavors as you look at those models.