Stanford XCS224U: NLU I Contextual Word Representations, Part 3: Positional Encoding I Spring 2023

Whisper Transcript | Transcript Only Page

00:00:00.000 | Welcome back everyone.
00:00:06.160 | This is part 3 in our series on contextual representations.
00:00:09.600 | We have a bunch of famous transformer-based
00:00:12.040 | architectures that we're going to talk about a bit later.
00:00:14.480 | But before doing that,
00:00:15.800 | I thought it would be good to pause and just reflect
00:00:18.600 | a little bit on this important notion of positional encoding.
00:00:21.480 | This is an idea that I feel
00:00:23.200 | the field took for granted for too long.
00:00:25.320 | I certainly took it for granted for too long.
00:00:27.440 | I think we now see that this is a crucial factor
00:00:31.000 | in shaping the performance of transformer-based models.
00:00:34.360 | Let's start by reflecting on the role of
00:00:36.840 | positional encoding in the context of the transformer.
00:00:39.520 | I think the central observation is that
00:00:41.760 | the transformer itself has only a very limited capacity
00:00:45.460 | to keep track of word order.
00:00:47.380 | The attention mechanisms are themselves not directional,
00:00:50.720 | it's just a bunch of dot products.
00:00:52.720 | There are no other interactions between the columns.
00:00:56.560 | We are in grave danger of losing track of the fact that
00:01:00.680 | the input sequence ABC is different from the input sequence CBA.
00:01:06.120 | Positional encodings will ensure that we retain
00:01:09.280 | a difference between those two sequences no matter what we
00:01:12.040 | do with the representations that come from the model.
00:01:15.880 | Secondarily, there's another purpose that
00:01:18.400 | positional encodings play which is hierarchical.
00:01:21.300 | They've been used to keep track of things like
00:01:23.280 | premise hypothesis in natural language inference.
00:01:26.380 | That was an important feature of the BERT model
00:01:28.740 | that we'll talk about a bit later in the series.
00:01:31.960 | I think there are a lot of perspectives that you
00:01:34.400 | could take on positional encoding.
00:01:36.580 | To keep things simple, I thought I would center
00:01:38.680 | our discussion around two crucial questions.
00:01:41.900 | The first is, does the set of
00:01:44.360 | positions need to be decided ahead of time?
00:01:47.860 | The second is, does the positional encoding scheme
00:01:51.640 | hinder generalization to new positions?
00:01:55.200 | I think those are good questions to guide us.
00:01:57.960 | One other rule that I wanted to introduce is the following.
00:02:02.080 | Modern transformer architectures might impose
00:02:04.880 | a max length on sequences for
00:02:06.900 | many reasons related to how they were designed and optimized.
00:02:10.520 | I would like to set all of that aside and just ask whether
00:02:14.220 | the positional encoding scheme itself is imposing
00:02:18.060 | anything about length generalization separately
00:02:20.360 | from all that other stuff that might be happening.
00:02:23.280 | Let's start with absolute positional encoding.
00:02:26.300 | This is the scheme that we have talked about so far.
00:02:29.520 | On this scheme, we have word representations,
00:02:32.760 | and we also have positional representations that we have
00:02:35.560 | learned corresponding to some fixed number of dimensions.
00:02:39.100 | To get our position-sensitive word representation,
00:02:43.220 | we simply add together the word vector with the position vector.
00:02:47.780 | How is this scheme doing for our two crucial questions?
00:02:51.580 | Well, not so well.
00:02:53.160 | First, obviously, the set of
00:02:55.280 | positions needs to be decided ahead of time.
00:02:58.160 | When we set up our model,
00:02:59.700 | we will have some embedding space,
00:03:01.640 | maybe up to 512.
00:03:04.180 | If we picked 512,
00:03:06.000 | when we hit position 513,
00:03:08.440 | we will not have a positional representation for that position.
00:03:13.860 | I also think it's clear that this scheme can
00:03:17.520 | hinder generalization to new positions
00:03:20.400 | even for familiar phenomena.
00:03:22.900 | Just consider the fact that the rock as a phrase,
00:03:26.400 | if it occurs early in the sequence,
00:03:28.620 | is simply a different representation than
00:03:31.240 | the rock if it appears later in the sequence.
00:03:34.520 | There will be some shared features across these two as a result of
00:03:38.440 | the fact that we have two word vectors involved in both places.
00:03:42.520 | But we add in those positional representations
00:03:45.480 | as equal partners in this representation,
00:03:48.000 | and I think the result is very heavy-handed when it comes to
00:03:51.880 | learning representations that are heavily position-dependent.
00:03:55.600 | That could make it hard for the model to see that in some sense,
00:03:59.040 | the rock is the same phrase whether it's at
00:04:01.840 | the start of the sequence or the middle or the end.
00:04:05.920 | Another scheme we could consider actually
00:04:08.680 | goes all the way back to the Transformers paper.
00:04:11.000 | I've called this frequency-based positional encoding.
00:04:14.220 | There are lots of ways we could set this up,
00:04:17.040 | but the essential idea here is that we'll define
00:04:19.880 | a mathematical function that given a position,
00:04:23.240 | will give us back a vector that encodes
00:04:25.960 | information about that position semantically in its structure.
00:04:30.160 | In the Transformer paper,
00:04:31.840 | they picked a scheme that's based in frequency oscillation.
00:04:35.400 | Essentially based in sine and cosine frequencies for
00:04:38.920 | these vectors where higher positions oscillate more frequently,
00:04:43.560 | and that information is encoded in
00:04:45.840 | the position vector that we create.
00:04:48.020 | I think there are lots of other schemes that we could use.
00:04:50.480 | The essential feature of this is this argument pause here.
00:04:54.320 | If you give this function position 1,
00:04:56.960 | it gives you a vector.
00:04:58.240 | If you give it 513,
00:04:59.920 | it gives you a vector.
00:05:00.840 | If you give it a million,
00:05:02.100 | it gives you a vector.
00:05:03.400 | All of those vectors manifestly do encode
00:05:07.160 | information about the relative position of that input.
00:05:12.320 | We have definitely overcome the first limitation,
00:05:15.600 | the set of positions does not need to be decided ahead of time in
00:05:18.880 | this scheme because we can fire off
00:05:20.480 | a new vector for any position that you give us.
00:05:23.800 | But I think our second question remains pressing.
00:05:26.880 | Just as before, this scheme can hinder generalization to
00:05:30.600 | new positions even for familiar phenomena in virtue of
00:05:34.160 | the fact that we are taking those word representations and adding
00:05:37.440 | in these positional ones for different positions as equal partners,
00:05:42.200 | as I said, and I think that makes it hard for models to
00:05:45.200 | see that the same phrase could appear in multiple places.
00:05:49.680 | The third scheme is the most
00:05:51.840 | promising of the three that we're going to discuss.
00:05:54.080 | This is relative positional encoding.
00:05:56.440 | We're going to take a few steps to build up
00:05:58.460 | an understanding of how the scheme works.
00:06:01.000 | Let's start with a reminder.
00:06:02.560 | This is a picture of the attention layer of the transformer.
00:06:06.440 | We have our three position sensitive inputs here,
00:06:09.600 | A input, B input, and C input.
00:06:11.840 | Remember, it's crucial that they be position sensitive because of
00:06:16.080 | how much symmetry there is in these dot product attention mechanisms.
00:06:21.000 | Here's a reminder about how that calculation
00:06:23.800 | works with respect to position C over here.
00:06:27.040 | For positional encoding,
00:06:29.180 | we really just add in some new parameters.
00:06:31.360 | What I've depicted at the bottom of the slide here is
00:06:34.160 | the same calculation that's at the top,
00:06:36.600 | except now in two crucial places,
00:06:38.980 | I have added in some new vectors
00:06:41.760 | that we're going to learn representations for.
00:06:44.100 | Down in blue here,
00:06:45.560 | we have key representations,
00:06:47.720 | which get added into this dot product.
00:06:50.300 | We up here in the final step,
00:06:52.320 | we have value representations,
00:06:54.120 | which get added in to
00:06:55.640 | this multiplied attention mechanism plus the thing we're attending to.
00:07:00.200 | Those are the new crucial parameters that we're adding in here.
00:07:05.520 | The essential idea is that having done this
00:07:08.640 | with all the position sensitivity that's going to be encoded in these vectors,
00:07:12.640 | we don't need these green representations here anymore to have
00:07:16.520 | positional information in them because that positional information is
00:07:20.100 | now being introduced in the attention layer because we're going to have
00:07:24.320 | potentially new vectors for every combination of
00:07:27.160 | position as indicated by these subscripts.
00:07:30.680 | But that's only part of the story.
00:07:32.920 | I think the really powerful thing about this method is
00:07:36.920 | the notion of having a positional encoding window.
00:07:40.440 | To illustrate that, I've repeated
00:07:42.720 | the core calculation at the top here as a reminder.
00:07:45.840 | Now for my illustration,
00:07:47.440 | I'm going to set the window size to two.
00:07:50.080 | Here's the input sequence that we'll use as an example.
00:07:54.280 | Above that, I'm going to show you
00:07:56.160 | just integers corresponding to the positions.
00:07:58.760 | Those aren't directly ingredients into the model,
00:08:01.600 | but they will help us keep track of where we are in the calculations.
00:08:06.200 | To start the illustration,
00:08:07.800 | let's zoom in on position 4.
00:08:10.640 | If we follow the letter of the definitions that I've offered so far for the key values here,
00:08:16.680 | we're going to have a vector A_44 corresponding to us attending from position 4 to position 4.
00:08:24.280 | As part of creating this more limited window-based version of the model,
00:08:29.320 | we're actually going to map that into a single vector W_0 for the keys.
00:08:35.200 | Now we travel to the position 1 to the left.
00:08:38.400 | In this case, we would have a vector A_43 for the keys.
00:08:42.960 | But what we're going to do is map that into a single vector W_-1,
00:08:47.480 | corresponding to taking 3 minus 4.
00:08:51.120 | When we travel one more to the left,
00:08:53.400 | we get a position 4, 2,
00:08:55.960 | but now we're going to map that to vector W_-2,
00:08:59.240 | again for the keys.
00:09:00.880 | Then because we set our window size to 2,
00:09:04.080 | when we get all the way to that leftmost position,
00:09:07.040 | that's also just W_-2 again.
00:09:10.000 | 4 minus 1, given the window size,
00:09:12.560 | takes us just to the maximum of this window,
00:09:14.960 | in this case, minus 2.
00:09:17.160 | Then a parallel thing happens when we travel to the right.
00:09:20.200 | We go from 4 to 5,
00:09:21.960 | that gives us vector W_1 for the keys.
00:09:25.000 | Then 4, 6 gives us W_2.
00:09:27.640 | Then when we get to the third position from our starting point,
00:09:31.240 | that again just flattens out to W_2 because of our window size.
00:09:36.480 | Actually represented in blue here,
00:09:39.240 | we have just a few vectors,
00:09:41.520 | the 0, 1, the minus 1,
00:09:43.320 | and the minus 2, 1,
00:09:44.800 | and then the 1, 2 vectors,
00:09:48.080 | as opposed to all the distinctions that are made with those alpha,
00:09:51.880 | sub 4, 3, and 4, 2, and so forth.
00:09:55.160 | We're collapsing those down into a smaller number of vectors
00:09:59.000 | corresponding to the window size.
00:10:01.320 | Then to continue the illustration,
00:10:03.480 | if we zoom in on position 3,
00:10:05.680 | that would be vector A_3, 3 for the keys,
00:10:09.120 | but now that gets mapped to W_0,
00:10:11.520 | k, which is the same vector that we have up here in that 4, 4 position.
00:10:16.680 | A similar collapsing is going to happen down here.
00:10:19.240 | When we move one to the left of that,
00:10:20.920 | we get minus 1,
00:10:21.960 | which is the same vector as we had up here just to the right.
00:10:27.000 | Then we have the same thing over here,
00:10:29.240 | minus 2 corresponding to the same vector that we had above.
00:10:33.560 | That would continue and we have a parallel calculation for
00:10:37.360 | the value parameters that you see in purple up here,
00:10:40.440 | the same notions of relative position and window size.
00:10:44.360 | We actually learn a relatively small number of position vectors.
00:10:49.040 | What we're doing is essentially giving
00:10:52.200 | a small window relative notion of position that's going to
00:10:56.240 | slide around and give us a lot of ability to generalize to
00:11:00.040 | new positions based on combinations that we've seen before,
00:11:03.440 | possibly in other parts of these inputs.
00:11:07.280 | A final thing I'll say is that this is actually
00:11:10.800 | embedded in that full theory of attention that might have
00:11:13.400 | a lot of learned parameters and might even be multi-headed.
00:11:16.320 | What I've depicted here is just
00:11:18.200 | the full calculation just to really give you all the details.
00:11:21.760 | But again, the cognitive shortcut is that it's
00:11:24.880 | the previous attention calculation
00:11:27.740 | with these new positional elements added in.
00:11:30.480 | Again, a reminder, in this new mode,
00:11:32.960 | we introduce position relativity in
00:11:34.800 | the attention layer, not in the embedding layer.
00:11:38.360 | Let's think about our two crucial questions.
00:11:41.080 | First, we don't need to decide the set of positions ahead of time,
00:11:44.240 | we just need to decide on the window.
00:11:46.480 | Then for a potentially extremely long string,
00:11:50.000 | we're just sliding it around in it using
00:11:52.320 | a relatively few number of
00:11:54.520 | positional vectors to keep track of relative position.
00:11:58.280 | I think we have also largely overcome the concern that
00:12:02.200 | positional embeddings might hinder generalization to new positions.
00:12:06.420 | After all, if you consider a phrase like the rock,
00:12:10.220 | the core position vectors that are involved there are 0,
00:12:15.160 | 1, and minus 1,
00:12:16.700 | no matter where this appears in the string.
00:12:19.280 | Now, depending on where it appears,
00:12:21.240 | there will be other positional things that are
00:12:23.440 | happening and other information will be
00:12:25.280 | brought in as part of the calculation.
00:12:27.400 | But we do have this sense of constancy that will allow the model to
00:12:31.180 | see that the rock is the same
00:12:33.720 | essentially wherever it appears in the string.
00:12:36.920 | My hypothesis is that because we have overcome these two crucial limitations,
00:12:42.440 | relative positional encoding is a very good bet for how to
00:12:46.160 | do positional encoding in general in the transformer.
00:12:48.880 | I believe that that is now well-supported
00:12:51.920 | by results across the field for the transformer.
00:12:56.600 | [BLANK_AUDIO]