back to index

Stanford XCS224U: NLU I Fantastic Language Models and How to Build Them, Part 1 I Spring 2023


Chapters

0:0
0:57 Addressing the known limitations with BERT
2:0 Core model structure (Clark et al. 2019)
5:49 Generator/Discriminator relationships
8:50 ELECTRA efficiency analyses
12:22 ELECTRA model releases
14:54 From the RNN era
15:45 Transformer-based options
21:9 Trends in model size
22:1 Distillation objectives
24:24 Distillation performance
28:16 Pretraining data
28:35 Current trends

Whisper Transcript | Transcript Only Page

00:00:00.000 | All right.
00:00:06.080 | Welcome everyone. Welcome back.
00:00:09.480 | Let's get started. We have another action-packed day for you.
00:00:13.520 | Time's a wasting.
00:00:15.280 | To start here, I'm gonna finish up our big,
00:00:19.120 | uh, slide deck on contextual word representations.
00:00:22.020 | There are just a few more small things to cover.
00:00:25.320 | And then Sid is gonna get- get us- help us get hands
00:00:28.680 | on with training really big models.
00:00:32.360 | So there's the link as usual, uh,
00:00:36.080 | from the website if you wanna follow along and we're gonna
00:00:38.400 | skip right to this section called Electra.
00:00:41.360 | Electra is a model that came from Stanford from Kevin Clark and collaborators.
00:00:45.920 | Uh, and I think it's really exciting.
00:00:48.520 | It shows you the kind of design space we're in,
00:00:51.840 | a really creative example of, you know,
00:00:54.160 | doing something that was different from what had come before in the space of transformers.
00:00:58.640 | Last time we talked about some known limitations of the BERT model.
00:01:03.120 | Most of them identified in the BERT paper itself.
00:01:05.880 | I covered that first one.
00:01:07.520 | You know, we- we just wanted more ablation studies,
00:01:09.880 | more exploration of the BERT architecture.
00:01:12.600 | The Roberta team kicked that off.
00:01:14.480 | I think they did a great job.
00:01:16.080 | With Electra, we're gonna address known limitations two and three.
00:01:20.640 | The first is that we have a mismatch between
00:01:23.880 | the trained vocabulary and the fine-tuned vocabulary
00:01:27.240 | because of the role of the mask token in training BERT models.
00:01:31.600 | And the second one which might feel more pressing to you,
00:01:35.040 | is that BERT is pretty inefficient when it comes to learning from data
00:01:39.000 | because we mask out or replace about 15% of the tokens.
00:01:43.720 | And as you recall from the BERT learning objective,
00:01:46.840 | those are the only tokens that contribute to the learning objective itself.
00:01:50.960 | All of the other work is kind of redundant.
00:01:53.320 | And so we might hope that we could make
00:01:55.240 | more efficient use of all these sequences that we're processing.
00:01:58.640 | Electra is gonna make some progress on that too.
00:02:01.640 | So let's focus on the core model structure
00:02:04.560 | and then we'll look at all the other things they did in the paper.
00:02:07.440 | We'll start with our input sequence X.
00:02:09.960 | This is the chef cooked the meal.
00:02:12.200 | And the first thing we do is mask out some of those tokens and that could be
00:02:16.080 | a random sample of 15% of the tokens just like in most work with BERT.
00:02:21.320 | Then we have what could be literally a BERT model.
00:02:24.240 | We're gonna call it the generator.
00:02:25.880 | Typically a small one that has a mass language modeling objective.
00:02:29.760 | And that can produce output sequences as usual.
00:02:33.200 | However, the twist here is that we're gonna replace some of
00:02:37.880 | the tokens that came from the input with kind of randomly sampled ones from the MLMs.
00:02:43.480 | You can see here that we've copied over the and copied over chef,
00:02:47.640 | but now eight has been replaced by cook.
00:02:49.980 | That might not have been the most probable output for the model,
00:02:52.880 | but we're gonna do that replacement step there.
00:02:55.520 | So what we've created- created here is a sequence that we can call X corrupt,
00:03:00.280 | a corrupted version of the input.
00:03:02.720 | And that is the primary job of this generator model.
00:03:06.360 | At this point, the heart of Electra takes over.
00:03:09.360 | This is called the discriminator,
00:03:11.080 | but we can also talk about it as the Electra model itself in essence.
00:03:14.880 | The job of the discriminator is to figure out which of
00:03:18.780 | those tokens were originals and which ones were replacements.
00:03:22.800 | So that's a kind of contrastive learning objective.
00:03:25.360 | You can see here that the actual label it's gonna learn from from eight is that it was
00:03:29.560 | replaced and for the- that it wasn't original even though it was a sampled token.
00:03:34.760 | And the actual loss for the model is the generator,
00:03:37.800 | that is the- the typical BERT MLM loss together with this Electra loss with a weighting.
00:03:44.160 | That's how the model is trained,
00:03:45.860 | but there is as I said an asymmetry here in the sense that once we've done
00:03:49.640 | the pre-training phase we can let the generator fall away entirely and focus just on
00:03:55.360 | the discriminator as the model that we're gonna use for downstream fine-tuning tasks.
00:04:00.680 | And so you can see already that we've in a way
00:04:04.040 | solved the problem of having this weird mask token that comes from
00:04:07.440 | the pre-training phase because the discriminator never sees mask tokens.
00:04:12.000 | All it sees are these corrupted inputs and it learns to figure out which ones are
00:04:16.960 | the corrupted versions and which ones are the originals.
00:04:20.740 | Which is a different capability intuitively than the one we were imbuing the core BERT model with.
00:04:27.960 | Right? So for BERT it's kind of like the objective is to figure
00:04:30.720 | out what was missing from the surrounding context.
00:04:33.720 | And here it's like trying to figure out which of the words in
00:04:37.120 | the sequence doesn't belong and which of them do belong.
00:04:40.520 | A kind of more discriminating objective.
00:04:43.920 | So that is Electra.
00:04:46.760 | Before we dive into the experiments and stuff,
00:04:50.000 | questions about how that model works. Yeah.
00:04:53.920 | Yes, I'm wondering what the uses of this model are.
00:04:56.880 | So it tries to predict which ones have been replaced.
00:05:00.040 | Like do you- like what applications do you use Electra for?
00:05:03.600 | For pre-training. Yeah, that's what you got to get your head around.
00:05:06.560 | This is a great subtlety to bring out.
00:05:08.360 | So the discriminator is now gonna be our pre-trained artifact.
00:05:12.200 | So just the way you download BERT,
00:05:14.000 | and when you do that you're downloading some MLN trained object- uh,
00:05:17.320 | thing, now you download Electra which is the discriminator.
00:05:21.360 | And it's been trained to do this distinguishing thing,
00:05:25.520 | as opposed to the filling in the blank or
00:05:27.600 | continuing thing from the models we've seen so far.
00:05:30.240 | The eye-opening thing is that that contrastive objective
00:05:33.500 | leads to a really good pre-trained state for fine-tuning.
00:05:37.680 | And we might hope that it's doing it much more efficiently,
00:05:42.080 | but that's what we can dive into now here.
00:05:43.840 | So first, generator-discriminator relationships.
00:05:47.760 | They observe in the paper that when the generator and the discriminator are the same size,
00:05:52.720 | they can share all their transformer parameters,
00:05:55.560 | and more sharing is better.
00:05:57.280 | So already we have an efficiency gain,
00:05:58.900 | and that's kind of intriguing that one in the same set of weights would be playing
00:06:02.560 | the role of the MLN and the discriminator- generator-discriminator.
00:06:06.760 | But they observe that they guess the- get
00:06:09.640 | the best results from having a generator that is small compared to the discriminator.
00:06:15.400 | And this plot kind of teases that out.
00:06:17.360 | So we've got our glue score along the y-axis.
00:06:19.880 | This will just be a measure of system quality for them.
00:06:23.520 | And along the x-axis here,
00:06:25.160 | we have generator size.
00:06:27.080 | And what they mean by that is the dimensionality of the model in the BERT sense.
00:06:32.040 | So essentially the size of each one of the layers.
00:06:35.080 | And if we zoom in, for example,
00:06:36.960 | on this blue line, the best performing model,
00:06:38.840 | this is where we have 768 as our dimensionality for the discriminator,
00:06:43.960 | and 768 for the generator.
00:06:47.520 | As we make the generator smaller,
00:06:50.600 | all the way down to 256,
00:06:52.800 | performance improves.
00:06:54.640 | And that's what we mean when we say better to have a small generator and a large discriminator.
00:07:00.000 | And that kind of U-shaped pattern is repeated across all the different discriminator sizes.
00:07:05.800 | And that's probably an insight about how this model is working,
00:07:09.080 | which is the sense that you kind of want the generator to be a little bit of
00:07:12.560 | a noisy process so that the discriminator has some interesting work to do.
00:07:17.760 | And by making the discriminator more powerful,
00:07:20.480 | I guess you're creating that kind of opportunity.
00:07:24.000 | They also do a lot of work looking at efficiency because one of
00:07:29.400 | the side goals of the elector paper was to end up with models that were
00:07:33.560 | overall more efficient in terms of the pre-training compute and in terms of the model size.
00:07:38.520 | Here's another way to quantify that.
00:07:40.200 | Again, along the y-axis,
00:07:41.600 | we have the glue score,
00:07:42.920 | and along the x-axis now we have pre-trained flops,
00:07:45.800 | which you could just think of as a very low level measure of how
00:07:49.040 | much compute resources we need to do the pre-training part.
00:07:53.520 | The blue line at the top is Electra,
00:07:55.960 | and it's the best no matter what your computational budget is along the x-axis.
00:08:01.040 | They also explore adversarial Electra.
00:08:04.240 | This is very intuitive to me.
00:08:05.960 | That is a slightly different objective where the generator is trying to fool
00:08:10.560 | the discriminator by creating corrupted sequences that are hard for the generator to distinguish.
00:08:16.840 | That's a really good model,
00:08:18.400 | but it's less than the kind of more cooperative,
00:08:21.120 | um, joint objective that I showed you before.
00:08:24.040 | And then the green line is cool too.
00:08:25.840 | So the green line is where I start training with BERT,
00:08:29.760 | and that at a certain point I switch to having also the discriminator loss.
00:08:34.640 | And at that point, the BERT model is less good for any compute budget,
00:08:38.640 | whereas the Electra variant starts to do its Electra thing and get better and better.
00:08:43.640 | So a bunch of perspectives on Electra,
00:08:46.200 | all pointing to it being a good and efficient model.
00:08:50.040 | And then finally, they do a bunch more efficiency analyses.
00:08:53.760 | So this is that picture that I showed you of the full Electra model before,
00:08:57.160 | where I have the generator creating corrupted sequences,
00:09:00.480 | and then the discriminator doing its discriminating part there.
00:09:04.640 | You could also explore Electra 15 percent,
00:09:07.400 | and this is different from full Electra in the sense that on the right for default Electra,
00:09:12.760 | we make predictions about all of these tokens,
00:09:15.240 | whether they were original or replaced.
00:09:17.600 | For the 15 percent version,
00:09:19.360 | we kind of do a BERT-like thing where we're going to assume that the ones that weren't
00:09:22.720 | part of those corrupted chains there, the sampled part,
00:09:26.920 | are just not part of the objective.
00:09:28.880 | There'll be fewer tokens there.
00:09:31.040 | Replace MLM. This is an ablation where actually we drop away the Electra part,
00:09:36.200 | and we're just looking at the MLM here,
00:09:38.800 | and we're going to not have the mask token at all.
00:09:42.160 | Because remember for BERT,
00:09:43.760 | there are a few ways that they do this learning.
00:09:45.680 | They do the mask token,
00:09:46.960 | and they also do the one where they just replace it with some random tokens here,
00:09:50.620 | like cook to run, and then the model has to reproduce a new token.
00:09:55.200 | Oh, that should say cooked I guess,
00:09:57.440 | because it's pure BERT.
00:10:00.080 | This is a kind of look at what happens if we don't introduce
00:10:04.720 | that mask token addressing that question about whether that was disrupting learning.
00:10:08.960 | Then finally, all tokens MLM.
00:10:10.840 | This is again just a BERT-based objective over here where instead of turning off
00:10:14.920 | the objective for these ones here that weren't part of the corrupted sequence,
00:10:18.760 | we do learning from all of them.
00:10:20.720 | That's a way of saying for BERT,
00:10:23.040 | if we were making more efficient use of the data,
00:10:25.120 | could we learn more quickly?
00:10:27.200 | Here are the results. So Electra is at the top,
00:10:30.720 | but just below it is all tokens MLM.
00:10:34.120 | So that's just BERT learning from all of the tokens,
00:10:36.400 | and I think that does show that BERT could have been a little better if they had not
00:10:40.640 | turned off the objective for every single token that wasn't
00:10:43.600 | part of the masking or corruption for that learning process.
00:10:47.400 | Replace MLM is just below that,
00:10:49.520 | and that's where we don't have any mask token.
00:10:51.340 | So there's no fine-tuning pre-trained mismatch.
00:10:53.960 | Electra 15 below that,
00:10:55.720 | and then BERT at the bottom.
00:10:56.920 | So overall, you're seeing these ablations are showing us that every piece of
00:11:01.240 | Electra is contributing something to the overall performance of the model,
00:11:05.760 | and that's quite nice as well. Yeah.
00:11:08.400 | How is the efficiency of all tokens MLM?
00:11:11.960 | The efficiency?
00:11:13.840 | Yeah.
00:11:14.640 | Well, we're making more efficient use of the data because
00:11:17.480 | we're getting a learning signal from every token,
00:11:20.000 | and I guess that would be the important dimension because a funny thing about BERT where
00:11:24.220 | we turn off the learning for the ones that weren't masked or corrupted,
00:11:28.980 | is that we still have to do the work of computing them.
00:11:31.420 | It's just that then they don't become part of the objective,
00:11:34.360 | and here we're just kind of bringing that in.
00:11:36.640 | So for free or close to it.
00:11:39.420 | My question was, how is the glue score calculated?
00:11:44.540 | What does it represent?
00:11:45.580 | Some accuracy in language generation afterwards,
00:11:48.940 | or is it the classifier that's being scored?
00:11:51.580 | Oh, yeah. So glue is a big multitask classification benchmark.
00:11:56.220 | It's a pretty diverse set of tasks,
00:11:58.540 | maybe biased toward natural language inference.
00:12:01.620 | The reason they're using it in the paper is just that it has been,
00:12:06.100 | it has been adopted as a kind of general purpose measure of performance,
00:12:11.020 | and it's driven a lot of reasoning about what's good and what's bad in the field.
00:12:17.060 | Then here are some model releases.
00:12:19.860 | Base and large, kind of align with BERT,
00:12:22.580 | and then we have this small model here,
00:12:24.460 | and that was designed to be quickly trained on a single GPU,
00:12:27.340 | again, as a nod toward efficiency,
00:12:28.940 | and all three are really good models. Yeah.
00:12:34.100 | The things that we've observed at Electra,
00:12:37.420 | so like putting our text into some kind of representation space,
00:12:42.140 | that's better than BERT and for BERT.
00:12:44.700 | Or is it just like generally glue-wise, it's like that?
00:12:50.060 | Oh, I like that question.
00:12:51.860 | That could kind of queue up some analysis work that you could do for a final project.
00:12:55.980 | Because I think the insight behind your question is that a lot of the time,
00:12:59.540 | we reason about these models just based on their performance on something like glue.
00:13:03.740 | You could ask a deeper question,
00:13:05.520 | what are their internal representations like,
00:13:07.940 | and are there places where they're transformatively better or worse?
00:13:11.500 | That you could tie that back to the fact that the learning objective is different.
00:13:15.700 | We're doing this discrimination thing as opposed to filling in the blanks in some sense.
00:13:20.140 | Maybe there are some underlying differences. I love that.
00:13:23.940 | All right. Couple more topics here,
00:13:32.500 | just quickly because we're going to do more work with seek-to-seek models later.
00:13:36.780 | We're going to train some of our own from scratch,
00:13:38.740 | and you all might use some fine-tuned ones.
00:13:40.420 | So I thought it would be good to just get them on the table as well.
00:13:44.060 | Seek-to-seek, here's some natural tasks that fall into the seek-to-seek structure.
00:13:49.060 | Machine translation, right?
00:13:50.460 | Source language to target language.
00:13:52.620 | Summarization, big text to hopefully smaller text.
00:13:56.720 | Freeform question answering, where you go from a question and then maybe you're
00:13:59.980 | generating as opposed to just extracting an answer.
00:14:03.640 | Dialogue, of course.
00:14:05.420 | Semantic parsing, this is the one we're going to tackle,
00:14:07.980 | where you go from a sentence to some kind of logical form representing its meaning.
00:14:12.740 | Code generation, of course,
00:14:14.540 | that's similar, and on and on.
00:14:16.380 | I think there are lots of problems that are pretty naturally cast as seek-to-seek problems,
00:14:20.980 | especially when you've got different stuff on the input and the output side.
00:14:27.100 | Yeah, and the more general class of things we could be talking about would be encoder-decoder,
00:14:32.380 | where that's just more general in the sense that at that point,
00:14:35.380 | the input could be a picture you're encoding and the output a text.
00:14:38.980 | Picture-to-picture, video-to-picture, in principle,
00:14:42.500 | anything could be happening on the two sides.
00:14:44.260 | Seek-to-seek would, for me,
00:14:45.740 | just be the special case where we're looking at sequential data,
00:14:48.620 | typically language data or computer code or something.
00:14:53.060 | From the RNN era,
00:14:56.980 | this is just nice if you hearken back to that era, if you live through it.
00:15:01.940 | This is a paper from Tang Luong.
00:15:03.820 | Doing seek-to-seek on the left in
00:15:06.140 | the traditional way with a recurrent neural network, an RNN.
00:15:09.020 | Pretty simple, right? We've got A, B, C,
00:15:10.900 | D coming in, and then it transitions into maybe other parameters,
00:15:14.340 | and it's trying to produce this new sequence left to right coming out.
00:15:17.820 | Just to remind you, this is part of the journey the field went on.
00:15:21.260 | What Tang did, very influentially,
00:15:23.540 | is think a lot about how you would add attention layers in to that RNN.
00:15:29.660 | That's what you see depicted here.
00:15:31.100 | This is a schematic diagram hinting at the fact that we were moving into an era,
00:15:36.340 | the Vaswani et al era,
00:15:38.020 | attention is all you need,
00:15:39.340 | where basically that attention layer would do all the work.
00:15:43.380 | That's where we're at now.
00:15:46.300 | For seek-to-seek problems in general,
00:15:49.580 | this is a nice framework from the T5 paper.
00:15:52.320 | There are a few different ways you could think about them.
00:15:54.800 | On the left is the one I'm nudging you toward,
00:15:57.340 | where we have an encoder and a decoder.
00:15:59.620 | If we're talking about transformer models,
00:16:02.100 | what essentially that means is that when we do encoding,
00:16:05.180 | we can connect everything to everything else.
00:16:07.140 | You can think of that as a process of simultaneously
00:16:10.220 | encoding the entire input with all its connections.
00:16:14.060 | But as we do decoding for many of these problems,
00:16:16.780 | we need to do some sequential generation.
00:16:19.700 | The result of that is we can look back to the decoder,
00:16:23.420 | sorry, the encoder all we want,
00:16:25.180 | but for the decoder, we have to do that masking that I
00:16:27.860 | described with the autoregressive loss last time,
00:16:30.900 | so that we don't look into the future.
00:16:33.220 | But with that constraint,
00:16:34.900 | we can do this encoder-decoder thing with
00:16:37.340 | a decoding step be truly sequential decoding.
00:16:40.740 | But that's not the only way to think about these problems.
00:16:43.540 | In fact, I don't want to presuppose that for a sequence-to-sequence thing,
00:16:47.220 | you'll use a sequence-to-sequence model.
00:16:49.180 | You could, for example,
00:16:50.340 | use a language model and the way you might do
00:16:52.260 | that is to say I'm just going to encode everything left to right.
00:16:55.660 | That's the version in the middle.
00:16:57.620 | Then a kind of compromise position would be that you would take your language model,
00:17:01.540 | which might be autoregressive,
00:17:03.060 | but simultaneously encode the entire input,
00:17:06.620 | and then begin your process of decoding without
00:17:09.820 | explicitly having an encoder part and a decoder part.
00:17:13.980 | I think all of them are on the table and people are solving
00:17:16.740 | seek-to-seek tasks right now using all of these variants.
00:17:21.300 | T5, I'm going to show you two.
00:17:25.180 | There are lots out there but these are very prominent ones that you might download.
00:17:28.820 | So T5, this is a wonderful,
00:17:30.700 | very rich paper that does a lot of
00:17:32.420 | exploration of which of these architectures are effective.
00:17:35.420 | T5 ended up on an encoder-decoder variant,
00:17:38.460 | and what they did is an impressive amount of multitask training,
00:17:42.860 | unsupervised and supervised objectives.
00:17:45.700 | An innovative thing that they did is have these task prefixes,
00:17:50.620 | like translate English to German and then an English sentence,
00:17:53.940 | or this is a COLA sentence,
00:17:55.900 | that's just a data set people use,
00:17:57.660 | or an STSB sentence,
00:18:00.020 | and that's the model's cue to take that input and condition it very informally on
00:18:05.500 | that task so that the output behavior is the expected behavior.
00:18:10.180 | There are lots of T5 models that you can download.
00:18:13.860 | This is nice for development because some of them are
00:18:15.860 | very small and some of them are very, very large.
00:18:18.540 | More recently, there are these FLAN models which took
00:18:21.460 | a T5 architecture and did a lot of reinforcement learning with
00:18:24.460 | human feedback to even further
00:18:27.060 | specialize them to different tasks in interesting ways.
00:18:30.660 | So that's T5, and then the other one that you often hear about that's very effective is BART.
00:18:35.900 | BART is interestingly different yet again.
00:18:38.740 | So BART is an encoder-decoder framework,
00:18:41.700 | and it's really got a BERT style thing on the left,
00:18:45.420 | and then a GPT style thing on the right,
00:18:47.740 | that is, joint encoding of everything,
00:18:50.100 | and then that autoregressive part if you want to do sequential generation.
00:18:53.940 | The innovative thing about BART is that the training
00:18:57.500 | involves a lot of corrupting of that input sequence.
00:19:01.220 | They tried to like do text infilling of pieces,
00:19:04.220 | they shuffled sentences around,
00:19:06.020 | they did some masking,
00:19:07.340 | and they did some token deletion,
00:19:08.980 | rotating of documents,
00:19:10.360 | all of this corrupting of the input,
00:19:12.520 | and then the model's objective is to learn how to
00:19:15.100 | essentially uncorrupt what it got as the input.
00:19:18.100 | They found that the joint process of this text infilling thing and
00:19:22.260 | sentence shuffling was the most effective for training BART.
00:19:26.220 | So that was for the pre-training phase,
00:19:28.220 | and when then you- then when you fine-tune with BART,
00:19:30.820 | for classification tasks,
00:19:32.460 | you just put in two uncorrupted copies of your sentence,
00:19:35.600 | and then you could fit your task specific labels on like
00:19:39.140 | the class token or the final token of the GPT output,
00:19:42.420 | and for seek-to-seek, you just use it as a standard encoder-decoder.
00:19:46.380 | And again, the intuition is that the pre-training phase which did all this corruption,
00:19:50.580 | has helped the model learn what sequences are like.
00:19:53.980 | And that blends together for me a lot of the intuitions we've seen from MLM,
00:19:58.700 | and from what we just talked about with Electra. Yeah.
00:20:03.460 | Um, kind of a question that I asked last week.
00:20:05.300 | Have any of these models worked with like spelling mistakes?
00:20:09.180 | Yeah. So BART is a really good option if you want to do spelling correction.
00:20:13.660 | And I actually think that that might be because spelling correction is kind of as a task,
00:20:19.180 | a corrupting of the input where you're
00:20:21.100 | trying to learn the uncorrupted version of the output.
00:20:23.580 | So I think if you want to do grammar correction,
00:20:25.540 | spelling correction, things like that,
00:20:27.280 | it's outstanding to use BART,
00:20:28.900 | and you might just think of training from scratch on a model that you know is
00:20:32.900 | going to be aware of characters for these character level things. Yeah.
00:20:37.580 | Sorry, what is text in building?
00:20:39.820 | That was where they like removed parts of the text essentially,
00:20:44.580 | and added other pieces to corrupt it.
00:20:46.820 | Different from masking where you just hide.
00:20:49.140 | Yeah, where you just, that's more like the BERT style thing where you hide some.
00:20:52.340 | Yeah. And okay, final quick topic.
00:20:59.260 | I just want you all to know about distillation.
00:21:02.220 | Again, because a theme of this course could be how can we do more with less?
00:21:05.980 | And distillation is a vision for how we could do more with less.
00:21:09.340 | Right. We saw this trend in model sizes here where they're getting bigger and bigger,
00:21:13.460 | and there is some hope that they might now be getting smaller.
00:21:16.980 | But we should all be pushing to make them ever smaller.
00:21:20.300 | And one way to think about doing that is distillation.
00:21:22.700 | And the metaphor here is that we're going to have two models.
00:21:25.460 | Maybe a really big teacher model that was trained in
00:21:28.540 | a very expensive way and might run only on a supercomputer,
00:21:31.840 | and then a much smaller student.
00:21:34.140 | And we're going to train that student to mimic the behavior of the teacher.
00:21:38.220 | And we could do that by just observing the output behavior of the teacher,
00:21:42.380 | and then trying to get the student to align at the level of the output.
00:21:46.260 | And that would basically just be treating this teacher as a kind of input output device.
00:21:50.980 | We could also though think about aligning the internal representations of these two,
00:21:56.020 | to get a deeper alignment between teacher and student.
00:22:00.100 | Here's some objectives in fact,
00:22:02.220 | and this is from least to most heavy duty,
00:22:04.620 | and you could combine them.
00:22:06.060 | So we could just use our gold data for the task.
00:22:08.940 | I put that as step zero because you might want it in the mix here,
00:22:12.080 | even as you use your teacher.
00:22:13.940 | We could also learn just from the teacher's output labels.
00:22:17.440 | That's a bit of a funny idea,
00:22:19.180 | but I think the intuition is that the teacher might be doing
00:22:22.260 | some very complicated regularization that helps the student learn more efficiently.
00:22:27.420 | So even if there are mistakes in the teacher's behavior,
00:22:30.260 | the student actually benefits.
00:22:32.740 | You could also think about going one level deeper and using
00:22:35.620 | the full output scores like the logits,
00:22:37.620 | so not just the discrete outputs but the whole distribution that the model predicts.
00:22:42.140 | And that's what they did in one of the original distillation papers.
00:22:45.900 | You could also tie together the final output states.
00:22:49.180 | If the two models have the same layer-wise dimensionality,
00:22:52.260 | then for example in this distilbert paper,
00:22:54.740 | they enforce as part of the objective,
00:22:56.720 | a cosine similarity between the output states of teacher and student.
00:23:00.820 | And now you need to have access to the model itself.
00:23:03.540 | And this will be much more expensive because you need to
00:23:05.540 | run the teacher as part of distillation.
00:23:08.700 | You could also think about doing this with lots of other hidden states.
00:23:12.300 | People have explored lots of other things.
00:23:14.500 | And you could even, this is a paper that we did,
00:23:16.860 | try to mimic them under different counterfactuals where you kind of
00:23:20.260 | change around the input representations of the teacher,
00:23:23.180 | observe the output, and then try to get the student to do
00:23:25.560 | that to mimic very strange behavior from the teacher.
00:23:30.980 | And then there are a bunch of other things you can do.
00:23:33.620 | So standard distillation is where you have
00:23:35.940 | your big model frozen and the teacher is being updated by the process.
00:23:39.980 | If you have multi-teacher,
00:23:41.300 | that's where there are lots of big models maybe doing
00:23:43.180 | multiple tasks and you try to distill them all at once down into a student.
00:23:47.100 | That's a very exciting new frontier.
00:23:49.860 | Co-distillation is where they're trained jointly,
00:23:53.220 | sometimes also called online distillation.
00:23:55.500 | That's where both the teacher and the student are learning together simultaneously.
00:23:59.300 | Might be unnerving in the classroom but effective for a model.
00:24:02.580 | And then self-distillation is actually where you try to get like usually
00:24:06.580 | lower parts of the model to be like other parts of the model by
00:24:10.300 | having them mimic themselves as part of the core model training.
00:24:14.060 | So that's a special case,
00:24:15.560 | I guess, of co-distillation where there's only one model and
00:24:18.600 | you're trying to distill parts of it into other parts.
00:24:21.420 | That's kind of wild to think about.
00:24:23.860 | And this has been applied in many domains.
00:24:26.860 | And the reason I can be encouraging about this is that as we get better and better at
00:24:30.980 | distillation we're finding that distilled models are as good or better than the teacher models.
00:24:36.580 | Maybe for a fraction of the cost and this is especially
00:24:40.020 | relevant if the model is being used in production on a small device or something.
00:24:44.580 | So here are just some glue performance numbers that show across a bunch of
00:24:48.140 | these different papers that with distillation you can still
00:24:51.300 | get glue performance like the teacher with a tiny model.
00:24:56.420 | Yeah.
00:24:59.100 | Something really puzzling to me is how can a smaller simple model be able to mimic a teacher
00:25:04.620 | when the training set is fixed?
00:25:06.500 | Couldn't you have just trained the simple model?
00:25:09.100 | Or is it just that the teacher has access to some point of learning that is easier to
00:25:14.180 | navigate to for a student but not for a student to get to a node?
00:25:19.860 | I think something like what you just said has to be right.
00:25:24.260 | I actually don't- so you're asking about the special case where the teacher just does
00:25:28.900 | its input output thing and produces a dataset that we train the student on, right?
00:25:33.460 | And you're asking why is that better than just training the student on your original data?
00:25:37.660 | It's very mysterious to me.
00:25:38.940 | [LAUGHTER]
00:25:40.660 | I- the best metaphor I can give you is that it is a kind of regularizer.
00:25:44.740 | So the teacher is doing something very complicated
00:25:47.500 | and even its mistakes are useful for the student.
00:25:50.740 | I guess this may be a simple way that I'm understanding it.
00:25:54.460 | It's okay to make certain mistakes and the teacher has figured out which mistakes you can-
00:25:58.420 | [NOISE]
00:25:59.660 | Not worry about.
00:26:00.420 | I like that. I like- that's a beautiful opening line of a paper.
00:26:03.900 | We need to make it substantive by actually explaining what that means.
00:26:07.180 | But it's a- I like it as a vision for sure.
00:26:10.380 | I want to be a little careful of time.
00:26:13.220 | One more question and then I'll just wrap up.
00:26:15.060 | Do we have some comparisons where the student is,
00:26:17.780 | I mean, less- less general versus- versus the teacher?
00:26:21.420 | I mean, does it- does it overfit the data in a sense,
00:26:25.100 | more than the teacher?
00:26:27.700 | The student?
00:26:29.300 | Yeah.
00:26:30.260 | I don't know. I mean, you would guess less if it has a tiny capacity.
00:26:33.820 | It won't have as much of a capacity to overfit than the teacher.
00:26:37.780 | And maybe that's why in some situations the students outperform the teachers.
00:26:43.180 | I hope that's inspiring to you all.
00:26:45.740 | Let me wrap up here.
00:26:47.340 | So you can go on to outperform me.
00:26:49.540 | Architectures I didn't mention, Transformer, Excel,
00:26:52.340 | wonderful creative attempt to model long sequences by essentially creating
00:26:56.940 | a recurrent process across cached versions of earlier parts of the long document you're processing.
00:27:04.140 | ExcelNet, this is a beautiful and creative attempt to use mask language modeling, sorry,
00:27:10.700 | an autoregressive language modeling objective but still have
00:27:14.620 | bidirectional context and they do this by creating all these permutation orders of
00:27:19.580 | the original sequence so that you can effectively condition on the left and the right,
00:27:24.020 | even though you can't look into the future.
00:27:27.220 | And then DeBerta, this is really cool.
00:27:29.460 | I regret not fitting this in.
00:27:30.980 | DeBerta is an attempt to separate out the word and positional encodings for
00:27:35.900 | these models and kind of make the word embeddings more like first-class citizens.
00:27:41.100 | And that's very intuitive for me because that's like showing that we want
00:27:44.460 | the model to learn some semantics for these things that's separate from their position.
00:27:49.340 | And they did that by reorganizing the attention mechanisms.
00:27:53.300 | The known limitations, we did a good job on these except for this final one.
00:27:57.940 | BERT assumes that the predicted tokens are all independent of each other given the unmasked tokens.
00:28:02.820 | I gave you that example of masking new in York and it thinking that both of
00:28:07.060 | those are independent of the other given the surrounding context.
00:28:10.460 | ExcelNet again addresses that and that might be something that you want to meditate on.
00:28:16.140 | Pre-training data, here's a whole mess of resources.
00:28:19.700 | If you did want to pre-train your own model,
00:28:21.420 | maybe Sid will talk more about this.
00:28:23.260 | I'm offering these primarily because I think you might want to audit them as
00:28:26.900 | you observe strange behavior from your large models.
00:28:29.860 | The data might be the key to figuring out where that behavior came from.
00:28:34.500 | And then finally, current trends, right?
00:28:36.700 | Autoregressive architecture seem to have taken over,
00:28:39.340 | but that could be just because everyone is so focused on generation.
00:28:42.940 | I have an intuition that models like BERT are still
00:28:46.540 | better if you just want to represent examples as opposed to doing generation.
00:28:50.900 | Seek-to-seek is still a dominant choice for tasks with that structure.
00:28:54.780 | Although again, 0.1 might be pushing everyone to just use GPT-3 or 4,
00:29:00.500 | even for models with seek-to-seek structure.
00:29:02.840 | We'll see how that plays out.
00:29:04.380 | And then people are still obsessed with scaling up.
00:29:06.860 | But we might be seeing a counter movement towards smaller models,
00:29:10.380 | especially with reinforcement learning with human feedback.
00:29:13.340 | And that's something that we're going to talk about next week and the week after.
00:29:17.660 | So I kind of restructured a little bit of my talk.
00:29:22.460 | So like we're only going to get through like part one today,
00:29:24.580 | which is actually back to basics,
00:29:26.700 | how transformers work, and then we're going to talk about the other stuff.
00:29:30.180 | I should maybe introduce myself. I'm Sid.
00:29:33.140 | I am a fourth year PhD.
00:29:34.900 | I actually work primarily on language for
00:29:37.700 | robotics and kind of channeling one of the core concepts of the class.
00:29:40.820 | It's all about doing a whole lot with very, very little.
00:29:44.420 | Like I'm really just working on how we get robots to follow instructions,
00:29:49.740 | given just like one example of a human,
00:29:51.740 | you know, opening a fridge or pouring coffee, things like that.
00:29:56.620 | But in kind of doing that research,
00:29:59.060 | it became really, really clear that we needed
00:30:00.860 | better raw materials, better starting points.
00:30:03.140 | So I started working on pre-training, first in language,
00:30:06.220 | and then more recently in vision, video,
00:30:08.980 | and robotics with language, kind of as the central theme.
00:30:12.380 | So I want to talk today about fantastic language models and how to build them.
00:30:17.100 | So Richard Feynman is probably not only one of the greatest physicists of all time,
00:30:21.740 | but he's one of the greatest educators of all time,
00:30:24.500 | one of the greatest science educators.
00:30:26.420 | And he has this quote, "What I cannot create, I do not understand."
00:30:31.620 | And for him, it was really just kind of about building blocks.
00:30:34.340 | How do I understand what is going on at the lowest level,
00:30:37.060 | so I can compose them together and figure out what to do next?
00:30:40.300 | Where is the next innovation?
00:30:41.500 | Where is the next discovery come from?
00:30:43.340 | And so kind of with that in mind,
00:30:45.620 | I actually just want to spend the next 12 or so minutes
00:30:48.740 | talking about building language models,
00:30:51.140 | building transformers, and how that all happened.
00:30:53.420 | So it's a practical take on these large-scale language models.
00:30:57.380 | We're really not going to get to the large-scale bit.
00:30:59.500 | And we're going to get to the full pipeline.
00:31:01.300 | Again, we're only going to focus on the model architecture,
00:31:02.860 | but today, the evolution of the transformer, how we got there.
00:31:06.660 | Training at scale, we'll probably cover some other time.
00:31:10.340 | And then we'll talk about very,
00:31:12.540 | very briefly efficient fine-tuning and inference.
00:31:14.340 | And we have some other great CAs who might actually be talking about this more in depth.
00:31:18.300 | But the punchline is the last few years,
00:31:21.860 | like I started my PhD in 2019.
00:31:24.380 | I trained my first deep learning MNIST model in 2018.
00:31:29.580 | Feels changed a lot since then.
00:31:31.380 | And with every new model, with every new GPT,
00:31:34.780 | one, two, three, four, the five that's training right now,
00:31:39.260 | there's been more and more folk knowledge,
00:31:41.260 | things that are hidden from plain sight that we never get to see.
00:31:45.380 | And it's been the job of people like us, students,
00:31:49.020 | people in academia to kind of rediscover,
00:31:51.580 | find the insights, find the intuition behind these ideas.
00:31:54.460 | And in kind of rediscovering those pipelines,
00:31:56.860 | it's actually our comparative advantage and figuring out,
00:32:00.780 | okay, so these are how these pieces came to be.
00:32:03.540 | What do I do next? And so I don't really care about time.
00:32:08.060 | If we get through five slides,
00:32:09.740 | that is a success for me.
00:32:11.300 | But be selfish, like this is your class.
00:32:13.860 | So if you have any questions,
00:32:15.340 | if I say anything you don't understand, that's the contract.
00:32:17.420 | Call me out, ask a question,
00:32:19.260 | and we're just going to kind of go step by step.
00:32:21.300 | So how did we get to the transformer?
00:32:25.580 | How did this become the bedrock of language modeling?
00:32:28.940 | And now, vision, also video,
00:32:31.900 | also robotics for some reason as of late.
00:32:34.780 | How did we get here? So what is the recipe for a good language model?
00:32:39.460 | We've talked a bit about contextual representations.
00:32:41.740 | Chris was talking through kind of the various,
00:32:43.420 | you know, different phase changes in language modeling history.
00:32:45.900 | Lisa was talking about diffusion language models,
00:32:47.300 | which is a completely different perspective.
00:32:50.020 | I'm going to kind of simplify things,
00:32:52.140 | oversimplify things, like two steps.
00:32:54.260 | I need massive amounts of cheap, easy to acquire data.
00:32:57.300 | We're a language model and we're building these contextual representations
00:33:00.260 | because we want to learn patterns,
00:33:01.500 | we want to learn truths about the world from data at scale.
00:33:05.540 | And to do that at scale, we need data at scale.
00:33:08.580 | So that's one component.
00:33:10.580 | And the other is we need a simple and high throughput way to consume it.
00:33:14.940 | So what does that mean?
00:33:18.060 | We need to be able to chew through all of this data as fast as we possibly can
00:33:22.660 | in the least opinionated way to figure out, you know,
00:33:26.540 | all of the possible patterns, all of the possible things that could be useful
00:33:29.900 | for people fine-tuning, generating, using these models for arbitrary things downstream.
00:33:35.380 | This isn't just applicable to language, it's applicable to pretty much everything.
00:33:38.580 | So vision does this, video does this, video and language,
00:33:41.820 | vision and language, language and robotics, all of them follow a similar strategy.
00:33:48.100 | So, right, so simple in that it's natural to scale the approach with data as we get,
00:33:53.300 | you know, go from 300 billion to 600 billion tokens,
00:33:56.340 | maybe make the model bigger to handle that in a pretty simple way.
00:33:59.340 | The model should be composable in general.
00:34:02.460 | The training, the way that we actually ingest this data should be fast and parallelizable
00:34:07.940 | and we should be, you know, making the most of our hardware.
00:34:10.060 | If we're going to run a data center with, I don't know, 512 GPUs
00:34:14.340 | with each 8 GPU box costing $120,000, I'd better be getting my money's worth at the end of the day.
00:34:20.740 | And the consumption part, right, like this minimal assumptions on relationships,
00:34:25.940 | the less opinionated I am about how different parts of my data is connected,
00:34:30.540 | the more I can learn given the first thing, massive amounts of data at scale.
00:34:36.980 | So, kind of like figure out how we got to the transformer, I want to kind of wind time back
00:34:41.700 | to kind of what Chris was alluding to earlier with RNNs, right.
00:34:46.020 | So, this is an RNN model, kind of complicated, but it's from 224M.
00:34:51.420 | I took it literally from their slides.
00:34:53.540 | I hope John doesn't get mad at me.
00:34:56.500 | And it's this very powerful class of model in theory, right.
00:35:00.820 | I am ingesting arbitrary length sequences left to right
00:35:04.460 | and I'm learning arbitrary patterns around them.
00:35:07.820 | People decided later on to, you know, add these attention mechanisms on top of the sequence
00:35:13.340 | to sequence RNN models to figure out like how to kind of sharpen their focus
00:35:17.420 | as they were decoding token by token, right.
00:35:19.940 | So, the strengths are like I get to handle arbitrary long context and like we see kind
00:35:23.700 | of the first semblance of attention appear kind of very motivated
00:35:27.980 | by like the way we do language translation, right.
00:35:30.020 | Like when I'm translating word by word, there are certain words in the input
00:35:32.580 | that are going to matter, I'm going to sharpen my focus too.
00:35:35.900 | But there are issues with RNNs.
00:35:37.540 | They're not the most scalable, producing the next token requires me
00:35:41.380 | to produce every single token beforehand.
00:35:44.260 | I can't really make them deeper without training stability going to pieces.
00:35:49.180 | So, that's rough.
00:35:51.140 | And so, chewing through a large amount of data with an RNN is hard.
00:35:54.940 | Some people refuse to believe that and they've actually done immense work in trying to scale
00:35:59.740 | up RNNs, make them more parallelizable using lots and lots
00:36:01.900 | of really cool linear algebra tricks.
00:36:03.540 | I'll post some links.
00:36:04.940 | And then separately, kind of from the vision community that kind of bled
00:36:09.820 | into the language community, we have convolutional neural networks.
00:36:13.100 | And this is from a other course from Lena Vojta about using CNNs for language modeling.
00:36:21.500 | And the idea here is we have this ability to do immense, deep, parallelizable training
00:36:29.060 | by kind of taking these like little windows, right.
00:36:32.180 | I'm going to look at, you know, each layer is only going to look at like the, you know,
00:36:34.540 | three word contexts at a time and going to give me representation.
00:36:37.460 | But if I stack this enough times and I have these residual connections that combine, you know,
00:36:42.220 | earlier inputs with later inputs, by the time I'm like 10 layers deep,
00:36:45.700 | I've seen everything in the window.
00:36:47.740 | But I need that depth and that's kind of a drawback.
00:36:52.260 | But there are like really cool, powerful ideas here.
00:36:54.180 | And I'd actually say that the transformers have way more to do with CNNs and the way
00:36:57.540 | that they behave than the way RNNs behave, right.
00:37:00.940 | So we have this idea of a CNN layer kind of having multiple filters, multiple kernels,
00:37:06.380 | different ways of looking at and extracting features from an image or features from text.
00:37:12.340 | You have, you know, this ability to kind of scale depth using these residual connections.
00:37:17.380 | The deepest networks that we had, you know, from 2012, 2015,
00:37:21.460 | even now are still vision models, right.
00:37:23.940 | ResNet 151 isn't called 151 because it's, you know, the 151st edition of the ResNet.
00:37:29.540 | It's 151 because it's 151 layers deep.
00:37:32.100 | It's actually 151 blocks deep.
00:37:34.020 | Layers, it's actually like probably 4x that.
00:37:37.980 | And it's parallelizable, right.
00:37:39.700 | Every little window that I see at every layer can be computed completely independently
00:37:44.380 | of every other layer, which is really, really great for modern hardware, modern GPUs.
00:37:51.460 | So looking at this, seems like CNNs are cool.
00:37:55.140 | Seems like RNNs are cool.
00:37:57.060 | There's a natural question, which is like how do you do better?
00:38:00.700 | This is the picture from Chris's slides that Lisa also used.
00:38:03.700 | This is a very scary looking picture, right.
00:38:07.060 | Like what does self-attention mean in a transformer?
00:38:10.340 | Where do those ideas come from?
00:38:12.500 | So one idea, like one key component, like one missing component for how you get from a CNN
00:38:19.180 | to an RNN and an RNN to a transformer is the idea
00:38:23.140 | that each individual token is its own query key and value.
00:38:27.220 | It's its own entity that can be used to shape the representations of all of the other tokens.
00:38:33.500 | Right. So I'm going to turn this word "the" into its own query key and value.
00:38:36.500 | I'm going to use the attention from the RNNs,
00:38:39.540 | and I'm going to use the depth parallelizability scaling from the CNNs.
00:38:44.940 | And then the multi-headed part of self-attention is exactly like what the, you know,
00:38:49.380 | different convolutional filters, the different kernels are doing in a CNN.
00:38:53.340 | It's giving you different perspectives on that same token, different ways to come up with queries,
00:38:57.780 | different insights into like how we can use, you know, a single token and come
00:39:01.900 | up with multiple different representations and fuse them together.
00:39:05.940 | As a code, code is semi-unimportant.
00:39:09.580 | It's a kind of very terse description of like what multi-headed self-attention looks like.
00:39:13.900 | Key parts here are this, you know, little bit where we kind of project an input sequence
00:39:19.020 | of tokens to the queries, keys, and values.
00:39:21.020 | And then we're just going to rearrange them kind of in some way.
00:39:24.940 | And the important part here is that like we are rearranging them in a way that splits them
00:39:29.380 | into these different views, these different heads, where each head has some fixed dimension,
00:39:34.060 | which is like the key dimension of the transformer.
00:39:36.620 | And then we have these query keys and values that we're then going to use
00:39:40.580 | for this attention operation which comes directly from the RNN literature.
00:39:44.740 | Right? It is really just this dot product between queries and keys.
00:39:48.820 | That is a very complicated way of saying that's a matrix multiply.
00:39:53.300 | And then we're going to just project them and combine them back into our tensor of, you know,
00:39:59.700 | batch size by sequence length by embedding dimension.
00:40:03.340 | >> There's a nice subtlety here I think that caught me off guard at one point.
00:40:08.620 | When you look at these models, like you download BERT and it says multi-headed attention
00:40:12.860 | or whatever, there's only one set of weights even though it's multi-headed.
00:40:17.780 | Do you want to unpack that for us a little bit?
00:40:19.980 | >> Yeah. So the convolutional kernel has kind of this really nice way of expressing
00:40:26.940 | like I have multiple resolutions of an image.
00:40:29.300 | Right? It's like that depth channel of a convolutional filter.
00:40:32.300 | And if you can unpack like the conv 2D layer in PyTorch, you kind of see that come out.
00:40:38.140 | You don't really see that here.
00:40:38.860 | You see just like this one big weight matrix that is literally like this, you know, dimensionality,
00:40:43.900 | you know, embed dimension by three times embed dimension is usually what it is in BERT or GPT.
00:40:49.340 | That three is the way you split it into queries, keys, and values.
00:40:55.060 | But we're actually going to kind of like take that vector that is like, you know, embed dim size
00:41:00.300 | and just chunk it up into each of these different filters.
00:41:03.300 | Right? So rather than make those filters explicit as by providing them as a parameter
00:41:07.620 | that defines some weight layer, for efficiency purposes,
00:41:10.540 | we're actually just going to treat it all as one matrix and then just chunk it
00:41:13.580 | up as we're doing the linear algebra operations.
00:41:17.140 | Does that make sense?
00:41:18.620 | >> So you're chunking it twice, the times three is queries, keys, and values.
00:41:22.420 | And then number of heads is further chunking each one of those.
00:41:25.260 | >> Yeah. So one of the kind of like the key rules that like no one ever tells you
00:41:28.820 | about transformers is that your number of heads has
00:41:31.940 | to evenly divide your transformer hidden dimension.
00:41:35.860 | That's usually a check that is explicitly done in the code for training like BERT or GPT.
00:41:40.580 | And usually the code doesn't work if that doesn't happen.
00:41:43.220 | And that's kind of how you get away with a lot of these evisions tricks.
00:41:46.140 | It's a great question.
00:41:46.780 | >> So we should do a whole course on broadcasting.
00:41:49.380 | >> Yeah.
00:41:49.700 | >> Before you even start this so that you can do this mess of things.
00:41:52.740 | >> Yeah. And there's a great professor at Cornell Tech, Sasha Rush, who kind of has like a bunch
00:41:57.380 | of like tutorials on just like basic like broadcasting tensor operations.
00:42:00.340 | It's fantastic.
00:42:01.100 | They should check out.
00:42:02.180 | There's a question?
00:42:03.020 | >> Yeah. Can you just clarify what you mean by chunking?
00:42:05.220 | >> Yeah. So if I have a vector of let's say length 1024, right?
00:42:10.500 | That is my embedding dimensions, the hidden dimension for my transformer.
00:42:13.700 | And if I have keys of say dimension let's say two, make it easy.
00:42:19.980 | Chunking just means that I'm going to split that vector of 1024
00:42:23.020 | into two heads each of dimension 512.
00:42:26.020 | Right? So I'm literally just going to like reshape that vector and chunk them
00:42:29.700 | up into like two views of the same input.
00:42:34.580 | Cool. Is this actually better?
00:42:38.060 | Is this alone enough to define a transformer?
00:42:42.580 | Maybe. Maybe not.
00:42:44.340 | All right.
00:42:44.660 | The answer is no.
00:42:45.940 | So it's good because like we get all of the parallelization advantages and all of kind
00:42:50.700 | of the attention advantages that I talked about on the previous slide.
00:42:53.060 | This is a slide from like a Justin Johnson and Dante Zhu who are both ex-Stanford alums now teaching
00:42:58.700 | courses about transformers and deep learning at various colleges.
00:43:02.820 | But you're missing kind of like one key component, right?
00:43:08.460 | So if you just look at this and squint at this for like a little bit of time,
00:43:14.180 | what you're missing is like, okay, so I am just taking different weighted averages
00:43:18.260 | of the same underlying values over and over again.
00:43:24.940 | Relative to the things that are coming out of my transformer,
00:43:27.300 | there actually is no non-linearity, but just self-attention.
00:43:31.500 | This is basically a glorified linear network.
00:43:36.180 | So we need some way to fix that because that's where the expressivity,
00:43:40.740 | the kind of the magic of deep learning happens.
00:43:42.420 | It's kind of when we stack these non-linearities, go deeper and learn new patterns at scale.
00:43:48.780 | So this is how we do it.
00:43:52.100 | We had an MLP.
00:43:53.220 | We had an MLP to the very end of the transformer block and it's very simple, right?
00:43:57.180 | It all it does is it kind of takes the embedding dimension that comes
00:43:59.820 | out of the self-attention block that we just defined,
00:44:02.340 | projects it to a higher dimensional space,
00:44:06.060 | adds a value on linearity,
00:44:08.180 | and then down projects it back to the embedding dimension.
00:44:11.020 | Usually what you're going to see is a factor of four.
00:44:12.660 | Why is it a factor of four?
00:44:14.300 | No one knows is the honest answer.
00:44:17.100 | Two didn't seem to work well enough.
00:44:18.900 | Eight seemed to be too big.
00:44:20.340 | But here's some like soft intuition for kind of why this might work.
00:44:24.660 | This kind of is a throwback to like 229, you know, OGML days.
00:44:30.060 | So you want your network as a whole to be able to kind of
00:44:34.660 | both forget the things that are unimportant,
00:44:37.900 | but also remember the things that are important, right?
00:44:39.820 | That's kind of the role. So the sharpening,
00:44:41.500 | the remembering that are important are these residual connections,
00:44:43.980 | like the fact that I'm adding X to some transform of X.
00:44:47.260 | The forgetting is what this MLP is doing.
00:44:49.860 | It's basically saying like what stuff can I throw away and should I
00:44:53.140 | basically forget because it's not really relevant to what I care about,
00:44:56.460 | at the end of the day, which are good contextual representations.
00:44:59.100 | And so the role of the MLP is very similar to the role of kind of
00:45:03.620 | kernel and you know the good old support vector machine literature, right?
00:45:07.900 | So if I have two classes that are kind of like this in a plane,
00:45:11.740 | and I want to draw a line that partitions them, how do I do it?
00:45:15.540 | Well, it's hard if I'm only working in 2D.
00:45:18.460 | But with just a very simple learn transform,
00:45:21.300 | if I just implicitly lift these things up to like 3D,
00:45:25.220 | I can turn this into a surface in 3D that can just cut in half,
00:45:28.900 | separates my stuff. So projecting up with this MLP is basically
00:45:32.700 | this way of kind of like aligning or crystallizing the structure of our features,
00:45:36.420 | learning a good decision boundary in space,
00:45:38.780 | and compressing from there.
00:45:40.780 | I think we're out of time, so we're going to go through like
00:45:42.700 | the rest of the transformer evolution in a bit.
00:45:45.020 | But all the slides are up.
00:45:46.540 | I have office hours tomorrow and I will be back. Thanks.
00:45:51.260 | >> Thank you.
00:45:52.260 | [ Applause ]
00:45:52.260 | [BLANK_AUDIO]