back to index

Stanford XCS224U: NLU I Contextual Word Representations, Part 2: Transformer I Spring 2023


Whisper Transcript | Transcript Only Page

00:00:00.000 | Welcome back everyone.
00:00:05.840 | This is part two in our series on contextual representations.
00:00:08.960 | We've come to the heart of it,
00:00:10.160 | the transformer architecture.
00:00:11.960 | While we're still feeling fresh,
00:00:13.840 | I propose that we just dive into the core model structure.
00:00:17.040 | I'm going to introduce that by way of a simple example.
00:00:20.200 | I've got that at the bottom of the slide here.
00:00:22.040 | Our sentence is the rock rules,
00:00:24.220 | and I've paired each one of those tokens with
00:00:26.440 | a token representing its position in the string.
00:00:30.040 | The first thing that we do in this model is look up
00:00:32.820 | each one of those tokens in its own embedding space.
00:00:36.820 | For word embeddings, we look those up and get things like x47,
00:00:40.360 | which is a vector corresponding to the word the.
00:00:43.380 | That representation is a static word representation
00:00:47.040 | that's very similar conceptually to what we had in
00:00:49.600 | the previous era with models like word2vec and GloVe.
00:00:53.420 | We do something similar for
00:00:55.220 | these positional tokens here and get their vector representations.
00:00:59.300 | Then to combine them, we simply add them together dimension-wise
00:01:03.460 | to get the representations that I have in green here,
00:01:06.360 | which you could think of as the first contextual
00:01:09.220 | representations that we have in this model.
00:01:12.380 | On the right here, I've depicted that calculation
00:01:15.380 | for the C input part of the sequence.
00:01:18.860 | That's a pattern that I'm going to continue all the way up as we
00:01:21.900 | build this transformer block just showing
00:01:24.420 | the calculations for the C dimension because
00:01:27.140 | the calculations are entirely parallel for A and for B.
00:01:31.340 | To get C input, we simply add together x34 with P3,
00:01:35.540 | and that gives us C input.
00:01:38.260 | The next layer is the attention layer.
00:01:41.360 | This is the part of the model that gives rise to
00:01:43.420 | that famous paper title, attention is all you need.
00:01:46.860 | The reason the paper has the title attention is all you need is
00:01:50.060 | that the author saw what was happening in
00:01:52.860 | the previous era with recurrent neural networks
00:01:55.500 | where people had recurrent mechanisms,
00:01:57.580 | and then they added a bunch of
00:01:59.260 | attention mechanisms on top of
00:02:01.280 | those recurrences to further connect everything to everything else.
00:02:04.980 | What the paper title is saying is,
00:02:06.940 | you can get rid of those recurrent connections
00:02:09.620 | and rely entirely on attention.
00:02:12.180 | Hence, attention is all you need.
00:02:13.940 | That's an important historical note because
00:02:15.720 | the transformer has many other pieces as you'll see,
00:02:19.060 | but they were saying in particular,
00:02:20.860 | I believe, that you could drop the recurrent mechanisms.
00:02:24.780 | The attention mechanism that the transformer uses is
00:02:28.420 | essentially the same one that I
00:02:30.220 | introduced in the previous lecture
00:02:32.200 | coming from the pre-transformer era.
00:02:34.440 | It is a dot product-based approach to attention.
00:02:38.300 | I've summarized that here.
00:02:39.740 | You can see in the numerator,
00:02:40.940 | we have C input dot product with A input,
00:02:43.640 | and C input dot product with B input.
00:02:46.340 | Let me show you what those look like.
00:02:48.260 | Here, I've got depicted each dot product is a dot,
00:02:53.280 | and the arrows going into it correspond to
00:02:55.720 | the components that feed into that calculation.
00:02:58.500 | This dot here corresponds to A input combined with
00:03:01.880 | C input and this one, A input with B input.
00:03:05.420 | We do that same thing for the B step,
00:03:08.020 | and then we do the same thing for the C step.
00:03:10.500 | The two dots that are depicted here correspond
00:03:13.580 | to the two dot products that are in this numerator.
00:03:16.780 | One new thing that they did in the transformer paper is
00:03:19.780 | normalize those dot products by the square root of DK.
00:03:24.300 | DK is the dimensionality of the model.
00:03:26.860 | It is the dimensionality of
00:03:28.220 | all the representations that we have talked about so far.
00:03:31.380 | That's a really important element of the transformer.
00:03:34.540 | We're going to do a lot of
00:03:35.660 | additive combinations in this model,
00:03:37.780 | and that means that essentially,
00:03:39.660 | every representation has to have
00:03:41.880 | the same dimensionality and that is DK.
00:03:44.220 | There is one exception to that which I will return to,
00:03:47.160 | but all the states that I depict on
00:03:49.220 | this slide need to have dimensionality DK.
00:03:52.220 | What the transformer authors found is that they got
00:03:55.740 | better scaling for the dot products when they
00:03:57.860 | normalized by the square root of
00:04:00.140 | that model dimensionality as a heuristic.
00:04:03.500 | Those normalized dot products give us a new vector,
00:04:06.780 | alpha with a tilde on top.
00:04:08.720 | We softmax normalize that,
00:04:10.820 | and that gives us alpha,
00:04:12.340 | which you could think of as attention scores.
00:04:15.700 | To get the actual attention representation
00:04:18.580 | corresponding to this block here,
00:04:20.640 | we take each component of this vector alpha and
00:04:24.440 | multiply it by each one
00:04:26.260 | of the representations that we're attending to.
00:04:28.580 | Alpha 1 by A input,
00:04:30.500 | alpha 2 by B input,
00:04:32.300 | and then we sum those values together to get C attention.
00:04:36.660 | As a reminder, we have all these
00:04:39.100 | dense connections for all of these different states.
00:04:41.460 | I'm just showing you the calculations for C attention.
00:04:45.500 | That's important because all those lines that are now on
00:04:48.600 | the slide are really the only place at which we knit
00:04:51.980 | together all of these columns which will
00:04:54.380 | otherwise be operating independently of each other.
00:04:57.320 | This really gives us all the dense connections that we think
00:05:00.480 | are so powerful for the transformer learning,
00:05:03.360 | what sequences are like.
00:05:05.360 | Now, I do think that
00:05:07.140 | the representations that I have in orange are
00:05:09.620 | attention representations but they're raw materials
00:05:12.940 | because they're really just recording the similarity
00:05:16.260 | between our target representation
00:05:18.400 | and the representations around it.
00:05:20.580 | To get an actual attention representation in the transformer,
00:05:23.900 | what we do is add together
00:05:26.340 | these contextual representations down
00:05:28.420 | here with these attention values,
00:05:30.940 | and that gives us the representations in yellow,
00:05:33.500 | see a layer, and those are
00:05:36.180 | full-fledged attention-based representations.
00:05:40.020 | I've depicted the calculation over here and that
00:05:42.260 | includes a nice reminder that we actually apply
00:05:44.300 | dropout to the sum of the orange and the green.
00:05:48.980 | Dropout is a simple regularization technique that will help
00:05:52.220 | the model to learn
00:05:53.340 | diverse representations as part of its training.
00:05:56.700 | The next step is layer normalization,
00:05:59.700 | and this is simply going to help us with scaling the values.
00:06:02.440 | We're going to adjust them so that we have
00:06:04.300 | zero mean and a nice normal distribution
00:06:06.900 | falling off of that zero mean,
00:06:08.740 | and that's just a happy place
00:06:10.640 | for machine learning models in general.
00:06:13.700 | The next step is really crucially important.
00:06:17.260 | These are the feedforward components in the transformer.
00:06:20.820 | I have depicted them as a single representation in blue,
00:06:24.220 | but it's really important to see that this is actually
00:06:26.620 | hiding two feedforward layers.
00:06:29.480 | We take CA norm in purple here as the input,
00:06:32.920 | and we feed that through a dense layer with parameters W1 and
00:06:36.980 | B1 and we apply a ReLU activation to that.
00:06:41.100 | That is fed into a second dense layer with
00:06:44.180 | parameters W2 and bias term B2,
00:06:47.280 | and that gives us CFF.
00:06:49.860 | This is important because many of the parameters for
00:06:53.300 | the transformer are actually hidden
00:06:55.020 | away in these feedforward layers.
00:06:57.260 | In fact, this is the one place where we could
00:07:00.340 | depart from this dimensionality decay
00:07:03.280 | because CA norm here has dimensionality decay by design.
00:07:08.760 | But since we have two feedforward layers,
00:07:10.920 | we have the opportunity to expand out to
00:07:13.660 | some larger dimensionality if we want as long
00:07:16.620 | as the output of that goes back down to decay.
00:07:20.360 | As we'll see for some of these very large
00:07:22.520 | deployed transformer architectures,
00:07:24.820 | people have seized this opportunity to have
00:07:27.420 | really wide internal layers in this feedforward step.
00:07:31.620 | Then of course, you have to collapse back down,
00:07:33.840 | and that might be giving these models a lot
00:07:36.100 | of their representational power.
00:07:38.640 | But we collapse back down to decay for CFF here.
00:07:42.620 | Then we have another addition of CA norm with CFF,
00:07:48.300 | to get CFF layer here in yellow,
00:07:50.500 | and we have dropout applied to CFF,
00:07:52.620 | that's that regularization step.
00:07:54.660 | Then finally, we have a layer normalization step,
00:07:57.660 | just as we had down here,
00:07:58.900 | which will help us with rescaling
00:08:01.000 | the values that we've produced thus far,
00:08:03.120 | and therefore help the model learn more effectively.
00:08:06.780 | That is the essence of the transformer architecture.
00:08:10.780 | There are few more details to add on,
00:08:13.140 | but I feel like this gives you
00:08:14.460 | a good conceptual understanding.
00:08:16.380 | We began with position-sensitive versions
00:08:19.660 | of static word embeddings.
00:08:21.820 | We had these attention layers down here,
00:08:24.780 | and then we have the feedforward layers up here.
00:08:27.880 | In between, we have some regularization
00:08:30.620 | and some normalization of the values,
00:08:32.600 | but the essence of it is position sensitivity,
00:08:35.840 | attention, feedforward.
00:08:37.900 | We are going to stack these blocks on top of each other,
00:08:41.060 | and that's going to lead to
00:08:42.100 | lots more representational power,
00:08:43.740 | but all the blocks will follow that same rhythm.
00:08:47.420 | Since attention is so important for these models,
00:08:50.420 | I thought I would linger a little bit
00:08:52.020 | over the attention calculation.
00:08:54.740 | What I've shown you so far is the calculation
00:08:57.520 | that I've given at the top of the slide here,
00:08:59.500 | which shows piecewise how all of these dot products
00:09:02.040 | come together and get rescaled and added in
00:09:05.020 | to form C-attention in this case.
00:09:07.940 | In the attention is all you need paper,
00:09:09.860 | and in a lot of the subsequent literature,
00:09:11.980 | that calculation is presented in this matrix format here.
00:09:16.040 | And if you're like me, you might not immediately see
00:09:18.720 | how these two calculations correspond to each other.
00:09:22.460 | And so what I've done is just offer you some simple code
00:09:25.660 | that you could get hands-on with to convince yourself
00:09:28.620 | that those two calculations are the same.
00:09:31.220 | And that might help you bootstrap an understanding
00:09:34.180 | of what you typically see in the literature,
00:09:36.100 | and then you can go forth with that more efficient
00:09:38.900 | matrix version of the calculation,
00:09:40.980 | secure in the knowledge that it corresponds
00:09:44.040 | to the more piecewise thing
00:09:45.680 | that I've been depicting thus far.
00:09:48.440 | The other major piece that I have so far not introduced
00:09:52.540 | is multi-headed attention.
00:09:54.140 | So far, I have showed you
00:09:56.060 | effectively single-headed attention.
00:09:58.960 | So let's dive into what it means to be multi-headed.
00:10:01.540 | I'm gonna show you a worked example with three heads.
00:10:05.180 | The idea is actually very simple,
00:10:06.900 | but there are a lot of moving pieces.
00:10:08.980 | So let's try to do this by way of a simple example.
00:10:12.260 | I've got our usual sequence at the bottom here,
00:10:14.740 | the rock rules, and I've got our usual
00:10:17.220 | three contextual representations given in green.
00:10:21.100 | We are gonna do three parallel calculations
00:10:25.080 | corresponding to our three heads.
00:10:27.080 | Here's the first head.
00:10:28.860 | We do our same dot products as before,
00:10:32.000 | and it is effectively the same calculation
00:10:34.400 | that leads to them with the small twist
00:10:36.940 | that we have introduced a bunch of new parameters
00:10:39.800 | into the calculation.
00:10:41.040 | Those are WQ1 for queries,
00:10:45.400 | WK1 for keys,
00:10:47.840 | and WV1 for values.
00:10:49.960 | Those are depicted in orange in this calculation,
00:10:52.160 | and I put them in orange to try to make it easy to see
00:10:55.120 | that if we simply remove all of those learned parameters,
00:10:58.280 | we get back to the dot product calculation
00:11:01.280 | that I was showing you before.
00:11:02.980 | We've introduced these new matrices
00:11:05.600 | to provide more representational power
00:11:08.080 | inside this attention block.
00:11:11.080 | And the subscripts one indicate
00:11:12.900 | that we are dealing with parameters
00:11:14.280 | for the first attention head.
00:11:16.100 | We do the same thing for our second attention head,
00:11:20.400 | all of those dot products,
00:11:21.740 | but now augmented with those new learned parameters.
00:11:24.840 | Same thing, queries, keys, and values,
00:11:27.760 | but now two for the second attention head.
00:11:31.680 | And we repeat exactly the same thing
00:11:33.600 | for the third attention head,
00:11:35.400 | again with parameters corresponding to that third head.
00:11:39.680 | And then to actually get back
00:11:41.320 | to the attention representations
00:11:42.920 | that I was showing you before,
00:11:44.560 | we kind of reassemble the pieces.
00:11:46.760 | So here is the attention representation for A,
00:11:50.640 | here it is for B, and here it is for C.
00:11:53.900 | We've pieced together from all the things
00:11:55.820 | that we did down here,
00:11:57.240 | these three separate representations.
00:11:59.840 | And those are what was depicted in orange
00:12:02.400 | on the previous slides.
00:12:03.800 | But now you can see that implicitly
00:12:06.120 | that was probably a multi-headed attention process.
00:12:09.740 | So now I think we can summarize.
00:12:13.680 | Maybe the one big idea that's worth repeating
00:12:16.720 | is that we typically stack transformer blocks
00:12:19.440 | on top of each other.
00:12:20.640 | So this is the first block,
00:12:22.340 | I've got C input coming in and C out here,
00:12:25.040 | but C out could be the basis for a second transformer block
00:12:29.600 | where those were the inputs.
00:12:30.920 | And then of course we could repeat that process.
00:12:33.200 | And that is very typical to have 12, 24,
00:12:36.720 | maybe even hundreds of transformer blocks
00:12:39.220 | stacked on top of each other.
00:12:41.640 | And the other thing that's worth reminding yourself of
00:12:43.880 | is that these representations in orange here
00:12:47.020 | are probably not single-headed attention representations,
00:12:49.920 | but rather multi-headed ones
00:12:52.000 | where we piece together a bunch of component pieces
00:12:55.360 | that themselves correspond to a lot of learned parameters.
00:13:00.040 | And that is again,
00:13:01.240 | why this attention layer is so much a part
00:13:04.400 | of the transformer architecture.
00:13:06.060 | In addition to the fact that that's the one place
00:13:08.760 | where all of these columns of representations
00:13:11.320 | interact with each other.
00:13:12.800 | So that probably further emphasizes
00:13:15.080 | why the attention layer is so important
00:13:17.080 | and why it's good to have lots of heads in there
00:13:19.440 | offering lots of diversity
00:13:20.820 | for this crucial interactional layer
00:13:23.680 | across the different parts of the sequence.
00:13:29.260 | So that is the essence of it.
00:13:31.420 | And I hope that you are now in a position
00:13:34.160 | to better understand the famous transformer diagram
00:13:38.500 | that appears in the attention is all you need paper.
00:13:41.380 | I will confess to you that I myself on first reading
00:13:43.780 | did not understand this diagram,
00:13:45.780 | but now I feel that I do understand it.
00:13:48.820 | Reminder that in that paper,
00:13:50.620 | they are dealing mainly with sequence to sequence problems
00:13:53.460 | so that they have an encoder and a decoder.
00:13:56.980 | And so now we can see that on the encoder side here,
00:13:59.820 | what they've depicted is repeated for every step
00:14:04.340 | in that encoder thing.
00:14:05.540 | So every step in the sequence that we're processing.
00:14:08.740 | And once you see that, you can see, okay,
00:14:10.540 | they've used the same,
00:14:11.620 | I use the same colors that they did.
00:14:13.220 | So red for the embeddings,
00:14:15.780 | we have multi-headed attention,
00:14:17.540 | additive and layer norm steps.
00:14:20.220 | Then we have the feed forward part,
00:14:22.260 | more normalization and kind of adding together
00:14:25.180 | of different representations.
00:14:27.020 | That's that same rhythm that I pointed out before.
00:14:30.060 | That's on the encoder side.
00:14:31.820 | On the decoder side, things get a little more complicated.
00:14:34.620 | We're gonna return to some of these details,
00:14:36.900 | but the important thing is that now we need to do
00:14:38.940 | masked attention because as we think about decoding,
00:14:42.180 | we need to be sure that our attention layer
00:14:44.860 | doesn't look into the future.
00:14:46.460 | We need to mask out future states
00:14:48.740 | and look only into the past when we do those dot products.
00:14:52.300 | So that's the masking down here,
00:14:53.740 | but otherwise the decoder has the same exact structure
00:14:57.100 | as the encoder.
00:14:58.780 | They do have additional parameters on top here
00:15:00.900 | corresponding to output probabilities.
00:15:03.300 | If we're doing something like machine translation
00:15:05.540 | or language modeling,
00:15:06.540 | we'll have those heads on every single state in the decoder.
00:15:10.700 | But if we're doing something like classification,
00:15:12.900 | we might have those task specific parameters
00:15:15.860 | only on one of the output states, maybe the final one.
00:15:19.260 | But other than that,
00:15:21.460 | you can see the same pieces that I've discussed before
00:15:24.340 | just presented in this encoder decoder phase.
00:15:27.740 | So I hope that helps a little bit with the famous diagram.
00:15:31.540 | The final thing I wanted to say under this heading
00:15:34.060 | is just that you can get an even deeper feel
00:15:36.460 | for how these models work by downloading them
00:15:39.220 | and using hugging face code
00:15:40.820 | to kind of inspect their structure.
00:15:43.100 | I've done that on this slide with BERT base,
00:15:46.300 | and this is really illuminating.
00:15:47.700 | You see a lot of the pieces that we've already discussed.
00:15:50.180 | This is the BERT model.
00:15:52.020 | It's got an embedding layer, which has word embeddings.
00:15:55.060 | And you can see that there are about 30,000 items
00:15:57.540 | in the embedding space, each one dimensionality 768.
00:16:01.180 | That's DK that I emphasize so much.
00:16:04.580 | The positional embeddings, we have 512 positional embeddings.
00:16:08.100 | So that will be our maximum sequence length.
00:16:10.660 | And those by definition
00:16:12.100 | have to have dimensionality 768 as well.
00:16:15.460 | We'll return to these token type embeddings
00:16:17.500 | when we talk about BERT in particular,
00:16:20.340 | but that's kind of like a positional embedding.
00:16:23.140 | Then we have layer norm and dropout.
00:16:24.740 | So that's kind of regularization of these values.
00:16:27.460 | And then we have the layers.
00:16:29.420 | And what you can see on this slide is just the first layer.
00:16:31.860 | It's the same structure repeated for all subsequent layers.
00:16:36.060 | Down here, we have the attention layer.
00:16:37.660 | You see 768 all over the place because that's DK.
00:16:41.740 | And the model pretty much defines for us
00:16:43.580 | that we need to have that same dimensionality everywhere.
00:16:46.660 | The one exception is that when we get down
00:16:48.620 | into the feed forward layers,
00:16:50.660 | we go from 768 out to 3072.
00:16:54.500 | That's that intermediate part.
00:16:56.460 | But then we have to go from 3072 back to 768 for the output
00:17:00.900 | so that we can stack these components on top of each other.
00:17:03.940 | But you can see that opportunity there
00:17:06.100 | to add a lot more parameters
00:17:07.900 | and therefore a lot more representational power.
00:17:11.980 | And as I said, this would continue for all the layers.
00:17:14.820 | And that's pretty much a summary of the architecture.
00:17:17.860 | And you can do this for lots of different models
00:17:20.020 | with Hugging Face.
00:17:20.860 | You can check out GPT and BERT and Roberta
00:17:23.420 | and all the other models we talk about.
00:17:25.260 | They'll differ subtly in their kind of graphs,
00:17:28.180 | but I expect that you'll see a lot of the core pieces
00:17:31.220 | repeated in various flavors as you look at those models.
00:17:35.300 | (upbeat music)