back to index

Stanford CS224N NLP with Deep Learning | 2023 | Lecture 8 - Self-Attention and Transformers


Whisper Transcript | Transcript Only Page

00:00:00.000 | Hi, everyone.
00:00:07.280 | Welcome to CS224N.
00:00:09.280 | We're about two minutes in, so let's get started.
00:00:12.880 | So today, we've got what I think is quite an exciting lecture topic.
00:00:17.240 | We're going to talk about self-attention and transformers.
00:00:21.760 | So these are some ideas that are sort of the foundation
00:00:25.320 | of most of the modern advances in natural language processing.
00:00:29.400 | And actually, AI systems in a broad range of fields.
00:00:34.640 | So it's a very, very fun topic.
00:00:37.600 | Before we get into that--
00:00:39.480 | OK, before we get into that, we're going to have a couple of reminders.
00:00:49.240 | So there are brand new lecture notes.
00:00:51.620 | [CHEERING]
00:00:53.560 | Nice, thank you.
00:00:54.640 | Yeah.
00:00:57.140 | I'm very excited about them.
00:00:59.280 | They go into-- they pretty much follow along
00:01:02.460 | with what I'll be talking about today, but go into considerably more detail.
00:01:07.700 | Assignment four is due a week from today.
00:01:11.780 | Yeah, so the issues with Azure continue.
00:01:15.380 | Thankfully-- woo!
00:01:16.540 | Thankfully, our TAs especially has tested that this works on Colab,
00:01:25.100 | and the amount of training is such that a Colab session will
00:01:29.040 | allow you to train your machine translation system.
00:01:33.340 | So if you don't have a GPU, use Colab.
00:01:35.180 | We're continuing to work on getting access
00:01:37.100 | to more GPUs for assignment five in the final project.
00:01:41.900 | We'll continue to update you as we're able to.
00:01:44.660 | But the usual systems this year are no longer
00:01:49.100 | holding because companies are changing their minds about things.
00:01:52.140 | OK, so our final project proposal, you have a proposal
00:01:57.540 | of what you want to work on for your final project.
00:02:00.460 | We will give you feedback on whether we think it's a feasible idea
00:02:04.400 | or how to change it.
00:02:05.260 | So this is very important because we want you to work on something
00:02:07.980 | that we think has a good chance of success for the rest of the quarter.
00:02:11.160 | That's going to be out tonight.
00:02:12.460 | We'll have an ad announcement when it is out.
00:02:15.660 | And we want to get you feedback on that pretty quickly
00:02:19.020 | because you'll be working on this after assignment five is done.
00:02:22.140 | Really, the major core component of the course after that
00:02:26.340 | is the final project.
00:02:29.380 | OK, any questions?
00:02:32.860 | Cool.
00:02:33.860 | So let's take a look back into what we've done so far in this course
00:02:41.900 | and see what we were doing in natural language processing.
00:02:47.120 | What was our strategy?
00:02:48.080 | If you had a natural language processing problem
00:02:50.080 | and you wanted to take your best effort attempt at it
00:02:53.060 | without doing anything too fancy, you would have said, OK,
00:02:56.040 | I'm going to have a bidirectional LSTM instead of a simple RNN.
00:03:01.180 | I'm going to use an LSTM to encode my sentences,
00:03:04.220 | I get bidirectional context.
00:03:06.140 | And if I have an output that I'm trying to generate,
00:03:09.300 | I'll have a unidirectional LSTM that I was going to generate one by one.
00:03:14.100 | So you have a translation or a parse or whatever.
00:03:17.140 | And so maybe I've encoded in a bidirectional LSTM the source sentence
00:03:20.500 | and I'm sort of one by one decoding out the target
00:03:24.260 | with my unidirectional LSTM.
00:03:26.480 | And then also, I was going to use something like attention
00:03:30.680 | to give flexible access to memory if I felt
00:03:34.800 | like I needed to do this sort of look back and see
00:03:37.080 | where I want to translate from.
00:03:39.040 | And this was just working exceptionally well.
00:03:41.960 | And we motivated attention through wanting to do machine translation.
00:03:46.200 | And you have this bottleneck where you don't
00:03:48.140 | want to have to encode the whole source sentence in a single vector.
00:03:52.840 | And in this lecture, we have the same goal.
00:03:55.000 | So we're going to be looking at a lot of the same problems
00:03:57.380 | that we did previously.
00:03:58.520 | But we're going to use different building blocks.
00:04:00.560 | We're going to say, if 2014 to 2017-ish I was using recurrence
00:04:07.480 | through lots of trial and error, years later,
00:04:10.400 | it had these brand new building blocks that we
00:04:12.720 | can plug in, direct replacement for LSTMs.
00:04:17.160 | And they're going to allow for just a huge range of much more successful
00:04:22.120 | applications.
00:04:23.320 | And so what are the issues with the recurrent neural networks
00:04:28.680 | we used to use?
00:04:29.680 | And what are the new systems that we're going to use from this point moving
00:04:32.720 | forward?
00:04:35.160 | So one of the issues with a recurrent neural network
00:04:38.880 | is what we're going to call linear interaction distance.
00:04:41.680 | So as we know, RNNs are unrolled left to right or right to left,
00:04:47.200 | depending on the language and the direction.
00:04:49.840 | But it encodes the notion of linear locality, which is useful.
00:04:53.080 | Because if two words occur right next to each other,
00:04:55.600 | sometimes they're actually quite related.
00:04:57.320 | So tasty pizza.
00:04:58.640 | They're nearby.
00:04:59.680 | And in the recurrent neural network, you encode tasty.
00:05:04.360 | And then you walk one step, and you encode pizza.
00:05:08.720 | So nearby words do often affect each other's meanings.
00:05:12.600 | But you have this problem where very long distance dependencies
00:05:17.200 | can take a very long time to interact.
00:05:18.960 | So if I have the sentence, the chef--
00:05:21.400 | so those are nearby.
00:05:22.600 | Those interact with each other.
00:05:25.120 | And then who, and then a bunch of stuff.
00:05:28.680 | Like the chef who went to the stores and picked up the ingredients
00:05:32.520 | and loves garlic.
00:05:35.320 | And then was.
00:05:37.160 | Like I actually have an RNN step, this sort
00:05:40.440 | of application of the recurrent weight matrix
00:05:43.000 | and some element-wise nonlinearities once, twice, three times.
00:05:47.320 | As many times as there is potentially the length of the sequence between chef
00:05:52.520 | and was.
00:05:53.760 | And it's the chef who was.
00:05:54.960 | So this is a long distance dependency.
00:05:56.840 | Should feel kind of related to the stuff that we did in dependency syntax.
00:06:01.120 | But it's quite difficult to learn potentially
00:06:06.440 | that these words should be related.
00:06:09.080 | So if you have a lot of steps between words,
00:06:19.440 | it can be difficult to learn the dependencies between them.
00:06:22.400 | We talked about all these gradient problems.
00:06:24.240 | LSTMs do a lot better at modeling the gradients across long distances
00:06:29.680 | than simple recurrent neural networks.
00:06:31.280 | But it's not perfect.
00:06:33.960 | And we already know that this linear order
00:06:36.680 | isn't sort of the right way to think about sentences.
00:06:40.520 | So if I wanted to learn that it's the chef who was,
00:06:46.920 | then I might have a hard time doing it because the gradients have
00:06:51.720 | to propagate from was to chef.
00:06:53.160 | And really, I'd like more direct connection
00:06:56.440 | between words that might be related in the sentence.
00:06:59.080 | Or in a document even, if these are going to get much longer.
00:07:04.000 | So this is this linear interaction distance problem.
00:07:06.160 | We would like words that might be related
00:07:08.400 | to be able to interact with each other in the neural networks
00:07:11.000 | computation graph more easily than being linearly far away
00:07:19.800 | so that we can learn these long distance dependencies better.
00:07:23.000 | And there's a related problem too that again comes back
00:07:25.560 | to the recurrent neural networks dependence on the index.
00:07:28.640 | On the index into the sequence, often called a dependence on time.
00:07:32.880 | So in a recurrent neural network, the forward and backward passes
00:07:36.800 | have O of sequence length many.
00:07:39.520 | So that means just roughly sequence, in this case,
00:07:41.600 | just sequence length many unparallelizable operations.
00:07:45.000 | So we know GPUs are great.
00:07:47.240 | They can do a lot of operations at once,
00:07:50.520 | as long as there's no dependency between the operations in terms
00:07:53.840 | of time.
00:07:54.360 | You have to compute one and then compute the other.
00:07:57.680 | But in a recurrent neural network, you
00:07:59.800 | can't actually compute the RNN hidden state for time step 5
00:08:03.800 | before you compute the RNN hidden state for time step 4
00:08:06.920 | or time step 3.
00:08:08.560 | And so you get this graph that looks very similar,
00:08:11.320 | where if I want to compute this hidden state,
00:08:13.200 | so I've got some word, I have zero operations
00:08:16.160 | I need to do before I can compute this state.
00:08:18.600 | I have one operation I can do before I can compute this state.
00:08:22.560 | And as my sequence length grows, I've got--
00:08:25.280 | OK, here I've got three operations
00:08:27.040 | I need to do before I can compute
00:08:28.880 | the state with the number 3, because I need to compute this
00:08:32.080 | and this and that.
00:08:33.880 | So there's three unparallelizable operations
00:08:37.000 | that I'm glomming all the matrix multiplies and stuff
00:08:39.640 | into a single one.
00:08:40.880 | So 1, 2, 3.
00:08:42.480 | And of course, this grows with the sequence length as well.
00:08:45.320 | So down over here, as the sequence length grows,
00:08:48.720 | I can't parallelize--
00:08:50.520 | I can't just have a big GPU just kachanka
00:08:53.600 | with the matrix multiply to compute this state,
00:08:56.840 | because I need to compute all the previous states beforehand.
00:08:59.600 | OK, any questions about that?
00:09:03.520 | So these are these two related problems,
00:09:06.040 | both with the dependence on time.
00:09:07.960 | Yeah.
00:09:08.800 | Yeah, so I have a question on the linear interaction issues.
00:09:11.360 | I thought that was the whole point of the attention network,
00:09:13.880 | and then how maybe you want, during the training,
00:09:17.960 | of the actual cells that depend more on each other.
00:09:21.080 | Can't we do something like the attention
00:09:22.840 | and then work our way around that?
00:09:26.200 | So the question is, with the linear interaction distance,
00:09:28.760 | wasn't this the point of attention
00:09:30.440 | that gets around that?
00:09:31.720 | Can't we use something with attention to help,
00:09:33.880 | or does that just help?
00:09:35.040 | So it won't solve the parallelizability problem.
00:09:37.480 | And in fact, everything we do in the rest of this lecture
00:09:39.840 | will be attention-based.
00:09:41.160 | But we'll get rid of the recurrence
00:09:42.620 | and just do attention, more or less.
00:09:44.360 | So well, yeah, it's a great intuition.
00:09:48.720 | Any other questions?
00:09:49.920 | OK, cool.
00:09:54.480 | So if not recurrence, what about attention?
00:09:57.920 | See, I'm just a slide back.
00:10:00.040 | And so we're going to get deep into attention today.
00:10:04.440 | But just for the second, attention
00:10:06.520 | treats each word's representation
00:10:08.160 | as a query to access and incorporate information
00:10:11.480 | from a set of values.
00:10:12.840 | So previously, we were in a decoder.
00:10:14.880 | We were decoding out a translation of a sentence.
00:10:17.360 | And we attended to the encoder so
00:10:19.280 | that we didn't have to store the entire representation
00:10:21.520 | of the source sentence into a single vector.
00:10:24.020 | And here, today, we'll think about attention
00:10:26.120 | within a single sentence.
00:10:27.520 | So I've got this sentence written out here
00:10:29.840 | with a word 1 through word t, in this case.
00:10:32.680 | And right on these integers in the boxes,
00:10:35.960 | I'm writing out the number of unparallelizable operations
00:10:38.840 | that you need to do before you can compute these.
00:10:41.880 | So for each word, you can independently
00:10:43.600 | compute its embedding without doing anything else previously,
00:10:46.920 | because the embedding just depends on the word identity.
00:10:50.320 | And then with attention, if I wanted
00:10:53.460 | to build an attention representation of this word
00:10:55.580 | by looking at all the other words in the sequence,
00:10:57.780 | that's one big operation.
00:10:59.740 | And I can do them in parallel for all the words.
00:11:02.460 | So the attention for this word, I
00:11:04.460 | can do for the attention for this word.
00:11:06.100 | I don't need to walk left to right like I did for an RNN.
00:11:09.100 | Again, we'll get much deeper into this.
00:11:10.940 | But you should have the intuition
00:11:13.940 | that it solves the linear interaction
00:11:16.220 | problem and the non-parallelizability problem.
00:11:18.860 | Because now, no matter how far away words are from each other,
00:11:22.120 | I am potentially interacting.
00:11:23.980 | I might just attend to you, even if you're very, very far away,
00:11:27.600 | sort of independent of how far away you are.
00:11:29.840 | And I also don't need to sort of walk along the sequence
00:11:33.120 | linearly long.
00:11:34.320 | So I'm treating the whole sequence at once.
00:11:36.840 | All right.
00:11:38.320 | So the intuition is that attention
00:11:40.520 | allows you to look very far away at once.
00:11:42.280 | And it doesn't have this dependence on the sequence
00:11:44.440 | index that keeps us from parallelizing operations.
00:11:47.120 | And so now, the rest of the lecture
00:11:48.620 | will talk in great depth about attention.
00:11:51.780 | So maybe let's just move on.
00:11:56.300 | So let's think more deeply about attention.
00:12:00.180 | One thing that you might think of with attention
00:12:02.660 | is that it's sort of performing kind of a fuzzy lookup
00:12:05.540 | in a key value store.
00:12:07.180 | So you have a bunch of keys, a bunch of values,
00:12:09.540 | and it's going to help you sort of access that.
00:12:12.140 | So in an actual lookup table, just
00:12:14.300 | like a dictionary in Python, for example, very simple.
00:12:18.220 | You have a table of keys that each key maps to a value.
00:12:22.100 | And then you give it a query.
00:12:23.500 | And the query matches one of the keys.
00:12:26.300 | And then you return the value.
00:12:27.940 | So I've got a bunch of keys here.
00:12:31.420 | And my query matches the key.
00:12:33.220 | So I return the value.
00:12:34.660 | Simple, fair, easy.
00:12:37.740 | Good.
00:12:39.500 | And in attention, so just like we saw before,
00:12:44.060 | the query matches all keys softly.
00:12:46.660 | There's no exact match.
00:12:48.940 | You sort of compute some sort of similarity
00:12:50.780 | between the key and all of the--
00:12:52.660 | sorry, the query and all of the keys.
00:12:54.620 | And then you sort of weight the results.
00:12:56.260 | So you've got a query again.
00:12:57.780 | You've got a bunch of keys.
00:13:00.020 | The query, to different extents, is similar to each of the keys.
00:13:04.500 | And you will sort of measure that similarity between 0 and 1
00:13:08.460 | through a softmax.
00:13:10.140 | And then you get the values out.
00:13:12.380 | So you average them via the weights of the similarity
00:13:15.660 | between the key and the query and the keys.
00:13:18.780 | You do a weighted sum with those weights.
00:13:20.580 | And you get an output.
00:13:21.660 | So it really is quite a bit like a lookup table,
00:13:24.780 | but in this sort of soft vector space, mushy sort of sense.
00:13:29.940 | So I'm really doing some kind of accessing
00:13:32.140 | into this information that's stored in the key value store.
00:13:35.820 | But I'm sort of softly looking at all of the results.
00:13:41.220 | OK, any questions there?
00:13:42.300 | Cool.
00:13:46.940 | So what might this look like?
00:13:48.620 | So if I was trying to represent this sentence,
00:13:50.980 | I went to Stanford CS224n and learned.
00:13:54.260 | So I'm trying to build a representation of learned.
00:13:56.260 | I have a key for each word.
00:14:01.580 | So this is this self-attention thing that we'll get into.
00:14:04.500 | I have a key for each word, a value for each word.
00:14:06.740 | I've got the query for learned.
00:14:08.380 | And I've got these sort of tealish bars up top,
00:14:11.620 | which sort of might say how much you're
00:14:13.620 | going to try to access each of the word.
00:14:15.500 | Like, oh, maybe 224n is not that important.
00:14:18.300 | CS, maybe that determines what I learned.
00:14:20.500 | You know, Stanford.
00:14:22.700 | And then learned, maybe that's important to representing
00:14:25.180 | itself.
00:14:25.900 | So you sort of look across at the whole sentence
00:14:28.100 | and build up this sort of soft accessing of information
00:14:31.020 | across the sentence in order to represent learned in context.
00:14:35.860 | So this is just a toy diagram.
00:14:38.860 | So let's get into the math.
00:14:40.460 | So we're going to look at a sequence of words.
00:14:43.540 | So that's w1 to n, a sequence of words in a vocabulary.
00:14:46.860 | So this is like, you know, Zuko made his uncle tea.
00:14:49.140 | That's a good sequence.
00:14:50.340 | And for each word, we're going to embed it
00:14:52.580 | with this embedding matrix, just like we've
00:14:54.620 | been doing in this class.
00:14:56.180 | So I have this embedding matrix that
00:14:57.780 | goes from the vocabulary size to the dimensionality d.
00:15:02.460 | So each word has a non-contextual,
00:15:04.660 | only dependent on itself, word embedding.
00:15:07.500 | And now I'm going to transform each word with one of three
00:15:11.340 | different weight matrices.
00:15:12.500 | So this is often called key query value self-attention.
00:15:16.820 | So I have a matrix Q, which is an rd to d.
00:15:19.980 | So this maps xi, which is a vector of dimensionality d,
00:15:23.420 | to another vector of dimensionality d.
00:15:25.780 | And that's going to be a query vector.
00:15:28.500 | So it takes an xi and it sort of rotates it,
00:15:31.260 | shuffles it around, stretches it, squishes it.
00:15:33.940 | Makes it different.
00:15:35.060 | And now it's a query.
00:15:35.940 | And now for a different learnable parameter, k--
00:15:38.300 | so that's another matrix.
00:15:39.500 | I'm going to come up with my keys.
00:15:41.940 | And with a different learnable parameter, v,
00:15:45.220 | I'm going to come up with my values.
00:15:47.100 | So I'm taking each of the non-contextual word embeddings,
00:15:49.640 | each of these xi's, and I'm transforming each of them
00:15:53.300 | to come up with my query for that word, my key for that word,
00:15:57.140 | and my value for that word.
00:16:00.220 | So every word is doing each of these roles.
00:16:03.700 | Next, I'm going to compute all pairs of similarities
00:16:06.660 | between the keys and queries.
00:16:08.220 | So in the toy example we saw, I was
00:16:10.500 | computing the similarity between a single query for the word
00:16:13.260 | learned and all of the keys for the entire sentence.
00:16:17.380 | In this context, I'm computing all pairs of similarities
00:16:20.380 | between all keys and all values because I want to represent
00:16:24.300 | all of these sums.
00:16:25.140 | So I've got this sort of dot--
00:16:27.620 | I'm just going to take the dot product between these two
00:16:29.780 | vectors.
00:16:30.420 | So I've got qi.
00:16:31.820 | So this is saying the query for word i
00:16:34.140 | dotted with the key for word j.
00:16:36.100 | And I get this score, which is a real value.
00:16:40.980 | Might be very large negative, might be zero,
00:16:42.820 | might be very large and positive.
00:16:44.660 | And so that's like, how much should I
00:16:46.580 | look at j in this lookup table?
00:16:50.140 | And then I do the softmax.
00:16:51.340 | So I softmax.
00:16:52.580 | So I say that the actual weight that I'm
00:16:55.140 | going to look at j from i is softmax of this
00:16:58.660 | over all of the possible indices.
00:17:00.900 | So it's like the affinity between i and j
00:17:03.780 | normalized by the affinity between i
00:17:06.020 | and all of the possible j prime in the sequence.
00:17:08.220 | And then my output is just the weighted sum of values.
00:17:13.940 | So I've got this output for word i.
00:17:16.060 | So maybe i is like 1 for Zuko.
00:17:18.420 | And I'm representing it as the sum of these weights
00:17:22.140 | for all j.
00:17:23.140 | So Zuko and maid and his and uncle and t.
00:17:26.180 | And the value vector for that word j.
00:17:30.140 | I'm looking from i to j as much as alpha ij.
00:17:34.940 | What's the dimension of Wi?
00:17:37.380 | [INAUDIBLE]
00:17:39.380 | Oh, Wi, you can either think of it as a symbol in vocab v.
00:17:44.900 | So that's like, you could think of it as a one-hot vector.
00:17:47.660 | And yeah, in this case, we are, I guess, thinking of it as--
00:17:51.220 | so one-hot vector in dimensionality size of vocab.
00:17:54.380 | So in the matrix E, you see that it's r d by bars around v.
00:17:59.660 | That's size of the vocabulary.
00:18:01.700 | So when I do E multiplied by Wi, that's
00:18:05.100 | taking E, which is d by v, multiplying it by w, which is v,
00:18:10.540 | and returning a vector that's dimensionality d.
00:18:13.100 | So w in that first line, like w1n,
00:18:16.940 | that's a matrix where it has maybe
00:18:20.700 | like a column for every word in that sentence.
00:18:23.500 | And each column is a length v.
00:18:25.820 | Yeah, usually, I guess we think of it as having a--
00:18:28.460 | I mean, if I'm putting the sequence length index first,
00:18:31.780 | you might think of it as having a row for each word.
00:18:33.980 | But similarly, yeah, it's n, which is the sequence length.
00:18:37.020 | And then the second dimension would be v,
00:18:39.020 | which is the vocabulary size.
00:18:40.740 | And then that gets mapped to this thing, which
00:18:42.740 | is sequence length by d.
00:18:46.380 | Why do we learn two different matrices, q and k,
00:18:49.460 | when q transpose--
00:18:51.420 | qi transpose kj is really just one matrix in the middle?
00:18:56.100 | That's a great question.
00:18:57.100 | It ends up being because this will end up
00:18:59.500 | being a low-rank approximation to that matrix.
00:19:02.060 | So it is for computational efficiency reasons.
00:19:05.500 | Although it also, I think, feels kind
00:19:07.660 | of nice in the presentation.
00:19:09.940 | But yeah, what we'll end up doing
00:19:11.300 | is having a very low-rank approximation to qk transpose.
00:19:14.860 | And so you actually do do it like this.
00:19:17.380 | It's a good question.
00:19:19.780 | Is vii, so the query with any specific?
00:19:26.140 | Sorry, could you repeat that for me?
00:19:27.620 | This eii, so the query of the word dotted with the key
00:19:32.620 | by itself, does it look like an identity,
00:19:34.980 | or does it look like anything in particular?
00:19:37.420 | That's a good question.
00:19:38.340 | OK, let me remember to repeat questions.
00:19:40.500 | So does eii, for j equal to i, so looking at itself,
00:19:44.660 | look like anything in particular?
00:19:46.080 | Does it look like the identity?
00:19:47.660 | Is that the question?
00:19:48.820 | OK, so right, it's unclear, actually.
00:19:53.020 | This question of should you look at yourself
00:19:54.940 | for representing yourself, well, it's
00:19:57.020 | going to be encoded by the matrices q and k.
00:20:00.780 | If I didn't have q and k in there,
00:20:02.940 | if those were the identity matrices,
00:20:04.940 | if q is identity, k is identity, then this
00:20:07.460 | would be sort of dot product with yourself, which is going
00:20:09.860 | to be high on average, like you're pointing
00:20:12.180 | in the same direction as yourself.
00:20:13.860 | But it could be that qxi and kxi might
00:20:18.460 | be sort of arbitrarily different from each other,
00:20:21.060 | because q could be the identity, and k
00:20:24.140 | could map you to the negative of yourself,
00:20:27.060 | for example, so that you don't look at yourself.
00:20:29.060 | So this is all learned in practice.
00:20:30.940 | So you end up--
00:20:32.380 | it can sort of decide by learning
00:20:35.860 | whether you should be looking at yourself or not.
00:20:38.180 | And that's some of the flexibility
00:20:39.580 | that parametrizing it as q and k gives you
00:20:42.820 | that wouldn't be there if I just used xis everywhere
00:20:46.180 | in this equation.
00:20:49.820 | I'm going to try to move on, I'm afraid,
00:20:51.860 | because there's a lot to get on.
00:20:53.340 | But we'll keep talking about self-attention.
00:20:55.540 | And so as more questions come up,
00:20:57.660 | I can also potentially return back.
00:21:01.460 | OK, so this is our basic building block.
00:21:05.140 | But there are a bunch of barriers
00:21:06.860 | to using it as a replacement for LSTMs.
00:21:10.260 | And so what we're going to do for this portion of the lecture
00:21:13.020 | is talk about the minimal components
00:21:14.740 | that we need in order to use self-attention as sort
00:21:18.180 | of this very fundamental building block.
00:21:21.540 | So we can't use it as it stands as I've presented it,
00:21:25.260 | because there are a couple of things
00:21:26.740 | that we need to sort of solve or fix.
00:21:29.340 | One of them is that there's no notion of sequence order
00:21:32.180 | in self-attention.
00:21:33.500 | So what does this mean?
00:21:37.580 | If I have a sentence like--
00:21:40.140 | I'm going to move over here to the whiteboard briefly,
00:21:42.380 | and hopefully I'll write quite large.
00:21:46.820 | If I have a sentence like, Zuko made his uncle.
00:21:55.700 | And let's say, his uncle made Zuko.
00:22:05.780 | If I were to embed each of these words
00:22:08.700 | using its embedding matrix, the embedding matrix
00:22:10.700 | isn't dependent on the index of the word.
00:22:14.980 | So this is the word index 1, 2, 3, 4,
00:22:18.060 | versus now his is over here, and uncle.
00:22:21.900 | And so when I compute the self-attention--
00:22:23.900 | and there's a lot more on this in the lecture notes that
00:22:26.240 | goes through a full example--
00:22:29.860 | the actual self-attention operation
00:22:32.260 | will give you exactly the same representations
00:22:34.380 | for this sequence, Zuko made his uncle, as for this sequence,
00:22:38.060 | his uncle made Zuko.
00:22:40.020 | And that's bad, because they're sentences
00:22:41.680 | that mean different things.
00:22:43.900 | And so it's this idea that self-attention
00:22:46.820 | is an operation on sets.
00:22:48.580 | You have a set of vectors that you're
00:22:50.780 | going to perform self-attention on,
00:22:52.660 | and nowhere does the exact position of the words
00:22:55.740 | come into play directly.
00:22:59.140 | So we're going to encode the position of words
00:23:02.540 | through the keys, queries, and values that we have.
00:23:06.100 | So consider now representing each sequence index--
00:23:10.060 | our sequences are going from 1 to n--
00:23:12.100 | as a vector.
00:23:13.380 | So don't worry so far about how it's being made,
00:23:17.220 | but you can imagine representing the number 1,
00:23:20.300 | the position 1, the position 2, the position 3,
00:23:23.540 | as a vector in the dimensionality d,
00:23:25.580 | just like we're representing our keys, queries, and values.
00:23:29.260 | And so these are position vectors.
00:23:33.060 | If you were to want to incorporate the information
00:23:37.940 | represented by these positions into our self-attention,
00:23:42.380 | you could just add these vectors, these p i
00:23:45.140 | vectors, to the inputs.
00:23:48.020 | So if I have this xi embedding of a word, which
00:23:53.340 | is the word at position i, but really just represents,
00:23:56.060 | oh, the word zuko is here, now I can say, oh, it's the word
00:23:59.100 | zuko, and it's at position 5, because this vector represents
00:24:03.860 | position 5.
00:24:04.660 | So how do we do this?
00:24:11.260 | And we might only have to do this once.
00:24:12.860 | So we can do it once at the very input to the network,
00:24:16.420 | and then that is sufficient.
00:24:18.260 | We don't have to do it at every layer,
00:24:19.840 | because it knows from the input.
00:24:23.660 | So one way in which people have done this
00:24:26.060 | is look at these sinusoidal position representations.
00:24:29.500 | So this looks a little bit like this, where you have--
00:24:32.100 | so this is a vector p i, which is in dimensionality d.
00:24:35.980 | And each one of the dimensions, you take the value i,
00:24:40.500 | you modify it by some constant, and you pass it
00:24:45.740 | to the sine or cosine function, and you
00:24:47.780 | get these sort of values that vary
00:24:49.900 | according to the period, differing periods depending
00:24:53.260 | on the dimensionality's d.
00:24:54.340 | So I've got this sort of a representation of a matrix,
00:24:57.500 | where d is the vertical dimension,
00:24:59.580 | and then n is the horizontal.
00:25:01.500 | And you can see that there's sort of like, oh, as I walk
00:25:05.580 | along, you see the period of the sine function going up and down,
00:25:08.300 | and each of the dimensions d has a different period.
00:25:11.220 | And so together, you can represent a bunch of different
00:25:13.620 | sort of position indices.
00:25:15.140 | And it gives this intuition that, oh, maybe sort
00:25:20.020 | of the absolute position of a word isn't as important.
00:25:22.780 | You've got the sort of periodicity
00:25:24.220 | of the sines and cosines.
00:25:26.220 | And maybe that allows you to extrapolate to longer sequences.
00:25:29.500 | But in practice, that doesn't work.
00:25:32.140 | But this is sort of like an early notion
00:25:34.220 | that is still sometimes used for how to represent position
00:25:37.140 | in transformers and self-attention networks
00:25:40.700 | in general.
00:25:43.260 | So that's one idea.
00:25:45.180 | You might think it's a little bit complicated,
00:25:48.580 | a little bit unintuitive.
00:25:50.300 | Here's something that feels a little bit more deep learning.
00:25:54.380 | So we're just going to say, oh, I've
00:25:57.460 | got a maximum sequence length of n.
00:25:59.860 | And I'm just going to learn a matrix that's
00:26:01.940 | dimensionality d by n.
00:26:03.700 | And that's going to represent my positions.
00:26:05.500 | And I'm going to learn it as a parameter, just like I
00:26:07.780 | learn every other parameter.
00:26:09.100 | And what do they mean?
00:26:10.020 | Oh, I have no idea.
00:26:10.860 | But it represents position.
00:26:13.100 | So you just sort of add this matrix to the xi's,
00:26:19.420 | your input embeddings.
00:26:22.140 | And it learns to fit to data.
00:26:24.180 | So whatever representation of position
00:26:26.300 | that's linear, sort of index-based that you want,
00:26:30.300 | you can learn.
00:26:31.460 | And the cons are that, well, you definitely now
00:26:33.900 | can't represent anything that's longer than n words long, right?
00:26:37.980 | No sequence longer than n you can handle because, well,
00:26:41.660 | you only learned a matrix of this many positions.
00:26:44.420 | And so in practice, you'll get a model error
00:26:47.660 | if you pass a self-attention model, something longer
00:26:50.500 | than length n.
00:26:51.620 | It will just sort of crash and say, I can't do this.
00:26:56.620 | And so this is sort of what most systems nowadays use.
00:26:59.660 | There are more flexible representations of position,
00:27:02.220 | including a couple in the lecture notes.
00:27:04.940 | You might want to look at the relative linear position,
00:27:08.020 | or words before or after each other,
00:27:09.900 | but not their absolute position.
00:27:11.500 | There's also some sort of representations
00:27:13.300 | that harken back to our dependency syntax.
00:27:16.660 | Because, oh, maybe words that are close in the dependency
00:27:19.100 | parse tree should be the things that are sort of close
00:27:21.620 | in the self-attention operation.
00:27:25.060 | OK, questions?
00:27:28.340 | In practice, do we typically just make n large enough
00:27:32.420 | that we don't run into the issue of having something
00:27:36.500 | that can be input longer than n?
00:27:39.060 | So the question is, in practice, do we just make n long enough
00:27:41.900 | that we don't run into the problem where we're going
00:27:44.100 | to look at a text longer than n?
00:27:46.660 | No, in practice, it's actually quite a problem.
00:27:49.420 | Even today, even in the largest, biggest language models,
00:27:52.540 | and can I fit this prompt into chat GPT or whatever?
00:27:58.020 | It's a thing that you might see on Twitter.
00:27:59.820 | These continue to be issues.
00:28:01.980 | And part of it is because the self-attention operation--
00:28:04.980 | and we'll get into this later in the lecture--
00:28:06.900 | it's quadratic complexity in the sequence length.
00:28:10.060 | So you're going to spend n squared memory budget in order
00:28:14.420 | to make sequence lengths longer.
00:28:15.740 | So in practice, this might be on a large model, say, 4,000 or so.
00:28:21.420 | n is 4,000, so you can fit 4,000 words, which feels like a lot,
00:28:24.620 | but it's not going to fit a novel.
00:28:26.220 | It's not going to fit a Wikipedia page.
00:28:29.740 | And there are models that do longer sequences, for sure.
00:28:33.740 | And again, we'll talk a bit about it,
00:28:35.280 | but no, this actually is an issue.
00:28:36.700 | How do you know that the p you learned
00:28:43.140 | is the position, which is not any other?
00:28:47.580 | I don't.
00:28:48.220 | It's yours.
00:28:49.100 | Yeah.
00:28:49.700 | So how do you know that the p that you've learned,
00:28:51.580 | this matrix that you've learned, is representing position
00:28:53.960 | as opposed to anything else?
00:28:55.540 | And the reason is the only thing it correlates is position.
00:28:58.700 | So when I see these vectors, I'm adding this p matrix
00:29:02.180 | to my x matrix, the word embeddings.
00:29:05.460 | I'm adding them together.
00:29:06.720 | And the words that show up at each index
00:29:08.420 | will vary depending on what word actually
00:29:10.900 | showed up there in the example.
00:29:12.380 | But the p matrix never differs.
00:29:13.820 | It's always exactly the same at every index.
00:29:16.300 | And so it's the only thing in the data
00:29:18.300 | that it correlates with.
00:29:19.260 | So you're learning it implicitly.
00:29:21.260 | This vector at index 1 is always at index 1 for every example,
00:29:24.540 | for every gradient update.
00:29:26.300 | And nothing else co-occurs like that.
00:29:31.900 | Yeah.
00:29:32.380 | So what do you end up learning?
00:29:33.820 | I don't know.
00:29:34.340 | It's unclear.
00:29:34.900 | But it definitely allows you to know, oh, this word
00:29:37.820 | is with this index.
00:29:39.020 | Yeah.
00:29:42.200 | Yeah.
00:29:42.700 | Just quickly, when you say quadratic constant in space,
00:29:47.100 | is a sequence right now defined as a sequence?
00:29:49.580 | Is there a sequence of words?
00:29:51.180 | Or I'm trying to figure out what unit is using it.
00:29:57.800 | So the question is, when this is quadratic in the sequence,
00:30:00.300 | is that a sequence of words?
00:30:01.420 | Yeah.
00:30:01.920 | Think of it as a sequence of words.
00:30:03.700 | Sometimes there'll be pieces that
00:30:05.120 | are smaller than words, which we'll
00:30:06.620 | go into in the next lecture.
00:30:08.180 | But yeah.
00:30:08.700 | Think of this as a sequence of words,
00:30:10.220 | but not necessarily just for a sentence,
00:30:12.100 | maybe for an entire paragraph, or an entire document,
00:30:15.380 | or something like that.
00:30:17.480 | But the attention is where it is.
00:30:19.820 | Yeah, the attention is based words to words.
00:30:24.260 | Cool.
00:30:25.060 | I'm going to move on.
00:30:28.700 | Right.
00:30:29.200 | So we have another problem.
00:30:30.540 | Another is that, based on the presentation of self-attention
00:30:34.060 | that we've done, there's really no nonlinearities
00:30:36.900 | for deep learning magic.
00:30:38.940 | We're just computing weighted averages of stuff.
00:30:43.420 | So if I apply self-attention, and then apply self-attention
00:30:47.820 | again, and then again, and again, and again,
00:30:50.660 | you should look at the next lecture notes
00:30:52.820 | if you're interested in this.
00:30:53.540 | It's actually quite cool.
00:30:54.580 | But what you end up doing is you're just
00:30:56.280 | re-averaging value vectors together.
00:30:58.240 | So you're computing averages of value vectors,
00:31:00.940 | and it ends up looking like one big self-attention.
00:31:03.700 | But there's an easy fix to this if you
00:31:05.400 | want the traditional deep learning magic.
00:31:07.980 | And you can just add a feed-forward network
00:31:10.180 | to post-process each output vector.
00:31:11.940 | So I've got a word here.
00:31:13.500 | That's the output of self-attention.
00:31:15.460 | And I'm going to pass it through--
00:31:17.460 | in this case, I'm calling it a multilayer perceptron MLP.
00:31:20.380 | So this is a vector in Rd that's going to be--
00:31:23.820 | and it's taking in as input a vector in Rd.
00:31:26.420 | And you do the usual multilayer perceptron thing,
00:31:30.020 | where you have the output, and you multiply it by a matrix,
00:31:32.460 | pass it through a nonlinearity, multiply it by another matrix.
00:31:36.020 | And so what this looks like in self-attention
00:31:38.300 | is that I've got this sentence, the chef who--
00:31:40.580 | da, da, da, da, da-- food.
00:31:42.140 | And I've got my embeddings for it.
00:31:44.040 | I pass it through this whole big self-attention block, which
00:31:46.780 | looks at the whole sequence and incorporates
00:31:49.140 | context and all that.
00:31:50.540 | And then I pass each one individually
00:31:52.660 | through a feed-forward layer.
00:31:55.300 | So this embedding, that's the output of the self-attention
00:31:58.420 | for the word "the," is passed independently
00:32:01.060 | through a multilayer perceptron here.
00:32:03.540 | And you can think of it as combining together or processing
00:32:09.540 | the result of attention.
00:32:11.600 | So there's a number of reasons why we do this.
00:32:14.360 | One of them also is that you can actually
00:32:16.020 | stack a ton of computation into these feed-forward networks
00:32:20.120 | very, very efficiently, very parallelizable,
00:32:22.600 | very good for GPUs.
00:32:23.880 | But this is what's done in practice.
00:32:25.600 | So you do self-attention, and then you
00:32:27.280 | can pass it through this position-wise feed-forward
00:32:31.320 | layer.
00:32:31.760 | Every word is processed independently
00:32:34.000 | by this feed-forward network to process the result.
00:32:40.360 | So that's adding our classical deep learning nonlinearities
00:32:43.740 | for self-attention.
00:32:45.980 | And that's an easy fix for this no nonlinearities
00:32:49.260 | problem in self-attention.
00:32:50.920 | And then we have a last issue before we
00:32:52.860 | have our final minimal self-attention building block
00:32:56.300 | with which we can replace RNNs.
00:32:59.740 | And that's that-- well, when I've
00:33:02.100 | been writing out all of these examples of self-attention,
00:33:04.940 | you can look at the entire sequence.
00:33:07.700 | And in practice, for some tasks, such as machine translation
00:33:12.400 | or language modeling, whenever you
00:33:13.940 | want to define a probability distribution over a sequence,
00:33:16.700 | you can't cheat and look at the future.
00:33:18.820 | So at every time step, I could define the set
00:33:24.860 | of keys and queries and values to only include past words.
00:33:29.260 | But this is inefficient.
00:33:31.140 | Bear with me.
00:33:31.780 | It's inefficient because you can't parallelize it so well.
00:33:34.660 | So instead, we compute the entire n by n matrix,
00:33:37.940 | just like I showed in the slide discussing self-attention.
00:33:41.220 | And then I mask out words in the future.
00:33:42.940 | So for this score, eij--
00:33:45.500 | and I computed eij for all n by n pairs of words--
00:33:49.900 | is equal to whatever it was before if the word that you're
00:33:54.820 | looking at, index j, is an index that is less than or equal to
00:33:59.380 | where you are, index i.
00:34:01.740 | And it's equal to negative infinity-ish otherwise,
00:34:04.820 | if it's in the future.
00:34:06.220 | And when you softmax the eij, negative infinity
00:34:08.640 | gets mapped to 0.
00:34:11.220 | So now my attention is weighted 0.
00:34:13.980 | My weighted average is 0 on the future.
00:34:16.500 | So I can't look at it.
00:34:18.740 | What does this look like?
00:34:20.020 | So in order to encode these words, the chef who--
00:34:24.140 | maybe the start symbol there--
00:34:27.940 | I can look at these words.
00:34:29.940 | That's all pairs of words.
00:34:31.620 | And then I just gray out--
00:34:32.700 | I negative infinity out the words I can't look at.
00:34:35.860 | So when encoding the start symbol,
00:34:37.300 | I can just look at the start symbol.
00:34:39.260 | When encoding the, I can look at the start symbol and the.
00:34:43.220 | When encoding chef, I can look at start the chef.
00:34:46.220 | But I can't look at who.
00:34:48.960 | And so with this representation of chef that is only looking
00:34:53.780 | at start the chef, I can define a probability distribution
00:34:57.940 | using this vector that allows me to predict who
00:35:01.100 | without having cheated by already looking ahead
00:35:03.180 | and seeing that, well, who is the next word.
00:35:05.660 | Questions?
00:35:11.900 | So it says for using it in decoders.
00:35:15.020 | Do we do this for both the encoding layer
00:35:17.020 | and the decoding layer?
00:35:18.140 | Or for the encoding layer, are we
00:35:19.700 | allowing ourselves to look for--
00:35:21.700 | The question is, it says here that we're
00:35:23.580 | using this in a decoder.
00:35:24.540 | Do we also use it in the encoder?
00:35:26.700 | So this is the distinction between a bidirectional LSTM
00:35:31.060 | and a unidirectional LSTM.
00:35:33.140 | So wherever you don't need this constraint,
00:35:37.100 | you probably don't use it.
00:35:38.180 | So if you're using an encoder on the source
00:35:40.180 | sentence of your machine translation problem,
00:35:42.460 | you probably don't do this masking
00:35:44.380 | because it's probably good to let everything
00:35:46.220 | look at each other.
00:35:47.140 | And then whenever you do need to use it
00:35:48.800 | because you have this autoregressive probability
00:35:51.620 | of word one, probability of two given one, three given two
00:35:54.980 | and one, then you would use this.
00:35:56.340 | So traditionally, yes, in decoders, you will use it.
00:35:58.740 | In encoders, you will not.
00:36:04.380 | My question is a little bit philosophical.
00:36:07.500 | How humans actually generate sentences
00:36:10.780 | by having some notion of the probability of future words
00:36:14.980 | before they say the words that--
00:36:19.020 | or before they choose the words that they are currently
00:36:24.620 | speaking or writing, generating?
00:36:26.940 | Good question.
00:36:27.520 | So the question is, isn't looking ahead a little bit
00:36:30.460 | and predicting or getting an idea of the words
00:36:32.900 | that you might say in the future sort of how humans generate
00:36:35.460 | language instead of the strict constraint of not
00:36:38.660 | seeing it into the future?
00:36:39.940 | Is that what you're--
00:36:41.580 | So right.
00:36:43.180 | Trying to plan ahead to see what I should do
00:36:46.420 | is definitely an interesting idea.
00:36:48.820 | But when I am training the network,
00:36:51.180 | I can't-- if I'm teaching it to try to predict the next word,
00:36:55.460 | and if I give it the answer, it's
00:36:57.100 | not going to learn anything useful.
00:36:59.820 | So in practice, when I'm generating text,
00:37:01.680 | maybe it would be a good idea to make some guesses
00:37:03.980 | far into the future or have a high-level plan or something.
00:37:07.700 | But in training the network, I can't encode that intuition
00:37:11.180 | about how humans build--
00:37:13.580 | like, generate sequences of language
00:37:15.260 | by just giving it the answer of the future
00:37:17.020 | directly, at least, because then it's just too easy.
00:37:19.820 | There's nothing to learn.
00:37:21.900 | Yeah.
00:37:22.380 | But there might be interesting ideas about maybe giving
00:37:24.460 | the network a hint as to what kind of thing
00:37:26.540 | could come next, for example.
00:37:28.220 | But that's out of scope for this.
00:37:29.660 | Yeah.
00:37:31.180 | Yeah, question over here.
00:37:32.220 | So I understand why we want to mask the future for stuff
00:37:35.820 | like language models, but how does it
00:37:37.460 | apply to machine translation?
00:37:39.260 | Like, why would we use it there?
00:37:40.500 | Yeah.
00:37:41.000 | So in machine translation--
00:37:43.500 | I'm going to come over to this board
00:37:46.020 | and hopefully get a better marker.
00:37:49.380 | Nice.
00:37:50.020 | In machine translation, I have a sentence like,
00:37:54.980 | "I like pizza."
00:37:59.500 | And I want to be able to translate it--
00:38:04.820 | "Je me pizza."
00:38:08.820 | Nice.
00:38:09.940 | And so when I'm looking at "I like pizza,"
00:38:14.980 | I get this as the input.
00:38:16.380 | And so I want self-attention without masking,
00:38:22.820 | because I want "I" to look at "like" and "I" to look at
00:38:26.660 | "pizza" and "like" to look at "pizza."
00:38:29.100 | And then when I'm generating this,
00:38:31.100 | if my tokens are like "Je m la pizza,"
00:38:35.180 | I want to, in encoding this word,
00:38:37.740 | I want to be able to look only at myself.
00:38:40.600 | And we'll talk about encoder-decoder architectures
00:38:43.100 | in this later in the lecture.
00:38:46.120 | But I want to be able to look at myself, none of the future,
00:38:48.620 | and all of this.
00:38:50.020 | And so what I'm talking about right now in this masking case
00:38:52.780 | is masking out with negative infinity all of these words.
00:38:59.380 | So that attention score from "Je" to everything else
00:39:02.540 | should be negative infinity.
00:39:05.820 | Yeah.
00:39:06.300 | Does that answer your question?
00:39:07.780 | Great.
00:39:09.380 | OK, let's move ahead.
00:39:11.380 | OK, so that was our last big building block
00:39:16.500 | issue with self-attention.
00:39:17.660 | So this is what I would call--
00:39:19.220 | and this is my personal opinion-- a minimal self-attention
00:39:22.220 | building block.
00:39:22.900 | You have self-attention, the basis of the method.
00:39:25.620 | So that's sort of here in the red.
00:39:29.260 | And maybe we had the inputs to the sequence here.
00:39:31.980 | And then you embed it with that embedding matrix E.
00:39:34.620 | And then you add position embeddings.
00:39:36.780 | And then these three arrows represent
00:39:38.580 | using the key, the value, and the query that's
00:39:42.980 | sort of stylized there.
00:39:44.140 | This is often how you see these diagrams.
00:39:47.300 | And so you pass it to self-attention
00:39:50.420 | with the position representation.
00:39:53.100 | So that specifies the sequence order,
00:39:54.980 | because otherwise you'd have no idea what order the words
00:39:57.460 | showed up in.
00:39:59.260 | You have the nonlinearities in sort of the TLFeedForward
00:40:01.940 | network there to sort of provide that sort of squashing
00:40:05.820 | and sort of deep learning expressivity.
00:40:08.900 | And then you have masking in order
00:40:10.700 | to have parallelizable operations that
00:40:13.500 | don't look at the future.
00:40:15.460 | So this is sort of our minimal architecture.
00:40:18.100 | And then up at the top above here,
00:40:20.180 | so you have this thing-- maybe you repeat this sort of
00:40:22.380 | self-attention and feedforward many times.
00:40:24.460 | So self-attention, feedforward, self-attention, feedforward,
00:40:27.380 | self-attention, feedforward.
00:40:28.940 | That's what I'm calling this block.
00:40:31.140 | And then maybe at the end of it, you predict something.
00:40:33.540 | I don't know.
00:40:33.860 | We haven't really talked about that.
00:40:35.340 | But you have these representations.
00:40:36.940 | And then you predict the next word,
00:40:38.500 | or you predict the sentiment, or you predict whatever.
00:40:40.740 | So this is like a self-attention architecture.
00:40:44.640 | OK, we're going to move on to the transformer next.
00:40:46.760 | So if there are any questions-- yeah?
00:40:48.260 | [INAUDIBLE]
00:40:52.180 | Other way around.
00:40:53.380 | We will use masking for decoders,
00:40:56.140 | where I want to decode out a sequence where I have
00:41:00.420 | an informational constraint, where
00:41:02.380 | to represent this word properly, I cannot have
00:41:05.140 | the information of the future.
00:41:06.460 | [INAUDIBLE]
00:41:08.580 | Yeah, OK.
00:41:09.080 | OK, great.
00:41:16.220 | So now let's talk about the transformer.
00:41:17.860 | So what I've pitched to you is what
00:41:20.660 | I call a minimal self-attention architecture.
00:41:24.740 | And I quite like pitching it that way.
00:41:28.980 | But really, no one uses the architecture
00:41:30.820 | that was just up on the slide, the previous slide.
00:41:34.100 | It doesn't work quite as well as it could.
00:41:36.060 | And there's a bunch of important details
00:41:38.660 | that we'll talk about now that goes into the transformer.
00:41:41.700 | What I would hope, though, to have you take away from that
00:41:46.220 | is that the transformer architecture,
00:41:48.020 | as I'll present it now, is not necessarily
00:41:51.420 | the end point of our search for better and better ways
00:41:54.580 | of representing language, even though it's now ubiquitous
00:41:57.940 | and has been for a couple of years.
00:42:00.060 | So think about these sort of ideas
00:42:01.500 | of the problems of using self-attention
00:42:05.020 | and maybe ways of fixing some of the issues with transformers.
00:42:08.940 | OK, so a transformer decoder is how we'll build systems
00:42:13.500 | like language models.
00:42:14.460 | And so we've discussed this.
00:42:15.740 | It's like our decoder with our self-attention-only sort
00:42:18.940 | of minimal architecture.
00:42:20.300 | It's got a couple of extra components,
00:42:21.940 | some of which I've grayed out here,
00:42:23.400 | that we'll go over one by one.
00:42:25.220 | The first that's actually different
00:42:28.900 | is that we'll replace our self-attention
00:42:31.660 | with masking with masked multi-head self-attention.
00:42:35.820 | This ends up being crucial.
00:42:36.940 | It's probably the most important distinction
00:42:39.820 | between the transformer and this sort of minimal architecture
00:42:42.360 | that I've presented.
00:42:43.820 | So let's come back to our toy example of attention,
00:42:46.740 | where we've been trying to represent the word learned
00:42:49.060 | in the context of the sequence, I went to Stanford CS224N
00:42:52.660 | and learned.
00:42:54.740 | And I was sort of giving these teal bars to say,
00:42:57.580 | oh, maybe intuitively you look at various things
00:43:01.020 | to build up your representation of learned.
00:43:04.340 | But really, there are varying ways
00:43:06.540 | in which I want to look back at the sequence
00:43:09.660 | to see varying sort of aspects of information
00:43:13.660 | that I want to incorporate into my representation.
00:43:16.300 | So maybe in this way, I sort of want
00:43:19.340 | to look at Stanford CS224N, because, oh, it's like entities.
00:43:25.900 | You learn different stuff at Stanford CS224N
00:43:28.300 | than you do at other courses or other universities or whatever.
00:43:31.940 | And so maybe I want to look here for this reason.
00:43:35.180 | And maybe in another sense, I actually
00:43:37.700 | want to look at the word learned.
00:43:39.380 | And I want to look at I. I went and learned.
00:43:43.380 | And I want to see maybe syntactically relevant words.
00:43:46.300 | It's very different reasons for which
00:43:47.940 | I might want to look at different things
00:43:49.640 | in the sequence.
00:43:50.900 | And so trying to average it all out
00:43:52.740 | with a single operation of self-attention
00:43:55.020 | ends up being maybe somewhat too difficult in a way
00:43:58.340 | that will make precise in assignment 5.
00:44:00.180 | Nice, we'll do a little bit more math.
00:44:03.940 | OK, so any questions about this intuition?
00:44:11.580 | [INAUDIBLE]
00:44:14.340 | Yeah, so it should be an application of attention
00:44:17.140 | just as I've presented it.
00:44:19.140 | So one independent define the keys, define the queries,
00:44:22.300 | define the values.
00:44:23.020 | I'll define it more precisely here.
00:44:24.940 | But think of it as I do attention once,
00:44:27.460 | and then I do it again with different parameters,
00:44:31.620 | being able to look at different things, et cetera.
00:44:33.740 | [INAUDIBLE]
00:44:36.220 | How do we ensure that they look at different things?
00:44:38.660 | We do not-- OK, so the question is,
00:44:40.060 | if we have two separate sets of weights trying to learn,
00:44:41.980 | say, to do this and to do that, how do we ensure that they
00:44:44.700 | learn different things?
00:44:45.940 | We do not ensure that they learn different things.
00:44:49.100 | And in practice, they do, although not perfectly.
00:44:52.780 | So it ends up being the case that you have some redundancy,
00:44:55.660 | and you can cut out some of these.
00:44:57.420 | But that's out of scope for this.
00:44:59.140 | But we hope, just like we hope that different dimensions
00:45:02.300 | in our feedforward layers will learn different things
00:45:04.500 | because of lack of symmetry and whatever,
00:45:06.740 | that we hope that the heads will start to specialize.
00:45:09.380 | And that will mean they'll specialize even more.
00:45:11.460 | And yeah.
00:45:16.340 | All right, so in order to discuss multi-head self
00:45:18.620 | attention well, we really need to talk about the matrices,
00:45:22.100 | how we're going to implement this in GPUs efficiently.
00:45:25.220 | We're going to talk about the sequence-stacked form
00:45:27.780 | of attention.
00:45:29.260 | So we've been talking about each word sort of individually
00:45:31.660 | as a vector in dimensionality D. But really, we're
00:45:35.140 | going to be working on these as big matrices that are stacked.
00:45:38.900 | So I take all of my word embeddings, x1 to xn,
00:45:42.340 | and I stack them together.
00:45:43.900 | And now I have a big matrix that is in dimensionality Rn by D.
00:45:49.860 | OK, and now with my matrices K, Q, and V,
00:45:55.180 | I can just multiply them on this side of x.
00:45:58.220 | So x is Rn by D. K is Rd by D. So n by D times d by D
00:46:04.300 | gives you n by D again.
00:46:07.060 | So I can just compute a big matrix multiply
00:46:10.460 | on my whole sequence to multiply each one of the words
00:46:13.340 | of my key query and value matrices very efficiently.
00:46:17.100 | So this is sort of this vectorization idea.
00:46:18.860 | I don't want to for loop over the sequence.
00:46:20.780 | I represent the sequence as a big matrix,
00:46:23.540 | and I just do one big matrix multiply.
00:46:27.580 | Then the output is defined as this sort
00:46:29.440 | of inscrutable bit of math, which
00:46:31.660 | I'm going to go over visually.
00:46:35.260 | So first, we're going to take the key query dot
00:46:37.820 | products in one matrix.
00:46:39.460 | So we've got xq, which is Rn by D.
00:46:47.100 | And I've got xk transpose, which is Rd by n.
00:46:50.660 | So n by D, d by n.
00:46:53.140 | This is computing all of the eij's,
00:46:55.400 | these scores for self-attention.
00:46:58.100 | So this is all pairs of attention scores
00:47:00.620 | computed in one big matrix multiply.
00:47:04.580 | So this is this big matrix here.
00:47:06.180 | Next, I use the softmax.
00:47:09.620 | So I softmax this over the second dimension,
00:47:13.860 | the second n dimension.
00:47:15.980 | And I get my sort of normalized scores,
00:47:19.060 | and then I multiply with xv.
00:47:20.700 | So this is an n by n matrix multiplied by an n by D matrix.
00:47:26.180 | And what do I get?
00:47:26.900 | Well, this is just doing the weighted average.
00:47:29.540 | So this is one big weighted average.
00:47:32.340 | Big weighted average contribution
00:47:34.220 | on the whole matrix, giving me my whole self-attention output
00:47:37.380 | in Rn by D. So I've just restated identically
00:47:41.620 | the self-attention operations, but computed
00:47:43.980 | in terms of matrices so that you could do this efficiently
00:47:47.220 | on a GPU.
00:47:52.060 | So multi-headed attention.
00:47:54.180 | This is going to give us--
00:47:55.500 | and it's going to be important to compute this
00:47:57.460 | in terms of the matrices, which we'll see.
00:47:59.580 | This is going to give us the ability
00:48:01.000 | to look in multiple places at once for different reasons.
00:48:04.220 | So for self-attention looks where this dot product here
00:48:09.060 | is high.
00:48:10.580 | This xi, the Q matrix, the key matrix.
00:48:15.380 | But maybe we want to look in different places
00:48:18.020 | for different reasons.
00:48:19.300 | So we actually define multiple query, key, and value matrices.
00:48:24.540 | So I'm going to have a bunch of heads.
00:48:26.740 | I'm going to have h self-attention heads.
00:48:30.260 | And for each head, I'm going to define an independent query,
00:48:33.060 | key, and value matrix.
00:48:34.860 | And I'm going to say that its shape is
00:48:37.220 | going to map from the model dimensionality
00:48:39.340 | to the model dimensionality over h.
00:48:41.300 | So each one of these is doing projection down
00:48:43.260 | to a lower dimensional space.
00:48:45.340 | This is going to be for computational efficiency.
00:48:47.700 | And I'll just apply self-attention
00:48:51.260 | independently for each output.
00:48:53.300 | So this equation here is identical to the one
00:48:56.100 | we saw for single-headed self-attention,
00:48:58.540 | except I've got these sort of l indices everywhere.
00:49:02.860 | So I've got this lower dimensional thing.
00:49:04.540 | I'm mapping to a lower dimensional space.
00:49:06.540 | And then I do have my lower dimensional value vector
00:49:09.100 | there.
00:49:09.700 | So my output is an rd by h.
00:49:11.900 | But really, you're doing exactly the same kind of operation.
00:49:14.860 | I'm just doing it h different times.
00:49:17.620 | And then you combine the outputs.
00:49:19.700 | So I've done sort of look in different places
00:49:22.140 | with the different key, query, and value matrices.
00:49:24.900 | And then I get each of their outputs.
00:49:28.340 | And then I concatenate them together.
00:49:31.020 | So each one is dimensionality d by h.
00:49:33.540 | And I concatenate them together and then sort of mix them
00:49:36.140 | together with the final linear transformation.
00:49:39.820 | And so each head gets to look at different things
00:49:43.040 | and construct their value vectors differently.
00:49:45.420 | And then I sort of combine the result all together at once.
00:49:49.500 | Let's go through this visually, because it's
00:49:51.660 | at least helpful for me.
00:49:55.820 | It's actually not more costly to do this, really,
00:49:58.540 | than it is to compute a single head of self-attention.
00:50:01.060 | And we'll see through the pictures.
00:50:02.520 | So in single-headed self-attention,
00:50:07.940 | we computed xq.
00:50:09.460 | And in multi-headed self-attention,
00:50:11.100 | we'll also compute xq the same way.
00:50:13.860 | So xq is rn by d.
00:50:16.260 | And then we can reshape it into rn, that's sequence length,
00:50:24.500 | times the number of heads, times the model dimensionality
00:50:29.420 | over the number of heads.
00:50:30.500 | So I've just reshaped it to say, now I've
00:50:32.580 | got a big three-axis tensor.
00:50:35.540 | The first axis is the sequence length.
00:50:37.820 | The second one is the number of heads.
00:50:39.420 | The third is this reduced model dimensionality.
00:50:42.020 | And that costs nothing.
00:50:43.780 | And do the same thing for x and v.
00:50:45.940 | And then I transpose so that I've got the head
00:50:48.460 | axis as the first axis.
00:50:50.740 | And now I can compute all my other operations
00:50:53.460 | with the head axis, kind of like a batch.
00:50:58.020 | So what does this look like in practice?
00:51:01.780 | Instead of having one big xq matrix that's
00:51:05.100 | model dimensionality d, I've got, in this case,
00:51:08.180 | three xq matrices of model dimensionality d by 3, d by 3,
00:51:12.780 | d by 3.
00:51:13.860 | Same thing with the key matrix here.
00:51:16.500 | So everything looks almost identical.
00:51:18.540 | It's just the reshaping of the tensors.
00:51:21.000 | And now, at the output of this, I've
00:51:23.340 | got three sets of attention scores
00:51:26.860 | just by doing this reshape.
00:51:29.060 | And the cost is that, well, each of my attention heads
00:51:33.420 | has only a d by h vector to work with instead
00:51:35.900 | of a d-dimensional vector to work with.
00:51:38.020 | So I get the output.
00:51:38.860 | I get these three sets of pairs of scores.
00:51:43.340 | I compute the softmax independently
00:51:45.660 | for each of the three.
00:51:47.100 | And then I have three value matrices there as well,
00:51:50.680 | each of them lower dimensional.
00:51:52.660 | And then finally, I get my three different output vectors.
00:51:56.700 | And I have a final linear transformation
00:51:58.460 | to mush them together.
00:52:01.020 | And I get an output.
00:52:02.660 | And in summary, what this allows you to do
00:52:04.620 | is exactly what I gave in the toy example, which
00:52:08.220 | was I can have each of these heads
00:52:10.240 | look at different parts of a sequence for different reasons.
00:52:12.820 | So this is at a given block, right?
00:52:21.020 | All of these attention heads are for a given transformer block.
00:52:23.780 | A next block could also have three attention heads.
00:52:26.980 | The question is, are all of these for a given block?
00:52:30.340 | And we'll talk about a block again.
00:52:31.800 | But this block was this sort of pair of self-attention
00:52:35.060 | and feed-forward network.
00:52:36.300 | So you do self-attention, feed-forward.
00:52:37.920 | That's one block.
00:52:38.780 | Another block is another self-attention,
00:52:40.120 | another feed-forward.
00:52:41.340 | And the question is, are the parameters shared
00:52:43.260 | between the blocks or not?
00:52:44.900 | Generally, they are not shared.
00:52:46.180 | You'll have independent parameters at every block,
00:52:48.660 | although there are some exceptions.
00:52:52.700 | Voting on that, is it typically the case
00:52:55.380 | that you have the same number of heads at each block?
00:52:58.820 | Or do you vary the number of heads across blocks?
00:53:01.380 | You have this-- you definitely could vary it.
00:53:04.020 | People haven't found reason to vary--
00:53:05.540 | so the question is, do you have different numbers of heads
00:53:07.960 | across the different blocks?
00:53:09.340 | Or do you have the same number of heads across all blocks?
00:53:12.780 | The simplest thing is to just have
00:53:14.540 | it be the same everywhere, which is what people have done.
00:53:16.940 | I haven't yet found a good reason to vary it,
00:53:19.100 | but it could be interesting.
00:53:21.860 | It's definitely the case that after training these networks,
00:53:25.540 | you can actually just totally zero out,
00:53:27.900 | remove some of the attention heads.
00:53:30.900 | And I'd be curious to know if you could remove more or less,
00:53:35.700 | depending on the layer index, which might then say,
00:53:39.380 | oh, we should just have fewer.
00:53:40.580 | But again, it's not actually more expensive
00:53:42.420 | to have a bunch.
00:53:43.700 | So people tend to instead set the number of heads
00:53:46.740 | to be roughly so that you have a reasonable number of dimensions
00:53:51.180 | per head, given the total model dimensionality d that you want.
00:53:55.460 | So for example, I might want at least 64 dimensions per head,
00:54:00.260 | which if d is 128, that tells me how many heads
00:54:04.340 | I'm going to have, roughly.
00:54:06.020 | So people tend to scale the number of heads
00:54:07.860 | up with the model dimensionality.
00:54:09.620 | Yeah, with that xq, by slicing it into different columns,
00:54:15.820 | you're reducing the rank of the final matrix, right?
00:54:19.020 | Yeah.
00:54:19.820 | But that doesn't really have any effect on the results.
00:54:23.180 | So the question is, by having these reduced xq and xk
00:54:29.300 | matrices, this is a very low rank approximation.
00:54:32.940 | This little sliver and this little sliver
00:54:35.340 | defining this whole big matrix, it's very low rank.
00:54:38.020 | Is that not bad?
00:54:39.780 | In practice, no.
00:54:40.940 | I mean, again, it's the reason why
00:54:42.700 | we limit the number of heads depending on the model
00:54:45.820 | dimensionality, because you want intuitively at least
00:54:49.620 | some number of dimensions.
00:54:51.220 | So 64 is sometimes done, 128, something like that.
00:54:55.820 | But if you're not giving each head too much to do,
00:54:58.300 | and it's got sort of a simple job, you've got a lot of heads,
00:55:01.020 | it ends up sort of being OK.
00:55:04.260 | All we really know is that empirically,
00:55:05.980 | it's way better to have more heads than one.
00:55:14.140 | I'm wondering, have there been studies
00:55:16.300 | to see if information in one of the sets of the attention
00:55:22.140 | scores, like information that one of them
00:55:25.100 | learns is consistent and related to each other,
00:55:29.540 | or how are they related?
00:55:32.380 | So the question is, have there been studies to see
00:55:34.580 | if there's consistent information encoded
00:55:37.140 | by the attention heads?
00:55:38.780 | And yes.
00:55:40.740 | Actually, there's been quite a lot of study
00:55:42.580 | and interpretability and analysis of these models
00:55:44.980 | to try to figure out what roles, what sort of mechanistic roles
00:55:48.420 | each of these heads takes on.
00:55:50.180 | And there's quite a bit of exciting results
00:55:52.740 | there around some attention heads
00:55:55.140 | learning to pick out the syntactic dependencies,
00:55:59.820 | or maybe doing a global averaging of context.
00:56:03.780 | The question is quite nuanced, though,
00:56:05.420 | because in a deep network, it's unclear--
00:56:07.780 | and we should talk about this more offline--
00:56:09.620 | it's unclear if you look at a word 10 layers deep
00:56:12.580 | in a network what you're really looking at,
00:56:14.900 | because it's already incorporated context
00:56:17.060 | from everyone else, and it's a little bit unclear.
00:56:20.140 | Active area of research.
00:56:21.300 | But I think I should move on now to keep
00:56:25.100 | discussing transformers.
00:56:26.780 | But yeah, if you want to talk more about it, I'm happy to.
00:56:31.420 | So another sort of hack that I'm going to toss in here--
00:56:34.580 | I mean, maybe they wouldn't call it hack,
00:56:36.300 | but it's a nice little method to improve things.
00:56:39.620 | It's called scaled dot product attention.
00:56:42.060 | So one of the issues with this sort of key query value
00:56:45.500 | self-attention is that when the model dimensionality becomes
00:56:47.940 | large, the dot products between vectors, even random vectors,
00:56:51.700 | tend to become large.
00:56:55.060 | And when that happens, the inputs to the softmax function
00:56:58.060 | can be very large, making the gradient small.
00:57:01.380 | So intuitively, if you have two random vectors
00:57:03.300 | and model dimensionality d, and you just dot product them
00:57:06.540 | together, as d grows, their dot product
00:57:09.140 | grows in expectation to be very large.
00:57:11.540 | And so you sort of want to start out
00:57:13.940 | with everyone's attention being very uniform, very flat,
00:57:16.900 | sort of look everywhere.
00:57:18.780 | But if some dot products are very large,
00:57:20.620 | then learning will be inhibited.
00:57:23.660 | And so what you end up doing is you just sort of--
00:57:26.300 | for each of your heads, you just sort of divide all the scores
00:57:29.700 | by this constant that's determined
00:57:31.380 | by the model dimensionality.
00:57:33.060 | So as the vectors grow very large,
00:57:35.660 | their dot products don't, at least at initialization time.
00:57:40.500 | So this is sort of like a nice little important,
00:57:45.020 | but maybe not--
00:57:46.020 | yeah, it's important to know.
00:57:52.260 | And so that's called scaled dot product attention.
00:57:55.500 | From here on out, we'll just assume that we do this.
00:57:58.340 | It's quite easy to implement.
00:57:59.540 | You just do a little division in all of your computations.
00:58:05.060 | OK, so now in the transformer decoder,
00:58:07.260 | we've got a couple of other things
00:58:08.780 | that I have unfaded out here.
00:58:12.660 | We have two big optimization tricks, or optimization
00:58:15.220 | methods, I should say, really, because these
00:58:17.060 | are quite important, that end up being very important.
00:58:20.100 | We've got residual connections and layer normalization.
00:58:22.940 | And in transformer diagrams that you see sort of around the web,
00:58:26.980 | they're often written together as this add and norm box.
00:58:32.380 | And in practice, in the transformer decoder,
00:58:34.540 | I'm going to apply mask multi-head attention
00:58:38.620 | and then do this sort of optimization add a norm.
00:58:41.340 | Then I'll do a feed forward application
00:58:43.340 | and then add a norm.
00:58:44.660 | So this is quite important.
00:58:47.660 | So let's go over these two individual components.
00:58:51.820 | The first is residual connections.
00:58:53.300 | I mean, I think we've talked about residual connections
00:58:55.460 | before, right?
00:58:56.100 | So it's worth doing it again.
00:58:58.140 | But it's really a good trick to help models train better.
00:59:01.820 | So just to recap, we're going to take--
00:59:04.660 | instead of having this sort of-- you have a layer, layer i
00:59:08.220 | minus 1, and you pass it through a thing.
00:59:10.540 | Maybe it's self-attention.
00:59:11.620 | Maybe it's a feed forward network.
00:59:13.060 | Now you've got layer i.
00:59:16.380 | I'm going to add the result of layer i to its input here.
00:59:23.060 | So now I'm saying I'm just going to compute the layer,
00:59:25.320 | and I'm going to add in the input to the layer
00:59:27.740 | so that I only have to learn the residual
00:59:30.780 | from the previous layer.
00:59:32.020 | So I've got this sort of connection here.
00:59:33.720 | It's often written as this.
00:59:34.860 | It's sort of like, boop, connection.
00:59:38.940 | It goes around.
00:59:39.860 | And you should think that the gradient is just
00:59:41.740 | really great through the residual connection.
00:59:43.820 | Like, ah, if I've got vanishing or exploding gradient--
00:59:47.700 | vanishing gradients through this layer,
00:59:49.500 | well, I can at least learn everything behind it
00:59:51.740 | because I've got this residual connection where
00:59:54.160 | the gradient is 1 because it's the identity.
00:59:58.060 | This is really nice.
00:59:59.180 | And it also maybe is like a--
01:00:01.460 | at least at initialization, everything
01:00:03.980 | looks a little bit like the identity function now, right?
01:00:06.660 | Because if the contribution of the layer
01:00:08.920 | is somewhat small because all of your weights are small,
01:00:11.680 | and I have the addition from the input,
01:00:13.980 | maybe the whole thing looks a little bit
01:00:15.620 | like the identity, which might be a good sort of place
01:00:18.380 | to start.
01:00:20.340 | And there are really nice visualizations.
01:00:22.100 | I just love this visualization.
01:00:24.800 | So this is your lost landscape.
01:00:26.420 | So you're gradient descent, and you're
01:00:28.000 | trying to traverse the mountains of the lost landscape.
01:00:30.860 | This is like the parameter space.
01:00:32.580 | And down is better in your loss function.
01:00:34.620 | And it's really hard.
01:00:35.500 | So you get stuck in some local optima,
01:00:38.060 | and you can't sort of find your way to get out.
01:00:41.180 | And then this is with residual connections.
01:00:43.140 | I mean, come on.
01:00:44.300 | You just sort of walk down.
01:00:47.060 | I mean, that's not actually, I guess,
01:00:48.860 | really how it works all the time.
01:00:50.380 | But I really love this.
01:00:52.140 | It's great.
01:00:58.540 | So yeah, we've seen residual connections.
01:01:00.260 | We should move on to layer normalization.
01:01:02.860 | So layer norm is another thing to help your model train
01:01:06.500 | faster.
01:01:08.380 | And the intuitions around layer normalization
01:01:14.760 | and sort of the empiricism of it working very well
01:01:17.040 | maybe aren't perfectly, let's say, connected.
01:01:21.020 | But you should imagine, I suppose,
01:01:25.700 | that we want to say this variation within each layer.
01:01:29.860 | Things can get very big.
01:01:31.180 | Things can get very small.
01:01:33.140 | That's not actually informative because of variations
01:01:36.700 | between maybe the gradients.
01:01:39.940 | Or I've got sort of weird things going on in my layers
01:01:43.860 | that I can't totally control.
01:01:45.140 | I haven't been able to sort of make everything behave sort
01:01:47.740 | of nicely where everything stays roughly the same norm.
01:01:50.460 | Maybe some things explode.
01:01:51.660 | Maybe some things shrink.
01:01:54.660 | And I want to cut down on sort of uninformative variation
01:01:59.580 | between layers.
01:02:00.940 | So I'm going to let x and rd be an individual word
01:02:03.740 | vector in the model.
01:02:05.380 | So this is like I have a single index, one vector.
01:02:09.100 | And what I'm going to try to do is just normalize it.
01:02:12.660 | Normalize it in the sense of it's got a bunch of variation.
01:02:15.540 | And I'm going to cut out on everything.
01:02:17.700 | I'm going to normalize it to unit mean and standard
01:02:20.340 | deviation.
01:02:21.020 | So I'm going to estimate the mean here across--
01:02:26.700 | so for all of the dimensions in the vector,
01:02:30.180 | so j equals 1 to the model dimensionality,
01:02:32.660 | I'm going to sum up the value.
01:02:33.900 | So I've got this one big word vector.
01:02:35.860 | And I sum up all the values.
01:02:37.540 | Division by d here, that's the mean.
01:02:40.220 | I'm going to have my estimate of the standard deviation.
01:02:43.540 | Again, these should say estimates.
01:02:45.020 | This is my simple estimate of the standard deviation
01:02:47.300 | or the values within this one vector.
01:02:50.500 | And I'm just going to--
01:02:53.700 | and then possibly, I guess I can have learned parameters
01:02:58.060 | to try to scale back out in terms of multiplicatively
01:03:02.780 | and additively here.
01:03:04.500 | That's optional.
01:03:05.540 | We're going to compute this standardization.
01:03:08.380 | I'm going to take my vector x, subtract out the mean,
01:03:11.340 | divide by the standard deviation,
01:03:12.780 | plus this epsilon constant.
01:03:14.820 | If there's not a lot of variation,
01:03:16.300 | I don't want things to explode.
01:03:17.900 | So I'm going to have this epsilon there that's close to 0.
01:03:21.700 | So this part here, x minus mu over square root sigma
01:03:25.500 | plus epsilon, is saying take all the variation
01:03:28.540 | and normalize it to unit mean and standard deviation.
01:03:32.600 | And then maybe I want to scale it, stretch it back out,
01:03:37.080 | and then maybe add an offset beta that I've learned.
01:03:40.860 | Although in practice, actually, this part-- and discuss this
01:03:43.300 | in the lecture notes--
01:03:44.580 | in practice, this part maybe isn't actually that important.
01:03:47.940 | But so layer normalization, yeah, you're sort of--
01:03:51.220 | you can think of this as when I get the output of layer
01:03:54.000 | normalization, it's going to be--
01:03:55.940 | sort of look nice and look similar to the next layer
01:03:58.940 | independent of what's gone on because it's
01:04:00.940 | going to be unit mean and standard deviation.
01:04:02.780 | So maybe that makes for a better thing
01:04:04.660 | to learn off of for the next layer.
01:04:06.260 | OK, any questions for residual or layer norm?
01:04:13.220 | What would it mean to subtract the scalar mu from the vector x?
01:04:17.340 | Yeah, it's a good question.
01:04:18.780 | When I subtract the scalar mu from the vector x,
01:04:21.580 | I broadcast mu to dimensionality d and remove mu from all d.
01:04:27.980 | Yeah, good point.
01:04:29.300 | Thank you.
01:04:29.900 | That was unclear.
01:04:30.580 | Yeah.
01:04:31.080 | In the fourth bullet, maybe I'm confused.
01:04:37.420 | Is it divided?
01:04:38.300 | Should it be divided by d or from mean?
01:04:42.500 | Sorry, can you repeat that?
01:04:43.620 | In the fourth bullet point when you're calculating the mean,
01:04:47.220 | is it divided by d or is it--
01:04:49.660 | or maybe I'm just confused.
01:04:51.180 | I think it is divided by d.
01:04:52.340 | Yeah.
01:04:55.180 | These are-- so this is the average deviation
01:04:57.300 | from the mean of all of the-- yeah.
01:05:00.760 | So if you have five words in a sentence by their norm,
01:05:04.700 | do you normalize based on the statistics of these five words
01:05:09.460 | or do you want one word by one?
01:05:11.700 | So the question is, if I have five words in the sequence,
01:05:14.500 | do I normalize by aggregating the statistics to estimate mu
01:05:18.820 | and sigma across all the five words,
01:05:21.060 | share their statistics, or do it independently for each word?
01:05:24.140 | This is a great question, which I
01:05:25.700 | think in all the papers that discuss transformers
01:05:28.260 | is under specified.
01:05:30.140 | You do not share across the five words, which
01:05:33.060 | is somewhat confusing to me.
01:05:35.500 | So each of the five words is done completely independently.
01:05:39.180 | You could have shared across the five words
01:05:41.380 | and said that my estimate of the statistics
01:05:43.300 | are just based on all five, but you do not.
01:05:49.740 | I can't pretend I understand totally why.
01:05:51.380 | [INAUDIBLE]
01:05:51.880 | [INAUDIBLE]
01:05:54.360 | For example, per batch or per output of the same position?
01:06:01.400 | So similar question.
01:06:02.840 | The question is, if you have a batch of sequences,
01:06:06.760 | so just like we were doing batch-based training,
01:06:10.040 | do you for a single word--
01:06:11.840 | now, we don't share across the sequence index
01:06:13.760 | for sharing the statistics, but do you share across the batch?
01:06:16.680 | And the answer is no.
01:06:17.680 | You also do not share across the batch.
01:06:19.480 | In fact, layer normalization was sort of
01:06:22.160 | invented as a replacement for batch normalization, which
01:06:25.360 | did just that.
01:06:26.320 | And the issue with batch normalization
01:06:27.900 | is that now your forward pass sort of depends in a way
01:06:30.960 | that you don't like on examples that should be not
01:06:34.080 | related to your example.
01:06:35.400 | And so, yeah, you don't share statistics across the batch.
01:06:41.640 | Cool.
01:06:44.240 | OK, so now we have our full transformer decoder,
01:06:48.600 | and we have our blocks.
01:06:50.520 | So in this sort of slightly grayed out thing here
01:06:52.960 | that says repeat for a number of decoder blocks,
01:06:58.640 | each block consists of--
01:07:00.520 | I pass it through self-attention,
01:07:02.400 | and then my add and norm.
01:07:04.720 | So I've got this residual connection here
01:07:06.480 | that goes around, add.
01:07:08.400 | I've got the layer normalization there, and then
01:07:10.780 | a feed-forward layer, and then another add and norm.
01:07:15.240 | And so that sort of set of four operations,
01:07:18.040 | I apply for some number of times, number of blocks.
01:07:21.720 | So that whole thing is called a single block.
01:07:24.000 | And that's it.
01:07:24.960 | That's the transformer decoder as it is.
01:07:29.040 | Cool, so that's a whole architecture right there.
01:07:34.040 | We've solved things like needing to represent position.
01:07:36.800 | We've solved things like not being
01:07:39.960 | able to look into the future.
01:07:41.760 | We've solved a lot of different optimization problems.
01:07:44.120 | You had a question?
01:07:45.180 | [INAUDIBLE]
01:07:49.680 | Yes, masked multi-head attention, yeah.
01:07:52.800 | With the dot product scaling with the square root
01:07:55.880 | d over h as well, yeah.
01:07:57.120 | So the question is, how do these models
01:08:05.880 | handle variable length inputs?
01:08:08.560 | Yeah, so if you have--
01:08:13.680 | so the input to the GPU forward pass
01:08:18.180 | is going to be a constant length.
01:08:20.820 | So you're going to maybe pad to a constant length.
01:08:24.740 | And in order to not look at the future, the stuff that's
01:08:28.580 | happening in the future, you can mask out the pad tokens,
01:08:32.720 | just like the masking that we showed for not looking
01:08:35.480 | at the future in general.
01:08:36.520 | You can just say, set all of the attention weights to 0,
01:08:40.380 | or the scores to negative infinity
01:08:42.300 | for all of the pad tokens.
01:08:43.520 | [INAUDIBLE]
01:08:47.660 | Yeah, exactly.
01:08:48.280 | So you can set everything to this maximum length.
01:08:52.200 | Now, in practice-- so the question was,
01:08:53.820 | do you set this length that you have everything
01:08:55.740 | be that maximum length?
01:08:56.980 | I mean, yes, often, although you can save computation
01:09:00.780 | by setting it to something smaller.
01:09:03.420 | And everything-- the math all still works out.
01:09:06.140 | You just have to code it properly so it can handle--
01:09:08.700 | you set everything instead of to n.
01:09:10.220 | You set it all to 5 if everything
01:09:12.260 | is shorter than length 5, and you save a lot of computation.
01:09:15.260 | All of the self-attention operations just work.
01:09:19.340 | So yeah.
01:09:21.980 | How many layers are in the feedforward normally?
01:09:25.340 | There's one hidden layer in the feedforward usually.
01:09:27.500 | Oh, just one?
01:09:28.060 | Yeah.
01:09:28.980 | OK, I should move on.
01:09:30.060 | We've got a couple more things and not very much time.
01:09:33.820 | But I'll be here after the class as well.
01:09:35.620 | So in the encoder-- so the transformer encoder
01:09:38.420 | is almost identical.
01:09:39.540 | But again, we want bidirectional context.
01:09:41.900 | And so we just don't do the masking.
01:09:44.500 | So I've got in my multi-head attention here,
01:09:46.740 | I've got no masking.
01:09:48.620 | And so it's that easy to make the model bidirectional.
01:09:53.380 | So that's easy.
01:09:54.100 | So that's called the transformer encoder.
01:09:55.800 | It's almost identical but no masking.
01:09:58.100 | And then finally, we've got the transformer encoder decoder,
01:10:01.900 | which is actually how the transformer was originally
01:10:04.060 | presented in this paper, "Attention is All You Need."
01:10:07.900 | And this is when we want to have a bidirectional network.
01:10:10.700 | Here's the encoder.
01:10:11.500 | It takes in, say, my source sentence
01:10:13.420 | for machine translation.
01:10:15.060 | Its multi-headed attention is not masked.
01:10:17.580 | And I have a decoder to decode out my sentence.
01:10:22.140 | Now, but you'll see that this is slightly more complicated.
01:10:24.740 | I have my masked multi-head self-attention,
01:10:27.500 | just like I had before in my decoder.
01:10:29.940 | But now I have an extra operation,
01:10:32.980 | which is called cross-attention, where
01:10:35.220 | I am going to use my decoder vectors as my queries.
01:10:41.460 | Then I'll take the output of the encoder as my keys and values.
01:10:45.860 | So now for every word in the decoder,
01:10:48.100 | I'm looking at all the possible words
01:10:50.580 | in the output of all of the blocks of the encoder.
01:10:54.460 | [INAUDIBLE]
01:10:54.940 | How do we get a key and value separated from the output?
01:11:04.340 | Because didn't we collapse those into the single output?
01:11:07.900 | So we-- well, how-- sorry.
01:11:10.700 | How will we get the keys and values out?
01:11:12.780 | Like, how do we-- because when we have the output,
01:11:15.020 | didn't we collapse the keys and values into a single output?
01:11:19.220 | So the output--
01:11:20.100 | [INAUDIBLE]
01:11:21.020 | Yeah, the question is, how do you
01:11:22.440 | get the keys and values and queries out
01:11:24.220 | of this single collapsed output?
01:11:25.700 | Now, remember, the output for each word
01:11:27.900 | is just this weighted average of the value vectors
01:11:30.620 | for the previous words.
01:11:33.020 | And then from that output for the next layer,
01:11:35.780 | we apply a new key, query, and value transformation
01:11:38.700 | to each of them for the next layer of self-attention.
01:11:42.420 | So it's not actually that you're--
01:11:43.780 | [INAUDIBLE]
01:11:44.280 | Yeah, you apply the key matrix, the query matrix,
01:11:50.780 | to the output of whatever came before it.
01:11:52.580 | Yeah.
01:11:53.820 | And so just in a little bit of math,
01:11:56.260 | we have these vectors, h1 through hn,
01:11:59.820 | I'm going to call them the output of the encoder.
01:12:03.420 | And then I've got vectors that are the output of the decoder.
01:12:08.020 | So I've got these z's I'm calling
01:12:09.500 | the output of the decoder.
01:12:10.860 | And then I simply define my keys and my values
01:12:15.340 | from the encoder vectors, these h's.
01:12:19.460 | So I take the h's, I apply a key matrix and a value matrix,
01:12:22.860 | and then I define the queries from my decoder.
01:12:26.540 | So my queries here-- so this is why two of the arrows
01:12:29.060 | come from the encoder, and one of the arrows
01:12:31.300 | comes from the decoder.
01:12:32.540 | I've got my z's here, my queries, my keys and values
01:12:36.140 | from the encoder.
01:12:37.260 | So that is it.
01:12:44.320 | I've got a couple of minutes.
01:12:45.540 | I want to discuss some of the results of transformers,
01:12:48.340 | and I'm happy to answer more questions
01:12:49.920 | about transformers after class.
01:12:53.060 | So really, the original results of transformers,
01:12:56.340 | they had this big pitch for, oh, look,
01:12:58.580 | you can do way more computation because of parallelization.
01:13:02.340 | They got great results in machine translation.
01:13:04.780 | So you had-- you had transformers doing quite well,
01:13:13.060 | although not astoundingly better than existing machine
01:13:16.980 | translation systems.
01:13:20.020 | But they were significantly more efficient to train.
01:13:22.260 | Because you don't have this parallelization problem,
01:13:25.380 | you could compute on much more data much faster,
01:13:27.620 | and you could make use of faster GPUs much more.
01:13:32.100 | After that, there were things like document generation,
01:13:35.060 | where you had the old standard of sequence-to-sequence models
01:13:37.940 | to the LSTMs.
01:13:39.060 | And eventually, everything became transformers
01:13:42.420 | all the way down.
01:13:45.140 | Transformers also enabled this revolution
01:13:47.340 | into pre-training, which we'll go over in next year,
01:13:50.420 | next class.
01:13:52.060 | And the efficiency, the parallelizability
01:13:54.660 | allows you to compute on tons and tons of data.
01:13:58.340 | And so after a certain point, on standard large benchmarks,
01:14:02.580 | everything became transformer-based.
01:14:04.740 | This ability to make use of lots and lots of data,
01:14:07.540 | lots and lots of compute, just put transformers head
01:14:10.380 | and shoulders above LSTMs in, let's say,
01:14:13.420 | almost every modern advancement in natural language processing.
01:14:19.900 | There are many drawbacks and variants to transformers.
01:14:24.620 | The clearest one that people have
01:14:25.820 | tried to work on quite a bit is this quadratic compute
01:14:28.420 | problem.
01:14:29.260 | So this all pairs of interactions
01:14:31.780 | means that our total computation for each block
01:14:34.420 | grows quadratically with the sequence length.
01:14:36.260 | And in a student's question, we heard that, well,
01:14:39.780 | as the sequence length becomes long,
01:14:41.580 | if I want to process a whole Wikipedia article,
01:14:44.220 | a whole novel, that becomes quite unfeasible.
01:14:48.100 | And actually, that's a step backwards in some sense,
01:14:50.780 | because for recurrent neural networks,
01:14:52.740 | it only grew linearly with the sequence length.
01:14:55.540 | Other things people have tried to work on
01:14:57.380 | are better position representations,
01:14:59.860 | because the absolute index of a word
01:15:02.060 | is not really the best way maybe to represent
01:15:05.260 | its position in a sequence.
01:15:07.940 | And just to give you an intuition of quadratic sequence
01:15:10.300 | length, remember that we had this big matrix multiply here
01:15:13.620 | that resulted in this matrix of n by n.
01:15:16.860 | And computing this is a big cost.
01:15:20.380 | It costs a lot of memory.
01:15:22.420 | And so there's been work--
01:15:23.700 | oh, yeah.
01:15:24.220 | And so if you think of the model dimensionality
01:15:26.260 | as like 1,000, although today it gets much larger,
01:15:29.340 | then for a short sequence of n is roughly 30,
01:15:32.460 | maybe if you're computing n squared times d, 30 isn't so
01:15:38.740 | But if you had something like 50,000,
01:15:42.020 | then n squared becomes huge and sort of totally infeasible.
01:15:46.540 | So people have tried to sort of map things down
01:15:49.180 | to a lower dimensional space to get rid
01:15:51.060 | of the sort of quadratic computation.
01:15:54.540 | But in practice, I mean, as people
01:15:56.820 | have gone to things like GPT-3, Chat-GPT,
01:15:59.700 | most of the computation doesn't show up in the self-attention.
01:16:03.260 | So people are wondering sort of is it even
01:16:05.500 | necessary to get rid of the self-attention operations
01:16:08.900 | quadratic constraint?
01:16:10.300 | It's an open form of research whether this
01:16:12.900 | is sort of necessary.
01:16:14.780 | And then finally, there have been a ton of modifications
01:16:17.460 | to the transformer over the last five, four-ish years.
01:16:21.940 | And it turns out that the original transformer
01:16:25.100 | plus maybe a couple of modifications
01:16:27.820 | is pretty much the best thing there is still.
01:16:31.140 | There have been a couple of things
01:16:32.580 | that end up being important.
01:16:33.940 | Changing out the nonlinearities in the feedforward network
01:16:37.380 | ends up being important.
01:16:38.740 | But it's had lasting power so far.
01:16:43.140 | But I think it's ripe for people to come through and think
01:16:46.340 | about how to sort of improve it in various ways.
01:16:49.220 | So pre-training is on Tuesday.
01:16:52.020 | Good luck on assignment four.
01:16:53.260 | And then we'll have the project proposal documents out tonight
01:16:57.060 | for you to talk about.
01:16:59.140 | [BLANK_AUDIO]